Merge pull request #2929 from matrix-org/erikj/split_regististration_store
Split registration storepull/2948/head
						commit
						f394f5574d
					
				|  | @ -14,20 +14,8 @@ | |||
| # limitations under the License. | ||||
| 
 | ||||
| from ._base import BaseSlavedStore | ||||
| from synapse.storage import DataStore | ||||
| from synapse.storage.registration import RegistrationStore | ||||
| from synapse.storage.registration import RegistrationWorkerStore | ||||
| 
 | ||||
| 
 | ||||
| class SlavedRegistrationStore(BaseSlavedStore): | ||||
|     def __init__(self, db_conn, hs): | ||||
|         super(SlavedRegistrationStore, self).__init__(db_conn, hs) | ||||
| 
 | ||||
|     # TODO: use the cached version and invalidate deleted tokens | ||||
|     get_user_by_access_token = RegistrationStore.__dict__[ | ||||
|         "get_user_by_access_token" | ||||
|     ] | ||||
| 
 | ||||
|     _query_for_auth = DataStore._query_for_auth.__func__ | ||||
|     get_user_by_id = RegistrationStore.__dict__[ | ||||
|         "get_user_by_id" | ||||
|     ] | ||||
| class SlavedRegistrationStore(RegistrationWorkerStore, BaseSlavedStore): | ||||
|     pass | ||||
|  |  | |||
|  | @ -19,10 +19,70 @@ from twisted.internet import defer | |||
| 
 | ||||
| from synapse.api.errors import StoreError, Codes | ||||
| from synapse.storage import background_updates | ||||
| from synapse.storage._base import SQLBaseStore | ||||
| from synapse.util.caches.descriptors import cached, cachedInlineCallbacks | ||||
| 
 | ||||
| 
 | ||||
| class RegistrationStore(background_updates.BackgroundUpdateStore): | ||||
| class RegistrationWorkerStore(SQLBaseStore): | ||||
|     @cached() | ||||
|     def get_user_by_id(self, user_id): | ||||
|         return self._simple_select_one( | ||||
|             table="users", | ||||
|             keyvalues={ | ||||
|                 "name": user_id, | ||||
|             }, | ||||
|             retcols=["name", "password_hash", "is_guest"], | ||||
|             allow_none=True, | ||||
|             desc="get_user_by_id", | ||||
|         ) | ||||
| 
 | ||||
|     @cached() | ||||
|     def get_user_by_access_token(self, token): | ||||
|         """Get a user from the given access token. | ||||
| 
 | ||||
|         Args: | ||||
|             token (str): The access token of a user. | ||||
|         Returns: | ||||
|             defer.Deferred: None, if the token did not match, otherwise dict | ||||
|                 including the keys `name`, `is_guest`, `device_id`, `token_id`. | ||||
|         """ | ||||
|         return self.runInteraction( | ||||
|             "get_user_by_access_token", | ||||
|             self._query_for_auth, | ||||
|             token | ||||
|         ) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def is_server_admin(self, user): | ||||
|         res = yield self._simple_select_one_onecol( | ||||
|             table="users", | ||||
|             keyvalues={"name": user.to_string()}, | ||||
|             retcol="admin", | ||||
|             allow_none=True, | ||||
|             desc="is_server_admin", | ||||
|         ) | ||||
| 
 | ||||
|         defer.returnValue(res if res else False) | ||||
| 
 | ||||
|     def _query_for_auth(self, txn, token): | ||||
|         sql = ( | ||||
|             "SELECT users.name, users.is_guest, access_tokens.id as token_id," | ||||
|             " access_tokens.device_id" | ||||
|             " FROM users" | ||||
|             " INNER JOIN access_tokens on users.name = access_tokens.user_id" | ||||
|             " WHERE token = ?" | ||||
|         ) | ||||
| 
 | ||||
|         txn.execute(sql, (token,)) | ||||
|         rows = self.cursor_to_dict(txn) | ||||
|         if rows: | ||||
|             return rows[0] | ||||
| 
 | ||||
|         return None | ||||
| 
 | ||||
| 
 | ||||
| class RegistrationStore(RegistrationWorkerStore, | ||||
|                         background_updates.BackgroundUpdateStore): | ||||
| 
 | ||||
|     def __init__(self, db_conn, hs): | ||||
|         super(RegistrationStore, self).__init__(db_conn, hs) | ||||
|  | @ -187,18 +247,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): | |||
|         ) | ||||
|         txn.call_after(self.is_guest.invalidate, (user_id,)) | ||||
| 
 | ||||
|     @cached() | ||||
|     def get_user_by_id(self, user_id): | ||||
|         return self._simple_select_one( | ||||
|             table="users", | ||||
|             keyvalues={ | ||||
|                 "name": user_id, | ||||
|             }, | ||||
|             retcols=["name", "password_hash", "is_guest"], | ||||
|             allow_none=True, | ||||
|             desc="get_user_by_id", | ||||
|         ) | ||||
| 
 | ||||
|     def get_users_by_id_case_insensitive(self, user_id): | ||||
|         """Gets users that match user_id case insensitively. | ||||
|         Returns a mapping of user_id -> password_hash. | ||||
|  | @ -304,34 +352,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): | |||
| 
 | ||||
|         return self.runInteraction("delete_access_token", f) | ||||
| 
 | ||||
|     @cached() | ||||
|     def get_user_by_access_token(self, token): | ||||
|         """Get a user from the given access token. | ||||
| 
 | ||||
|         Args: | ||||
|             token (str): The access token of a user. | ||||
|         Returns: | ||||
|             defer.Deferred: None, if the token did not match, otherwise dict | ||||
|                 including the keys `name`, `is_guest`, `device_id`, `token_id`. | ||||
|         """ | ||||
|         return self.runInteraction( | ||||
|             "get_user_by_access_token", | ||||
|             self._query_for_auth, | ||||
|             token | ||||
|         ) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def is_server_admin(self, user): | ||||
|         res = yield self._simple_select_one_onecol( | ||||
|             table="users", | ||||
|             keyvalues={"name": user.to_string()}, | ||||
|             retcol="admin", | ||||
|             allow_none=True, | ||||
|             desc="is_server_admin", | ||||
|         ) | ||||
| 
 | ||||
|         defer.returnValue(res if res else False) | ||||
| 
 | ||||
|     @cachedInlineCallbacks() | ||||
|     def is_guest(self, user_id): | ||||
|         res = yield self._simple_select_one_onecol( | ||||
|  | @ -344,22 +364,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): | |||
| 
 | ||||
|         defer.returnValue(res if res else False) | ||||
| 
 | ||||
|     def _query_for_auth(self, txn, token): | ||||
|         sql = ( | ||||
|             "SELECT users.name, users.is_guest, access_tokens.id as token_id," | ||||
|             " access_tokens.device_id" | ||||
|             " FROM users" | ||||
|             " INNER JOIN access_tokens on users.name = access_tokens.user_id" | ||||
|             " WHERE token = ?" | ||||
|         ) | ||||
| 
 | ||||
|         txn.execute(sql, (token,)) | ||||
|         rows = self.cursor_to_dict(txn) | ||||
|         if rows: | ||||
|             return rows[0] | ||||
| 
 | ||||
|         return None | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def user_add_threepid(self, user_id, medium, address, validated_at, added_at): | ||||
|         yield self._simple_upsert("user_threepids", { | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 Erik Johnston
						Erik Johnston