Last active
October 7, 2024 17:33
-
-
Save helton/e5ea607592e02a516e31fd385fa60fe3 to your computer and use it in GitHub Desktop.
[WIP] Redis Cluster support for Kombu (used by Celery)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from .redis_cluster_transport import RedisClusterTransport | |
# Register the transport with Kombu | |
from kombu.transport import register_transport | |
register_transport('redis_cluster', RedisClusterTransport) | |
# --- | |
from celery import Celery | |
app = Celery('your_project') | |
# Broker URL format for custom transport | |
# kombu expects the transport to be registered, and the URL to specify the transport | |
app.conf.broker_url = 'redis_cluster://' | |
# Transport options | |
app.conf.broker_transport_options = { | |
'hosts': ['redis-cluster-node1:6379', 'redis-cluster-node2:6379', 'redis-cluster-node3:6379'], | |
'password': 'your_redis_password', # if applicable | |
'db': 0, | |
'queue_prefix': 'celery:', | |
'exchange_prefix': 'exchange:', | |
} | |
# Optional: result backend configuration | |
app.conf.result_backend = 'redis://redis-cluster-node1:6379/1' |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from kombu.transport import virtual | |
from kombu.utils import cached_property | |
from kombu.utils.encoding import bytes_to_str | |
from kombu.exceptions import TransportError | |
from kombu import Connection | |
from kombu.transport.base import Message | |
from redis.cluster import RedisCluster | |
import json | |
import uuid | |
import time | |
class RedisClusterTransport(virtual.Transport): | |
Channel = 'RedisClusterChannel' | |
driver_type = 'redis_cluster' | |
driver_name = 'redis_cluster' | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.options = self.connection.transport_options or {} | |
self.redis_hosts = self.options.get('hosts', ['localhost:6379']) | |
self.redis_password = self.options.get('password', None) | |
self.db = self.options.get('db', 0) | |
self.socket_timeout = self.options.get('socket_timeout', 5) | |
# Initialize RedisCluster client | |
try: | |
startup_nodes = [] | |
for host_port in self.redis_hosts: | |
host, port = host_port.split(':') | |
startup_nodes.append({'host': host, 'port': int(port)}) | |
self.redis = RedisCluster( | |
startup_nodes=startup_nodes, | |
password=self.redis_password, | |
decode_responses=True, | |
socket_timeout=self.socket_timeout, | |
skip_full_coverage_check=True # Optional: depends on your cluster setup | |
) | |
except Exception as e: | |
raise TransportError(f"Failed to connect to Redis Cluster: {e}") | |
def driver_version(self): | |
import redis | |
return redis.__version__ | |
def default_port(self): | |
return 6379 | |
def _decode(self, body): | |
if isinstance(body, bytes): | |
return body.decode('utf-8') | |
return body | |
class RedisClusterChannel(virtual.Channel): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.transport = self.connection.transport | |
self.redis = self.transport.redis | |
# Prefix for queues to avoid key collisions | |
self.queue_prefix = self.transport.options.get('queue_prefix', 'celery:') | |
self.exchange_prefix = self.transport.options.get('exchange_prefix', 'exchange:') | |
def _queue_key(self, queue): | |
return f"{self.queue_prefix}{queue}" | |
def _exchange_key(self, exchange): | |
return f"{self.exchange_prefix}{exchange}" | |
def _put(self, queue, message): | |
key = self._queue_key(queue) | |
message_id = str(uuid.uuid4()) | |
body = json.dumps(message) | |
try: | |
self.redis.rpush(key, json.dumps({'id': message_id, 'body': body})) | |
except Exception as e: | |
raise TransportError(f"Failed to enqueue message: {e}") | |
def _get(self, queue, timeout=None): | |
key = self._queue_key(queue) | |
try: | |
if timeout is not None: | |
# BLPOP returns (key, value) or None | |
result = self.redis.blpop(key, timeout=timeout) | |
if result: | |
_, value = result | |
message = json.loads(value) | |
return message | |
else: | |
value = self.redis.lpop(key) | |
if value: | |
message = json.loads(value) | |
return message | |
except Exception as e: | |
raise TransportError(f"Failed to dequeue message: {e}") | |
return None | |
def drain_events(self, connection, timeout=None): | |
queue = self.default_queue | |
message = self._get(queue, timeout=timeout) | |
if message: | |
return self._from_message(message) | |
else: | |
raise self.connection.connection_errors | |
def _from_message(self, message): | |
body = json.loads(message['body']) | |
return Message(body, content_type='application/json', content_encoding='utf-8') | |
def declare_queue(self, queue, **kwargs): | |
# In Redis, queues are implicitly created when messages are pushed. | |
pass | |
def queue_exists(self, queue): | |
key = self._queue_key(queue) | |
return self.redis.exists(key) == 1 | |
def basic_publish(self, message, exchange, routing_key, declare=False, **kwargs): | |
self._put(routing_key, message) | |
def basic_consume(self, queue, callback, **kwargs): | |
while True: | |
message = self._get(queue, timeout=kwargs.get('timeout', 1)) | |
if message: | |
callback(message) | |
else: | |
time.sleep(0.1) # Prevent tight loop | |
def close(self): | |
# RedisCluster does not require explicit closing | |
pass |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment