Task scheduler: add replication notify for new task to launch ASAP (#16184)
parent
224c2bbcfa
commit
501da8ecd8
|
@ -0,0 +1 @@
|
|||
Task scheduler: add replication notify for new task to launch ASAP.
|
|
@ -452,6 +452,17 @@ class LockReleasedCommand(Command):
|
|||
return json_encoder.encode([self.instance_name, self.lock_name, self.lock_key])
|
||||
|
||||
|
||||
class NewActiveTaskCommand(_SimpleCommand):
|
||||
"""Sent to inform instance handling background tasks that a new active task is available to run.
|
||||
|
||||
Format::
|
||||
|
||||
NEW_ACTIVE_TASK "<task_id>"
|
||||
"""
|
||||
|
||||
NAME = "NEW_ACTIVE_TASK"
|
||||
|
||||
|
||||
_COMMANDS: Tuple[Type[Command], ...] = (
|
||||
ServerCommand,
|
||||
RdataCommand,
|
||||
|
@ -466,6 +477,7 @@ _COMMANDS: Tuple[Type[Command], ...] = (
|
|||
RemoteServerUpCommand,
|
||||
ClearUserSyncsCommand,
|
||||
LockReleasedCommand,
|
||||
NewActiveTaskCommand,
|
||||
)
|
||||
|
||||
# Map of command name to command type.
|
||||
|
|
|
@ -40,6 +40,7 @@ from synapse.replication.tcp.commands import (
|
|||
Command,
|
||||
FederationAckCommand,
|
||||
LockReleasedCommand,
|
||||
NewActiveTaskCommand,
|
||||
PositionCommand,
|
||||
RdataCommand,
|
||||
RemoteServerUpCommand,
|
||||
|
@ -238,6 +239,10 @@ class ReplicationCommandHandler:
|
|||
if self._is_master:
|
||||
self._server_notices_sender = hs.get_server_notices_sender()
|
||||
|
||||
self._task_scheduler = None
|
||||
if hs.config.worker.run_background_tasks:
|
||||
self._task_scheduler = hs.get_task_scheduler()
|
||||
|
||||
if hs.config.redis.redis_enabled:
|
||||
# If we're using Redis, it's the background worker that should
|
||||
# receive USER_IP commands and store the relevant client IPs.
|
||||
|
@ -663,6 +668,15 @@ class ReplicationCommandHandler:
|
|||
cmd.instance_name, cmd.lock_name, cmd.lock_key
|
||||
)
|
||||
|
||||
async def on_NEW_ACTIVE_TASK(
|
||||
self, conn: IReplicationConnection, cmd: NewActiveTaskCommand
|
||||
) -> None:
|
||||
"""Called when get a new NEW_ACTIVE_TASK command."""
|
||||
if self._task_scheduler:
|
||||
task = await self._task_scheduler.get_task(cmd.data)
|
||||
if task:
|
||||
await self._task_scheduler._launch_task(task)
|
||||
|
||||
def new_connection(self, connection: IReplicationConnection) -> None:
|
||||
"""Called when we have a new connection."""
|
||||
self._connections.append(connection)
|
||||
|
@ -776,6 +790,10 @@ class ReplicationCommandHandler:
|
|||
if instance_name == self._instance_name:
|
||||
self.send_command(LockReleasedCommand(instance_name, lock_name, lock_key))
|
||||
|
||||
def send_new_active_task(self, task_id: str) -> None:
|
||||
"""Called when a new task has been scheduled for immediate launch and is ACTIVE."""
|
||||
self.send_command(NewActiveTaskCommand(task_id))
|
||||
|
||||
|
||||
UpdateToken = TypeVar("UpdateToken")
|
||||
UpdateRow = TypeVar("UpdateRow")
|
||||
|
|
|
@ -57,14 +57,13 @@ class TaskScheduler:
|
|||
the code launching the task.
|
||||
You can also specify the `result` (and/or an `error`) when returning from the function.
|
||||
|
||||
The reconciliation loop runs every 5 mns, so this is not a precise scheduler. When wanting
|
||||
to launch now, the launch will still not happen before the next loop run.
|
||||
|
||||
Tasks will be run on the worker specified with `run_background_tasks_on` config,
|
||||
or the main one by default.
|
||||
The reconciliation loop runs every minute, so this is not a precise scheduler.
|
||||
There is a limit of 10 concurrent tasks, so tasks may be delayed if the pool is already
|
||||
full. In this regard, please take great care that scheduled tasks can actually finished.
|
||||
For now there is no mechanism to stop a running task if it is stuck.
|
||||
|
||||
Tasks will be run on the worker specified with `run_background_tasks_on` config,
|
||||
or the main one by default.
|
||||
"""
|
||||
|
||||
# Precision of the scheduler, evaluation of tasks to run will only happen
|
||||
|
@ -85,7 +84,7 @@ class TaskScheduler:
|
|||
self._actions: Dict[
|
||||
str,
|
||||
Callable[
|
||||
[ScheduledTask, bool],
|
||||
[ScheduledTask],
|
||||
Awaitable[Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]],
|
||||
],
|
||||
] = {}
|
||||
|
@ -98,11 +97,13 @@ class TaskScheduler:
|
|||
"handle_scheduled_tasks",
|
||||
self._handle_scheduled_tasks,
|
||||
)
|
||||
else:
|
||||
self.replication_client = hs.get_replication_command_handler()
|
||||
|
||||
def register_action(
|
||||
self,
|
||||
function: Callable[
|
||||
[ScheduledTask, bool],
|
||||
[ScheduledTask],
|
||||
Awaitable[Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]],
|
||||
],
|
||||
action_name: str,
|
||||
|
@ -115,10 +116,9 @@ class TaskScheduler:
|
|||
calling `schedule_task` but rather in an `__init__` method.
|
||||
|
||||
Args:
|
||||
function: The function to be executed for this action. The parameters
|
||||
passed to the function when launched are the `ScheduledTask` being run,
|
||||
and a `first_launch` boolean to signal if it's a resumed task or the first
|
||||
launch of it. The function should return a tuple of new `status`, `result`
|
||||
function: The function to be executed for this action. The parameter
|
||||
passed to the function when launched is the `ScheduledTask` being run.
|
||||
The function should return a tuple of new `status`, `result`
|
||||
and `error` as specified in `ScheduledTask`.
|
||||
action_name: The name of the action to be associated with the function
|
||||
"""
|
||||
|
@ -171,6 +171,12 @@ class TaskScheduler:
|
|||
)
|
||||
await self._store.insert_scheduled_task(task)
|
||||
|
||||
if status == TaskStatus.ACTIVE:
|
||||
if self._run_background_tasks:
|
||||
await self._launch_task(task)
|
||||
else:
|
||||
self.replication_client.send_new_active_task(task.id)
|
||||
|
||||
return task.id
|
||||
|
||||
async def update_task(
|
||||
|
@ -265,21 +271,13 @@ class TaskScheduler:
|
|||
Args:
|
||||
id: id of the task to delete
|
||||
"""
|
||||
if self.task_is_running(id):
|
||||
raise Exception(f"Task {id} is currently running and can't be deleted")
|
||||
task = await self.get_task(id)
|
||||
if task is None:
|
||||
raise Exception(f"Task {id} does not exist")
|
||||
if task.status == TaskStatus.ACTIVE:
|
||||
raise Exception(f"Task {id} is currently ACTIVE and can't be deleted")
|
||||
await self._store.delete_scheduled_task(id)
|
||||
|
||||
def task_is_running(self, id: str) -> bool:
|
||||
"""Check if a task is currently running.
|
||||
|
||||
Can only be called from the worker handling the task scheduling.
|
||||
|
||||
Args:
|
||||
id: id of the task to check
|
||||
"""
|
||||
assert self._run_background_tasks
|
||||
return id in self._running_tasks
|
||||
|
||||
async def _handle_scheduled_tasks(self) -> None:
|
||||
"""Main loop taking care of launching tasks and cleaning up old ones."""
|
||||
await self._launch_scheduled_tasks()
|
||||
|
@ -288,29 +286,11 @@ class TaskScheduler:
|
|||
async def _launch_scheduled_tasks(self) -> None:
|
||||
"""Retrieve and launch scheduled tasks that should be running at that time."""
|
||||
for task in await self.get_tasks(statuses=[TaskStatus.ACTIVE]):
|
||||
if not self.task_is_running(task.id):
|
||||
if (
|
||||
len(self._running_tasks)
|
||||
< TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS
|
||||
):
|
||||
await self._launch_task(task, first_launch=False)
|
||||
else:
|
||||
if (
|
||||
self._clock.time_msec()
|
||||
> task.timestamp + TaskScheduler.LAST_UPDATE_BEFORE_WARNING_MS
|
||||
):
|
||||
logger.warn(
|
||||
f"Task {task.id} (action {task.action}) has seen no update for more than 24h and may be stuck"
|
||||
)
|
||||
await self._launch_task(task)
|
||||
for task in await self.get_tasks(
|
||||
statuses=[TaskStatus.SCHEDULED], max_timestamp=self._clock.time_msec()
|
||||
):
|
||||
if (
|
||||
not self.task_is_running(task.id)
|
||||
and len(self._running_tasks)
|
||||
< TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS
|
||||
):
|
||||
await self._launch_task(task, first_launch=True)
|
||||
await self._launch_task(task)
|
||||
|
||||
running_tasks_gauge.set(len(self._running_tasks))
|
||||
|
||||
|
@ -320,27 +300,27 @@ class TaskScheduler:
|
|||
statuses=[TaskStatus.FAILED, TaskStatus.COMPLETE]
|
||||
):
|
||||
# FAILED and COMPLETE tasks should never be running
|
||||
assert not self.task_is_running(task.id)
|
||||
assert task.id not in self._running_tasks
|
||||
if (
|
||||
self._clock.time_msec()
|
||||
> task.timestamp + TaskScheduler.KEEP_TASKS_FOR_MS
|
||||
):
|
||||
await self._store.delete_scheduled_task(task.id)
|
||||
|
||||
async def _launch_task(self, task: ScheduledTask, first_launch: bool) -> None:
|
||||
async def _launch_task(self, task: ScheduledTask) -> None:
|
||||
"""Launch a scheduled task now.
|
||||
|
||||
Args:
|
||||
task: the task to launch
|
||||
first_launch: `True` if it's the first time is launched, `False` otherwise
|
||||
"""
|
||||
assert task.action in self._actions
|
||||
assert self._run_background_tasks
|
||||
|
||||
assert task.action in self._actions
|
||||
function = self._actions[task.action]
|
||||
|
||||
async def wrapper() -> None:
|
||||
try:
|
||||
(status, result, error) = await function(task, first_launch)
|
||||
(status, result, error) = await function(task)
|
||||
except Exception:
|
||||
f = Failure()
|
||||
logger.error(
|
||||
|
@ -360,6 +340,20 @@ class TaskScheduler:
|
|||
)
|
||||
self._running_tasks.remove(task.id)
|
||||
|
||||
if len(self._running_tasks) >= TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS:
|
||||
return
|
||||
|
||||
if (
|
||||
self._clock.time_msec()
|
||||
> task.timestamp + TaskScheduler.LAST_UPDATE_BEFORE_WARNING_MS
|
||||
):
|
||||
logger.warn(
|
||||
f"Task {task.id} (action {task.action}) has seen no update for more than 24h and may be stuck"
|
||||
)
|
||||
|
||||
if task.id in self._running_tasks:
|
||||
return
|
||||
|
||||
self._running_tasks.add(task.id)
|
||||
await self.update_task(task.id, status=TaskStatus.ACTIVE)
|
||||
description = f"{task.id}-{task.action}"
|
||||
|
|
|
@ -22,10 +22,11 @@ from synapse.types import JsonMapping, ScheduledTask, TaskStatus
|
|||
from synapse.util import Clock
|
||||
from synapse.util.task_scheduler import TaskScheduler
|
||||
|
||||
from tests import unittest
|
||||
from tests.replication._base import BaseMultiWorkerStreamTestCase
|
||||
from tests.unittest import HomeserverTestCase, override_config
|
||||
|
||||
|
||||
class TestTaskScheduler(unittest.HomeserverTestCase):
|
||||
class TestTaskScheduler(HomeserverTestCase):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.task_scheduler = hs.get_task_scheduler()
|
||||
self.task_scheduler.register_action(self._test_task, "_test_task")
|
||||
|
@ -34,7 +35,7 @@ class TestTaskScheduler(unittest.HomeserverTestCase):
|
|||
self.task_scheduler.register_action(self._resumable_task, "_resumable_task")
|
||||
|
||||
async def _test_task(
|
||||
self, task: ScheduledTask, first_launch: bool
|
||||
self, task: ScheduledTask
|
||||
) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
|
||||
# This test task will copy the parameters to the result
|
||||
result = None
|
||||
|
@ -77,7 +78,7 @@ class TestTaskScheduler(unittest.HomeserverTestCase):
|
|||
self.assertIsNone(task)
|
||||
|
||||
async def _sleeping_task(
|
||||
self, task: ScheduledTask, first_launch: bool
|
||||
self, task: ScheduledTask
|
||||
) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
|
||||
# Sleep for a second
|
||||
await deferLater(self.reactor, 1, lambda: None)
|
||||
|
@ -85,24 +86,18 @@ class TestTaskScheduler(unittest.HomeserverTestCase):
|
|||
|
||||
def test_schedule_lot_of_tasks(self) -> None:
|
||||
"""Schedule more than `TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS` tasks and check the behavior."""
|
||||
timestamp = self.clock.time_msec() + 30 * 1000
|
||||
task_ids = []
|
||||
for i in range(TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS + 1):
|
||||
task_ids.append(
|
||||
self.get_success(
|
||||
self.task_scheduler.schedule_task(
|
||||
"_sleeping_task",
|
||||
timestamp=timestamp,
|
||||
params={"val": i},
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# The timestamp being 30s after now the task should been executed
|
||||
# after the first scheduling loop is run
|
||||
self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000))
|
||||
|
||||
# This is to give the time to the sleeping tasks to finish
|
||||
# This is to give the time to the active tasks to finish
|
||||
self.reactor.advance(1)
|
||||
|
||||
# Check that only MAX_CONCURRENT_RUNNING_TASKS tasks has run and that one
|
||||
|
@ -120,10 +115,11 @@ class TestTaskScheduler(unittest.HomeserverTestCase):
|
|||
)
|
||||
|
||||
scheduled_tasks = [
|
||||
t for t in tasks if t is not None and t.status == TaskStatus.SCHEDULED
|
||||
t for t in tasks if t is not None and t.status == TaskStatus.ACTIVE
|
||||
]
|
||||
self.assertEquals(len(scheduled_tasks), 1)
|
||||
|
||||
# We need to wait for the next run of the scheduler loop
|
||||
self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000))
|
||||
self.reactor.advance(1)
|
||||
|
||||
|
@ -138,7 +134,7 @@ class TestTaskScheduler(unittest.HomeserverTestCase):
|
|||
)
|
||||
|
||||
async def _raising_task(
|
||||
self, task: ScheduledTask, first_launch: bool
|
||||
self, task: ScheduledTask
|
||||
) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
|
||||
raise Exception("raising")
|
||||
|
||||
|
@ -146,15 +142,13 @@ class TestTaskScheduler(unittest.HomeserverTestCase):
|
|||
"""Schedule a task raising an exception and check it runs to failure and report exception content."""
|
||||
task_id = self.get_success(self.task_scheduler.schedule_task("_raising_task"))
|
||||
|
||||
self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000))
|
||||
|
||||
task = self.get_success(self.task_scheduler.get_task(task_id))
|
||||
assert task is not None
|
||||
self.assertEqual(task.status, TaskStatus.FAILED)
|
||||
self.assertEqual(task.error, "raising")
|
||||
|
||||
async def _resumable_task(
|
||||
self, task: ScheduledTask, first_launch: bool
|
||||
self, task: ScheduledTask
|
||||
) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
|
||||
if task.result and "in_progress" in task.result:
|
||||
return TaskStatus.COMPLETE, {"success": True}, None
|
||||
|
@ -169,8 +163,6 @@ class TestTaskScheduler(unittest.HomeserverTestCase):
|
|||
"""Schedule a resumable task and check that it gets properly resumed and complete after simulating a synapse restart."""
|
||||
task_id = self.get_success(self.task_scheduler.schedule_task("_resumable_task"))
|
||||
|
||||
self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000))
|
||||
|
||||
task = self.get_success(self.task_scheduler.get_task(task_id))
|
||||
assert task is not None
|
||||
self.assertEqual(task.status, TaskStatus.ACTIVE)
|
||||
|
@ -184,3 +176,33 @@ class TestTaskScheduler(unittest.HomeserverTestCase):
|
|||
self.assertEqual(task.status, TaskStatus.COMPLETE)
|
||||
assert task.result is not None
|
||||
self.assertTrue(task.result.get("success"))
|
||||
|
||||
|
||||
class TestTaskSchedulerWithBackgroundWorker(BaseMultiWorkerStreamTestCase):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.task_scheduler = hs.get_task_scheduler()
|
||||
self.task_scheduler.register_action(self._test_task, "_test_task")
|
||||
|
||||
async def _test_task(
|
||||
self, task: ScheduledTask
|
||||
) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
|
||||
return (TaskStatus.COMPLETE, None, None)
|
||||
|
||||
@override_config({"run_background_tasks_on": "worker1"})
|
||||
def test_schedule_task(self) -> None:
|
||||
"""Check that a task scheduled to run now is launch right away on the background worker."""
|
||||
bg_worker_hs = self.make_worker_hs(
|
||||
"synapse.app.generic_worker",
|
||||
extra_config={"worker_name": "worker1"},
|
||||
)
|
||||
bg_worker_hs.get_task_scheduler().register_action(self._test_task, "_test_task")
|
||||
|
||||
task_id = self.get_success(
|
||||
self.task_scheduler.schedule_task(
|
||||
"_test_task",
|
||||
)
|
||||
)
|
||||
|
||||
task = self.get_success(self.task_scheduler.get_task(task_id))
|
||||
assert task is not None
|
||||
self.assertEqual(task.status, TaskStatus.COMPLETE)
|
||||
|
|
Loading…
Reference in New Issue