Skip to content

Instantly share code, notes, and snippets.

@luiscape
Created August 13, 2024 20:16
Show Gist options
  • Save luiscape/22bd7c26ed7e952b40b8b69e71efd731 to your computer and use it in GitHub Desktop.
Save luiscape/22bd7c26ed7e952b40b8b69e71efd731 to your computer and use it in GitHub Desktop.
Running ZeroMQ in Modal with a Modal Tunnel.
import sys
import time
import zmq
from common import Message, app
from server import process_message
def zmq_query(address: str, n_messages: int = 10):
context = zmq.Context()
print(f"connecting to {address=}")
soc = context.socket(zmq.REQ)
soc.connect(f"tcp://{address}")
soc.RCVTIMEO = 12 * 1000 # Timeout interval (in milliseconds)
for i in range(n_messages):
message = Message(data=f"message {i}")
soc.send(message.serialize())
# Handle timeout errors.
start = time.perf_counter()
try:
response = soc.recv()
message = Message.deserialize(response)
except zmq.error.Again:
print("Failed to receive message.")
break
# print(f"got response in {time.perf_counter() - start:.3f}s")
# print(message)
def modal_query(n_messages: int = 10):
for i in range(n_messages):
message = Message(data=f"message {i}")
process_message.remote(message)
if __name__ == "__main__":
n_messages = 10
start = time.perf_counter()
zmq_query(sys.argv[1], n_messages=n_messages)
print(f"zmq_query total time: {time.perf_counter() - start:.3f}s")
with app.run(show_progress=False):
start = time.perf_counter()
modal_query(n_messages=n_messages)
print(f"modal_query total time: {time.perf_counter() - start:.3f}s")
from dataclasses import dataclass
import modal
import cloudpickle
image = modal.Image.debian_slim().pip_install("zmq", "cloudpickle")
app = modal.App(image=image)
@dataclass
class Message:
data: str
processed: bool = False
def serialize(self):
return cloudpickle.dumps(self)
@staticmethod
def deserialize(data: bytes) -> "Message":
return cloudpickle.loads(data)
import zmq
from common import Message, app
import modal
@app.function()
def start_server():
context = zmq.Context()
socket = context.socket(zmq.REP)
socket.bind("tcp://*:5555")
print("Server is running...")
with modal.forward(5555, unencrypted=True) as tunnel:
origin, port = tunnel.tcp_socket
print(f"running at {origin}:{port}")
while True:
# Wait for next request from client.
data = socket.recv()
message = Message.deserialize(data)
print(f"received message: {message}")
# Update and send back to client.
message = process_message.local(message)
socket.send(message.serialize())
@app.function()
def process_message(message: Message) -> Message:
message.processed = True
return message
@app.local_entrypoint()
def main():
start_server.remote()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment