Merge branch 'develop' of github.com:matrix-org/synapse into erikj/join_perf
						commit
						ebfdd2eb5b
					
				| 
						 | 
				
			
			@ -223,7 +223,7 @@ class FederationClient(FederationBase):
 | 
			
		|||
                        for p in transaction_data["pdus"]
 | 
			
		||||
                    ]
 | 
			
		||||
 | 
			
		||||
                    if pdu_list:
 | 
			
		||||
                    if pdu_list and pdu_list[0]:
 | 
			
		||||
                        pdu = pdu_list[0]
 | 
			
		||||
 | 
			
		||||
                        # Check signatures are correct.
 | 
			
		||||
| 
						 | 
				
			
			@ -256,7 +256,7 @@ class FederationClient(FederationBase):
 | 
			
		|||
                )
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
        if self._get_pdu_cache is not None:
 | 
			
		||||
        if self._get_pdu_cache is not None and pdu:
 | 
			
		||||
            self._get_pdu_cache[event_id] = pdu
 | 
			
		||||
 | 
			
		||||
        defer.returnValue(pdu)
 | 
			
		||||
| 
						 | 
				
			
			@ -566,7 +566,7 @@ class FederationClient(FederationBase):
 | 
			
		|||
 | 
			
		||||
            res = yield defer.DeferredList(deferreds, consumeErrors=True)
 | 
			
		||||
            for (result, val), (e_id, _) in zip(res, ordered_missing):
 | 
			
		||||
                if result:
 | 
			
		||||
                if result and val:
 | 
			
		||||
                    signed_events.append(val)
 | 
			
		||||
                else:
 | 
			
		||||
                    failed_to_fetch.add(e_id)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -51,7 +51,7 @@ logger = logging.getLogger(__name__)
 | 
			
		|||
 | 
			
		||||
# Remember to update this number every time a change is made to database
 | 
			
		||||
# schema files, so the users will be informed on server restarts.
 | 
			
		||||
SCHEMA_VERSION = 18
 | 
			
		||||
SCHEMA_VERSION = 19
 | 
			
		||||
 | 
			
		||||
dir_path = os.path.abspath(os.path.dirname(__file__))
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -25,6 +25,7 @@ from util.id_generators import IdGenerator, StreamIdGenerator
 | 
			
		|||
from twisted.internet import defer
 | 
			
		||||
 | 
			
		||||
from collections import namedtuple, OrderedDict
 | 
			
		||||
 | 
			
		||||
import functools
 | 
			
		||||
import sys
 | 
			
		||||
import time
 | 
			
		||||
| 
						 | 
				
			
			@ -45,7 +46,6 @@ sql_scheduling_timer = metrics.register_distribution("schedule_time")
 | 
			
		|||
 | 
			
		||||
sql_query_timer = metrics.register_distribution("query_time", labels=["verb"])
 | 
			
		||||
sql_txn_timer = metrics.register_distribution("transaction_time", labels=["desc"])
 | 
			
		||||
sql_getevents_timer = metrics.register_distribution("getEvents_time", labels=["desc"])
 | 
			
		||||
 | 
			
		||||
caches_by_name = {}
 | 
			
		||||
cache_counter = metrics.register_cache(
 | 
			
		||||
| 
						 | 
				
			
			@ -298,6 +298,12 @@ class SQLBaseStore(object):
 | 
			
		|||
        self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True,
 | 
			
		||||
                                      max_entries=hs.config.event_cache_size)
 | 
			
		||||
 | 
			
		||||
        self._event_fetch_lock = threading.Condition()
 | 
			
		||||
        self._event_fetch_list = []
 | 
			
		||||
        self._event_fetch_ongoing = 0
 | 
			
		||||
 | 
			
		||||
        self._pending_ds = []
 | 
			
		||||
 | 
			
		||||
        self.database_engine = hs.database_engine
 | 
			
		||||
 | 
			
		||||
        self._stream_id_gen = StreamIdGenerator()
 | 
			
		||||
| 
						 | 
				
			
			@ -337,6 +343,75 @@ class SQLBaseStore(object):
 | 
			
		|||
 | 
			
		||||
        self._clock.looping_call(loop, 10000)
 | 
			
		||||
 | 
			
		||||
    def _new_transaction(self, conn, desc, after_callbacks, func, *args, **kwargs):
 | 
			
		||||
        start = time.time() * 1000
 | 
			
		||||
        txn_id = self._TXN_ID
 | 
			
		||||
 | 
			
		||||
        # We don't really need these to be unique, so lets stop it from
 | 
			
		||||
        # growing really large.
 | 
			
		||||
        self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1)
 | 
			
		||||
 | 
			
		||||
        name = "%s-%x" % (desc, txn_id, )
 | 
			
		||||
 | 
			
		||||
        transaction_logger.debug("[TXN START] {%s}", name)
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            i = 0
 | 
			
		||||
            N = 5
 | 
			
		||||
            while True:
 | 
			
		||||
                try:
 | 
			
		||||
                    txn = conn.cursor()
 | 
			
		||||
                    txn = LoggingTransaction(
 | 
			
		||||
                        txn, name, self.database_engine, after_callbacks
 | 
			
		||||
                    )
 | 
			
		||||
                    r = func(txn, *args, **kwargs)
 | 
			
		||||
                    conn.commit()
 | 
			
		||||
                    return r
 | 
			
		||||
                except self.database_engine.module.OperationalError as e:
 | 
			
		||||
                    # This can happen if the database disappears mid
 | 
			
		||||
                    # transaction.
 | 
			
		||||
                    logger.warn(
 | 
			
		||||
                        "[TXN OPERROR] {%s} %s %d/%d",
 | 
			
		||||
                        name, e, i, N
 | 
			
		||||
                    )
 | 
			
		||||
                    if i < N:
 | 
			
		||||
                        i += 1
 | 
			
		||||
                        try:
 | 
			
		||||
                            conn.rollback()
 | 
			
		||||
                        except self.database_engine.module.Error as e1:
 | 
			
		||||
                            logger.warn(
 | 
			
		||||
                                "[TXN EROLL] {%s} %s",
 | 
			
		||||
                                name, e1,
 | 
			
		||||
                            )
 | 
			
		||||
                        continue
 | 
			
		||||
                    raise
 | 
			
		||||
                except self.database_engine.module.DatabaseError as e:
 | 
			
		||||
                    if self.database_engine.is_deadlock(e):
 | 
			
		||||
                        logger.warn("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
 | 
			
		||||
                        if i < N:
 | 
			
		||||
                            i += 1
 | 
			
		||||
                            try:
 | 
			
		||||
                                conn.rollback()
 | 
			
		||||
                            except self.database_engine.module.Error as e1:
 | 
			
		||||
                                logger.warn(
 | 
			
		||||
                                    "[TXN EROLL] {%s} %s",
 | 
			
		||||
                                    name, e1,
 | 
			
		||||
                                )
 | 
			
		||||
                            continue
 | 
			
		||||
                    raise
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            logger.debug("[TXN FAIL] {%s} %s", name, e)
 | 
			
		||||
            raise
 | 
			
		||||
        finally:
 | 
			
		||||
            end = time.time() * 1000
 | 
			
		||||
            duration = end - start
 | 
			
		||||
 | 
			
		||||
            transaction_logger.debug("[TXN END] {%s} %f", name, duration)
 | 
			
		||||
 | 
			
		||||
            self._current_txn_total_time += duration
 | 
			
		||||
            self._txn_perf_counters.update(desc, start, end)
 | 
			
		||||
            sql_txn_timer.inc_by(duration, desc)
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def runInteraction(self, desc, func, *args, **kwargs):
 | 
			
		||||
        """Wraps the .runInteraction() method on the underlying db_pool."""
 | 
			
		||||
| 
						 | 
				
			
			@ -348,75 +423,16 @@ class SQLBaseStore(object):
 | 
			
		|||
 | 
			
		||||
        def inner_func(conn, *args, **kwargs):
 | 
			
		||||
            with LoggingContext("runInteraction") as context:
 | 
			
		||||
                sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
 | 
			
		||||
 | 
			
		||||
                if self.database_engine.is_connection_closed(conn):
 | 
			
		||||
                    logger.debug("Reconnecting closed database connection")
 | 
			
		||||
                    conn.reconnect()
 | 
			
		||||
 | 
			
		||||
                current_context.copy_to(context)
 | 
			
		||||
                start = time.time() * 1000
 | 
			
		||||
                txn_id = self._TXN_ID
 | 
			
		||||
 | 
			
		||||
                # We don't really need these to be unique, so lets stop it from
 | 
			
		||||
                # growing really large.
 | 
			
		||||
                self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1)
 | 
			
		||||
 | 
			
		||||
                name = "%s-%x" % (desc, txn_id, )
 | 
			
		||||
 | 
			
		||||
                sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
 | 
			
		||||
                transaction_logger.debug("[TXN START] {%s}", name)
 | 
			
		||||
                try:
 | 
			
		||||
                    i = 0
 | 
			
		||||
                    N = 5
 | 
			
		||||
                    while True:
 | 
			
		||||
                        try:
 | 
			
		||||
                            txn = conn.cursor()
 | 
			
		||||
                            txn = LoggingTransaction(
 | 
			
		||||
                                txn, name, self.database_engine, after_callbacks
 | 
			
		||||
                            )
 | 
			
		||||
                            return func(txn, *args, **kwargs)
 | 
			
		||||
                        except self.database_engine.module.OperationalError as e:
 | 
			
		||||
                            # This can happen if the database disappears mid
 | 
			
		||||
                            # transaction.
 | 
			
		||||
                            logger.warn(
 | 
			
		||||
                                "[TXN OPERROR] {%s} %s %d/%d",
 | 
			
		||||
                                name, e, i, N
 | 
			
		||||
                            )
 | 
			
		||||
                            if i < N:
 | 
			
		||||
                                i += 1
 | 
			
		||||
                                try:
 | 
			
		||||
                                    conn.rollback()
 | 
			
		||||
                                except self.database_engine.module.Error as e1:
 | 
			
		||||
                                    logger.warn(
 | 
			
		||||
                                        "[TXN EROLL] {%s} %s",
 | 
			
		||||
                                        name, e1,
 | 
			
		||||
                                    )
 | 
			
		||||
                                continue
 | 
			
		||||
                        except self.database_engine.module.DatabaseError as e:
 | 
			
		||||
                            if self.database_engine.is_deadlock(e):
 | 
			
		||||
                                logger.warn("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
 | 
			
		||||
                                if i < N:
 | 
			
		||||
                                    i += 1
 | 
			
		||||
                                    try:
 | 
			
		||||
                                        conn.rollback()
 | 
			
		||||
                                    except self.database_engine.module.Error as e1:
 | 
			
		||||
                                        logger.warn(
 | 
			
		||||
                                            "[TXN EROLL] {%s} %s",
 | 
			
		||||
                                            name, e1,
 | 
			
		||||
                                        )
 | 
			
		||||
                                    continue
 | 
			
		||||
                            raise
 | 
			
		||||
                except Exception as e:
 | 
			
		||||
                    logger.debug("[TXN FAIL] {%s} %s", name, e)
 | 
			
		||||
                    raise
 | 
			
		||||
                finally:
 | 
			
		||||
                    end = time.time() * 1000
 | 
			
		||||
                    duration = end - start
 | 
			
		||||
 | 
			
		||||
                    transaction_logger.debug("[TXN END] {%s} %f", name, duration)
 | 
			
		||||
 | 
			
		||||
                    self._current_txn_total_time += duration
 | 
			
		||||
                    self._txn_perf_counters.update(desc, start, end)
 | 
			
		||||
                    sql_txn_timer.inc_by(duration, desc)
 | 
			
		||||
                return self._new_transaction(
 | 
			
		||||
                    conn, desc, after_callbacks, func, *args, **kwargs
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
        result = yield preserve_context_over_fn(
 | 
			
		||||
            self._db_pool.runWithConnection,
 | 
			
		||||
| 
						 | 
				
			
			@ -427,6 +443,32 @@ class SQLBaseStore(object):
 | 
			
		|||
            after_callback(*after_args)
 | 
			
		||||
        defer.returnValue(result)
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def runWithConnection(self, func, *args, **kwargs):
 | 
			
		||||
        """Wraps the .runInteraction() method on the underlying db_pool."""
 | 
			
		||||
        current_context = LoggingContext.current_context()
 | 
			
		||||
 | 
			
		||||
        start_time = time.time() * 1000
 | 
			
		||||
 | 
			
		||||
        def inner_func(conn, *args, **kwargs):
 | 
			
		||||
            with LoggingContext("runWithConnection") as context:
 | 
			
		||||
                sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
 | 
			
		||||
 | 
			
		||||
                if self.database_engine.is_connection_closed(conn):
 | 
			
		||||
                    logger.debug("Reconnecting closed database connection")
 | 
			
		||||
                    conn.reconnect()
 | 
			
		||||
 | 
			
		||||
                current_context.copy_to(context)
 | 
			
		||||
 | 
			
		||||
                return func(conn, *args, **kwargs)
 | 
			
		||||
 | 
			
		||||
        result = yield preserve_context_over_fn(
 | 
			
		||||
            self._db_pool.runWithConnection,
 | 
			
		||||
            inner_func, *args, **kwargs
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        defer.returnValue(result)
 | 
			
		||||
 | 
			
		||||
    def cursor_to_dict(self, cursor):
 | 
			
		||||
        """Converts a SQL cursor into an list of dicts.
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -19,6 +19,8 @@ from ._base import IncorrectDatabaseSetup
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
class PostgresEngine(object):
 | 
			
		||||
    single_threaded = False
 | 
			
		||||
 | 
			
		||||
    def __init__(self, database_module):
 | 
			
		||||
        self.module = database_module
 | 
			
		||||
        self.module.extensions.register_type(self.module.extensions.UNICODE)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -17,6 +17,8 @@ from synapse.storage import prepare_database, prepare_sqlite3_database
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
class Sqlite3Engine(object):
 | 
			
		||||
    single_threaded = True
 | 
			
		||||
 | 
			
		||||
    def __init__(self, database_module):
 | 
			
		||||
        self.module = database_module
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -13,6 +13,8 @@
 | 
			
		|||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
from twisted.internet import defer
 | 
			
		||||
 | 
			
		||||
from ._base import SQLBaseStore, cached
 | 
			
		||||
from syutil.base64util import encode_base64
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -33,16 +35,7 @@ class EventFederationStore(SQLBaseStore):
 | 
			
		|||
    """
 | 
			
		||||
 | 
			
		||||
    def get_auth_chain(self, event_ids):
 | 
			
		||||
        return self.runInteraction(
 | 
			
		||||
            "get_auth_chain",
 | 
			
		||||
            self._get_auth_chain_txn,
 | 
			
		||||
            event_ids
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _get_auth_chain_txn(self, txn, event_ids):
 | 
			
		||||
        results = self._get_auth_chain_ids_txn(txn, event_ids)
 | 
			
		||||
 | 
			
		||||
        return self._get_events_txn(txn, results)
 | 
			
		||||
        return self.get_auth_chain_ids(event_ids).addCallback(self._get_events)
 | 
			
		||||
 | 
			
		||||
    def get_auth_chain_ids(self, event_ids):
 | 
			
		||||
        return self.runInteraction(
 | 
			
		||||
| 
						 | 
				
			
			@ -370,7 +363,7 @@ class EventFederationStore(SQLBaseStore):
 | 
			
		|||
        return self.runInteraction(
 | 
			
		||||
            "get_backfill_events",
 | 
			
		||||
            self._get_backfill_events, room_id, event_list, limit
 | 
			
		||||
        )
 | 
			
		||||
        ).addCallback(self._get_events)
 | 
			
		||||
 | 
			
		||||
    def _get_backfill_events(self, txn, room_id, event_list, limit):
 | 
			
		||||
        logger.debug(
 | 
			
		||||
| 
						 | 
				
			
			@ -416,16 +409,26 @@ class EventFederationStore(SQLBaseStore):
 | 
			
		|||
            front = new_front
 | 
			
		||||
            event_results += new_front
 | 
			
		||||
 | 
			
		||||
        return self._get_events_txn(txn, event_results)
 | 
			
		||||
        return event_results
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def get_missing_events(self, room_id, earliest_events, latest_events,
 | 
			
		||||
                           limit, min_depth):
 | 
			
		||||
        return self.runInteraction(
 | 
			
		||||
        ids = yield self.runInteraction(
 | 
			
		||||
            "get_missing_events",
 | 
			
		||||
            self._get_missing_events,
 | 
			
		||||
            room_id, earliest_events, latest_events, limit, min_depth
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        events = yield self._get_events(ids)
 | 
			
		||||
 | 
			
		||||
        events = sorted(
 | 
			
		||||
            [ev for ev in events if ev.depth >= min_depth],
 | 
			
		||||
            key=lambda e: e.depth,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        defer.returnValue(events[:limit])
 | 
			
		||||
 | 
			
		||||
    def _get_missing_events(self, txn, room_id, earliest_events, latest_events,
 | 
			
		||||
                            limit, min_depth):
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -457,14 +460,7 @@ class EventFederationStore(SQLBaseStore):
 | 
			
		|||
            front = new_front
 | 
			
		||||
            event_results |= new_front
 | 
			
		||||
 | 
			
		||||
        events = self._get_events_txn(txn, event_results)
 | 
			
		||||
 | 
			
		||||
        events = sorted(
 | 
			
		||||
            [ev for ev in events if ev.depth >= min_depth],
 | 
			
		||||
            key=lambda e: e.depth,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        return events[:limit]
 | 
			
		||||
        return event_results
 | 
			
		||||
 | 
			
		||||
    def clean_room_for_join(self, room_id):
 | 
			
		||||
        return self.runInteraction(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -15,11 +15,12 @@
 | 
			
		|||
 | 
			
		||||
from _base import SQLBaseStore, _RollbackButIsFineException
 | 
			
		||||
 | 
			
		||||
from twisted.internet import defer
 | 
			
		||||
from twisted.internet import defer, reactor
 | 
			
		||||
 | 
			
		||||
from synapse.events import FrozenEvent
 | 
			
		||||
from synapse.events.utils import prune_event
 | 
			
		||||
 | 
			
		||||
from synapse.util.logcontext import preserve_context_over_deferred
 | 
			
		||||
from synapse.util.logutils import log_function
 | 
			
		||||
from synapse.api.constants import EventTypes
 | 
			
		||||
from synapse.crypto.event_signing import compute_event_reference_hash
 | 
			
		||||
| 
						 | 
				
			
			@ -34,6 +35,16 @@ import simplejson as json
 | 
			
		|||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# These values are used in the `enqueus_event` and `_do_fetch` methods to
 | 
			
		||||
# control how we batch/bulk fetch events from the database.
 | 
			
		||||
# The values are plucked out of thing air to make initial sync run faster
 | 
			
		||||
# on jki.re
 | 
			
		||||
# TODO: Make these configurable.
 | 
			
		||||
EVENT_QUEUE_THREADS = 3  # Max number of threads that will fetch events
 | 
			
		||||
EVENT_QUEUE_ITERATIONS = 3  # No. times we block waiting for requests for events
 | 
			
		||||
EVENT_QUEUE_TIMEOUT_S = 0.1  # Timeout when waiting for requests for events
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class EventsStore(SQLBaseStore):
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    @log_function
 | 
			
		||||
| 
						 | 
				
			
			@ -91,18 +102,17 @@ class EventsStore(SQLBaseStore):
 | 
			
		|||
        Returns:
 | 
			
		||||
            Deferred : A FrozenEvent.
 | 
			
		||||
        """
 | 
			
		||||
        event = yield self.runInteraction(
 | 
			
		||||
            "get_event", self._get_event_txn,
 | 
			
		||||
            event_id,
 | 
			
		||||
        events = yield self._get_events(
 | 
			
		||||
            [event_id],
 | 
			
		||||
            check_redacted=check_redacted,
 | 
			
		||||
            get_prev_content=get_prev_content,
 | 
			
		||||
            allow_rejected=allow_rejected,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if not event and not allow_none:
 | 
			
		||||
        if not events and not allow_none:
 | 
			
		||||
            raise RuntimeError("Could not find event %s" % (event_id,))
 | 
			
		||||
 | 
			
		||||
        defer.returnValue(event)
 | 
			
		||||
        defer.returnValue(events[0] if events else None)
 | 
			
		||||
 | 
			
		||||
    @log_function
 | 
			
		||||
    def _persist_event_txn(self, txn, event, context, backfilled,
 | 
			
		||||
| 
						 | 
				
			
			@ -401,28 +411,75 @@ class EventsStore(SQLBaseStore):
 | 
			
		|||
            "have_events", f,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def _get_events(self, event_ids, check_redacted=True,
 | 
			
		||||
                    get_prev_content=False):
 | 
			
		||||
        return self.runInteraction(
 | 
			
		||||
            "_get_events", self._get_events_txn, event_ids,
 | 
			
		||||
            check_redacted=check_redacted, get_prev_content=get_prev_content,
 | 
			
		||||
                    get_prev_content=False, allow_rejected=False):
 | 
			
		||||
        if not event_ids:
 | 
			
		||||
            defer.returnValue([])
 | 
			
		||||
 | 
			
		||||
        event_map = self._get_events_from_cache(
 | 
			
		||||
            event_ids,
 | 
			
		||||
            check_redacted=check_redacted,
 | 
			
		||||
            get_prev_content=get_prev_content,
 | 
			
		||||
            allow_rejected=allow_rejected,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        missing_events_ids = [e for e in event_ids if e not in event_map]
 | 
			
		||||
 | 
			
		||||
        if not missing_events_ids:
 | 
			
		||||
            defer.returnValue([
 | 
			
		||||
                event_map[e_id] for e_id in event_ids
 | 
			
		||||
                if e_id in event_map and event_map[e_id]
 | 
			
		||||
            ])
 | 
			
		||||
 | 
			
		||||
        missing_events = yield self._enqueue_events(
 | 
			
		||||
            missing_events_ids,
 | 
			
		||||
            check_redacted=check_redacted,
 | 
			
		||||
            get_prev_content=get_prev_content,
 | 
			
		||||
            allow_rejected=allow_rejected,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        event_map.update(missing_events)
 | 
			
		||||
 | 
			
		||||
        defer.returnValue([
 | 
			
		||||
            event_map[e_id] for e_id in event_ids
 | 
			
		||||
            if e_id in event_map and event_map[e_id]
 | 
			
		||||
        ])
 | 
			
		||||
 | 
			
		||||
    def _get_events_txn(self, txn, event_ids, check_redacted=True,
 | 
			
		||||
                        get_prev_content=False):
 | 
			
		||||
                        get_prev_content=False, allow_rejected=False):
 | 
			
		||||
        if not event_ids:
 | 
			
		||||
            return []
 | 
			
		||||
 | 
			
		||||
        events = [
 | 
			
		||||
            self._get_event_txn(
 | 
			
		||||
                txn, event_id,
 | 
			
		||||
                check_redacted=check_redacted,
 | 
			
		||||
                get_prev_content=get_prev_content
 | 
			
		||||
            )
 | 
			
		||||
            for event_id in event_ids
 | 
			
		||||
        ]
 | 
			
		||||
        event_map = self._get_events_from_cache(
 | 
			
		||||
            event_ids,
 | 
			
		||||
            check_redacted=check_redacted,
 | 
			
		||||
            get_prev_content=get_prev_content,
 | 
			
		||||
            allow_rejected=allow_rejected,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        return [e for e in events if e]
 | 
			
		||||
        missing_events_ids = [e for e in event_ids if e not in event_map]
 | 
			
		||||
 | 
			
		||||
        if not missing_events_ids:
 | 
			
		||||
            return [
 | 
			
		||||
                event_map[e_id] for e_id in event_ids
 | 
			
		||||
                if e_id in event_map and event_map[e_id]
 | 
			
		||||
            ]
 | 
			
		||||
 | 
			
		||||
        missing_events = self._fetch_events_txn(
 | 
			
		||||
            txn,
 | 
			
		||||
            missing_events_ids,
 | 
			
		||||
            check_redacted=check_redacted,
 | 
			
		||||
            get_prev_content=get_prev_content,
 | 
			
		||||
            allow_rejected=allow_rejected,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        event_map.update(missing_events)
 | 
			
		||||
 | 
			
		||||
        return [
 | 
			
		||||
            event_map[e_id] for e_id in event_ids
 | 
			
		||||
            if e_id in event_map and event_map[e_id]
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
    def _invalidate_get_event_cache(self, event_id):
 | 
			
		||||
        for check_redacted in (False, True):
 | 
			
		||||
| 
						 | 
				
			
			@ -433,54 +490,217 @@ class EventsStore(SQLBaseStore):
 | 
			
		|||
    def _get_event_txn(self, txn, event_id, check_redacted=True,
 | 
			
		||||
                       get_prev_content=False, allow_rejected=False):
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            ret = self._get_event_cache.get(event_id, check_redacted, get_prev_content)
 | 
			
		||||
 | 
			
		||||
            if allow_rejected or not ret.rejected_reason:
 | 
			
		||||
                return ret
 | 
			
		||||
            else:
 | 
			
		||||
                return None
 | 
			
		||||
        except KeyError:
 | 
			
		||||
            pass
 | 
			
		||||
 | 
			
		||||
        sql = (
 | 
			
		||||
            "SELECT e.internal_metadata, e.json, r.event_id, rej.reason "
 | 
			
		||||
            "FROM event_json as e "
 | 
			
		||||
            "LEFT JOIN redactions as r ON e.event_id = r.redacts "
 | 
			
		||||
            "LEFT JOIN rejections as rej on rej.event_id = e.event_id  "
 | 
			
		||||
            "WHERE e.event_id = ? "
 | 
			
		||||
            "LIMIT 1 "
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        txn.execute(sql, (event_id,))
 | 
			
		||||
 | 
			
		||||
        res = txn.fetchone()
 | 
			
		||||
 | 
			
		||||
        if not res:
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
        internal_metadata, js, redacted, rejected_reason = res
 | 
			
		||||
 | 
			
		||||
        result = self._get_event_from_row_txn(
 | 
			
		||||
            txn, internal_metadata, js, redacted,
 | 
			
		||||
        events = self._get_events_txn(
 | 
			
		||||
            txn, [event_id],
 | 
			
		||||
            check_redacted=check_redacted,
 | 
			
		||||
            get_prev_content=get_prev_content,
 | 
			
		||||
            rejected_reason=rejected_reason,
 | 
			
		||||
            allow_rejected=allow_rejected,
 | 
			
		||||
        )
 | 
			
		||||
        self._get_event_cache.prefill(event_id, check_redacted, get_prev_content, result)
 | 
			
		||||
 | 
			
		||||
        if allow_rejected or not rejected_reason:
 | 
			
		||||
            return result
 | 
			
		||||
        else:
 | 
			
		||||
            return None
 | 
			
		||||
        return events[0] if events else None
 | 
			
		||||
 | 
			
		||||
    def _get_event_from_row_txn(self, txn, internal_metadata, js, redacted,
 | 
			
		||||
                                check_redacted=True, get_prev_content=False,
 | 
			
		||||
                                rejected_reason=None):
 | 
			
		||||
    def _get_events_from_cache(self, events, check_redacted, get_prev_content,
 | 
			
		||||
                               allow_rejected):
 | 
			
		||||
        event_map = {}
 | 
			
		||||
 | 
			
		||||
        for event_id in events:
 | 
			
		||||
            try:
 | 
			
		||||
                ret = self._get_event_cache.get(
 | 
			
		||||
                    event_id, check_redacted, get_prev_content
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
                if allow_rejected or not ret.rejected_reason:
 | 
			
		||||
                    event_map[event_id] = ret
 | 
			
		||||
                else:
 | 
			
		||||
                    event_map[event_id] = None
 | 
			
		||||
            except KeyError:
 | 
			
		||||
                pass
 | 
			
		||||
 | 
			
		||||
        return event_map
 | 
			
		||||
 | 
			
		||||
    def _do_fetch(self, conn):
 | 
			
		||||
        """Takes a database connection and waits for requests for events from
 | 
			
		||||
        the _event_fetch_list queue.
 | 
			
		||||
        """
 | 
			
		||||
        event_list = []
 | 
			
		||||
        i = 0
 | 
			
		||||
        while True:
 | 
			
		||||
            try:
 | 
			
		||||
                with self._event_fetch_lock:
 | 
			
		||||
                    event_list = self._event_fetch_list
 | 
			
		||||
                    self._event_fetch_list = []
 | 
			
		||||
 | 
			
		||||
                    if not event_list:
 | 
			
		||||
                        single_threaded = self.database_engine.single_threaded
 | 
			
		||||
                        if single_threaded or i > EVENT_QUEUE_ITERATIONS:
 | 
			
		||||
                            self._event_fetch_ongoing -= 1
 | 
			
		||||
                            return
 | 
			
		||||
                        else:
 | 
			
		||||
                            self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
 | 
			
		||||
                            i += 1
 | 
			
		||||
                            continue
 | 
			
		||||
                    i = 0
 | 
			
		||||
 | 
			
		||||
                event_id_lists = zip(*event_list)[0]
 | 
			
		||||
                event_ids = [
 | 
			
		||||
                    item for sublist in event_id_lists for item in sublist
 | 
			
		||||
                ]
 | 
			
		||||
 | 
			
		||||
                rows = self._new_transaction(
 | 
			
		||||
                    conn, "do_fetch", [], self._fetch_event_rows, event_ids
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
                row_dict = {
 | 
			
		||||
                    r["event_id"]: r
 | 
			
		||||
                    for r in rows
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                # We only want to resolve deferreds from the main thread
 | 
			
		||||
                def fire(lst, res):
 | 
			
		||||
                    for ids, d in lst:
 | 
			
		||||
                        if not d.called:
 | 
			
		||||
                            try:
 | 
			
		||||
                                d.callback([
 | 
			
		||||
                                    res[i]
 | 
			
		||||
                                    for i in ids
 | 
			
		||||
                                    if i in res
 | 
			
		||||
                                ])
 | 
			
		||||
                            except:
 | 
			
		||||
                                logger.exception("Failed to callback")
 | 
			
		||||
                reactor.callFromThread(fire, event_list, row_dict)
 | 
			
		||||
            except Exception as e:
 | 
			
		||||
                logger.exception("do_fetch")
 | 
			
		||||
 | 
			
		||||
                # We only want to resolve deferreds from the main thread
 | 
			
		||||
                def fire(evs):
 | 
			
		||||
                    for _, d in evs:
 | 
			
		||||
                        if not d.called:
 | 
			
		||||
                            d.errback(e)
 | 
			
		||||
 | 
			
		||||
                if event_list:
 | 
			
		||||
                    reactor.callFromThread(fire, event_list)
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def _enqueue_events(self, events, check_redacted=True,
 | 
			
		||||
                        get_prev_content=False, allow_rejected=False):
 | 
			
		||||
        """Fetches events from the database using the _event_fetch_list. This
 | 
			
		||||
        allows batch and bulk fetching of events - it allows us to fetch events
 | 
			
		||||
        without having to create a new transaction for each request for events.
 | 
			
		||||
        """
 | 
			
		||||
        if not events:
 | 
			
		||||
            defer.returnValue({})
 | 
			
		||||
 | 
			
		||||
        events_d = defer.Deferred()
 | 
			
		||||
        with self._event_fetch_lock:
 | 
			
		||||
            self._event_fetch_list.append(
 | 
			
		||||
                (events, events_d)
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            self._event_fetch_lock.notify()
 | 
			
		||||
 | 
			
		||||
            if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
 | 
			
		||||
                self._event_fetch_ongoing += 1
 | 
			
		||||
                should_start = True
 | 
			
		||||
            else:
 | 
			
		||||
                should_start = False
 | 
			
		||||
 | 
			
		||||
        if should_start:
 | 
			
		||||
            self.runWithConnection(
 | 
			
		||||
                self._do_fetch
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        rows = yield preserve_context_over_deferred(events_d)
 | 
			
		||||
 | 
			
		||||
        if not allow_rejected:
 | 
			
		||||
            rows[:] = [r for r in rows if not r["rejects"]]
 | 
			
		||||
 | 
			
		||||
        res = yield defer.gatherResults(
 | 
			
		||||
            [
 | 
			
		||||
                self._get_event_from_row(
 | 
			
		||||
                    row["internal_metadata"], row["json"], row["redacts"],
 | 
			
		||||
                    check_redacted=check_redacted,
 | 
			
		||||
                    get_prev_content=get_prev_content,
 | 
			
		||||
                    rejected_reason=row["rejects"],
 | 
			
		||||
                )
 | 
			
		||||
                for row in rows
 | 
			
		||||
            ],
 | 
			
		||||
            consumeErrors=True
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        defer.returnValue({
 | 
			
		||||
            e.event_id: e
 | 
			
		||||
            for e in res if e
 | 
			
		||||
        })
 | 
			
		||||
 | 
			
		||||
    def _fetch_event_rows(self, txn, events):
 | 
			
		||||
        rows = []
 | 
			
		||||
        N = 200
 | 
			
		||||
        for i in range(1 + len(events) / N):
 | 
			
		||||
            evs = events[i*N:(i + 1)*N]
 | 
			
		||||
            if not evs:
 | 
			
		||||
                break
 | 
			
		||||
 | 
			
		||||
            sql = (
 | 
			
		||||
                "SELECT "
 | 
			
		||||
                " e.event_id as event_id, "
 | 
			
		||||
                " e.internal_metadata,"
 | 
			
		||||
                " e.json,"
 | 
			
		||||
                " r.redacts as redacts,"
 | 
			
		||||
                " rej.event_id as rejects "
 | 
			
		||||
                " FROM event_json as e"
 | 
			
		||||
                " LEFT JOIN rejections as rej USING (event_id)"
 | 
			
		||||
                " LEFT JOIN redactions as r ON e.event_id = r.redacts"
 | 
			
		||||
                " WHERE e.event_id IN (%s)"
 | 
			
		||||
            ) % (",".join(["?"]*len(evs)),)
 | 
			
		||||
 | 
			
		||||
            txn.execute(sql, evs)
 | 
			
		||||
            rows.extend(self.cursor_to_dict(txn))
 | 
			
		||||
 | 
			
		||||
        return rows
 | 
			
		||||
 | 
			
		||||
    def _fetch_events_txn(self, txn, events, check_redacted=True,
 | 
			
		||||
                          get_prev_content=False, allow_rejected=False):
 | 
			
		||||
        if not events:
 | 
			
		||||
            return {}
 | 
			
		||||
 | 
			
		||||
        rows = self._fetch_event_rows(
 | 
			
		||||
            txn, events,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if not allow_rejected:
 | 
			
		||||
            rows[:] = [r for r in rows if not r["rejects"]]
 | 
			
		||||
 | 
			
		||||
        res = [
 | 
			
		||||
            self._get_event_from_row_txn(
 | 
			
		||||
                txn,
 | 
			
		||||
                row["internal_metadata"], row["json"], row["redacts"],
 | 
			
		||||
                check_redacted=check_redacted,
 | 
			
		||||
                get_prev_content=get_prev_content,
 | 
			
		||||
                rejected_reason=row["rejects"],
 | 
			
		||||
            )
 | 
			
		||||
            for row in rows
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        return {
 | 
			
		||||
            r.event_id: r
 | 
			
		||||
            for r in res
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def _get_event_from_row(self, internal_metadata, js, redacted,
 | 
			
		||||
                            check_redacted=True, get_prev_content=False,
 | 
			
		||||
                            rejected_reason=None):
 | 
			
		||||
        d = json.loads(js)
 | 
			
		||||
        internal_metadata = json.loads(internal_metadata)
 | 
			
		||||
 | 
			
		||||
        if rejected_reason:
 | 
			
		||||
            rejected_reason = yield self._simple_select_one_onecol(
 | 
			
		||||
                table="rejections",
 | 
			
		||||
                keyvalues={"event_id": rejected_reason},
 | 
			
		||||
                retcol="reason",
 | 
			
		||||
                desc="_get_event_from_row",
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        ev = FrozenEvent(
 | 
			
		||||
            d,
 | 
			
		||||
            internal_metadata_dict=internal_metadata,
 | 
			
		||||
| 
						 | 
				
			
			@ -490,12 +710,74 @@ class EventsStore(SQLBaseStore):
 | 
			
		|||
        if check_redacted and redacted:
 | 
			
		||||
            ev = prune_event(ev)
 | 
			
		||||
 | 
			
		||||
            ev.unsigned["redacted_by"] = redacted
 | 
			
		||||
            redaction_id = yield self._simple_select_one_onecol(
 | 
			
		||||
                table="redactions",
 | 
			
		||||
                keyvalues={"redacts": ev.event_id},
 | 
			
		||||
                retcol="event_id",
 | 
			
		||||
                desc="_get_event_from_row",
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            ev.unsigned["redacted_by"] = redaction_id
 | 
			
		||||
            # Get the redaction event.
 | 
			
		||||
 | 
			
		||||
            because = yield self.get_event(
 | 
			
		||||
                redaction_id,
 | 
			
		||||
                check_redacted=False
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            if because:
 | 
			
		||||
                ev.unsigned["redacted_because"] = because
 | 
			
		||||
 | 
			
		||||
        if get_prev_content and "replaces_state" in ev.unsigned:
 | 
			
		||||
            prev = yield self.get_event(
 | 
			
		||||
                ev.unsigned["replaces_state"],
 | 
			
		||||
                get_prev_content=False,
 | 
			
		||||
            )
 | 
			
		||||
            if prev:
 | 
			
		||||
                ev.unsigned["prev_content"] = prev.get_dict()["content"]
 | 
			
		||||
 | 
			
		||||
        self._get_event_cache.prefill(
 | 
			
		||||
            ev.event_id, check_redacted, get_prev_content, ev
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        defer.returnValue(ev)
 | 
			
		||||
 | 
			
		||||
    def _get_event_from_row_txn(self, txn, internal_metadata, js, redacted,
 | 
			
		||||
                                check_redacted=True, get_prev_content=False,
 | 
			
		||||
                                rejected_reason=None):
 | 
			
		||||
        d = json.loads(js)
 | 
			
		||||
        internal_metadata = json.loads(internal_metadata)
 | 
			
		||||
 | 
			
		||||
        if rejected_reason:
 | 
			
		||||
            rejected_reason = self._simple_select_one_onecol_txn(
 | 
			
		||||
                txn,
 | 
			
		||||
                table="rejections",
 | 
			
		||||
                keyvalues={"event_id": rejected_reason},
 | 
			
		||||
                retcol="reason",
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        ev = FrozenEvent(
 | 
			
		||||
            d,
 | 
			
		||||
            internal_metadata_dict=internal_metadata,
 | 
			
		||||
            rejected_reason=rejected_reason,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if check_redacted and redacted:
 | 
			
		||||
            ev = prune_event(ev)
 | 
			
		||||
 | 
			
		||||
            redaction_id = self._simple_select_one_onecol_txn(
 | 
			
		||||
                txn,
 | 
			
		||||
                table="redactions",
 | 
			
		||||
                keyvalues={"redacts": ev.event_id},
 | 
			
		||||
                retcol="event_id",
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            ev.unsigned["redacted_by"] = redaction_id
 | 
			
		||||
            # Get the redaction event.
 | 
			
		||||
 | 
			
		||||
            because = self._get_event_txn(
 | 
			
		||||
                txn,
 | 
			
		||||
                redacted,
 | 
			
		||||
                redaction_id,
 | 
			
		||||
                check_redacted=False
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -511,6 +793,10 @@ class EventsStore(SQLBaseStore):
 | 
			
		|||
            if prev:
 | 
			
		||||
                ev.unsigned["prev_content"] = prev.get_dict()["content"]
 | 
			
		||||
 | 
			
		||||
        self._get_event_cache.prefill(
 | 
			
		||||
            ev.event_id, check_redacted, get_prev_content, ev
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        return ev
 | 
			
		||||
 | 
			
		||||
    def _parse_events(self, rows):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -76,16 +76,16 @@ class RoomMemberStore(SQLBaseStore):
 | 
			
		|||
        Returns:
 | 
			
		||||
            Deferred: Results in a MembershipEvent or None.
 | 
			
		||||
        """
 | 
			
		||||
        def f(txn):
 | 
			
		||||
            events = self._get_members_events_txn(
 | 
			
		||||
                txn,
 | 
			
		||||
                room_id,
 | 
			
		||||
                user_id=user_id,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            return events[0] if events else None
 | 
			
		||||
 | 
			
		||||
        return self.runInteraction("get_room_member", f)
 | 
			
		||||
        return self.runInteraction(
 | 
			
		||||
            "get_room_member",
 | 
			
		||||
            self._get_members_events_txn,
 | 
			
		||||
            room_id,
 | 
			
		||||
            user_id=user_id,
 | 
			
		||||
        ).addCallback(
 | 
			
		||||
            self._get_events
 | 
			
		||||
        ).addCallback(
 | 
			
		||||
            lambda events: events[0] if events else None
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def get_users_in_room(self, room_id):
 | 
			
		||||
        def f(txn):
 | 
			
		||||
| 
						 | 
				
			
			@ -110,15 +110,12 @@ class RoomMemberStore(SQLBaseStore):
 | 
			
		|||
        Returns:
 | 
			
		||||
            list of namedtuples representing the members in this room.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        def f(txn):
 | 
			
		||||
            return self._get_members_events_txn(
 | 
			
		||||
                txn,
 | 
			
		||||
                room_id,
 | 
			
		||||
                membership=membership,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        return self.runInteraction("get_room_members", f)
 | 
			
		||||
        return self.runInteraction(
 | 
			
		||||
            "get_room_members",
 | 
			
		||||
            self._get_members_events_txn,
 | 
			
		||||
            room_id,
 | 
			
		||||
            membership=membership,
 | 
			
		||||
        ).addCallback(self._get_events)
 | 
			
		||||
 | 
			
		||||
    def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
 | 
			
		||||
        """ Get all the rooms for this user where the membership for this user
 | 
			
		||||
| 
						 | 
				
			
			@ -190,14 +187,14 @@ class RoomMemberStore(SQLBaseStore):
 | 
			
		|||
        return self.runInteraction(
 | 
			
		||||
            "get_members_query", self._get_members_events_txn,
 | 
			
		||||
            where_clause, where_values
 | 
			
		||||
        )
 | 
			
		||||
        ).addCallbacks(self._get_events)
 | 
			
		||||
 | 
			
		||||
    def _get_members_events_txn(self, txn, room_id, membership=None, user_id=None):
 | 
			
		||||
        rows = self._get_members_rows_txn(
 | 
			
		||||
            txn,
 | 
			
		||||
            room_id, membership, user_id,
 | 
			
		||||
        )
 | 
			
		||||
        return self._get_events_txn(txn, [r["event_id"] for r in rows])
 | 
			
		||||
        return [r["event_id"] for r in rows]
 | 
			
		||||
 | 
			
		||||
    def _get_members_rows_txn(self, txn, room_id, membership=None, user_id=None):
 | 
			
		||||
        where_clause = "c.room_id = ?"
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,19 @@
 | 
			
		|||
/* Copyright 2015 OpenMarket Ltd
 | 
			
		||||
 *
 | 
			
		||||
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
 * you may not use this file except in compliance with the License.
 | 
			
		||||
 * You may obtain a copy of the License at
 | 
			
		||||
 *
 | 
			
		||||
 *    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 *
 | 
			
		||||
 * Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
 * See the License for the specific language governing permissions and
 | 
			
		||||
 * limitations under the License.
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
CREATE INDEX events_order_topo_stream_room ON events(
 | 
			
		||||
    topological_ordering, stream_ordering, room_id
 | 
			
		||||
);
 | 
			
		||||
| 
						 | 
				
			
			@ -43,6 +43,7 @@ class StateStore(SQLBaseStore):
 | 
			
		|||
      * `state_groups_state`: Maps state group to state events.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def get_state_groups(self, event_ids):
 | 
			
		||||
        """ Get the state groups for the given list of event_ids
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -71,17 +72,29 @@ class StateStore(SQLBaseStore):
 | 
			
		|||
                    retcol="event_id",
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
                state = self._get_events_txn(txn, state_ids)
 | 
			
		||||
 | 
			
		||||
                res[group] = state
 | 
			
		||||
                res[group] = state_ids
 | 
			
		||||
 | 
			
		||||
            return res
 | 
			
		||||
 | 
			
		||||
        return self.runInteraction(
 | 
			
		||||
        states = yield self.runInteraction(
 | 
			
		||||
            "get_state_groups",
 | 
			
		||||
            f,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        @defer.inlineCallbacks
 | 
			
		||||
        def c(vals):
 | 
			
		||||
            vals[:] = yield self._get_events(vals, get_prev_content=False)
 | 
			
		||||
 | 
			
		||||
        yield defer.gatherResults(
 | 
			
		||||
            [
 | 
			
		||||
                c(vals)
 | 
			
		||||
                for vals in states.values()
 | 
			
		||||
            ],
 | 
			
		||||
            consumeErrors=True,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        defer.returnValue(states)
 | 
			
		||||
 | 
			
		||||
    def _store_state_groups_txn(self, txn, event, context):
 | 
			
		||||
        if context.current_state is None:
 | 
			
		||||
            return
 | 
			
		||||
| 
						 | 
				
			
			@ -146,11 +159,12 @@ class StateStore(SQLBaseStore):
 | 
			
		|||
                args = (room_id, )
 | 
			
		||||
 | 
			
		||||
            txn.execute(sql, args)
 | 
			
		||||
            results = self.cursor_to_dict(txn)
 | 
			
		||||
            results = txn.fetchall()
 | 
			
		||||
 | 
			
		||||
            return self._parse_events_txn(txn, results)
 | 
			
		||||
            return [r[0] for r in results]
 | 
			
		||||
 | 
			
		||||
        events = yield self.runInteraction("get_current_state", f)
 | 
			
		||||
        event_ids = yield self.runInteraction("get_current_state", f)
 | 
			
		||||
        events = yield self._get_events(event_ids, get_prev_content=False)
 | 
			
		||||
        defer.returnValue(events)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -224,7 +224,7 @@ class StreamStore(SQLBaseStore):
 | 
			
		|||
 | 
			
		||||
        return self.runInteraction("get_room_events_stream", f)
 | 
			
		||||
 | 
			
		||||
    @log_function
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def paginate_room_events(self, room_id, from_key, to_key=None,
 | 
			
		||||
                             direction='b', limit=-1,
 | 
			
		||||
                             with_feedback=False):
 | 
			
		||||
| 
						 | 
				
			
			@ -286,18 +286,20 @@ class StreamStore(SQLBaseStore):
 | 
			
		|||
                # TODO (erikj): We should work out what to do here instead.
 | 
			
		||||
                next_token = to_key if to_key else from_key
 | 
			
		||||
 | 
			
		||||
            events = self._get_events_txn(
 | 
			
		||||
                txn,
 | 
			
		||||
                [r["event_id"] for r in rows],
 | 
			
		||||
                get_prev_content=True
 | 
			
		||||
            )
 | 
			
		||||
            return rows, next_token,
 | 
			
		||||
 | 
			
		||||
            self._set_before_and_after(events, rows)
 | 
			
		||||
        rows, token = yield self.runInteraction("paginate_room_events", f)
 | 
			
		||||
 | 
			
		||||
            return events, next_token,
 | 
			
		||||
        events = yield self._get_events(
 | 
			
		||||
            [r["event_id"] for r in rows],
 | 
			
		||||
            get_prev_content=True
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        return self.runInteraction("paginate_room_events", f)
 | 
			
		||||
        self._set_before_and_after(events, rows)
 | 
			
		||||
 | 
			
		||||
        defer.returnValue((events, token))
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def get_recent_events_for_room(self, room_id, limit, end_token,
 | 
			
		||||
                                   with_feedback=False, from_token=None):
 | 
			
		||||
        # TODO (erikj): Handle compressed feedback
 | 
			
		||||
| 
						 | 
				
			
			@ -349,20 +351,23 @@ class StreamStore(SQLBaseStore):
 | 
			
		|||
            else:
 | 
			
		||||
                token = (str(end_token), str(end_token))
 | 
			
		||||
 | 
			
		||||
            events = self._get_events_txn(
 | 
			
		||||
                txn,
 | 
			
		||||
                [r["event_id"] for r in rows],
 | 
			
		||||
                get_prev_content=True
 | 
			
		||||
            )
 | 
			
		||||
            return rows, token
 | 
			
		||||
 | 
			
		||||
            self._set_before_and_after(events, rows)
 | 
			
		||||
 | 
			
		||||
            return events, token
 | 
			
		||||
 | 
			
		||||
        return self.runInteraction(
 | 
			
		||||
        rows, token = yield self.runInteraction(
 | 
			
		||||
            "get_recent_events_for_room", get_recent_events_for_room_txn
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        logger.debug("stream before")
 | 
			
		||||
        events = yield self._get_events(
 | 
			
		||||
            [r["event_id"] for r in rows],
 | 
			
		||||
            get_prev_content=True
 | 
			
		||||
        )
 | 
			
		||||
        logger.debug("stream after")
 | 
			
		||||
 | 
			
		||||
        self._set_before_and_after(events, rows)
 | 
			
		||||
 | 
			
		||||
        defer.returnValue((events, token))
 | 
			
		||||
 | 
			
		||||
    @defer.inlineCallbacks
 | 
			
		||||
    def get_room_events_max_id(self, direction='f'):
 | 
			
		||||
        token = yield self._stream_id_gen.get_max_token(self)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -29,6 +29,34 @@ def unwrapFirstError(failure):
 | 
			
		|||
    return failure.value.subFailure
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def unwrap_deferred(d):
 | 
			
		||||
    """Given a deferred that we know has completed, return its value or raise
 | 
			
		||||
    the failure as an exception
 | 
			
		||||
    """
 | 
			
		||||
    if not d.called:
 | 
			
		||||
        raise RuntimeError("deferred has not finished")
 | 
			
		||||
 | 
			
		||||
    res = []
 | 
			
		||||
 | 
			
		||||
    def f(r):
 | 
			
		||||
        res.append(r)
 | 
			
		||||
        return r
 | 
			
		||||
    d.addCallback(f)
 | 
			
		||||
 | 
			
		||||
    if res:
 | 
			
		||||
        return res[0]
 | 
			
		||||
 | 
			
		||||
    def f(r):
 | 
			
		||||
        res.append(r)
 | 
			
		||||
        return r
 | 
			
		||||
    d.addErrback(f)
 | 
			
		||||
 | 
			
		||||
    if res:
 | 
			
		||||
        res[0].raiseException()
 | 
			
		||||
    else:
 | 
			
		||||
        raise RuntimeError("deferred did not call callbacks")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Clock(object):
 | 
			
		||||
    """A small utility that obtains current time-of-day so that time may be
 | 
			
		||||
    mocked during unit-tests.
 | 
			
		||||
| 
						 | 
				
			
			@ -52,16 +80,16 @@ class Clock(object):
 | 
			
		|||
    def stop_looping_call(self, loop):
 | 
			
		||||
        loop.stop()
 | 
			
		||||
 | 
			
		||||
    def call_later(self, delay, callback):
 | 
			
		||||
    def call_later(self, delay, callback, *args, **kwargs):
 | 
			
		||||
        current_context = LoggingContext.current_context()
 | 
			
		||||
 | 
			
		||||
        def wrapped_callback():
 | 
			
		||||
        def wrapped_callback(*args, **kwargs):
 | 
			
		||||
            with PreserveLoggingContext():
 | 
			
		||||
                LoggingContext.thread_local.current_context = current_context
 | 
			
		||||
                callback()
 | 
			
		||||
                callback(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
        with PreserveLoggingContext():
 | 
			
		||||
            return reactor.callLater(delay, wrapped_callback)
 | 
			
		||||
            return reactor.callLater(delay, wrapped_callback, *args, **kwargs)
 | 
			
		||||
 | 
			
		||||
    def cancel_call_later(self, timer):
 | 
			
		||||
        timer.cancel()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -33,8 +33,9 @@ class SQLBaseStoreTestCase(unittest.TestCase):
 | 
			
		|||
    def setUp(self):
 | 
			
		||||
        self.db_pool = Mock(spec=["runInteraction"])
 | 
			
		||||
        self.mock_txn = Mock()
 | 
			
		||||
        self.mock_conn = Mock(spec_set=["cursor"])
 | 
			
		||||
        self.mock_conn = Mock(spec_set=["cursor", "rollback", "commit"])
 | 
			
		||||
        self.mock_conn.cursor.return_value = self.mock_txn
 | 
			
		||||
        self.mock_conn.rollback.return_value = None
 | 
			
		||||
        # Our fake runInteraction just runs synchronously inline
 | 
			
		||||
 | 
			
		||||
        def runInteraction(func, *args, **kwargs):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue