|
import json |
|
from datetime import datetime |
|
from typing import Any, Coroutine, Dict, List |
|
|
|
from psycopg.rows import class_row |
|
from psycopg_pool import AsyncConnectionPool |
|
from pydantic import BaseModel |
|
from taskiq import ScheduleSource |
|
from taskiq.kicker import AsyncKicker |
|
from taskiq.scheduler.scheduler import ScheduledTask |
|
from typing_extensions import ParamSpec |
|
|
|
_PAR = ParamSpec("_PAR") |
|
|
|
|
|
class DbSchedule(BaseModel): |
|
id: int |
|
task_name: str |
|
args: List[Any] |
|
kwargs: Dict[str, Any] |
|
labels: Dict[str, Any] |
|
time: datetime |
|
|
|
|
|
class DbScheduleSource(ScheduleSource): |
|
def __init__(self, db_url: str ) -> None: |
|
self.pool = AsyncConnectionPool(db_url, open=False) |
|
|
|
async def startup(self): |
|
await self.pool.open() |
|
async with self.pool.connection() as conn: |
|
await conn.execute( |
|
""" |
|
CREATE TABLE IF NOT EXISTS |
|
taskiq_schedules( |
|
id SERIAL PRIMARY KEY, |
|
task_name TEXT NOT NULL, |
|
args JSONB NOT NULL, |
|
kwargs JSONB NOT NULL, |
|
labels JSONB NOT NULL, |
|
time TIMESTAMP NOT NULL |
|
); |
|
""" |
|
) |
|
|
|
async def shutdown(self) -> None: |
|
await self.pool.close() |
|
|
|
async def get_schedules(self) -> List[ScheduledTask]: |
|
async with self.pool.connection() as conn: |
|
async with conn.cursor( |
|
binary=True, |
|
row_factory=class_row(DbSchedule), |
|
) as cur: |
|
ret = await cur.execute("SELECT * FROM taskiq_schedules;") |
|
tasks = await ret.fetchall() |
|
schdedules = [] |
|
for task in tasks: |
|
schdedules.append( |
|
ScheduledTask( |
|
source=self, |
|
task_name=task.task_name, |
|
args=task.args, |
|
kwargs=task.kwargs, |
|
labels={ |
|
"_sched_id": task.id, |
|
**task.labels, |
|
}, |
|
time=task.time, |
|
) |
|
) |
|
|
|
return schdedules |
|
|
|
async def add_task( |
|
self, |
|
task: AsyncKicker[_PAR, Any], |
|
time: datetime, |
|
*args: _PAR.args, |
|
**kwargs: _PAR.kwargs, |
|
) -> None: |
|
async with self.pool.connection() as conn: |
|
await conn.execute( |
|
""" |
|
INSERT INTO |
|
taskiq_schedules( |
|
task_name, |
|
args, |
|
kwargs, |
|
labels, |
|
time |
|
) |
|
VALUES ( |
|
%(name)s, |
|
%(args)s, |
|
%(kwargs)s, |
|
%(labels)s, |
|
%(time)s |
|
);""", |
|
{ |
|
"name": task.task_name, |
|
"args": json.dumps(list(args)), |
|
"kwargs": json.dumps(kwargs), |
|
"labels": json.dumps(task.labels), |
|
"time": time, |
|
}, |
|
) |
|
|
|
async def remove_schedule(self, schedule_id: int) -> None: |
|
async with self.pool.connection() as conn: |
|
await conn.execute( |
|
"DELETE FROM taskiq_schedules WHERE id=%(id)s", |
|
{ |
|
"id": schedule_id, |
|
}, |
|
) |
|
|
|
async def post_send(self, task: ScheduledTask) -> Coroutine[Any, Any, None] | None: |
|
schedule_id = task.labels.get("_sched_id") |
|
if schedule_id is None: |
|
return |
|
await self.remove_schedule(int(schedule_id)) |