add an expiring cache to `_introspect_token`

pull/16117/head
H. Shay 2023-08-15 14:27:39 -07:00
parent 4ce32ade5a
commit 9db3a90782
1 changed files with 69 additions and 35 deletions

View File

@ -39,6 +39,7 @@ from synapse.logging.context import make_deferred_yieldable
from synapse.types import Requester, UserID, create_requester from synapse.types import Requester, UserID, create_requester
from synapse.util import json_decoder from synapse.util import json_decoder
from synapse.util.caches.cached_call import RetryOnExceptionCachedCall from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
from synapse.util.caches.expiringcache import ExpiringCache
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -106,6 +107,14 @@ class MSC3861DelegatedAuth(BaseAuth):
self._issuer_metadata = RetryOnExceptionCachedCall(self._load_metadata) self._issuer_metadata = RetryOnExceptionCachedCall(self._load_metadata)
self._clock = hs.get_clock()
self._token_cache: ExpiringCache[str, IntrospectionToken] = ExpiringCache(
cache_name="introspection_token_cache",
clock=self._clock,
max_len=10000,
expiry_ms=5 * 60 * 1000,
)
if isinstance(auth_method, PrivateKeyJWTWithKid): if isinstance(auth_method, PrivateKeyJWTWithKid):
# Use the JWK as the client secret when using the private_key_jwt method # Use the JWK as the client secret when using the private_key_jwt method
assert self._config.jwk, "No JWK provided" assert self._config.jwk, "No JWK provided"
@ -144,50 +153,75 @@ class MSC3861DelegatedAuth(BaseAuth):
Returns: Returns:
The introspection response The introspection response
""" """
metadata = await self._issuer_metadata.get() # check the cache before doing a request
introspection_endpoint = metadata.get("introspection_endpoint") introspection_token = self._token_cache.get(token, None)
raw_headers: Dict[str, str] = {
"Content-Type": "application/x-www-form-urlencoded",
"User-Agent": str(self._http_client.user_agent, "utf-8"),
"Accept": "application/json",
}
args = {"token": token, "token_type_hint": "access_token"} expired = False
body = urlencode(args, True) if introspection_token:
# check the expiration field of the token (if it exists)
exp = introspection_token.get("exp", None)
if exp:
time_now = self._clock.time_msec()
expired = time_now > exp
# Fill the body/headers with credentials if not introspection_token or expired:
uri, raw_headers, body = self._client_auth.prepare( metadata = await self._issuer_metadata.get()
method="POST", uri=introspection_endpoint, headers=raw_headers, body=body introspection_endpoint = metadata.get("introspection_endpoint")
) raw_headers: Dict[str, str] = {
headers = Headers({k: [v] for (k, v) in raw_headers.items()}) "Content-Type": "application/x-www-form-urlencoded",
"User-Agent": str(self._http_client.user_agent, "utf-8"),
"Accept": "application/json",
}
# Do the actual request args = {"token": token, "token_type_hint": "access_token"}
# We're not using the SimpleHttpClient util methods as we don't want to body = urlencode(args, True)
# check the HTTP status code, and we do the body encoding ourselves.
response = await self._http_client.request(
method="POST",
uri=uri,
data=body.encode("utf-8"),
headers=headers,
)
resp_body = await make_deferred_yieldable(readBody(response)) # Fill the body/headers with credentials
uri, raw_headers, body = self._client_auth.prepare(
method="POST",
uri=introspection_endpoint,
headers=raw_headers,
body=body,
)
headers = Headers({k: [v] for (k, v) in raw_headers.items()})
if response.code < 200 or response.code >= 300: # Do the actual request
raise HttpResponseException( # We're not using the SimpleHttpClient util methods as we don't want to
response.code, # check the HTTP status code, and we do the body encoding ourselves.
response.phrase.decode("ascii", errors="replace"), response = await self._http_client.request(
resp_body, method="POST",
uri=uri,
data=body.encode("utf-8"),
headers=headers,
) )
resp = json_decoder.decode(resp_body.decode("utf-8")) resp_body = await make_deferred_yieldable(readBody(response))
if not isinstance(resp, dict): if response.code < 200 or response.code >= 300:
raise ValueError( raise HttpResponseException(
"The introspection endpoint returned an invalid JSON response." response.code,
) response.phrase.decode("ascii", errors="replace"),
resp_body,
)
return IntrospectionToken(**resp) resp = json_decoder.decode(resp_body.decode("utf-8"))
if not isinstance(resp, dict):
raise ValueError(
"The introspection endpoint returned an invalid JSON response."
)
expiration = resp.get("exp", None)
if expiration:
if self._clock.time_msec() > expiration:
raise InvalidClientTokenError("Token is expired.")
introspection_token = IntrospectionToken(**resp)
# add token to cache
self._token_cache[token] = introspection_token
return introspection_token
async def is_server_admin(self, requester: Requester) -> bool: async def is_server_admin(self, requester: Requester) -> bool:
return "urn:synapse:admin:*" in requester.scope return "urn:synapse:admin:*" in requester.scope