Fix backfill to work. Add auth to backfill request

pull/12/head
Erik Johnston 2014-11-10 11:59:51 +00:00
parent 65f846ade0
commit 6447db063a
6 changed files with 56 additions and 18 deletions

View File

@ -104,6 +104,12 @@ class Auth(object):
pass pass
defer.returnValue(None) defer.returnValue(None)
@defer.inlineCallbacks
def check_host_in_room(self, room_id, host):
joined_hosts = yield self.store.get_joined_hosts_for_room(room_id)
defer.returnValue(host in joined_hosts)
def check_event_sender_in_room(self, event): def check_event_sender_in_room(self, event):
key = (RoomMemberEvent.TYPE, event.user_id, ) key = (RoomMemberEvent.TYPE, event.user_id, )
member_event = event.state_events.get(key) member_event = event.state_events.get(key)

View File

@ -205,7 +205,7 @@ class ReplicationLayer(object):
pdus = [Pdu(outlier=False, **p) for p in transaction.pdus] pdus = [Pdu(outlier=False, **p) for p in transaction.pdus]
for pdu in pdus: for pdu in pdus:
yield self._handle_new_pdu(pdu, backfilled=True) yield self._handle_new_pdu(dest, pdu, backfilled=True)
defer.returnValue(pdus) defer.returnValue(pdus)
@ -274,9 +274,9 @@ class ReplicationLayer(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def on_backfill_request(self, context, versions, limit): def on_backfill_request(self, origin, context, versions, limit):
pdus = yield self.handler.on_backfill_request( pdus = yield self.handler.on_backfill_request(
context, versions, limit origin, context, versions, limit
) )
defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict())) defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
@ -408,13 +408,22 @@ class ReplicationLayer(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_make_join_request(self, context, user_id): def on_make_join_request(self, context, user_id):
pdu = yield self.handler.on_make_join_request(context, user_id) pdu = yield self.handler.on_make_join_request(context, user_id)
defer.returnValue(pdu.get_dict()) defer.returnValue({
"event": pdu.get_dict(),
})
@defer.inlineCallbacks @defer.inlineCallbacks
def on_invite_request(self, origin, content): def on_invite_request(self, origin, content):
pdu = Pdu(**content) pdu = Pdu(**content)
ret_pdu = yield self.handler.on_invite_request(origin, pdu) ret_pdu = yield self.handler.on_invite_request(origin, pdu)
defer.returnValue((200, ret_pdu.get_dict())) defer.returnValue(
(
200,
{
"event": ret_pdu.get_dict(),
}
)
)
@defer.inlineCallbacks @defer.inlineCallbacks
def on_send_join_request(self, origin, content): def on_send_join_request(self, origin, content):
@ -429,16 +438,25 @@ class ReplicationLayer(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_event_auth(self, origin, context, event_id): def on_event_auth(self, origin, context, event_id):
auth_pdus = yield self.handler.on_event_auth(event_id) auth_pdus = yield self.handler.on_event_auth(event_id)
defer.returnValue((200, [a.get_dict() for a in auth_pdus])) defer.returnValue(
(
200,
{
"auth_chain": [a.get_dict() for a in auth_pdus],
}
)
)
@defer.inlineCallbacks @defer.inlineCallbacks
def make_join(self, destination, context, user_id): def make_join(self, destination, context, user_id):
pdu_dict = yield self.transport_layer.make_join( ret = yield self.transport_layer.make_join(
destination=destination, destination=destination,
context=context, context=context,
user_id=user_id, user_id=user_id,
) )
pdu_dict = ret["event"]
logger.debug("Got response to make_join: %s", pdu_dict) logger.debug("Got response to make_join: %s", pdu_dict)
defer.returnValue(Pdu(**pdu_dict)) defer.returnValue(Pdu(**pdu_dict))
@ -467,13 +485,15 @@ class ReplicationLayer(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def send_invite(self, destination, context, event_id, pdu): def send_invite(self, destination, context, event_id, pdu):
code, pdu_dict = yield self.transport_layer.send_invite( code, content = yield self.transport_layer.send_invite(
destination=destination, destination=destination,
context=context, context=context,
event_id=event_id, event_id=event_id,
content=pdu.get_dict(), content=pdu.get_dict(),
) )
pdu_dict = content["event"]
logger.debug("Got response to send_invite: %s", pdu_dict) logger.debug("Got response to send_invite: %s", pdu_dict)
defer.returnValue(Pdu(**pdu_dict)) defer.returnValue(Pdu(**pdu_dict))

View File

@ -413,7 +413,7 @@ class TransportLayer(object):
self._with_authentication( self._with_authentication(
lambda origin, content, query, context: lambda origin, content, query, context:
self._on_backfill_request( self._on_backfill_request(
context, query["v"], query["limit"] origin, context, query["v"], query["limit"]
) )
) )
) )
@ -552,7 +552,7 @@ class TransportLayer(object):
defer.returnValue(data) defer.returnValue(data)
@log_function @log_function
def _on_backfill_request(self, context, v_list, limits): def _on_backfill_request(self, origin, context, v_list, limits):
if not limits: if not limits:
return defer.succeed( return defer.succeed(
(400, {"error": "Did not include limit param"}) (400, {"error": "Did not include limit param"})
@ -563,7 +563,7 @@ class TransportLayer(object):
versions = v_list versions = v_list
return self.request_handler.on_backfill_request( return self.request_handler.on_backfill_request(
context, versions, limit origin, context, versions, limit
) )
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -193,10 +193,7 @@ class FederationHandler(BaseHandler):
dest, dest,
room_id, room_id,
limit, limit,
extremities=[ extremities=extremities,
self.pdu_codec.decode_event_id(e)
for e in extremities
]
) )
events = [] events = []
@ -473,7 +470,10 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def on_backfill_request(self, context, pdu_list, limit): def on_backfill_request(self, origin, context, pdu_list, limit):
in_room = yield self.auth.check_host_in_room(context, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
events = yield self.store.get_backfill_events( events = yield self.store.get_backfill_events(
context, context,

View File

@ -447,6 +447,18 @@ class SQLBaseStore(object):
**d **d
) )
def _get_events_txn(self, txn, event_ids):
# FIXME (erikj): This should be batched?
sql = "SELECT * FROM events WHERE event_id = ?"
event_rows = []
for e_id in event_ids:
c = txn.execute(sql, (e_id,))
event_rows.extend(self.cursor_to_dict(c))
return self._parse_events_txn(txn, event_rows)
def _parse_events(self, rows): def _parse_events(self, rows):
return self.runInteraction( return self.runInteraction(
"_parse_events", self._parse_events_txn, rows "_parse_events", self._parse_events_txn, rows

View File

@ -371,10 +371,10 @@ class EventFederationStore(SQLBaseStore):
"_backfill_interaction: got id=%s", "_backfill_interaction: got id=%s",
*row *row
) )
new_front.append(row) new_front.append(row[0])
front = new_front front = new_front
event_results += new_front event_results += new_front
# We also want to update the `prev_pdus` attributes before returning. # We also want to update the `prev_pdus` attributes before returning.
return self._get_pdu_tuples(txn, event_results) return self._get_events_txn(txn, event_results)