Compare commits
1 Commits
| Author | SHA1 | Date |
|---|---|---|
|
|
49c66f139f |
File diff suppressed because it is too large
Load Diff
|
|
@ -53,10 +53,6 @@ class RunnerRunning(BaseRunnerStatus):
|
|||
pass
|
||||
|
||||
|
||||
class RunnerShuttingDown(BaseRunnerStatus):
|
||||
pass
|
||||
|
||||
|
||||
class RunnerShutdown(BaseRunnerStatus):
|
||||
pass
|
||||
|
||||
|
|
@ -74,7 +70,6 @@ RunnerStatus = (
|
|||
| RunnerWarmingUp
|
||||
| RunnerReady
|
||||
| RunnerRunning
|
||||
| RunnerShuttingDown
|
||||
| RunnerShutdown
|
||||
| RunnerFailed
|
||||
)
|
||||
|
|
|
|||
|
|
@ -274,12 +274,6 @@ def _pending_tasks(
|
|||
if task.instance_id != runner.bound_instance.instance.instance_id:
|
||||
continue
|
||||
|
||||
# I have a design point here; this is a state race in disguise as the task status doesn't get updated to completed fast enough
|
||||
# however, realistically the task status should be set to completed by the LAST runner, so this is a true race
|
||||
# the actual solution is somewhat deeper than this bypass - TODO!
|
||||
if task.task_id in runner.completed:
|
||||
continue
|
||||
|
||||
# TODO: Check ordering aligns with MLX distributeds expectations.
|
||||
|
||||
if isinstance(runner.status, RunnerReady) and all(
|
||||
|
|
|
|||
|
|
@ -32,7 +32,6 @@ from exo.shared.types.worker.runners import (
|
|||
RunnerReady,
|
||||
RunnerRunning,
|
||||
RunnerShutdown,
|
||||
RunnerShuttingDown,
|
||||
RunnerStatus,
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
|
|
@ -188,14 +187,13 @@ def main(
|
|||
current_status = RunnerReady()
|
||||
logger.info("runner ready")
|
||||
case Shutdown():
|
||||
current_status = RunnerShuttingDown()
|
||||
logger.info("runner shutting down")
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(
|
||||
runner_id=runner_id, runner_status=current_status
|
||||
TaskStatusUpdated(
|
||||
task_id=task.task_id, task_status=TaskStatus.Complete
|
||||
)
|
||||
)
|
||||
current_status = RunnerShutdown()
|
||||
break
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Received {task.__class__.__name__} outside of state machine in {current_status=}"
|
||||
|
|
@ -210,8 +208,9 @@ def main(
|
|||
runner_id=runner_id, runner_status=current_status
|
||||
)
|
||||
)
|
||||
if isinstance(current_status, RunnerShutdown):
|
||||
break
|
||||
event_sender.send(
|
||||
RunnerStatusUpdated(runner_id=runner_id, runner_status=RunnerShutdown())
|
||||
)
|
||||
except ClosedResourceError:
|
||||
logger.warning("runner communication closed unexpectedly")
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -14,23 +14,13 @@ from anyio import (
|
|||
from anyio.abc import TaskGroup
|
||||
from loguru import logger
|
||||
|
||||
from exo.shared.types.events import (
|
||||
Event,
|
||||
RunnerStatusUpdated,
|
||||
TaskAcknowledged,
|
||||
TaskStatusUpdated,
|
||||
)
|
||||
from exo.shared.types.tasks import Task, TaskId, TaskStatus
|
||||
from exo.shared.types.events import Event, RunnerStatusUpdated, TaskAcknowledged
|
||||
from exo.shared.types.tasks import Task, TaskId
|
||||
from exo.shared.types.worker.instances import BoundInstance
|
||||
from exo.shared.types.worker.runners import (
|
||||
RunnerConnecting,
|
||||
RunnerFailed,
|
||||
RunnerIdle,
|
||||
RunnerLoading,
|
||||
RunnerRunning,
|
||||
RunnerShuttingDown,
|
||||
RunnerStatus,
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.shared.types.worker.shards import ShardMetadata
|
||||
from exo.utils.channels import MpReceiver, MpSender, Sender, mp_channel
|
||||
|
|
@ -49,10 +39,10 @@ class RunnerSupervisor:
|
|||
_ev_recv: MpReceiver[Event]
|
||||
_task_sender: MpSender[Task]
|
||||
_event_sender: Sender[Event]
|
||||
# err_path: str
|
||||
_tg: TaskGroup | None = field(default=None, init=False)
|
||||
status: RunnerStatus = field(default_factory=RunnerIdle, init=False)
|
||||
pending: dict[TaskId, anyio.Event] = field(default_factory=dict, init=False)
|
||||
completed: set[TaskId] = field(default_factory=set, init=False)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
|
|
@ -87,6 +77,7 @@ class RunnerSupervisor:
|
|||
_ev_recv=ev_recv,
|
||||
_task_sender=task_sender,
|
||||
_event_sender=event_sender,
|
||||
# err_path=err_path,
|
||||
)
|
||||
|
||||
return self
|
||||
|
|
@ -127,10 +118,6 @@ class RunnerSupervisor:
|
|||
self._tg.cancel_scope.cancel()
|
||||
|
||||
async def start_task(self, task: Task):
|
||||
if task.task_id in self.completed:
|
||||
logger.info(
|
||||
f"Skipping invalid task {task} as it has already been completed"
|
||||
)
|
||||
logger.info(f"Starting task {task}")
|
||||
event = anyio.Event()
|
||||
self.pending[task.task_id] = event
|
||||
|
|
@ -151,22 +138,6 @@ class RunnerSupervisor:
|
|||
if isinstance(event, TaskAcknowledged):
|
||||
self.pending.pop(event.task_id).set()
|
||||
continue
|
||||
if (
|
||||
isinstance(event, TaskStatusUpdated)
|
||||
and event.task_status == TaskStatus.Complete
|
||||
):
|
||||
# If a task has just been completed, we should be working on it.
|
||||
assert isinstance(
|
||||
self.status,
|
||||
(
|
||||
RunnerRunning,
|
||||
RunnerWarmingUp,
|
||||
RunnerLoading,
|
||||
RunnerConnecting,
|
||||
RunnerShuttingDown,
|
||||
),
|
||||
)
|
||||
self.completed.add(event.task_id)
|
||||
await self._event_sender.send(event)
|
||||
except (ClosedResourceError, BrokenResourceError) as e:
|
||||
await self._check_runner(e)
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
from dataclasses import dataclass, field
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from exo.shared.types.common import NodeId
|
||||
from exo.shared.types.memory import Memory
|
||||
from exo.shared.types.models import ModelId, ModelMetadata
|
||||
from exo.shared.types.tasks import BaseTask, TaskId
|
||||
from exo.shared.types.tasks import BaseTask
|
||||
from exo.shared.types.worker.instances import (
|
||||
BoundInstance,
|
||||
Instance,
|
||||
|
|
@ -19,7 +21,6 @@ from exo.shared.types.worker.shards import PipelineShardMetadata, ShardMetadata
|
|||
class FakeRunnerSupervisor:
|
||||
bound_instance: BoundInstance
|
||||
status: RunnerStatus
|
||||
completed: set[TaskId] = field(default_factory=set)
|
||||
|
||||
|
||||
class OtherTask(BaseTask):
|
||||
|
|
|
|||
|
|
@ -34,7 +34,6 @@ from exo.shared.types.worker.runners import (
|
|||
RunnerReady,
|
||||
RunnerRunning,
|
||||
RunnerShutdown,
|
||||
RunnerShuttingDown,
|
||||
RunnerWarmingUp,
|
||||
)
|
||||
from exo.utils.channels import mp_channel
|
||||
|
|
@ -200,9 +199,6 @@ def test_events_processed_in_correct_order(patch_out_mlx: pytest.MonkeyPatch):
|
|||
RunnerStatusUpdated(runner_id=RUNNER_1_ID, runner_status=RunnerReady()),
|
||||
TaskStatusUpdated(task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Running),
|
||||
TaskAcknowledged(task_id=SHUTDOWN_TASK_ID),
|
||||
RunnerStatusUpdated(
|
||||
runner_id=RUNNER_1_ID, runner_status=RunnerShuttingDown()
|
||||
),
|
||||
TaskStatusUpdated(
|
||||
task_id=SHUTDOWN_TASK_ID, task_status=TaskStatus.Complete
|
||||
),
|
||||
|
|
|
|||
Loading…
Reference in New Issue