Implement updating devices

You can update the displayname of devices now.
pull/949/head
Richard van der Hoff 2016-07-25 17:51:24 +01:00
parent 436bffd15f
commit 012b4c1913
5 changed files with 120 additions and 9 deletions

View File

@ -141,6 +141,30 @@ class DeviceHandler(BaseHandler):
yield self.store.user_delete_access_tokens(user_id, yield self.store.user_delete_access_tokens(user_id,
device_id=device_id) device_id=device_id)
@defer.inlineCallbacks
def update_device(self, user_id, device_id, content):
""" Update the given device
Args:
user_id (str):
device_id (str):
content (dict): body of update request
Returns:
defer.Deferred:
"""
try:
yield self.store.update_device(
user_id,
device_id,
new_display_name=content.get("display_name")
)
except errors.StoreError, e:
if e.code == 404:
raise errors.NotFoundError()
else:
raise
def _update_device_from_client_ips(device, client_ips): def _update_device_from_client_ips(device, client_ips):

View File

@ -13,19 +13,17 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
from synapse.http.servlet import RestServlet
from ._base import client_v2_patterns
import logging import logging
from twisted.internet import defer
from synapse.http import servlet
from ._base import client_v2_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DevicesRestServlet(RestServlet): class DevicesRestServlet(servlet.RestServlet):
PATTERNS = client_v2_patterns("/devices$", releases=[], v2_alpha=False) PATTERNS = client_v2_patterns("/devices$", releases=[], v2_alpha=False)
def __init__(self, hs): def __init__(self, hs):
@ -47,7 +45,7 @@ class DevicesRestServlet(RestServlet):
defer.returnValue((200, {"devices": devices})) defer.returnValue((200, {"devices": devices}))
class DeviceRestServlet(RestServlet): class DeviceRestServlet(servlet.RestServlet):
PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$", PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$",
releases=[], v2_alpha=False) releases=[], v2_alpha=False)
@ -84,6 +82,18 @@ class DeviceRestServlet(RestServlet):
) )
defer.returnValue((200, {})) defer.returnValue((200, {}))
@defer.inlineCallbacks
def on_PUT(self, request, device_id):
requester = yield self.auth.get_user_by_req(request)
body = servlet.parse_json_object_from_request(request)
yield self.device_handler.update_device(
requester.user.to_string(),
device_id,
body
)
defer.returnValue((200, {}))
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
DevicesRestServlet(hs).register(http_server) DevicesRestServlet(hs).register(http_server)

View File

@ -81,7 +81,7 @@ class DeviceStore(SQLBaseStore):
Args: Args:
user_id (str): The ID of the user which owns the device user_id (str): The ID of the user which owns the device
device_id (str): The ID of the device to retrieve device_id (str): The ID of the device to delete
Returns: Returns:
defer.Deferred defer.Deferred
""" """
@ -91,6 +91,31 @@ class DeviceStore(SQLBaseStore):
desc="delete_device", desc="delete_device",
) )
def update_device(self, user_id, device_id, new_display_name=None):
"""Update a device.
Args:
user_id (str): The ID of the user which owns the device
device_id (str): The ID of the device to update
new_display_name (str|None): new displayname for device; None
to leave unchanged
Raises:
StoreError: if the device is not found
Returns:
defer.Deferred
"""
updates = {}
if new_display_name is not None:
updates["display_name"] = new_display_name
if not updates:
return defer.succeed(None)
return self._simple_update_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id},
updatevalues=updates,
desc="update_device",
)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_devices_by_user(self, user_id): def get_devices_by_user(self, user_id):
"""Retrieve all of a user's registered devices. """Retrieve all of a user's registered devices.

View File

@ -140,6 +140,22 @@ class DeviceTestCase(unittest.TestCase):
# we'd like to check the access token was invalidated, but that's a # we'd like to check the access token was invalidated, but that's a
# bit of a PITA. # bit of a PITA.
@defer.inlineCallbacks
def test_update_device(self):
yield self._record_users()
update = {"display_name": "new display"}
yield self.handler.update_device(user1, "abc", update)
res = yield self.handler.get_device(user1, "abc")
self.assertEqual(res["display_name"], "new display")
@defer.inlineCallbacks
def test_update_unknown_device(self):
update = {"display_name": "new_display"}
with self.assertRaises(synapse.api.errors.NotFoundError):
yield self.handler.update_device("user_id", "unknown_device_id",
update)
@defer.inlineCallbacks @defer.inlineCallbacks
def _record_users(self): def _record_users(self):

View File

@ -15,6 +15,7 @@
from twisted.internet import defer from twisted.internet import defer
import synapse.api.errors
import tests.unittest import tests.unittest
import tests.utils import tests.utils
@ -67,3 +68,38 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
"device_id": "device2", "device_id": "device2",
"display_name": "display_name 2", "display_name": "display_name 2",
}, res["device2"]) }, res["device2"])
@defer.inlineCallbacks
def test_update_device(self):
yield self.store.store_device(
"user_id", "device_id", "display_name 1"
)
res = yield self.store.get_device("user_id", "device_id")
self.assertEqual("display_name 1", res["display_name"])
# do a no-op first
yield self.store.update_device(
"user_id", "device_id",
)
res = yield self.store.get_device("user_id", "device_id")
self.assertEqual("display_name 1", res["display_name"])
# do the update
yield self.store.update_device(
"user_id", "device_id",
new_display_name="display_name 2",
)
# check it worked
res = yield self.store.get_device("user_id", "device_id")
self.assertEqual("display_name 2", res["display_name"])
@defer.inlineCallbacks
def test_update_unknown_device(self):
with self.assertRaises(synapse.api.errors.StoreError) as cm:
yield self.store.update_device(
"user_id", "unknown_device_id",
new_display_name="display_name 2",
)
self.assertEqual(404, cm.exception.code)