Skip to content

Instantly share code, notes, and snippets.

@helton
Last active October 7, 2024 17:33
Show Gist options
  • Save helton/e5ea607592e02a516e31fd385fa60fe3 to your computer and use it in GitHub Desktop.
Save helton/e5ea607592e02a516e31fd385fa60fe3 to your computer and use it in GitHub Desktop.
[WIP] Redis Cluster support for Kombu (used by Celery)
from .redis_cluster_transport import RedisClusterTransport
# Register the transport with Kombu
from kombu.transport import register_transport
register_transport('redis_cluster', RedisClusterTransport)
# ---
from celery import Celery
app = Celery('your_project')
# Broker URL format for custom transport
# kombu expects the transport to be registered, and the URL to specify the transport
app.conf.broker_url = 'redis_cluster://'
# Transport options
app.conf.broker_transport_options = {
'hosts': ['redis-cluster-node1:6379', 'redis-cluster-node2:6379', 'redis-cluster-node3:6379'],
'password': 'your_redis_password', # if applicable
'db': 0,
'queue_prefix': 'celery:',
'exchange_prefix': 'exchange:',
}
# Optional: result backend configuration
app.conf.result_backend = 'redis://redis-cluster-node1:6379/1'
from kombu.transport import virtual
from kombu.utils import cached_property
from kombu.utils.encoding import bytes_to_str
from kombu.exceptions import TransportError
from kombu import Connection
from kombu.transport.base import Message
from redis.cluster import RedisCluster
import json
import uuid
import time
class RedisClusterTransport(virtual.Transport):
Channel = 'RedisClusterChannel'
driver_type = 'redis_cluster'
driver_name = 'redis_cluster'
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.options = self.connection.transport_options or {}
self.redis_hosts = self.options.get('hosts', ['localhost:6379'])
self.redis_password = self.options.get('password', None)
self.db = self.options.get('db', 0)
self.socket_timeout = self.options.get('socket_timeout', 5)
# Initialize RedisCluster client
try:
startup_nodes = []
for host_port in self.redis_hosts:
host, port = host_port.split(':')
startup_nodes.append({'host': host, 'port': int(port)})
self.redis = RedisCluster(
startup_nodes=startup_nodes,
password=self.redis_password,
decode_responses=True,
socket_timeout=self.socket_timeout,
skip_full_coverage_check=True # Optional: depends on your cluster setup
)
except Exception as e:
raise TransportError(f"Failed to connect to Redis Cluster: {e}")
def driver_version(self):
import redis
return redis.__version__
def default_port(self):
return 6379
def _decode(self, body):
if isinstance(body, bytes):
return body.decode('utf-8')
return body
class RedisClusterChannel(virtual.Channel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.transport = self.connection.transport
self.redis = self.transport.redis
# Prefix for queues to avoid key collisions
self.queue_prefix = self.transport.options.get('queue_prefix', 'celery:')
self.exchange_prefix = self.transport.options.get('exchange_prefix', 'exchange:')
def _queue_key(self, queue):
return f"{self.queue_prefix}{queue}"
def _exchange_key(self, exchange):
return f"{self.exchange_prefix}{exchange}"
def _put(self, queue, message):
key = self._queue_key(queue)
message_id = str(uuid.uuid4())
body = json.dumps(message)
try:
self.redis.rpush(key, json.dumps({'id': message_id, 'body': body}))
except Exception as e:
raise TransportError(f"Failed to enqueue message: {e}")
def _get(self, queue, timeout=None):
key = self._queue_key(queue)
try:
if timeout is not None:
# BLPOP returns (key, value) or None
result = self.redis.blpop(key, timeout=timeout)
if result:
_, value = result
message = json.loads(value)
return message
else:
value = self.redis.lpop(key)
if value:
message = json.loads(value)
return message
except Exception as e:
raise TransportError(f"Failed to dequeue message: {e}")
return None
def drain_events(self, connection, timeout=None):
queue = self.default_queue
message = self._get(queue, timeout=timeout)
if message:
return self._from_message(message)
else:
raise self.connection.connection_errors
def _from_message(self, message):
body = json.loads(message['body'])
return Message(body, content_type='application/json', content_encoding='utf-8')
def declare_queue(self, queue, **kwargs):
# In Redis, queues are implicitly created when messages are pushed.
pass
def queue_exists(self, queue):
key = self._queue_key(queue)
return self.redis.exists(key) == 1
def basic_publish(self, message, exchange, routing_key, declare=False, **kwargs):
self._put(routing_key, message)
def basic_consume(self, queue, callback, **kwargs):
while True:
message = self._get(queue, timeout=kwargs.get('timeout', 1))
if message:
callback(message)
else:
time.sleep(0.1) # Prevent tight loop
def close(self):
# RedisCluster does not require explicit closing
pass
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment