Implement updating devices
You can update the displayname of devices now.pull/949/head
parent
436bffd15f
commit
012b4c1913
|
@ -141,6 +141,30 @@ class DeviceHandler(BaseHandler):
|
||||||
yield self.store.user_delete_access_tokens(user_id,
|
yield self.store.user_delete_access_tokens(user_id,
|
||||||
device_id=device_id)
|
device_id=device_id)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def update_device(self, user_id, device_id, content):
|
||||||
|
""" Update the given device
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id (str):
|
||||||
|
device_id (str):
|
||||||
|
content (dict): body of update request
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
defer.Deferred:
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield self.store.update_device(
|
||||||
|
user_id,
|
||||||
|
device_id,
|
||||||
|
new_display_name=content.get("display_name")
|
||||||
|
)
|
||||||
|
except errors.StoreError, e:
|
||||||
|
if e.code == 404:
|
||||||
|
raise errors.NotFoundError()
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
def _update_device_from_client_ips(device, client_ips):
|
def _update_device_from_client_ips(device, client_ips):
|
||||||
|
|
|
@ -13,19 +13,17 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.http.servlet import RestServlet
|
|
||||||
|
|
||||||
from ._base import client_v2_patterns
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.http import servlet
|
||||||
|
from ._base import client_v2_patterns
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class DevicesRestServlet(RestServlet):
|
class DevicesRestServlet(servlet.RestServlet):
|
||||||
PATTERNS = client_v2_patterns("/devices$", releases=[], v2_alpha=False)
|
PATTERNS = client_v2_patterns("/devices$", releases=[], v2_alpha=False)
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
|
@ -47,7 +45,7 @@ class DevicesRestServlet(RestServlet):
|
||||||
defer.returnValue((200, {"devices": devices}))
|
defer.returnValue((200, {"devices": devices}))
|
||||||
|
|
||||||
|
|
||||||
class DeviceRestServlet(RestServlet):
|
class DeviceRestServlet(servlet.RestServlet):
|
||||||
PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$",
|
PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$",
|
||||||
releases=[], v2_alpha=False)
|
releases=[], v2_alpha=False)
|
||||||
|
|
||||||
|
@ -84,6 +82,18 @@ class DeviceRestServlet(RestServlet):
|
||||||
)
|
)
|
||||||
defer.returnValue((200, {}))
|
defer.returnValue((200, {}))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_PUT(self, request, device_id):
|
||||||
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
|
body = servlet.parse_json_object_from_request(request)
|
||||||
|
yield self.device_handler.update_device(
|
||||||
|
requester.user.to_string(),
|
||||||
|
device_id,
|
||||||
|
body
|
||||||
|
)
|
||||||
|
defer.returnValue((200, {}))
|
||||||
|
|
||||||
|
|
||||||
def register_servlets(hs, http_server):
|
def register_servlets(hs, http_server):
|
||||||
DevicesRestServlet(hs).register(http_server)
|
DevicesRestServlet(hs).register(http_server)
|
||||||
|
|
|
@ -81,7 +81,7 @@ class DeviceStore(SQLBaseStore):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): The ID of the user which owns the device
|
user_id (str): The ID of the user which owns the device
|
||||||
device_id (str): The ID of the device to retrieve
|
device_id (str): The ID of the device to delete
|
||||||
Returns:
|
Returns:
|
||||||
defer.Deferred
|
defer.Deferred
|
||||||
"""
|
"""
|
||||||
|
@ -91,6 +91,31 @@ class DeviceStore(SQLBaseStore):
|
||||||
desc="delete_device",
|
desc="delete_device",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def update_device(self, user_id, device_id, new_display_name=None):
|
||||||
|
"""Update a device.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id (str): The ID of the user which owns the device
|
||||||
|
device_id (str): The ID of the device to update
|
||||||
|
new_display_name (str|None): new displayname for device; None
|
||||||
|
to leave unchanged
|
||||||
|
Raises:
|
||||||
|
StoreError: if the device is not found
|
||||||
|
Returns:
|
||||||
|
defer.Deferred
|
||||||
|
"""
|
||||||
|
updates = {}
|
||||||
|
if new_display_name is not None:
|
||||||
|
updates["display_name"] = new_display_name
|
||||||
|
if not updates:
|
||||||
|
return defer.succeed(None)
|
||||||
|
return self._simple_update_one(
|
||||||
|
table="devices",
|
||||||
|
keyvalues={"user_id": user_id, "device_id": device_id},
|
||||||
|
updatevalues=updates,
|
||||||
|
desc="update_device",
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_devices_by_user(self, user_id):
|
def get_devices_by_user(self, user_id):
|
||||||
"""Retrieve all of a user's registered devices.
|
"""Retrieve all of a user's registered devices.
|
||||||
|
|
|
@ -140,6 +140,22 @@ class DeviceTestCase(unittest.TestCase):
|
||||||
# we'd like to check the access token was invalidated, but that's a
|
# we'd like to check the access token was invalidated, but that's a
|
||||||
# bit of a PITA.
|
# bit of a PITA.
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_update_device(self):
|
||||||
|
yield self._record_users()
|
||||||
|
|
||||||
|
update = {"display_name": "new display"}
|
||||||
|
yield self.handler.update_device(user1, "abc", update)
|
||||||
|
|
||||||
|
res = yield self.handler.get_device(user1, "abc")
|
||||||
|
self.assertEqual(res["display_name"], "new display")
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_update_unknown_device(self):
|
||||||
|
update = {"display_name": "new_display"}
|
||||||
|
with self.assertRaises(synapse.api.errors.NotFoundError):
|
||||||
|
yield self.handler.update_device("user_id", "unknown_device_id",
|
||||||
|
update)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _record_users(self):
|
def _record_users(self):
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
import synapse.api.errors
|
||||||
import tests.unittest
|
import tests.unittest
|
||||||
import tests.utils
|
import tests.utils
|
||||||
|
|
||||||
|
@ -67,3 +68,38 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
|
||||||
"device_id": "device2",
|
"device_id": "device2",
|
||||||
"display_name": "display_name 2",
|
"display_name": "display_name 2",
|
||||||
}, res["device2"])
|
}, res["device2"])
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_update_device(self):
|
||||||
|
yield self.store.store_device(
|
||||||
|
"user_id", "device_id", "display_name 1"
|
||||||
|
)
|
||||||
|
|
||||||
|
res = yield self.store.get_device("user_id", "device_id")
|
||||||
|
self.assertEqual("display_name 1", res["display_name"])
|
||||||
|
|
||||||
|
# do a no-op first
|
||||||
|
yield self.store.update_device(
|
||||||
|
"user_id", "device_id",
|
||||||
|
)
|
||||||
|
res = yield self.store.get_device("user_id", "device_id")
|
||||||
|
self.assertEqual("display_name 1", res["display_name"])
|
||||||
|
|
||||||
|
# do the update
|
||||||
|
yield self.store.update_device(
|
||||||
|
"user_id", "device_id",
|
||||||
|
new_display_name="display_name 2",
|
||||||
|
)
|
||||||
|
|
||||||
|
# check it worked
|
||||||
|
res = yield self.store.get_device("user_id", "device_id")
|
||||||
|
self.assertEqual("display_name 2", res["display_name"])
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_update_unknown_device(self):
|
||||||
|
with self.assertRaises(synapse.api.errors.StoreError) as cm:
|
||||||
|
yield self.store.update_device(
|
||||||
|
"user_id", "unknown_device_id",
|
||||||
|
new_display_name="display_name 2",
|
||||||
|
)
|
||||||
|
self.assertEqual(404, cm.exception.code)
|
||||||
|
|
Loading…
Reference in New Issue