303 lines
		
	
	
		
			9.5 KiB
		
	
	
	
		
			Python
		
	
	
		
			Executable File
		
	
			
		
		
	
	
			303 lines
		
	
	
		
			9.5 KiB
		
	
	
	
		
			Python
		
	
	
		
			Executable File
		
	
| #!/usr/bin/env python
 | |
| # Copyright 2022-2023 The Matrix.org Foundation C.I.C.
 | |
| #
 | |
| # Licensed under the Apache License, Version 2.0 (the "License");
 | |
| # you may not use this file except in compliance with the License.
 | |
| # You may obtain a copy of the License at
 | |
| #
 | |
| #     http://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| # Unless required by applicable law or agreed to in writing, software
 | |
| # distributed under the License is distributed on an "AS IS" BASIS,
 | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| # See the License for the specific language governing permissions and
 | |
| # limitations under the License.
 | |
| 
 | |
| import argparse
 | |
| import logging
 | |
| import re
 | |
| from collections import defaultdict
 | |
| from dataclasses import dataclass
 | |
| from typing import Dict, Iterable, Optional, Pattern, Set, Tuple
 | |
| 
 | |
| import yaml
 | |
| 
 | |
| from synapse.config.homeserver import HomeServerConfig
 | |
| from synapse.federation.transport.server import (
 | |
|     TransportLayerServer,
 | |
|     register_servlets as register_federation_servlets,
 | |
| )
 | |
| from synapse.http.server import HttpServer, ServletCallback
 | |
| from synapse.rest import ClientRestResource
 | |
| from synapse.rest.key.v2 import RemoteKey
 | |
| from synapse.server import HomeServer
 | |
| from synapse.storage import DataStore
 | |
| 
 | |
| logger = logging.getLogger("generate_workers_map")
 | |
| 
 | |
| 
 | |
| class MockHomeserver(HomeServer):
 | |
|     DATASTORE_CLASS = DataStore  # type: ignore
 | |
| 
 | |
|     def __init__(self, config: HomeServerConfig, worker_app: Optional[str]) -> None:
 | |
|         super().__init__(config.server.server_name, config=config)
 | |
|         self.config.worker.worker_app = worker_app
 | |
| 
 | |
| 
 | |
| GROUP_PATTERN = re.compile(r"\(\?P<[^>]+?>(.+?)\)")
 | |
| 
 | |
| 
 | |
| @dataclass
 | |
| class EndpointDescription:
 | |
|     """
 | |
|     Describes an endpoint and how it should be routed.
 | |
|     """
 | |
| 
 | |
|     # The servlet class that handles this endpoint
 | |
|     servlet_class: object
 | |
| 
 | |
|     # The category of this endpoint. Is read from the `CATEGORY` constant in the servlet
 | |
|     # class.
 | |
|     category: Optional[str]
 | |
| 
 | |
|     # TODO:
 | |
|     #  - does it need to be routed based on a stream writer config?
 | |
|     #  - does it benefit from any optimised, but optional, routing?
 | |
|     #  - what 'opinionated synapse worker class' (event_creator, synchrotron, etc) does
 | |
|     #    it go in?
 | |
| 
 | |
| 
 | |
| class EnumerationResource(HttpServer):
 | |
|     """
 | |
|     Accepts servlet registrations for the purposes of building up a description of
 | |
|     all endpoints.
 | |
|     """
 | |
| 
 | |
|     def __init__(self, is_worker: bool) -> None:
 | |
|         self.registrations: Dict[Tuple[str, str], EndpointDescription] = {}
 | |
|         self._is_worker = is_worker
 | |
| 
 | |
|     def register_paths(
 | |
|         self,
 | |
|         method: str,
 | |
|         path_patterns: Iterable[Pattern],
 | |
|         callback: ServletCallback,
 | |
|         servlet_classname: str,
 | |
|     ) -> None:
 | |
|         # federation servlet callbacks are wrapped, so unwrap them.
 | |
|         callback = getattr(callback, "__wrapped__", callback)
 | |
| 
 | |
|         # fish out the servlet class
 | |
|         servlet_class = callback.__self__.__class__  # type: ignore
 | |
| 
 | |
|         if self._is_worker and method in getattr(
 | |
|             servlet_class, "WORKERS_DENIED_METHODS", ()
 | |
|         ):
 | |
|             # This endpoint would cause an error if called on a worker, so pretend it
 | |
|             # was never registered!
 | |
|             return
 | |
| 
 | |
|         sd = EndpointDescription(
 | |
|             servlet_class=servlet_class,
 | |
|             category=getattr(servlet_class, "CATEGORY", None),
 | |
|         )
 | |
| 
 | |
|         for pat in path_patterns:
 | |
|             self.registrations[(method, pat.pattern)] = sd
 | |
| 
 | |
| 
 | |
| def get_registered_paths_for_hs(
 | |
|     hs: HomeServer,
 | |
| ) -> Dict[Tuple[str, str], EndpointDescription]:
 | |
|     """
 | |
|     Given a homeserver, get all registered endpoints and their descriptions.
 | |
|     """
 | |
| 
 | |
|     enumerator = EnumerationResource(is_worker=hs.config.worker.worker_app is not None)
 | |
|     ClientRestResource.register_servlets(enumerator, hs)
 | |
|     federation_server = TransportLayerServer(hs)
 | |
| 
 | |
|     # we can't use `federation_server.register_servlets` but this line does the
 | |
|     # same thing, only it uses this enumerator
 | |
|     register_federation_servlets(
 | |
|         federation_server.hs,
 | |
|         resource=enumerator,
 | |
|         ratelimiter=federation_server.ratelimiter,
 | |
|         authenticator=federation_server.authenticator,
 | |
|         servlet_groups=federation_server.servlet_groups,
 | |
|     )
 | |
| 
 | |
|     # the key server endpoints are separate again
 | |
|     RemoteKey(hs).register(enumerator)
 | |
| 
 | |
|     return enumerator.registrations
 | |
| 
 | |
| 
 | |
| def get_registered_paths_for_default(
 | |
|     worker_app: Optional[str], base_config: HomeServerConfig
 | |
| ) -> Dict[Tuple[str, str], EndpointDescription]:
 | |
|     """
 | |
|     Given the name of a worker application and a base homeserver configuration,
 | |
|     returns:
 | |
| 
 | |
|         Dict from (method, path) to EndpointDescription
 | |
| 
 | |
|     TODO Don't require passing in a config
 | |
|     """
 | |
| 
 | |
|     hs = MockHomeserver(base_config, worker_app)
 | |
|     # TODO We only do this to avoid an error, but don't need the database etc
 | |
|     hs.setup()
 | |
|     return get_registered_paths_for_hs(hs)
 | |
| 
 | |
| 
 | |
| def elide_http_methods_if_unconflicting(
 | |
|     registrations: Dict[Tuple[str, str], EndpointDescription],
 | |
|     all_possible_registrations: Dict[Tuple[str, str], EndpointDescription],
 | |
| ) -> Dict[Tuple[str, str], EndpointDescription]:
 | |
|     """
 | |
|     Elides HTTP methods (by replacing them with `*`) if all possible registered methods
 | |
|     can be handled by the worker whose registration map is `registrations`.
 | |
| 
 | |
|     i.e. the only endpoints left with methods (other than `*`) should be the ones where
 | |
|     the worker can't handle all possible methods for that path.
 | |
|     """
 | |
| 
 | |
|     def paths_to_methods_dict(
 | |
|         methods_and_paths: Iterable[Tuple[str, str]]
 | |
|     ) -> Dict[str, Set[str]]:
 | |
|         """
 | |
|         Given (method, path) pairs, produces a dict from path to set of methods
 | |
|         available at that path.
 | |
|         """
 | |
|         result: Dict[str, Set[str]] = {}
 | |
|         for method, path in methods_and_paths:
 | |
|             result.setdefault(path, set()).add(method)
 | |
|         return result
 | |
| 
 | |
|     all_possible_reg_methods = paths_to_methods_dict(all_possible_registrations)
 | |
|     reg_methods = paths_to_methods_dict(registrations)
 | |
| 
 | |
|     output = {}
 | |
| 
 | |
|     for path, handleable_methods in reg_methods.items():
 | |
|         if handleable_methods == all_possible_reg_methods[path]:
 | |
|             any_method = next(iter(handleable_methods))
 | |
|             # TODO This assumes that all methods have the same servlet.
 | |
|             #      I suppose that's possibly dubious?
 | |
|             output[("*", path)] = registrations[(any_method, path)]
 | |
|         else:
 | |
|             for method in handleable_methods:
 | |
|                 output[(method, path)] = registrations[(method, path)]
 | |
| 
 | |
|     return output
 | |
| 
 | |
| 
 | |
