Add a cache around server ACL checking (#16360)

* Pre-compiles the server ACLs onto an object per room and
  invalidates them when new events come in.
* Converts the server ACL checking into Rust.
pull/16389/head
Patrick Cloke 2023-09-26 11:57:50 -04:00 committed by GitHub
parent 17800a0e97
commit f84da3c32e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 235 additions and 85 deletions

1
changelog.d/16360.misc Normal file
View File

@ -0,0 +1 @@
Cache server ACL checking.

102
rust/src/acl/mod.rs Normal file
View File

@ -0,0 +1,102 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! An implementation of Matrix server ACL rules.
use std::net::Ipv4Addr;
use std::str::FromStr;
use anyhow::Error;
use pyo3::prelude::*;
use regex::Regex;
use crate::push::utils::{glob_to_regex, GlobMatchType};
/// Called when registering modules with python.
pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> {
let child_module = PyModule::new(py, "acl")?;
child_module.add_class::<ServerAclEvaluator>()?;
m.add_submodule(child_module)?;
// We need to manually add the module to sys.modules to make `from
// synapse.synapse_rust import acl` work.
py.import("sys")?
.getattr("modules")?
.set_item("synapse.synapse_rust.acl", child_module)?;
Ok(())
}
#[derive(Debug, Clone)]
#[pyclass(frozen)]
pub struct ServerAclEvaluator {
allow_ip_literals: bool,
allow: Vec<Regex>,
deny: Vec<Regex>,
}
#[pymethods]
impl ServerAclEvaluator {
#[new]
pub fn py_new(
allow_ip_literals: bool,
allow: Vec<&str>,
deny: Vec<&str>,
) -> Result<Self, Error> {
let allow = allow
.iter()
.map(|s| glob_to_regex(s, GlobMatchType::Whole))
.collect::<Result<_, _>>()?;
let deny = deny
.iter()
.map(|s| glob_to_regex(s, GlobMatchType::Whole))
.collect::<Result<_, _>>()?;
Ok(ServerAclEvaluator {
allow_ip_literals,
allow,
deny,
})
}
pub fn server_matches_acl_event(&self, server_name: &str) -> bool {
// first of all, check if literal IPs are blocked, and if so, whether the
// server name is a literal IP
if !self.allow_ip_literals {
// check for ipv6 literals. These start with '['.
if server_name.starts_with('[') {
return false;
}
// check for ipv4 literals. We can just lift the routine from std::net.
if Ipv4Addr::from_str(server_name).is_ok() {
return false;
}
}
// next, check the deny list
if self.deny.iter().any(|e| e.is_match(server_name)) {
return false;
}
// then the allow list.
if self.allow.iter().any(|e| e.is_match(server_name)) {
return true;
}
// everything else should be rejected.
false
}
}

View File

@ -2,6 +2,7 @@ use lazy_static::lazy_static;
use pyo3::prelude::*; use pyo3::prelude::*;
use pyo3_log::ResetHandle; use pyo3_log::ResetHandle;
pub mod acl;
pub mod push; pub mod push;
lazy_static! { lazy_static! {
@ -38,6 +39,7 @@ fn synapse_rust(py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(get_rust_file_digest, m)?)?; m.add_function(wrap_pyfunction!(get_rust_file_digest, m)?)?;
m.add_function(wrap_pyfunction!(reset_logging_config, m)?)?; m.add_function(wrap_pyfunction!(reset_logging_config, m)?)?;
acl::register_module(py, m)?;
push::register_module(py, m)?; push::register_module(py, m)?;
Ok(()) Ok(())

View File

@ -0,0 +1,21 @@
# Copyright 2023 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
class ServerAclEvaluator:
def __init__(
self, allow_ip_literals: bool, allow: List[str], deny: List[str]
) -> None: ...
def server_matches_acl_event(self, server_name: str) -> bool: ...

View File

@ -39,9 +39,9 @@ from synapse.events.utils import (
CANONICALJSON_MIN_INT, CANONICALJSON_MIN_INT,
validate_canonicaljson, validate_canonicaljson,
) )
from synapse.federation.federation_server import server_matches_acl_event
from synapse.http.servlet import validate_json_object from synapse.http.servlet import validate_json_object
from synapse.rest.models import RequestBodyModel from synapse.rest.models import RequestBodyModel
from synapse.storage.controllers.state import server_acl_evaluator_from_event
from synapse.types import EventID, JsonDict, RoomID, StrCollection, UserID from synapse.types import EventID, JsonDict, RoomID, StrCollection, UserID
@ -106,7 +106,10 @@ class EventValidator:
self._validate_retention(event) self._validate_retention(event)
elif event.type == EventTypes.ServerACL: elif event.type == EventTypes.ServerACL:
if not server_matches_acl_event(config.server.server_name, event): server_acl_evaluator = server_acl_evaluator_from_event(event)
if not server_acl_evaluator.server_matches_acl_event(
config.server.server_name
):
raise SynapseError( raise SynapseError(
400, "Can't create an ACL event that denies the local server" 400, "Can't create an ACL event that denies the local server"
) )

View File

@ -29,10 +29,8 @@ from typing import (
Union, Union,
) )
from matrix_common.regex import glob_to_regex
from prometheus_client import Counter, Gauge, Histogram from prometheus_client import Counter, Gauge, Histogram
from twisted.internet.abstract import isIPAddress
from twisted.python import failure from twisted.python import failure
from synapse.api.constants import ( from synapse.api.constants import (
@ -1324,75 +1322,13 @@ class FederationServer(FederationBase):
Raises: Raises:
AuthError if the server does not match the ACL AuthError if the server does not match the ACL
""" """
acl_event = await self._storage_controllers.state.get_current_state_event( server_acl_evaluator = (
room_id, EventTypes.ServerACL, "" await self._storage_controllers.state.get_server_acl_for_room(room_id)
) )
if not acl_event or server_matches_acl_event(server_name, acl_event): if server_acl_evaluator and not server_acl_evaluator.server_matches_acl_event(
return server_name
):
raise AuthError(code=403, msg="Server is banned from room") raise AuthError(code=403, msg="Server is banned from room")
def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool:
"""Check if the given server is allowed by the ACL event
Args:
server_name: name of server, without any port part
acl_event: m.room.server_acl event
Returns:
True if this server is allowed by the ACLs
"""
logger.debug("Checking %s against acl %s", server_name, acl_event.content)
# first of all, check if literal IPs are blocked, and if so, whether the
# server name is a literal IP
allow_ip_literals = acl_event.content.get("allow_ip_literals", True)
if not isinstance(allow_ip_literals, bool):
logger.warning("Ignoring non-bool allow_ip_literals flag")
allow_ip_literals = True
if not allow_ip_literals:
# check for ipv6 literals. These start with '['.
if server_name[0] == "[":
return False
# check for ipv4 literals. We can just lift the routine from twisted.
if isIPAddress(server_name):
return False
# next, check the deny list
deny = acl_event.content.get("deny", [])
if not isinstance(deny, (list, tuple)):
logger.warning("Ignoring non-list deny ACL %s", deny)
deny = []
for e in deny:
if _acl_entry_matches(server_name, e):
# logger.info("%s matched deny rule %s", server_name, e)
return False
# then the allow list.
allow = acl_event.content.get("allow", [])
if not isinstance(allow, (list, tuple)):
logger.warning("Ignoring non-list allow ACL %s", allow)
allow = []
for e in allow:
if _acl_entry_matches(server_name, e):
# logger.info("%s matched allow rule %s", server_name, e)
return True
# everything else should be rejected.
# logger.info("%s fell through", server_name)
return False
def _acl_entry_matches(server_name: str, acl_entry: Any) -> bool:
if not isinstance(acl_entry, str):
logger.warning(
"Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry)
)
return False
regex = glob_to_regex(acl_entry)
return bool(regex.match(server_name))
class FederationHandlerRegistry: class FederationHandlerRegistry:

View File

@ -2342,6 +2342,12 @@ class FederationEventHandler:
# TODO retrieve the previous state, and exclude join -> join transitions # TODO retrieve the previous state, and exclude join -> join transitions
self._notifier.notify_user_joined_room(event.event_id, event.room_id) self._notifier.notify_user_joined_room(event.event_id, event.room_id)
# If this is a server ACL event, clear the cache in the storage controller.
if event.type == EventTypes.ServerACL:
self._state_storage_controller.get_server_acl_for_room.invalidate(
(event.room_id,)
)
def _sanity_check_event(self, ev: EventBase) -> None: def _sanity_check_event(self, ev: EventBase) -> None:
""" """
Do some early sanity checks of a received event Do some early sanity checks of a received event

View File

@ -1730,6 +1730,11 @@ class EventCreationHandler:
event.event_id, event.room_id event.event_id, event.room_id
) )
if event.type == EventTypes.ServerACL:
self._storage_controllers.state.get_server_acl_for_room.invalidate(
(event.room_id,)
)
await self._maybe_kick_guest_users(event, context) await self._maybe_kick_guest_users(event, context)
if event.type == EventTypes.CanonicalAlias: if event.type == EventTypes.CanonicalAlias:

View File

@ -205,6 +205,12 @@ class ReplicationDataHandler:
self.notifier.notify_user_joined_room( self.notifier.notify_user_joined_room(
row.data.event_id, row.data.room_id row.data.event_id, row.data.room_id
) )
# If this is a server ACL event, clear the cache in the storage controller.
if row.data.type == EventTypes.ServerACL:
self._state_storage_controller.get_server_acl_for_room.invalidate(
(row.data.room_id,)
)
elif stream_name == UnPartialStatedRoomStream.NAME: elif stream_name == UnPartialStatedRoomStream.NAME:
for row in rows: for row in rows:
assert isinstance(row, UnPartialStatedRoomStreamRow) assert isinstance(row, UnPartialStatedRoomStreamRow)

View File

@ -37,6 +37,7 @@ from synapse.storage.util.partial_state_events_tracker import (
PartialCurrentStateTracker, PartialCurrentStateTracker,
PartialStateEventsTracker, PartialStateEventsTracker,
) )
from synapse.synapse_rust.acl import ServerAclEvaluator
from synapse.types import MutableStateMap, StateMap, get_domain_from_id from synapse.types import MutableStateMap, StateMap, get_domain_from_id
from synapse.types.state import StateFilter from synapse.types.state import StateFilter
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
@ -501,6 +502,31 @@ class StateStorageController:
return event.content.get("alias") return event.content.get("alias")
@cached()
async def get_server_acl_for_room(
self, room_id: str
) -> Optional[ServerAclEvaluator]:
"""Get the server ACL evaluator for room, if any
This does up-front parsing of the content to ignore bad data and pre-compile
regular expressions.
Args:
room_id: The room ID
Returns:
The server ACL evaluator, if any
"""
acl_event = await self.get_current_state_event(
room_id, EventTypes.ServerACL, ""
)
if not acl_event:
return None
return server_acl_evaluator_from_event(acl_event)
@trace @trace
@tag_args @tag_args
async def get_current_state_deltas( async def get_current_state_deltas(
@ -760,3 +786,36 @@ class StateStorageController:
cache.state_group = object() cache.state_group = object()
return frozenset(cache.hosts_to_joined_users) return frozenset(cache.hosts_to_joined_users)
def server_acl_evaluator_from_event(acl_event: EventBase) -> "ServerAclEvaluator":
"""
Create a ServerAclEvaluator from a m.room.server_acl event's content.
This does up-front parsing of the content to ignore bad data. It then creates
the ServerAclEvaluator which will pre-compile regular expressions from the globs.
"""
# first of all, parse if literal IPs are blocked.
allow_ip_literals = acl_event.content.get("allow_ip_literals", True)
if not isinstance(allow_ip_literals, bool):
logger.warning("Ignoring non-bool allow_ip_literals flag")
allow_ip_literals = True
# next, parse the deny list by ignoring any non-strings.
deny = acl_event.content.get("deny", [])
if not isinstance(deny, (list, tuple)):
logger.warning("Ignoring non-list deny ACL %s", deny)
deny = []
else:
deny = [s for s in deny if isinstance(s, str)]
# then the allow list.
allow = acl_event.content.get("allow", [])
if not isinstance(allow, (list, tuple)):
logger.warning("Ignoring non-list allow ACL %s", allow)
allow = []
else:
allow = [s for s in allow if isinstance(s, str)]
return ServerAclEvaluator(allow_ip_literals, allow, deny)

View File

@ -22,10 +22,10 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.config.server import DEFAULT_ROOM_VERSION from synapse.config.server import DEFAULT_ROOM_VERSION
from synapse.events import EventBase, make_event_from_dict from synapse.events import EventBase, make_event_from_dict
from synapse.federation.federation_server import server_matches_acl_event
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import login, room from synapse.rest.client import login, room
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.controllers.state import server_acl_evaluator_from_event
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import Clock from synapse.util import Clock
@ -67,37 +67,46 @@ class ServerACLsTestCase(unittest.TestCase):
e = _create_acl_event({"allow": ["*"], "deny": ["evil.com"]}) e = _create_acl_event({"allow": ["*"], "deny": ["evil.com"]})
logging.info("ACL event: %s", e.content) logging.info("ACL event: %s", e.content)
self.assertFalse(server_matches_acl_event("evil.com", e)) server_acl_evalutor = server_acl_evaluator_from_event(e)
self.assertFalse(server_matches_acl_event("EVIL.COM", e))
self.assertTrue(server_matches_acl_event("evil.com.au", e)) self.assertFalse(server_acl_evalutor.server_matches_acl_event("evil.com"))
self.assertTrue(server_matches_acl_event("honestly.not.evil.com", e)) self.assertFalse(server_acl_evalutor.server_matches_acl_event("EVIL.COM"))
self.assertTrue(server_acl_evalutor.server_matches_acl_event("evil.com.au"))
self.assertTrue(
server_acl_evalutor.server_matches_acl_event("honestly.not.evil.com")
)
def test_block_ip_literals(self) -> None: def test_block_ip_literals(self) -> None:
e = _create_acl_event({"allow_ip_literals": False, "allow": ["*"]}) e = _create_acl_event({"allow_ip_literals": False, "allow": ["*"]})
logging.info("ACL event: %s", e.content) logging.info("ACL event: %s", e.content)
self.assertFalse(server_matches_acl_event("1.2.3.4", e)) server_acl_evalutor = server_acl_evaluator_from_event(e)
self.assertTrue(server_matches_acl_event("1a.2.3.4", e))
self.assertFalse(server_matches_acl_event("[1:2::]", e)) self.assertFalse(server_acl_evalutor.server_matches_acl_event("1.2.3.4"))
self.assertTrue(server_matches_acl_event("1:2:3:4", e)) self.assertTrue(server_acl_evalutor.server_matches_acl_event("1a.2.3.4"))
self.assertFalse(server_acl_evalutor.server_matches_acl_event("[1:2::]"))
self.assertTrue(server_acl_evalutor.server_matches_acl_event("1:2:3:4"))
def test_wildcard_matching(self) -> None: def test_wildcard_matching(self) -> None:
e = _create_acl_event({"allow": ["good*.com"]}) e = _create_acl_event({"allow": ["good*.com"]})
server_acl_evalutor = server_acl_evaluator_from_event(e)
self.assertTrue( self.assertTrue(
server_matches_acl_event("good.com", e), server_acl_evalutor.server_matches_acl_event("good.com"),
"* matches 0 characters", "* matches 0 characters",
) )
self.assertTrue( self.assertTrue(
server_matches_acl_event("GOOD.COM", e), server_acl_evalutor.server_matches_acl_event("GOOD.COM"),
"pattern is case-insensitive", "pattern is case-insensitive",
) )
self.assertTrue( self.assertTrue(
server_matches_acl_event("good.aa.com", e), server_acl_evalutor.server_matches_acl_event("good.aa.com"),
"* matches several characters, including '.'", "* matches several characters, including '.'",
) )
self.assertFalse( self.assertFalse(
server_matches_acl_event("ishgood.com", e), server_acl_evalutor.server_matches_acl_event("ishgood.com"),
"pattern does not allow prefixes", "pattern does not allow prefixes",
) )