Skip to content

Commit d135c94

Browse files
authored
Add deps (#54)
* Add deps * Add deps * Add deps
1 parent 24f4830 commit d135c94

File tree

6 files changed

+61
-31
lines changed

6 files changed

+61
-31
lines changed

LICENSE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
BSD 3-Clause License
22

3-
Copyright (c) 2024, Quantmind
3+
Copyright (c) 2025, Quantmind
44
All rights reserved.
55

66
Redistribution and use in source and binary forms, with or without

examples/tasks/__init__.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,30 @@
11
import asyncio
22
import os
33
import time
4+
from dataclasses import dataclass, field
45
from datetime import timedelta
5-
from typing import cast
6+
from typing import Self, cast
67

78
from fastapi import FastAPI
89
from pydantic import BaseModel, Field
910

1011
from fluid.scheduler import TaskRun, TaskScheduler, every, task
1112
from fluid.scheduler.broker import RedisTaskBroker
1213
from fluid.scheduler.endpoints import setup_fastapi
14+
from fluid.utils.http_client import HttpxClient
15+
16+
17+
@dataclass
18+
class Deps:
19+
http_client: HttpxClient = field(default_factory=HttpxClient)
20+
21+
@classmethod
22+
def get(cls, context: TaskRun) -> Self:
23+
return context.deps
1324

1425

1526
def task_scheduler() -> TaskScheduler:
16-
task_manager = TaskScheduler()
27+
task_manager = TaskScheduler(deps=Deps())
1728
task_manager.register_from_dict(globals())
1829
return task_manager
1930

@@ -63,3 +74,16 @@ async def cpu_bound(context: TaskRun) -> None:
6374
broker = cast(RedisTaskBroker, context.task_manager.broker)
6475
redis = broker.redis_cli
6576
await redis.setex(context.id, os.getpid(), 10)
77+
78+
79+
class Scrape(BaseModel):
80+
url: str = Field(default="https://httpbin.org/get", description="URL to scrape")
81+
82+
83+
@task
84+
async def scrape(context: TaskRun[Scrape]) -> None:
85+
"""Scrape a website"""
86+
deps = Deps.get(context)
87+
response = await deps.http_client.get(context.params.url, callback=True)
88+
text = await response.text()
89+
context.logger.info(text)

fluid/scheduler/consumer.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Any, Awaitable, Callable, Self
99

1010
from inflection import underscore
11+
from starlette.datastructures import State
1112
from typing_extensions import Annotated, Doc
1213

1314
from fluid.utils import log
@@ -47,10 +48,32 @@ class TaskManager:
4748
"""The task manager is the main class for managing tasks"""
4849

4950
def __init__(
50-
self, *, config: TaskManagerConfig | None = None, **kwargs: Any
51+
self,
52+
*,
53+
deps: Any = None,
54+
config: TaskManagerConfig | None = None,
55+
**kwargs: Any,
5156
) -> None:
52-
self.state: dict[str, Any] = {}
53-
self.config: TaskManagerConfig = config or TaskManagerConfig(**kwargs)
57+
self.deps: Annotated[
58+
Any,
59+
Doc(
60+
"""
61+
Dependencies for the task manager.
62+
63+
Production applications requires global dependencies to be
64+
available to all tasks. This can be achieved by setting
65+
the `deps` attribute of the task manager to an object
66+
with the required dependencies.
67+
68+
Each task can cast the dependencies to the required type.
69+
"""
70+
),
71+
] = (
72+
deps if deps is not None else State()
73+
)
74+
self.config: Annotated[
75+
TaskManagerConfig, Doc("""Task manager configuration""")
76+
] = config or TaskManagerConfig(**kwargs)
5477
self.dispatcher: Annotated[
5578
TaskDispatcher,
5679
Doc(

fluid/scheduler/models.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ class TaskRun(BaseModel, Generic[TP], arbitrary_types_allowed=True):
191191
async def execute(self) -> None:
192192
try:
193193
self.set_state(TaskState.running)
194-
await self.task.executor(self)
194+
await self.task.executor(self) # type: ignore [arg-type]
195195
except Exception:
196196
self.set_state(TaskState.failure)
197197
raise
@@ -247,6 +247,10 @@ def is_done(self) -> bool:
247247
def is_failure(self) -> bool:
248248
return self.state.is_failure
249249

250+
@property
251+
def deps(self) -> Any:
252+
return self.task_manager.deps
253+
250254
def set_state(
251255
self,
252256
state: TaskState,
@@ -293,7 +297,7 @@ def lock(self, timeout: float | None) -> Lock:
293297
return self.task_manager.broker.lock(self.name, timeout=timeout)
294298

295299
def _dispatch(self) -> None:
296-
self.task_manager.dispatcher.dispatch(self.model_copy())
300+
self.task_manager.dispatcher.dispatch(self.model_copy()) # type: ignore [arg-type]
297301

298302

299303
@dataclass

tests/scheduler/test_endpoints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
async def test_get_tasks(cli: TaskClient) -> None:
1111
data = await cli.get(f"{cli.url}/tasks")
12-
assert len(data) == 4
12+
assert len(data) == 5
1313
tasks = {task["name"]: TaskInfo(**task) for task in data}
1414
dummy = tasks["dummy"]
1515
assert dummy.name == "dummy"

tests/scheduler/test_scheduler.py

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import asyncio
22
import os
3-
from dataclasses import dataclass, field
3+
from dataclasses import dataclass
44

55
import pytest
66
from redis.asyncio import Redis
@@ -17,17 +17,6 @@
1717
pytestmark = pytest.mark.asyncio(loop_scope="module")
1818

1919

20-
@dataclass
21-
class WaitFor:
22-
name: str
23-
times: int = 2
24-
runs: list[TaskRun] = field(default_factory=list)
25-
26-
def __call__(self, task_run: TaskRun) -> None:
27-
if task_run.name == self.name:
28-
self.runs.append(task_run)
29-
30-
3120
async def test_scheduler_manager(task_scheduler: TaskScheduler) -> None:
3221
assert task_scheduler
3322
assert task_scheduler.broker.registry
@@ -71,16 +60,6 @@ async def test_dummy_rate_limit(task_scheduler: TaskScheduler) -> None:
7160
assert s1.state is TaskState.rate_limited or s2.state is TaskState.rate_limited
7261

7362

74-
@pytest.mark.flaky
75-
async def test_scheduled(task_scheduler: TaskScheduler) -> None:
76-
handler = WaitFor(name="scheduled")
77-
task_scheduler.dispatcher.register_handler("end.scheduled", handler)
78-
try:
79-
await wait_for(lambda: len(handler.runs) >= 2, timeout=3)
80-
finally:
81-
task_scheduler.dispatcher.unregister_handler("end.handler")
82-
83-
8463
@pytest.mark.flaky
8564
async def test_cpubound_execution(
8665
task_scheduler: TaskScheduler, redis: Redis # type: ignore

0 commit comments

Comments
 (0)