Merge pull request #2929 from matrix-org/erikj/split_regististration_store
Split registration storepull/2948/head
commit
f394f5574d
|
@ -14,20 +14,8 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from ._base import BaseSlavedStore
|
from ._base import BaseSlavedStore
|
||||||
from synapse.storage import DataStore
|
from synapse.storage.registration import RegistrationWorkerStore
|
||||||
from synapse.storage.registration import RegistrationStore
|
|
||||||
|
|
||||||
|
|
||||||
class SlavedRegistrationStore(BaseSlavedStore):
|
class SlavedRegistrationStore(RegistrationWorkerStore, BaseSlavedStore):
|
||||||
def __init__(self, db_conn, hs):
|
pass
|
||||||
super(SlavedRegistrationStore, self).__init__(db_conn, hs)
|
|
||||||
|
|
||||||
# TODO: use the cached version and invalidate deleted tokens
|
|
||||||
get_user_by_access_token = RegistrationStore.__dict__[
|
|
||||||
"get_user_by_access_token"
|
|
||||||
]
|
|
||||||
|
|
||||||
_query_for_auth = DataStore._query_for_auth.__func__
|
|
||||||
get_user_by_id = RegistrationStore.__dict__[
|
|
||||||
"get_user_by_id"
|
|
||||||
]
|
|
||||||
|
|
|
@ -19,10 +19,70 @@ from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.errors import StoreError, Codes
|
from synapse.api.errors import StoreError, Codes
|
||||||
from synapse.storage import background_updates
|
from synapse.storage import background_updates
|
||||||
|
from synapse.storage._base import SQLBaseStore
|
||||||
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
|
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
|
||||||
|
|
||||||
|
|
||||||
class RegistrationStore(background_updates.BackgroundUpdateStore):
|
class RegistrationWorkerStore(SQLBaseStore):
|
||||||
|
@cached()
|
||||||
|
def get_user_by_id(self, user_id):
|
||||||
|
return self._simple_select_one(
|
||||||
|
table="users",
|
||||||
|
keyvalues={
|
||||||
|
"name": user_id,
|
||||||
|
},
|
||||||
|
retcols=["name", "password_hash", "is_guest"],
|
||||||
|
allow_none=True,
|
||||||
|
desc="get_user_by_id",
|
||||||
|
)
|
||||||
|
|
||||||
|
@cached()
|
||||||
|
def get_user_by_access_token(self, token):
|
||||||
|
"""Get a user from the given access token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token (str): The access token of a user.
|
||||||
|
Returns:
|
||||||
|
defer.Deferred: None, if the token did not match, otherwise dict
|
||||||
|
including the keys `name`, `is_guest`, `device_id`, `token_id`.
|
||||||
|
"""
|
||||||
|
return self.runInteraction(
|
||||||
|
"get_user_by_access_token",
|
||||||
|
self._query_for_auth,
|
||||||
|
token
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def is_server_admin(self, user):
|
||||||
|
res = yield self._simple_select_one_onecol(
|
||||||
|
table="users",
|
||||||
|
keyvalues={"name": user.to_string()},
|
||||||
|
retcol="admin",
|
||||||
|
allow_none=True,
|
||||||
|
desc="is_server_admin",
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue(res if res else False)
|
||||||
|
|
||||||
|
def _query_for_auth(self, txn, token):
|
||||||
|
sql = (
|
||||||
|
"SELECT users.name, users.is_guest, access_tokens.id as token_id,"
|
||||||
|
" access_tokens.device_id"
|
||||||
|
" FROM users"
|
||||||
|
" INNER JOIN access_tokens on users.name = access_tokens.user_id"
|
||||||
|
" WHERE token = ?"
|
||||||
|
)
|
||||||
|
|
||||||
|
txn.execute(sql, (token,))
|
||||||
|
rows = self.cursor_to_dict(txn)
|
||||||
|
if rows:
|
||||||
|
return rows[0]
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class RegistrationStore(RegistrationWorkerStore,
|
||||||
|
background_updates.BackgroundUpdateStore):
|
||||||
|
|
||||||
def __init__(self, db_conn, hs):
|
def __init__(self, db_conn, hs):
|
||||||
super(RegistrationStore, self).__init__(db_conn, hs)
|
super(RegistrationStore, self).__init__(db_conn, hs)
|
||||||
|
@ -187,18 +247,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
|
||||||
)
|
)
|
||||||
txn.call_after(self.is_guest.invalidate, (user_id,))
|
txn.call_after(self.is_guest.invalidate, (user_id,))
|
||||||
|
|
||||||
@cached()
|
|
||||||
def get_user_by_id(self, user_id):
|
|
||||||
return self._simple_select_one(
|
|
||||||
table="users",
|
|
||||||
keyvalues={
|
|
||||||
"name": user_id,
|
|
||||||
},
|
|
||||||
retcols=["name", "password_hash", "is_guest"],
|
|
||||||
allow_none=True,
|
|
||||||
desc="get_user_by_id",
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_users_by_id_case_insensitive(self, user_id):
|
def get_users_by_id_case_insensitive(self, user_id):
|
||||||
"""Gets users that match user_id case insensitively.
|
"""Gets users that match user_id case insensitively.
|
||||||
Returns a mapping of user_id -> password_hash.
|
Returns a mapping of user_id -> password_hash.
|
||||||
|
@ -304,34 +352,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
|
||||||
|
|
||||||
return self.runInteraction("delete_access_token", f)
|
return self.runInteraction("delete_access_token", f)
|
||||||
|
|
||||||
@cached()
|
|
||||||
def get_user_by_access_token(self, token):
|
|
||||||
"""Get a user from the given access token.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
token (str): The access token of a user.
|
|
||||||
Returns:
|
|
||||||
defer.Deferred: None, if the token did not match, otherwise dict
|
|
||||||
including the keys `name`, `is_guest`, `device_id`, `token_id`.
|
|
||||||
"""
|
|
||||||
return self.runInteraction(
|
|
||||||
"get_user_by_access_token",
|
|
||||||
self._query_for_auth,
|
|
||||||
token
|
|
||||||
)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def is_server_admin(self, user):
|
|
||||||
res = yield self._simple_select_one_onecol(
|
|
||||||
table="users",
|
|
||||||
keyvalues={"name": user.to_string()},
|
|
||||||
retcol="admin",
|
|
||||||
allow_none=True,
|
|
||||||
desc="is_server_admin",
|
|
||||||
)
|
|
||||||
|
|
||||||
defer.returnValue(res if res else False)
|
|
||||||
|
|
||||||
@cachedInlineCallbacks()
|
@cachedInlineCallbacks()
|
||||||
def is_guest(self, user_id):
|
def is_guest(self, user_id):
|
||||||
res = yield self._simple_select_one_onecol(
|
res = yield self._simple_select_one_onecol(
|
||||||
|
@ -344,22 +364,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
|
||||||
|
|
||||||
defer.returnValue(res if res else False)
|
defer.returnValue(res if res else False)
|
||||||
|
|
||||||
def _query_for_auth(self, txn, token):
|
|
||||||
sql = (
|
|
||||||
"SELECT users.name, users.is_guest, access_tokens.id as token_id,"
|
|
||||||
" access_tokens.device_id"
|
|
||||||
" FROM users"
|
|
||||||
" INNER JOIN access_tokens on users.name = access_tokens.user_id"
|
|
||||||
" WHERE token = ?"
|
|
||||||
)
|
|
||||||
|
|
||||||
txn.execute(sql, (token,))
|
|
||||||
rows = self.cursor_to_dict(txn)
|
|
||||||
if rows:
|
|
||||||
return rows[0]
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
|
def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
|
||||||
yield self._simple_upsert("user_threepids", {
|
yield self._simple_upsert("user_threepids", {
|
||||||
|
|
Loading…
Reference in New Issue