Implement updating devices
You can update the displayname of devices now.pull/949/head
							parent
							
								
									436bffd15f
								
							
						
					
					
						commit
						012b4c1913
					
				| 
						 | 
					@ -141,6 +141,30 @@ class DeviceHandler(BaseHandler):
 | 
				
			||||||
        yield self.store.user_delete_access_tokens(user_id,
 | 
					        yield self.store.user_delete_access_tokens(user_id,
 | 
				
			||||||
                                                   device_id=device_id)
 | 
					                                                   device_id=device_id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @defer.inlineCallbacks
 | 
				
			||||||
 | 
					    def update_device(self, user_id, device_id, content):
 | 
				
			||||||
 | 
					        """ Update the given device
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					            user_id (str):
 | 
				
			||||||
 | 
					            device_id (str):
 | 
				
			||||||
 | 
					            content (dict): body of update request
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					            defer.Deferred:
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            yield self.store.update_device(
 | 
				
			||||||
 | 
					                user_id,
 | 
				
			||||||
 | 
					                device_id,
 | 
				
			||||||
 | 
					                new_display_name=content.get("display_name")
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        except errors.StoreError, e:
 | 
				
			||||||
 | 
					            if e.code == 404:
 | 
				
			||||||
 | 
					                raise errors.NotFoundError()
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                raise
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _update_device_from_client_ips(device, client_ips):
 | 
					def _update_device_from_client_ips(device, client_ips):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -13,19 +13,17 @@
 | 
				
			||||||
# See the License for the specific language governing permissions and
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
# limitations under the License.
 | 
					# limitations under the License.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from twisted.internet import defer
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from synapse.http.servlet import RestServlet
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from ._base import client_v2_patterns
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import logging
 | 
					import logging
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from twisted.internet import defer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from synapse.http import servlet
 | 
				
			||||||
 | 
					from ._base import client_v2_patterns
 | 
				
			||||||
 | 
					
 | 
				
			||||||
logger = logging.getLogger(__name__)
 | 
					logger = logging.getLogger(__name__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class DevicesRestServlet(RestServlet):
 | 
					class DevicesRestServlet(servlet.RestServlet):
 | 
				
			||||||
    PATTERNS = client_v2_patterns("/devices$", releases=[], v2_alpha=False)
 | 
					    PATTERNS = client_v2_patterns("/devices$", releases=[], v2_alpha=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self, hs):
 | 
					    def __init__(self, hs):
 | 
				
			||||||
| 
						 | 
					@ -47,7 +45,7 @@ class DevicesRestServlet(RestServlet):
 | 
				
			||||||
        defer.returnValue((200, {"devices": devices}))
 | 
					        defer.returnValue((200, {"devices": devices}))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class DeviceRestServlet(RestServlet):
 | 
					class DeviceRestServlet(servlet.RestServlet):
 | 
				
			||||||
    PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$",
 | 
					    PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$",
 | 
				
			||||||
                                  releases=[], v2_alpha=False)
 | 
					                                  releases=[], v2_alpha=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -84,6 +82,18 @@ class DeviceRestServlet(RestServlet):
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        defer.returnValue((200, {}))
 | 
					        defer.returnValue((200, {}))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @defer.inlineCallbacks
 | 
				
			||||||
 | 
					    def on_PUT(self, request, device_id):
 | 
				
			||||||
 | 
					        requester = yield self.auth.get_user_by_req(request)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        body = servlet.parse_json_object_from_request(request)
 | 
				
			||||||
 | 
					        yield self.device_handler.update_device(
 | 
				
			||||||
 | 
					            requester.user.to_string(),
 | 
				
			||||||
 | 
					            device_id,
 | 
				
			||||||
 | 
					            body
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        defer.returnValue((200, {}))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def register_servlets(hs, http_server):
 | 
					def register_servlets(hs, http_server):
 | 
				
			||||||
    DevicesRestServlet(hs).register(http_server)
 | 
					    DevicesRestServlet(hs).register(http_server)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -81,7 +81,7 @@ class DeviceStore(SQLBaseStore):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Args:
 | 
					        Args:
 | 
				
			||||||
            user_id (str): The ID of the user which owns the device
 | 
					            user_id (str): The ID of the user which owns the device
 | 
				
			||||||
            device_id (str): The ID of the device to retrieve
 | 
					            device_id (str): The ID of the device to delete
 | 
				
			||||||
        Returns:
 | 
					        Returns:
 | 
				
			||||||
            defer.Deferred
 | 
					            defer.Deferred
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
| 
						 | 
					@ -91,6 +91,31 @@ class DeviceStore(SQLBaseStore):
 | 
				
			||||||
            desc="delete_device",
 | 
					            desc="delete_device",
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def update_device(self, user_id, device_id, new_display_name=None):
 | 
				
			||||||
 | 
					        """Update a device.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					            user_id (str): The ID of the user which owns the device
 | 
				
			||||||
 | 
					            device_id (str): The ID of the device to update
 | 
				
			||||||
 | 
					            new_display_name (str|None): new displayname for device; None
 | 
				
			||||||
 | 
					               to leave unchanged
 | 
				
			||||||
 | 
					        Raises:
 | 
				
			||||||
 | 
					            StoreError: if the device is not found
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					            defer.Deferred
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        updates = {}
 | 
				
			||||||
 | 
					        if new_display_name is not None:
 | 
				
			||||||
 | 
					            updates["display_name"] = new_display_name
 | 
				
			||||||
 | 
					        if not updates:
 | 
				
			||||||
 | 
					            return defer.succeed(None)
 | 
				
			||||||
 | 
					        return self._simple_update_one(
 | 
				
			||||||
 | 
					            table="devices",
 | 
				
			||||||
 | 
					            keyvalues={"user_id": user_id, "device_id": device_id},
 | 
				
			||||||
 | 
					            updatevalues=updates,
 | 
				
			||||||
 | 
					            desc="update_device",
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @defer.inlineCallbacks
 | 
					    @defer.inlineCallbacks
 | 
				
			||||||
    def get_devices_by_user(self, user_id):
 | 
					    def get_devices_by_user(self, user_id):
 | 
				
			||||||
        """Retrieve all of a user's registered devices.
 | 
					        """Retrieve all of a user's registered devices.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -140,6 +140,22 @@ class DeviceTestCase(unittest.TestCase):
 | 
				
			||||||
        # we'd like to check the access token was invalidated, but that's a
 | 
					        # we'd like to check the access token was invalidated, but that's a
 | 
				
			||||||
        # bit of a PITA.
 | 
					        # bit of a PITA.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @defer.inlineCallbacks
 | 
				
			||||||
 | 
					    def test_update_device(self):
 | 
				
			||||||
 | 
					        yield self._record_users()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        update = {"display_name": "new display"}
 | 
				
			||||||
 | 
					        yield self.handler.update_device(user1, "abc", update)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        res = yield self.handler.get_device(user1, "abc")
 | 
				
			||||||
 | 
					        self.assertEqual(res["display_name"], "new display")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @defer.inlineCallbacks
 | 
				
			||||||
 | 
					    def test_update_unknown_device(self):
 | 
				
			||||||
 | 
					        update = {"display_name": "new_display"}
 | 
				
			||||||
 | 
					        with self.assertRaises(synapse.api.errors.NotFoundError):
 | 
				
			||||||
 | 
					            yield self.handler.update_device("user_id", "unknown_device_id",
 | 
				
			||||||
 | 
					                                             update)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @defer.inlineCallbacks
 | 
					    @defer.inlineCallbacks
 | 
				
			||||||
    def _record_users(self):
 | 
					    def _record_users(self):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -15,6 +15,7 @@
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from twisted.internet import defer
 | 
					from twisted.internet import defer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import synapse.api.errors
 | 
				
			||||||
import tests.unittest
 | 
					import tests.unittest
 | 
				
			||||||
import tests.utils
 | 
					import tests.utils
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -67,3 +68,38 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
 | 
				
			||||||
            "device_id": "device2",
 | 
					            "device_id": "device2",
 | 
				
			||||||
            "display_name": "display_name 2",
 | 
					            "display_name": "display_name 2",
 | 
				
			||||||
        }, res["device2"])
 | 
					        }, res["device2"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @defer.inlineCallbacks
 | 
				
			||||||
 | 
					    def test_update_device(self):
 | 
				
			||||||
 | 
					        yield self.store.store_device(
 | 
				
			||||||
 | 
					            "user_id", "device_id", "display_name 1"
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        res = yield self.store.get_device("user_id", "device_id")
 | 
				
			||||||
 | 
					        self.assertEqual("display_name 1", res["display_name"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # do a no-op first
 | 
				
			||||||
 | 
					        yield self.store.update_device(
 | 
				
			||||||
 | 
					            "user_id", "device_id",
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        res = yield self.store.get_device("user_id", "device_id")
 | 
				
			||||||
 | 
					        self.assertEqual("display_name 1", res["display_name"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # do the update
 | 
				
			||||||
 | 
					        yield self.store.update_device(
 | 
				
			||||||
 | 
					            "user_id", "device_id",
 | 
				
			||||||
 | 
					            new_display_name="display_name 2",
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # check it worked
 | 
				
			||||||
 | 
					        res = yield self.store.get_device("user_id", "device_id")
 | 
				
			||||||
 | 
					        self.assertEqual("display_name 2", res["display_name"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @defer.inlineCallbacks
 | 
				
			||||||
 | 
					    def test_update_unknown_device(self):
 | 
				
			||||||
 | 
					        with self.assertRaises(synapse.api.errors.StoreError) as cm:
 | 
				
			||||||
 | 
					            yield self.store.update_device(
 | 
				
			||||||
 | 
					                "user_id", "unknown_device_id",
 | 
				
			||||||
 | 
					                new_display_name="display_name 2",
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        self.assertEqual(404, cm.exception.code)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue