Add a cache around server ACL checking (#16360)

* Pre-compiles the server ACLs onto an object per room and
  invalidates them when new events come in.
* Converts the server ACL checking into Rust.
pull/16389/head
Patrick Cloke 2023-09-26 11:57:50 -04:00 committed by GitHub
parent 17800a0e97
commit f84da3c32e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 235 additions and 85 deletions

1
changelog.d/16360.misc Normal file
View File

@ -0,0 +1 @@
Cache server ACL checking.

102
rust/src/acl/mod.rs Normal file
View File

@ -0,0 +1,102 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! An implementation of Matrix server ACL rules.
use std::net::Ipv4Addr;
use std::str::FromStr;
use anyhow::Error;
use pyo3::prelude::*;
use regex::Regex;
use crate::push::utils::{glob_to_regex, GlobMatchType};
/// Called when registering modules with python.
pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> {
let child_module = PyModule::new(py, "acl")?;
child_module.add_class::<ServerAclEvaluator>()?;
m.add_submodule(child_module)?;
// We need to manually add the module to sys.modules to make `from
// synapse.synapse_rust import acl` work.
py.import("sys")?
.getattr("modules")?
.set_item("synapse.synapse_rust.acl", child_module)?;
Ok(())
}
#[derive(Debug, Clone)]
#[pyclass(frozen)]
pub struct ServerAclEvaluator {
allow_ip_literals: bool,
allow: Vec<Regex>,
deny: Vec<Regex>,
}
#[pymethods]
impl ServerAclEvaluator {
#[new]
pub fn py_new(
allow_ip_literals: bool,
allow: Vec<&str>,
deny: Vec<&str>,
) -> Result<Self, Error> {
let allow = allow
.iter()
.map(|s| glob_to_regex(s, GlobMatchType::Whole))
.collect::<Result<_, _>>()?;
let deny = deny
.iter()
.map(|s| glob_to_regex(s, GlobMatchType::Whole))
.collect::<Result<_, _>>()?;
Ok(ServerAclEvaluator {
allow_ip_literals,
allow,
deny,
})
}
pub fn server_matches_acl_event(&self, server_name: &str) -> bool {
// first of all, check if literal IPs are blocked, and if so, whether the
// server name is a literal IP
if !self.allow_ip_literals {
// check for ipv6 literals. These start with '['.
if server_name.starts_with('[') {
return false;
}
// check for ipv4 literals. We can just lift the routine from std::net.
if Ipv4Addr::from_str(server_name).is_ok() {
return false;
}
}
// next, check the deny list
if self.deny.iter().any(|e| e.is_match(server_name)) {
return false;
}
// then the allow list.
if self.allow.iter().any(|e| e.is_match(server_name)) {
return true;
}
// everything else should be rejected.
false
}
}

View File

@ -2,6 +2,7 @@ use lazy_static::lazy_static;
use pyo3::prelude::*;
use pyo3_log::ResetHandle;
pub mod acl;
pub mod push;
lazy_static! {
@ -38,6 +39,7 @@ fn synapse_rust(py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(get_rust_file_digest, m)?)?;
m.add_function(wrap_pyfunction!(reset_logging_config, m)?)?;
acl::register_module(py, m)?;
push::register_module(py, m)?;
Ok(())

View File

@ -0,0 +1,21 @@
# Copyright 2023 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
class ServerAclEvaluator:
def __init__(
self, allow_ip_literals: bool, allow: List[str], deny: List[str]
) -> None: ...
def server_matches_acl_event(self, server_name: str) -> bool: ...

View File

@ -39,9 +39,9 @@ from synapse.events.utils import (
CANONICALJSON_MIN_INT,
validate_canonicaljson,
)
from synapse.federation.federation_server import server_matches_acl_event
from synapse.http.servlet import validate_json_object
from synapse.rest.models import RequestBodyModel
from synapse.storage.controllers.state import server_acl_evaluator_from_event
from synapse.types import EventID, JsonDict, RoomID, StrCollection, UserID
@ -106,7 +106,10 @@ class EventValidator:
self._validate_retention(event)
elif event.type == EventTypes.ServerACL:
if not server_matches_acl_event(config.server.server_name, event):
server_acl_evaluator = server_acl_evaluator_from_event(event)
if not server_acl_evaluator.server_matches_acl_event(
config.server.server_name
):
raise SynapseError(
400, "Can't create an ACL event that denies the local server"
)

View File

@ -29,10 +29,8 @@ from typing import (
Union,
)
from matrix_common.regex import glob_to_regex
from prometheus_client import Counter, Gauge, Histogram
from twisted.internet.abstract import isIPAddress
from twisted.python import failure
from synapse.api.constants import (
@ -1324,75 +1322,13 @@ class FederationServer(FederationBase):
Raises:
AuthError if the server does not match the ACL
"""
acl_event = await self._storage_controllers.state.get_current_state_event(
room_id, EventTypes.ServerACL, ""
server_acl_evaluator = (
await self._storage_controllers.state.get_server_acl_for_room(room_id)
)
if not acl_event or server_matches_acl_event(server_name, acl_event):
return
raise AuthError(code=403, msg="Server is banned from room")
def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool:
"""Check if the given server is allowed by the ACL event
Args:
server_name: name of server, without any port part
acl_event: m.room.server_acl event
Returns:
True if this server is allowed by the ACLs
"""
logger.debug("Checking %s against acl %s", server_name, acl_event.content)
# first of all, check if literal IPs are blocked, and if so, whether the
# server name is a literal IP
allow_ip_literals = acl_event.content.get("allow_ip_literals", True)
if not isinstance(allow_ip_literals, bool):
logger.warning("Ignoring non-bool allow_ip_literals flag")
allow_ip_literals = True
if not allow_ip_literals:
# check for ipv6 literals. These start with '['.
if server_name[0] == "[":
return False
# check for ipv4 literals. We can just lift the routine from twisted.
if isIPAddress(server_name):
return False
# next, check the deny list
deny = acl_event.content.get("deny", [])
if not isinstance(deny, (list, tuple)):
logger.warning("Ignoring non-list deny ACL %s", deny)
deny = []
for e in deny:
if _acl_entry_matches(server_name, e):
# logger.info("%s matched deny rule %s", server_name, e)
return False
# then the allow list.
allow = acl_event.content.get("allow", [])
if not isinstance(allow, (list, tuple)):
logger.warning("Ignoring non-list allow ACL %s", allow)
allow = []
for e in allow:
if _acl_entry_matches(server_name, e):
# logger.info("%s matched allow rule %s", server_name, e)
return True
# everything else should be rejected.
# logger.info("%s fell through", server_name)
return False
def _acl_entry_matches(server_name: str, acl_entry: Any) -> bool:
if not isinstance(acl_entry, str):
logger.warning(
"Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry)
)
return False
regex = glob_to_regex(acl_entry)
return bool(regex.match(server_name))
if server_acl_evaluator and not server_acl_evaluator.server_matches_acl_event(
server_name
):
raise AuthError(code=403, msg="Server is banned from room")
class FederationHandlerRegistry:

View File

@ -2342,6 +2342,12 @@ class FederationEventHandler:
# TODO retrieve the previous state, and exclude join -> join transitions
self._notifier.notify_user_joined_room(event.event_id, event.room_id)
# If this is a server ACL event, clear the cache in the storage controller.
if event.type == EventTypes.ServerACL:
self._state_storage_controller.get_server_acl_for_room.invalidate(
(event.room_id,)
)
def _sanity_check_event(self, ev: EventBase) -> None:
"""
Do some early sanity checks of a received event

View File

@ -1730,6 +1730,11 @@ class EventCreationHandler:
event.event_id, event.room_id
)
if event.type == EventTypes.ServerACL:
self._storage_controllers.state.get_server_acl_for_room.invalidate(
(event.room_id,)
)
await self._maybe_kick_guest_users(event, context)
if event.type == EventTypes.CanonicalAlias:

View File

@ -205,6 +205,12 @@ class ReplicationDataHandler:
self.notifier.notify_user_joined_room(
row.data.event_id, row.data.room_id
)
# If this is a server ACL event, clear the cache in the storage controller.
if row.data.type == EventTypes.ServerACL:
self._state_storage_controller.get_server_acl_for_room.invalidate(
(row.data.room_id,)
)
elif stream_name == UnPartialStatedRoomStream.NAME:
for row in rows:
assert isinstance(row, UnPartialStatedRoomStreamRow)

View File

@ -37,6 +37,7 @@ from synapse.storage.util.partial_state_events_tracker import (
PartialCurrentStateTracker,
PartialStateEventsTracker,
)
from synapse.synapse_rust.acl import ServerAclEvaluator
from synapse.types import MutableStateMap, StateMap, get_domain_from_id
from synapse.types.state import StateFilter
from synapse.util.async_helpers import Linearizer
@ -501,6 +502,31 @@ class StateStorageController:
return event.content.get("alias")
@cached()
async def get_server_acl_for_room(
self, room_id: str
) -> Optional[ServerAclEvaluator]:
"""Get the server ACL evaluator for room, if any
This does up-front parsing of the content to ignore bad data and pre-compile
regular expressions.
Args:
room_id: The room ID
Returns:
The server ACL evaluator, if any
"""
acl_event = await self.get_current_state_event(
room_id, EventTypes.ServerACL, ""
)
if not acl_event:
return None
return server_acl_evaluator_from_event(acl_event)
@trace
@tag_args
async def get_current_state_deltas(
@ -760,3 +786,36 @@ class StateStorageController:
cache.state_group = object()
return frozenset(cache.hosts_to_joined_users)
def server_acl_evaluator_from_event(acl_event: EventBase) -> "ServerAclEvaluator":
"""
Create a ServerAclEvaluator from a m.room.server_acl event's content.
This does up-front parsing of the content to ignore bad data. It then creates
the ServerAclEvaluator which will pre-compile regular expressions from the globs.
"""
# first of all, parse if literal IPs are blocked.
allow_ip_literals = acl_event.content.get("allow_ip_literals", True)
if not isinstance(allow_ip_literals, bool):
logger.warning("Ignoring non-bool allow_ip_literals flag")
allow_ip_literals = True
# next, parse the deny list by ignoring any non-strings.
deny = acl_event.content.get("deny", [])
if not isinstance(deny, (list, tuple)):
logger.warning("Ignoring non-list deny ACL %s", deny)
deny = []
else:
deny = [s for s in deny if isinstance(s, str)]
# then the allow list.
allow = acl_event.content.get("allow", [])
if not isinstance(allow, (list, tuple)):
logger.warning("Ignoring non-list allow ACL %s", allow)
allow = []
else:
allow = [s for s in allow if isinstance(s, str)]
return ServerAclEvaluator(allow_ip_literals, allow, deny)

View File

@ -22,10 +22,10 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.config.server import DEFAULT_ROOM_VERSION
from synapse.events import EventBase, make_event_from_dict
from synapse.federation.federation_server import server_matches_acl_event
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.storage.controllers.state import server_acl_evaluator_from_event
from synapse.types import JsonDict
from synapse.util import Clock
@ -67,37 +67,46 @@ class ServerACLsTestCase(unittest.TestCase):
e = _create_acl_event({"allow": ["*"], "deny": ["evil.com"]})
logging.info("ACL event: %s", e.content)
self.assertFalse(server_matches_acl_event("evil.com", e))
self.assertFalse(server_matches_acl_event("EVIL.COM", e))
server_acl_evalutor = server_acl_evaluator_from_event(e)
self.assertTrue(server_matches_acl_event("evil.com.au", e))
self.assertTrue(server_matches_acl_event("honestly.not.evil.com", e))
self.assertFalse(server_acl_evalutor.server_matches_acl_event("evil.com"))
self.assertFalse(server_acl_evalutor.server_matches_acl_event("EVIL.COM"))
self.assertTrue(server_acl_evalutor.server_matches_acl_event("evil.com.au"))
self.assertTrue(
server_acl_evalutor.server_matches_acl_event("honestly.not.evil.com")
)
def test_block_ip_literals(self) -> None:
e = _create_acl_event({"allow_ip_literals": False, "allow": ["*"]})
logging.info("ACL event: %s", e.content)
self.assertFalse(server_matches_acl_event("1.2.3.4", e))
self.assertTrue(server_matches_acl_event("1a.2.3.4", e))
self.assertFalse(server_matches_acl_event("[1:2::]", e))
self.assertTrue(server_matches_acl_event("1:2:3:4", e))
server_acl_evalutor = server_acl_evaluator_from_event(e)
self.assertFalse(server_acl_evalutor.server_matches_acl_event("1.2.3.4"))
self.assertTrue(server_acl_evalutor.server_matches_acl_event("1a.2.3.4"))
self.assertFalse(server_acl_evalutor.server_matches_acl_event("[1:2::]"))
self.assertTrue(server_acl_evalutor.server_matches_acl_event("1:2:3:4"))
def test_wildcard_matching(self) -> None:
e = _create_acl_event({"allow": ["good*.com"]})
server_acl_evalutor = server_acl_evaluator_from_event(e)
self.assertTrue(
server_matches_acl_event("good.com", e),
server_acl_evalutor.server_matches_acl_event("good.com"),
"* matches 0 characters",
)
self.assertTrue(
server_matches_acl_event("GOOD.COM", e),
server_acl_evalutor.server_matches_acl_event("GOOD.COM"),
"pattern is case-insensitive",
)
self.assertTrue(
server_matches_acl_event("good.aa.com", e),
server_acl_evalutor.server_matches_acl_event("good.aa.com"),
"* matches several characters, including '.'",
)
self.assertFalse(
server_matches_acl_event("ishgood.com", e),
server_acl_evalutor.server_matches_acl_event("ishgood.com"),
"pattern does not allow prefixes",
)