102 lines
		
	
	
		
			3.6 KiB
		
	
	
	
		
			Python
		
	
	
			
		
		
	
	
			102 lines
		
	
	
		
			3.6 KiB
		
	
	
	
		
			Python
		
	
	
#  Copyright 2021 The Matrix.org Foundation C.I.C.
 | 
						|
#
 | 
						|
#  Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
#  you may not use this file except in compliance with the License.
 | 
						|
#  You may obtain a copy of the License at
 | 
						|
#
 | 
						|
#      http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
#
 | 
						|
#  Unless required by applicable law or agreed to in writing, software
 | 
						|
#  distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
#  See the License for the specific language governing permissions and
 | 
						|
#  limitations under the License.
 | 
						|
 | 
						|
from io import BytesIO
 | 
						|
 | 
						|
from mock import Mock
 | 
						|
 | 
						|
from twisted.python.failure import Failure
 | 
						|
from twisted.web.client import ResponseDone
 | 
						|
 | 
						|
from synapse.http.client import BodyExceededMaxSize, read_body_with_max_size
 | 
						|
 | 
						|
from tests.unittest import TestCase
 | 
						|
 | 
						|
 | 
						|
class ReadBodyWithMaxSizeTests(TestCase):
 | 
						|
    def setUp(self):
 | 
						|
        """Start reading the body, returns the response, result and proto"""
 | 
						|
        self.response = Mock()
 | 
						|
        self.result = BytesIO()
 | 
						|
        self.deferred = read_body_with_max_size(self.response, self.result, 6)
 | 
						|
 | 
						|
        # Fish the protocol out of the response.
 | 
						|
        self.protocol = self.response.deliverBody.call_args[0][0]
 | 
						|
        self.protocol.transport = Mock()
 | 
						|
 | 
						|
    def _cleanup_error(self):
 | 
						|
        """Ensure that the error in the Deferred is handled gracefully."""
 | 
						|
        called = [False]
 | 
						|
 | 
						|
        def errback(f):
 | 
						|
            called[0] = True
 | 
						|
 | 
						|
        self.deferred.addErrback(errback)
 | 
						|
        self.assertTrue(called[0])
 | 
						|
 | 
						|
    def test_no_error(self):
 | 
						|
        """A response that is NOT too large."""
 | 
						|
 | 
						|
        # Start sending data.
 | 
						|
        self.protocol.dataReceived(b"12345")
 | 
						|
        # Close the connection.
 | 
						|
        self.protocol.connectionLost(Failure(ResponseDone()))
 | 
						|
 | 
						|
        self.assertEqual(self.result.getvalue(), b"12345")
 | 
						|
        self.assertEqual(self.deferred.result, 5)
 | 
						|
 | 
						|
    def test_too_large(self):
 | 
						|
        """A response which is too large raises an exception."""
 | 
						|
 | 
						|
        # Start sending data.
 | 
						|
        self.protocol.dataReceived(b"1234567890")
 | 
						|
        # Close the connection.
 | 
						|
        self.protocol.connectionLost(Failure(ResponseDone()))
 | 
						|
 | 
						|
        self.assertEqual(self.result.getvalue(), b"1234567890")
 | 
						|
        self.assertIsInstance(self.deferred.result, Failure)
 | 
						|
        self.assertIsInstance(self.deferred.result.value, BodyExceededMaxSize)
 | 
						|
        self._cleanup_error()
 | 
						|
 | 
						|
    def test_multiple_packets(self):
 | 
						|
        """Data should be accummulated through mutliple packets."""
 | 
						|
 | 
						|
        # Start sending data.
 | 
						|
        self.protocol.dataReceived(b"12")
 | 
						|
        self.protocol.dataReceived(b"34")
 | 
						|
        # Close the connection.
 | 
						|
        self.protocol.connectionLost(Failure(ResponseDone()))
 | 
						|
 | 
						|
        self.assertEqual(self.result.getvalue(), b"1234")
 | 
						|
        self.assertEqual(self.deferred.result, 4)
 | 
						|
 | 
						|
    def test_additional_data(self):
 | 
						|
        """A connection can receive data after being closed."""
 | 
						|
 | 
						|
        # Start sending data.
 | 
						|
        self.protocol.dataReceived(b"1234567890")
 | 
						|
        self.assertIsInstance(self.deferred.result, Failure)
 | 
						|
        self.assertIsInstance(self.deferred.result.value, BodyExceededMaxSize)
 | 
						|
        self.protocol.transport.loseConnection.assert_called_once()
 | 
						|
 | 
						|
        # More data might have come in.
 | 
						|
        self.protocol.dataReceived(b"1234567890")
 | 
						|
        # Close the connection.
 | 
						|
        self.protocol.connectionLost(Failure(ResponseDone()))
 | 
						|
 | 
						|
        self.assertEqual(self.result.getvalue(), b"1234567890")
 | 
						|
        self.assertIsInstance(self.deferred.result, Failure)
 | 
						|
        self.assertIsInstance(self.deferred.result.value, BodyExceededMaxSize)
 | 
						|
        self._cleanup_error()
 |