Add more key storage funcs into slave store

pull/967/head
Erik Johnston 2016-07-27 15:51:43 +01:00
parent aede7248ab
commit 6ede23ff1b
2 changed files with 26 additions and 24 deletions

View File

@ -13,17 +13,21 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
from ._base import BaseSlavedStore from ._base import BaseSlavedStore
from synapse.storage import DataStore from synapse.storage import DataStore
from synapse.storage.keys import KeyStore from synapse.storage.keys import KeyStore
class SlavedKeyStore(BaseSlavedStore): class SlavedKeyStore(BaseSlavedStore):
# TODO: use the cached version and invalidate deleted tokens _get_server_verify_key = KeyStore.__dict__[
get_all_server_verify_keys = defer.inlineCallbacks(KeyStore.__dict__[ "_get_server_verify_key"
"get_all_server_verify_keys" ]
].orig)
get_server_verify_keys = DataStore.get_server_verify_keys.__func__ get_server_verify_keys = DataStore.get_server_verify_keys.__func__
store_server_verify_key = DataStore.store_server_verify_key.__func__
get_server_certificate = DataStore.get_server_certificate.__func__
store_server_certificate = DataStore.store_server_certificate.__func__
get_server_keys_json = DataStore.get_server_keys_json.__func__
store_server_keys_json = DataStore.store_server_keys_json.__func__

View File

@ -78,22 +78,22 @@ class KeyStore(SQLBaseStore):
) )
@cachedInlineCallbacks() @cachedInlineCallbacks()
def get_all_server_verify_keys(self, server_name): def _get_server_verify_key(self, server_name, key_id):
rows = yield self._simple_select_list( verify_key_bytes = yield self._simple_select_one_onecol(
table="server_signature_keys", table="server_signature_keys",
keyvalues={ keyvalues={
"server_name": server_name, "server_name": server_name,
"key_id": key_id,
}, },
retcols=["key_id", "verify_key"], retcol="verify_key",
desc="get_all_server_verify_keys", desc="_get_server_verify_key",
allow_none=True,
) )
defer.returnValue({ if verify_key_bytes:
row["key_id"]: decode_verify_key_bytes( defer.returnValue(decode_verify_key_bytes(
row["key_id"], str(row["verify_key"]) key_id, str(verify_key_bytes)
) ))
for row in rows
})
@defer.inlineCallbacks @defer.inlineCallbacks
def get_server_verify_keys(self, server_name, key_ids): def get_server_verify_keys(self, server_name, key_ids):
@ -105,12 +105,12 @@ class KeyStore(SQLBaseStore):
Returns: Returns:
(list of VerifyKey): The verification keys. (list of VerifyKey): The verification keys.
""" """
keys = yield self.get_all_server_verify_keys(server_name) keys = {}
defer.returnValue({ for key_id in key_ids:
k: keys[k] key = yield self._get_server_verify_key(server_name, key_id)
for k in key_ids if key:
if k in keys and keys[k] keys[key_id] = key
}) defer.returnValue(keys)
@defer.inlineCallbacks @defer.inlineCallbacks
def store_server_verify_key(self, server_name, from_server, time_now_ms, def store_server_verify_key(self, server_name, from_server, time_now_ms,
@ -137,8 +137,6 @@ class KeyStore(SQLBaseStore):
desc="store_server_verify_key", desc="store_server_verify_key",
) )
self.get_all_server_verify_keys.invalidate((server_name,))
def store_server_keys_json(self, server_name, key_id, from_server, def store_server_keys_json(self, server_name, key_id, from_server,
ts_now_ms, ts_expires_ms, key_json_bytes): ts_now_ms, ts_expires_ms, key_json_bytes):
"""Stores the JSON bytes for a set of keys from a server """Stores the JSON bytes for a set of keys from a server