99 lines
		
	
	
		
			3.4 KiB
		
	
	
	
		
			Python
		
	
	
			
		
		
	
	
			99 lines
		
	
	
		
			3.4 KiB
		
	
	
	
		
			Python
		
	
	
# Copyright 2019 The Matrix.org Foundation C.I.C.
 | 
						|
#
 | 
						|
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
# you may not use this file except in compliance with the License.
 | 
						|
# You may obtain a copy of the License at
 | 
						|
#
 | 
						|
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
#
 | 
						|
# Unless required by applicable law or agreed to in writing, software
 | 
						|
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
# See the License for the specific language governing permissions and
 | 
						|
# limitations under the License.
 | 
						|
from typing import Optional
 | 
						|
 | 
						|
from synapse.config.homeserver import HomeServerConfig
 | 
						|
from synapse.util.ratelimitutils import FederationRateLimiter
 | 
						|
 | 
						|
from tests.server import get_clock
 | 
						|
from tests.unittest import TestCase
 | 
						|
from tests.utils import default_config
 | 
						|
 | 
						|
 | 
						|
class FederationRateLimiterTestCase(TestCase):
 | 
						|
    def test_ratelimit(self):
 | 
						|
        """A simple test with the default values"""
 | 
						|
        reactor, clock = get_clock()
 | 
						|
        rc_config = build_rc_config()
 | 
						|
        ratelimiter = FederationRateLimiter(clock, rc_config)
 | 
						|
 | 
						|
        with ratelimiter.ratelimit("testhost") as d1:
 | 
						|
            # shouldn't block
 | 
						|
            self.successResultOf(d1)
 | 
						|
 | 
						|
    def test_concurrent_limit(self):
 | 
						|
        """Test what happens when we hit the concurrent limit"""
 | 
						|
        reactor, clock = get_clock()
 | 
						|
        rc_config = build_rc_config({"rc_federation": {"concurrent": 2}})
 | 
						|
        ratelimiter = FederationRateLimiter(clock, rc_config)
 | 
						|
 | 
						|
        with ratelimiter.ratelimit("testhost") as d1:
 | 
						|
            # shouldn't block
 | 
						|
            self.successResultOf(d1)
 | 
						|
 | 
						|
            cm2 = ratelimiter.ratelimit("testhost")
 | 
						|
            d2 = cm2.__enter__()
 | 
						|
            # also shouldn't block
 | 
						|
            self.successResultOf(d2)
 | 
						|
 | 
						|
            cm3 = ratelimiter.ratelimit("testhost")
 | 
						|
            d3 = cm3.__enter__()
 | 
						|
            # this one should block, though ...
 | 
						|
            self.assertNoResult(d3)
 | 
						|
 | 
						|
            # ... until we complete an earlier request
 | 
						|
            cm2.__exit__(None, None, None)
 | 
						|
            self.successResultOf(d3)
 | 
						|
 | 
						|
    def test_sleep_limit(self):
 | 
						|
        """Test what happens when we hit the sleep limit"""
 | 
						|
        reactor, clock = get_clock()
 | 
						|
        rc_config = build_rc_config(
 | 
						|
            {"rc_federation": {"sleep_limit": 2, "sleep_delay": 500}}
 | 
						|
        )
 | 
						|
        ratelimiter = FederationRateLimiter(clock, rc_config)
 | 
						|
 | 
						|
        with ratelimiter.ratelimit("testhost") as d1:
 | 
						|
            # shouldn't block
 | 
						|
            self.successResultOf(d1)
 | 
						|
 | 
						|
        with ratelimiter.ratelimit("testhost") as d2:
 | 
						|
            # nor this
 | 
						|
            self.successResultOf(d2)
 | 
						|
 | 
						|
        with ratelimiter.ratelimit("testhost") as d3:
 | 
						|
            # this one should block, though ...
 | 
						|
            self.assertNoResult(d3)
 | 
						|
            sleep_time = _await_resolution(reactor, d3)
 | 
						|
            self.assertAlmostEqual(sleep_time, 500, places=3)
 | 
						|
 | 
						|
 | 
						|
def _await_resolution(reactor, d):
 | 
						|
    """advance the clock until the deferred completes.
 | 
						|
 | 
						|
    Returns the number of milliseconds it took to complete.
 | 
						|
    """
 | 
						|
    start_time = reactor.seconds()
 | 
						|
    while not d.called:
 | 
						|
        reactor.advance(0.01)
 | 
						|
    return (reactor.seconds() - start_time) * 1000
 | 
						|
 | 
						|
 | 
						|
def build_rc_config(settings: Optional[dict] = None):
 | 
						|
    config_dict = default_config("test")
 | 
						|
    config_dict.update(settings or {})
 | 
						|
    config = HomeServerConfig()
 | 
						|
    config.parse_config_dict(config_dict, "", "")
 | 
						|
    return config.ratelimiting.rc_federation
 |