Add basic function to get all data for a user out of synapse

pull/5589/head
Erik Johnston 2019-07-01 17:55:11 +01:00
parent b4914681a5
commit 8ee69f299c
4 changed files with 478 additions and 1 deletions

View File

@ -14,9 +14,17 @@
# limitations under the License.
import logging
import os
import tempfile
from canonicaljson import json
from twisted.internet import defer
from synapse.api.constants import Membership
from synapse.types import RoomStreamToken
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
logger = logging.getLogger(__name__)
@ -89,3 +97,242 @@ class AdminHandler(BaseHandler):
ret = yield self.store.search_users(term)
defer.returnValue(ret)
@defer.inlineCallbacks
def exfiltrate_user_data(self, user_id, writer):
"""Write all data we have of the user to the specified directory.
Args:
user_id (str)
writer (ExfiltrationWriter)
Returns:
defer.Deferred
"""
# Get all rooms the user is in or has been in
rooms = yield self.store.get_rooms_for_user_where_membership_is(
user_id,
membership_list=(
Membership.JOIN,
Membership.LEAVE,
Membership.BAN,
Membership.INVITE,
),
)
# We only try and fetch events for rooms the user has been in. If
# they've been e.g. invited to a room without joining then we handle
# those seperately.
rooms_user_has_been_in = yield self.store.get_rooms_user_has_been_in(user_id)
for index, room in enumerate(rooms):
room_id = room.room_id
logger.info(
"[%s] Handling room %s, %d/%d", user_id, room_id, index + 1, len(rooms)
)
forgotten = yield self.store.did_forget(user_id, room_id)
if forgotten:
logger.info("[%s] User forgot room %d, ignoring", room_id)
continue
if room_id not in rooms_user_has_been_in:
# If we haven't been in the rooms then the filtering code below
# won't return anything, so we need to handle these cases
# explicitly.
if room.membership == Membership.INVITE:
event_id = room.event_id
invite = yield self.store.get_event(event_id, allow_none=True)
if invite:
invited_state = invite.unsigned["invite_room_state"]
writer.write_invite(room_id, invite, invited_state)
continue
# We only want to bother fetching events up to the last time they
# were joined. We estimate that point by looking at the
# stream_ordering of the last membership if it wasn't a join.
if room.membership == Membership.JOIN:
stream_ordering = yield self.store.get_room_max_stream_ordering()
else:
stream_ordering = room.stream_ordering
from_key = str(RoomStreamToken(0, 0))
to_key = str(RoomStreamToken(None, stream_ordering))
written_events = set() # Events that we've processed in this room
# We need to track gaps in the events stream so that we can then
# write out the state at those events. We do this by keeping track
# of events whose prev events we haven't seen.
# Map from event ID to prev events that haven't been processed,
# dict[str, set[str]].
event_to_unseen_prevs = {}
# The reverse mapping to above, i.e. map from unseen event to parent
# events. dict[str, set[str]]
unseen_event_to_parents = {}
# We fetch events in the room the user could see by fetching *all*
# events that we have and then filtering, this isn't the most
# efficient method perhaps but it does guarentee we get everything.
while True:
events, _ = yield self.store.paginate_room_events(
room_id, from_key, to_key, limit=100, direction="f"
)
if not events:
break
from_key = events[-1].internal_metadata.after
events = yield filter_events_for_client(self.store, user_id, events)
writer.write_events(room_id, events)
# Update the extremity tracking dicts
for event in events:
# Check if we have any prev events that haven't been
# processed yet, and add those to the appropriate dicts.
unseen_events = set(event.prev_event_ids()) - written_events
if unseen_events:
event_to_unseen_prevs[event.event_id] = unseen_events
for unseen in unseen_events:
unseen_event_to_parents.setdefault(unseen, set()).add(
event.event_id
)
# Now check if this event is an unseen prev event, if so
# then we remove this event from the appropriate dicts.
for event_id in unseen_event_to_parents.pop(event.event_id, []):
event_to_unseen_prevs.get(event_id, set()).discard(
event.event_id
)
written_events.add(event.event_id)
logger.info(
"Written %d events in room %s", len(written_events), room_id
)
# Extremities are the events who have at least one unseen prev event.
extremities = (
event_id
for event_id, unseen_prevs in event_to_unseen_prevs.items()
if unseen_prevs
)
for event_id in extremities:
if not event_to_unseen_prevs[event_id]:
continue
state = yield self.store.get_state_for_event(event_id)
writer.write_state(room_id, event_id, state)
defer.returnValue(writer.finished())
class ExfiltrationWriter(object):
"""Interfaced used to specify how to write exfilrated data.
"""
def write_events(self, room_id, events):
"""Write a batch of events for a room.
Args:
room_id (str)
events (list[FrozenEvent])
"""
pass
def write_state(self, room_id, event_id, state):
"""Write the state at the given event in the room.
This only gets called for backward extremities rather than for each
event.
Args:
room_id (str)
event_id (str)
state (list[FrozenEvent])
"""
pass
def write_invite(self, room_id, event, state):
"""Write an invite for the room, with associated invite state.
Args:
room_id (str)
invite (FrozenEvent)
state (list[dict]): A subset of the state at the invite, with a
subset of the event keys (type, state_key, content and sender)
"""
def finished(self):
"""Called when exfiltration is complete, and the return valus is passed
to the requester.
"""
pass
class FileExfiltrationWriter(ExfiltrationWriter):
"""An ExfiltrationWriter that writes the users data to a directory.
Returns the directory location on completion.
Args:
user_id (str): The user whose data is being exfiltrated.
directory (str|None): The directory to write the data to, if None then
will write to a temporary directory.
"""
def __init__(self, user_id, directory=None):
self.user_id = user_id
if directory:
self.base_directory = directory
else:
self.base_directory = tempfile.mkdtemp(
prefix="synapse-exfiltrate__%s__" % (user_id,)
)
os.makedirs(self.base_directory, exist_ok=True)
if list(os.listdir(self.base_directory)):
raise Exception("Directory must be empty")
def write_events(self, room_id, events):
room_directory = os.path.join(self.base_directory, "rooms", room_id)
os.makedirs(room_directory, exist_ok=True)
events_file = os.path.join(room_directory, "events")
with open(events_file, "a") as f:
for event in events:
print(json.dumps(event.get_pdu_json()), file=f)
def write_state(self, room_id, event_id, state):
room_directory = os.path.join(self.base_directory, "rooms", room_id)
state_directory = os.path.join(room_directory, "state")
os.makedirs(state_directory, exist_ok=True)
event_file = os.path.join(state_directory, event_id)
with open(event_file, "a") as f:
for event in state.values():
print(json.dumps(event.get_pdu_json()), file=f)
def write_invite(self, room_id, event, state):
self.write_events(room_id, [event])
# We write the invite state somewhere else as they aren't full events
# and are only a subset of the state at the event.
room_directory = os.path.join(self.base_directory, "rooms", room_id)
os.makedirs(room_directory, exist_ok=True)
invite_state = os.path.join(room_directory, "invite_state")
with open(invite_state, "a") as f:
for event in state.values():
print(json.dumps(event), file=f)
def finished(self):
return self.base_directory

View File

@ -575,6 +575,26 @@ class RoomMemberWorkerStore(EventsWorkerStore):
count = yield self.runInteraction("did_forget_membership", f)
defer.returnValue(count == 0)
@defer.inlineCallbacks
def get_rooms_user_has_been_in(self, user_id):
"""Get all rooms that the user has ever been in.
Args:
user_id (str)
Returns:
Deferred[set[str]]: Set of room IDs.
"""
room_ids = yield self._simple_select_onecol(
table="room_memberships",
keyvalues={"membership": Membership.JOIN, "user_id": user_id},
retcol="room_id",
desc="get_rooms_user_has_been_in",
)
return set(room_ids)
class RoomMemberStore(RoomMemberWorkerStore):
def __init__(self, db_conn, hs):

View File

@ -0,0 +1,210 @@
# -*- coding: utf-8 -*-
# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import Counter
from mock import Mock
import synapse.api.errors
import synapse.handlers.admin
import synapse.rest.admin
import synapse.storage
from synapse.api.constants import EventTypes
from synapse.rest.client.v1 import login, room
from tests import unittest
class ExfiltrateData(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets,
room.register_servlets,
]
def prepare(self, reactor, clock, hs):
self.admin_handler = hs.get_handlers().admin_handler
self.user1 = self.register_user("user1", "password")
self.token1 = self.login("user1", "password")
self.user2 = self.register_user("user2", "password")
self.token2 = self.login("user2", "password")
def test_single_public_joined_room(self):
"""Test that we write *all* events for a public room
"""
room_id = self.helper.create_room_as(
self.user1, tok=self.token1, is_public=True
)
self.helper.send(room_id, body="Hello!", tok=self.token1)
self.helper.join(room_id, self.user2, tok=self.token2)
self.helper.send(room_id, body="Hello again!", tok=self.token1)
writer = Mock()
self.get_success(self.admin_handler.exfiltrate_user_data(self.user2, writer))
writer.write_events.assert_called()
# Since we can see all events there shouldn't be any extremities, so no
# state should be written
writer.write_state.assert_not_called()
# Collect all events that were written
written_events = []
for (called_room_id, events), _ in writer.write_events.call_args_list:
self.assertEqual(called_room_id, room_id)
written_events.extend(events)
# Check that the right number of events were written
counter = Counter(
(event.type, getattr(event, "state_key", None)) for event in written_events
)
self.assertEqual(counter[(EventTypes.Message, None)], 2)
self.assertEqual(counter[(EventTypes.Member, self.user1)], 1)
self.assertEqual(counter[(EventTypes.Member, self.user2)], 1)
def test_single_private_joined_room(self):
"""Tests that we correctly write state when we can't see all events in
a room.
"""
room_id = self.helper.create_room_as(self.user1, tok=self.token1)
self.helper.send_state(
room_id,
EventTypes.RoomHistoryVisibility,
body={"history_visibility": "joined"},
tok=self.token1,
)
self.helper.send(room_id, body="Hello!", tok=self.token1)
self.helper.join(room_id, self.user2, tok=self.token2)
self.helper.send(room_id, body="Hello again!", tok=self.token1)
writer = Mock()
self.get_success(self.admin_handler.exfiltrate_user_data(self.user2, writer))
writer.write_events.assert_called()
# Since we can't see all events there should be one extremity.
writer.write_state.assert_called_once()
# Collect all events that were written
written_events = []
for (called_room_id, events), _ in writer.write_events.call_args_list:
self.assertEqual(called_room_id, room_id)
written_events.extend(events)
# Check that the right number of events were written
counter = Counter(
(event.type, getattr(event, "state_key", None)) for event in written_events
)
self.assertEqual(counter[(EventTypes.Message, None)], 1)
self.assertEqual(counter[(EventTypes.Member, self.user1)], 1)
self.assertEqual(counter[(EventTypes.Member, self.user2)], 1)
def test_single_left_room(self):
"""Tests that we don't see events in the room after we leave.
"""
room_id = self.helper.create_room_as(self.user1, tok=self.token1)
self.helper.send(room_id, body="Hello!", tok=self.token1)
self.helper.join(room_id, self.user2, tok=self.token2)
self.helper.send(room_id, body="Hello again!", tok=self.token1)
self.helper.leave(room_id, self.user2, tok=self.token2)
self.helper.send(room_id, body="Helloooooo!", tok=self.token1)
writer = Mock()
self.get_success(self.admin_handler.exfiltrate_user_data(self.user2, writer))
writer.write_events.assert_called()
# Since we can see all events there shouldn't be any extremities, so no
# state should be written
writer.write_state.assert_not_called()
written_events = []
for (called_room_id, events), _ in writer.write_events.call_args_list:
self.assertEqual(called_room_id, room_id)
written_events.extend(events)
# Check that the right number of events were written
counter = Counter(
(event.type, getattr(event, "state_key", None)) for event in written_events
)
self.assertEqual(counter[(EventTypes.Message, None)], 2)
self.assertEqual(counter[(EventTypes.Member, self.user1)], 1)
self.assertEqual(counter[(EventTypes.Member, self.user2)], 2)
def test_single_left_rejoined_private_room(self):
"""Tests that see the correct events in private rooms when we
repeatedly join and leave.
"""
room_id = self.helper.create_room_as(self.user1, tok=self.token1)
self.helper.send_state(
room_id,
EventTypes.RoomHistoryVisibility,
body={"history_visibility": "joined"},
tok=self.token1,
)
self.helper.send(room_id, body="Hello!", tok=self.token1)
self.helper.join(room_id, self.user2, tok=self.token2)
self.helper.send(room_id, body="Hello again!", tok=self.token1)
self.helper.leave(room_id, self.user2, tok=self.token2)
self.helper.send(room_id, body="Helloooooo!", tok=self.token1)
self.helper.join(room_id, self.user2, tok=self.token2)
self.helper.send(room_id, body="Helloooooo!!", tok=self.token1)
writer = Mock()
self.get_success(self.admin_handler.exfiltrate_user_data(self.user2, writer))
writer.write_events.assert_called_once()
# Since we joined/left/joined again we expect there to be two gaps.
self.assertEqual(writer.write_state.call_count, 2)
written_events = []
for (called_room_id, events), _ in writer.write_events.call_args_list:
self.assertEqual(called_room_id, room_id)
written_events.extend(events)
# Check that the right number of events were written
counter = Counter(
(event.type, getattr(event, "state_key", None)) for event in written_events
)
self.assertEqual(counter[(EventTypes.Message, None)], 2)
self.assertEqual(counter[(EventTypes.Member, self.user1)], 1)
self.assertEqual(counter[(EventTypes.Member, self.user2)], 3)
def test_invite(self):
"""Tests that pending invites get handled correctly.
"""
room_id = self.helper.create_room_as(self.user1, tok=self.token1)
self.helper.send(room_id, body="Hello!", tok=self.token1)
self.helper.invite(room_id, self.user1, self.user2, tok=self.token1)
writer = Mock()
self.get_success(self.admin_handler.exfiltrate_user_data(self.user2, writer))
writer.write_events.assert_not_called()
writer.write_state.assert_not_called()
writer.write_invite.assert_called_once()
args = writer.write_invite.call_args[0]
self.assertEqual(args[0], room_id)
self.assertEqual(args[1].content["membership"], "invite")
self.assertTrue(args[2]) # Assert there is at least one bit of state

View File

@ -443,7 +443,7 @@ class HomeserverTestCase(TestCase):
"POST", "/_matrix/client/r0/admin/register", body.encode("utf8")
)
self.render(request)
self.assertEqual(channel.code, 200)
self.assertEqual(channel.code, 200, channel.json_body)
user_id = channel.json_body["user_id"]
return user_id