| def simplify_path_regexes(
 | |
|     registrations: Dict[Tuple[str, str], EndpointDescription]
 | |
| ) -> Dict[Tuple[str, str], EndpointDescription]:
 | |
|     """
 | |
|     Simplify all the path regexes for the dict of endpoint descriptions,
 | |
|     so that we don't use the Python-specific regex extensions
 | |
|     (and also to remove needlessly specific detail).
 | |
|     """
 | |
| 
 | |
|     def simplify_path_regex(path: str) -> str:
 | |
|         """
 | |
|         Given a regex pattern, replaces all named capturing groups (e.g. `(?P<blah>xyz)`)
 | |
|         with a simpler version available in more common regex dialects (e.g. `.*`).
 | |
|         """
 | |
| 
 | |
|         # TODO it's hard to choose between these two;
 | |
|         #      `.*` is a vague simplification
 | |
|         # return GROUP_PATTERN.sub(r"\1", path)
 | |
|         return GROUP_PATTERN.sub(r".*", path)
 | |
| 
 | |
|     return {(m, simplify_path_regex(p)): v for (m, p), v in registrations.items()}
 | |
| 
 | |
| 
 | |
| def main() -> None:
 | |
|     parser = argparse.ArgumentParser(
 | |
|         description=(
 | |
|             "Updates a synapse database to the latest schema and optionally runs background updates"
 | |
|             " on it."
 | |
|         )
 | |
|     )
 | |
|     parser.add_argument("-v", action="store_true")
 | |
|     parser.add_argument(
 | |
|         "--config-path",
 | |
|         type=argparse.FileType("r"),
 | |
|         required=True,
 | |
|         help="Synapse configuration file",
 | |
|     )
 | |
| 
 | |
|     args = parser.parse_args()
 | |
| 
 | |
|     # TODO
 | |
|     # logging.basicConfig(**logging_config)
 | |
| 
 | |
|     # Load, process and sanity-check the config.
 | |
|     hs_config = yaml.safe_load(args.config_path)
 | |
| 
 | |
|     config = HomeServerConfig()
 | |
|     config.parse_config_dict(hs_config, "", "")
 | |
| 
 | |
|     master_paths = get_registered_paths_for_default(None, config)
 | |
|     worker_paths = get_registered_paths_for_default(
 | |
|         "synapse.app.generic_worker", config
 | |
|     )
 | |
| 
 | |
|     all_paths = {**master_paths, **worker_paths}
 | |
| 
 | |
|     elided_worker_paths = elide_http_methods_if_unconflicting(worker_paths, all_paths)
 | |
|     elide_http_methods_if_unconflicting(master_paths, all_paths)
 | |
| 
 | |
|     # TODO SSO endpoints (pick_idp etc) NOT REGISTERED BY THIS SCRIPT
 | |
| 
 | |
|     categories_to_methods_and_paths: Dict[
 | |
|         Optional[str], Dict[Tuple[str, str], EndpointDescription]
 | |
|     ] = defaultdict(dict)
 | |
| 
 | |
|     for (method, path), desc in elided_worker_paths.items():
 | |
|         categories_to_methods_and_paths[desc.category][method, path] = desc
 | |
| 
 | |
|     for category, contents in categories_to_methods_and_paths.items():
 | |
|         print_category(category, contents)
 | |
| 
 | |
| 
 | |
| def print_category(
 | |
|     category_name: Optional[str],
 | |
|     elided_worker_paths: Dict[Tuple[str, str], EndpointDescription],
 | |
| ) -> None:
 | |
|     """
 | |
|     Prints out a category, in documentation page style.
 | |
| 
 | |
|     Example:
 | |
|     ```
 | |
|     # Category name
 | |
|     /path/xyz
 | |
| 
 | |
|     GET /path/abc
 | |
|     ```
 | |
|     """
 | |
| 
 | |
|     if category_name:
 | |
|         print(f"# {category_name}")
 | |
|     else:
 | |
|         print("# (Uncategorised requests)")
 | |
| 
 | |
|     for ln in sorted(
 | |
|         p for m, p in simplify_path_regexes(elided_worker_paths) if m == "*"
 | |
|     ):
 | |
|         print(ln)
 | |
|     print()
 | |
|     for ln in sorted(
 | |
|         f"{m:6} {p}" for m, p in simplify_path_regexes(elided_worker_paths) if m != "*"
 | |
|     ):
 | |
|         print(ln)
 | |
|     print()
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     main()
 |