Last active
October 21, 2024 12:51
-
-
Save JacobFV/c9c068f66d371175e5334d1635deaea0 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from celery import Celery | |
from enum import Enum | |
from sqlmodel import SQLModel, Field, create_engine, Session | |
from datetime import datetime, timedelta | |
from sqlalchemy.exc import SQLAlchemyError | |
from contextlib import contextmanager | |
import threading | |
import time | |
from loguru import logger | |
from sqlalchemy import Enum as SQLEnum | |
from celery.result import AsyncResult | |
from fastapi import FastAPI, Depends | |
import os | |
# Environment variables for configuration | |
DATABASE_URL = os.environ.get("DATABASE_URL", "postgresql://username:password@localhost/mydatabase") | |
REDIS_BROKER_URL = os.environ.get("REDIS_BROKER_URL", "redis://localhost:6379/0") | |
REDIS_BACKEND_URL = os.environ.get("REDIS_BACKEND_URL", "redis://localhost:6379/0") | |
# Database setup for multi-tenant application | |
engine = create_engine(DATABASE_URL) | |
# Define the enum for task statuses | |
class TaskStatus(str, Enum): | |
STARTED = "STARTED" | |
IN_PROGRESS = "IN_PROGRESS" | |
COMPLETED = "COMPLETED" | |
FAILURE = "FAILURE" | |
REVOKED = "REVOKED" | |
RESTARTED = "RESTARTED" | |
# Define Task model for database interactions | |
class TaskModel(SQLModel, table=True): | |
__tablename__ = "tasks" | |
id: int = Field(default=None, primary_key=True) | |
task_id: str = Field(sa_column_kwargs={"unique": True, "index": True}) | |
tenant_id: str | |
start_time: datetime = Field(default_factory=datetime.now) | |
status: TaskStatus = Field(sa_column_kwargs={"enum": SQLEnum(TaskStatus)}) | |
last_heartbeat: datetime = Field(default_factory=datetime.now) | |
task_type_id: str = Field() | |
# Create tables in the database | |
SQLModel.metadata.create_all(engine) | |
# Context manager for database sessions to ensure safe transactions | |
@contextmanager | |
def get_session(): | |
session = Session(engine) | |
try: | |
yield session | |
session.commit() | |
except SQLAlchemyError as e: | |
session.rollback() | |
logger.error(f"Database error: {e}") | |
raise | |
finally: | |
session.close() | |
# Celery configuration using Redis for task queuing and result storage | |
celery_app = Celery( | |
'tasks', | |
broker=REDIS_BROKER_URL, | |
backend=REDIS_BACKEND_URL | |
) | |
# BaseTask class to handle common database logic, heartbeat updates, and automatic task registration | |
class BaseTask(celery_app.Task): | |
task_types = {} | |
def __init_subclass__(cls, **kwargs): | |
super().__init_subclass__(**kwargs) | |
cls.task_type_id = cls.__name__ | |
BaseTask.task_types[cls.task_type_id] = cls | |
cls.name = cls.task_type_id.lower() # Generate a unique task name based on the class name | |
celery_app.tasks.register(cls()) # Automatically register the task with Celery | |
def run(self, *args, **kwargs): | |
self._stop_heartbeat = False | |
self.before_start(*args, **kwargs) | |
self.start_heartbeat_thread(*args, **kwargs) | |
try: | |
result = self._run(*args, **kwargs) | |
self.after_completion() | |
return result | |
except Exception as e: | |
logger.error(f"Error in task: {e}") | |
raise | |
finally: | |
self.stop_heartbeat_thread() | |
def before_start(self, tenant_id): | |
with get_session() as db: | |
new_task = TaskModel(task_id=self.request.id, tenant_id=tenant_id, status=TaskStatus.STARTED, task_type_id=self.task_type_id) | |
db.add(new_task) | |
def start_heartbeat_thread(self, tenant_id): | |
self.heartbeat_thread = threading.Thread(target=self.heartbeat_loop, args=(tenant_id,)) | |
self.heartbeat_thread.start() | |
def heartbeat_loop(self, tenant_id): | |
while not self._stop_heartbeat: | |
self.update_heartbeat(tenant_id) | |
time.sleep(60) | |
def stop_heartbeat_thread(self): | |
self._stop_heartbeat = True | |
if self.heartbeat_thread: | |
self.heartbeat_thread.join() | |
def update_heartbeat(self, tenant_id): | |
with get_session() as db: | |
task = db.query(TaskModel).filter(TaskModel.task_id == self.request.id).first() | |
if task: | |
task.last_heartbeat = datetime.now() | |
def after_completion(self): | |
with get_session() as db: | |
task = db.query(TaskModel).filter(TaskModel.task_id == self.request.id).first() | |
if task: | |
task.status = TaskStatus.COMPLETED | |
def _run(self, *args, **kwargs): | |
raise NotImplementedError("Subclasses must implement this method") | |
class Wait2HoursTask(BaseTask): | |
def _run(self, tenant_id: str): | |
for minute in range(120): | |
time.sleep(60) | |
return 'Completed 2 hours wait' | |
@celery_app.task | |
def monitor_tasks(): | |
with get_session() as db: | |
tasks = db.query(TaskModel).filter(TaskModel.status == TaskStatus.STARTED).all() | |
for task in tasks: | |
task_result = AsyncResult(task.task_id, app=celery_app) | |
if task_result.state in ['FAILURE', 'REVOKED']: | |
task.status = TaskStatus.FAILURE | |
elif task_result.state == 'STARTED': | |
if needs_intervention(task, task_result): | |
restart_task(task.task_id) | |
task.status = TaskStatus.RESTARTED | |
def needs_intervention(task, task_result): | |
max_duration = timedelta(minutes=150) | |
if datetime.now() - task.start_time > max_duration: | |
return True | |
return False | |
def restart_task(task_id): | |
logger.info(f"Restarting task {task_id}") | |
celery_app.control.revoke(task_id, terminate=True) | |
with get_session() as db: | |
task = db.query(TaskModel).filter(TaskModel.task_id == task_id).first() | |
if task: | |
task_cls = BaseTask.task_types[task.task_type_id] | |
new_task = task_cls.apply_async(args=[task.tenant_id]) | |
task.task_id = new_task.id | |
app = FastAPI() | |
@app.post("/wait2hours/{tenant_id}") | |
def start_wait2hours_task(tenant_id: str, session: Session = Depends(get_session)): | |
task = Wait2HoursTask.apply_async(args=[tenant_id]) | |
return {"task_id": task.id, "tenant_id": tenant_id} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Orchestrating Your FastAPI and Celery Application with Docker Compose
To set up and run your FastAPI application, Celery workers, Redis, and PostgreSQL database using Docker Compose, follow these steps:
1. Prepare Your Application
Dockerfile
in the same directory as your Python application file. This Dockerfile will be used to build Docker images for your FastAPI application and Celery workers.2. Docker Compose Setup
docker-compose.yml
file in the same directory as your application code. This file will define and configure the necessary services (FastAPI app, Celery worker, Redis, PostgreSQL).3. Building and Running the Containers
docker-compose.yml
file.docker-compose.yml
file:4. Accessing the Services
http://localhost:8000
.6379
, and PostgreSQL will be accessible on port5432
on your host machine.5. Stopping the Services
Ctrl+C
in the terminal wheredocker-compose
is running.Docker Compose File (
docker-compose.yml
)This setup will help you orchestrate your entire application, including the FastAPI server, Celery worker, Redis, and PostgreSQL database, in a coordinated manner using Docker Compose.