Created
August 13, 2024 20:16
-
-
Save luiscape/22bd7c26ed7e952b40b8b69e71efd731 to your computer and use it in GitHub Desktop.
Running ZeroMQ in Modal with a Modal Tunnel.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import time | |
import zmq | |
from common import Message, app | |
from server import process_message | |
def zmq_query(address: str, n_messages: int = 10): | |
context = zmq.Context() | |
print(f"connecting to {address=}") | |
soc = context.socket(zmq.REQ) | |
soc.connect(f"tcp://{address}") | |
soc.RCVTIMEO = 12 * 1000 # Timeout interval (in milliseconds) | |
for i in range(n_messages): | |
message = Message(data=f"message {i}") | |
soc.send(message.serialize()) | |
# Handle timeout errors. | |
start = time.perf_counter() | |
try: | |
response = soc.recv() | |
message = Message.deserialize(response) | |
except zmq.error.Again: | |
print("Failed to receive message.") | |
break | |
# print(f"got response in {time.perf_counter() - start:.3f}s") | |
# print(message) | |
def modal_query(n_messages: int = 10): | |
for i in range(n_messages): | |
message = Message(data=f"message {i}") | |
process_message.remote(message) | |
if __name__ == "__main__": | |
n_messages = 10 | |
start = time.perf_counter() | |
zmq_query(sys.argv[1], n_messages=n_messages) | |
print(f"zmq_query total time: {time.perf_counter() - start:.3f}s") | |
with app.run(show_progress=False): | |
start = time.perf_counter() | |
modal_query(n_messages=n_messages) | |
print(f"modal_query total time: {time.perf_counter() - start:.3f}s") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from dataclasses import dataclass | |
import modal | |
import cloudpickle | |
image = modal.Image.debian_slim().pip_install("zmq", "cloudpickle") | |
app = modal.App(image=image) | |
@dataclass | |
class Message: | |
data: str | |
processed: bool = False | |
def serialize(self): | |
return cloudpickle.dumps(self) | |
@staticmethod | |
def deserialize(data: bytes) -> "Message": | |
return cloudpickle.loads(data) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import zmq | |
from common import Message, app | |
import modal | |
@app.function() | |
def start_server(): | |
context = zmq.Context() | |
socket = context.socket(zmq.REP) | |
socket.bind("tcp://*:5555") | |
print("Server is running...") | |
with modal.forward(5555, unencrypted=True) as tunnel: | |
origin, port = tunnel.tcp_socket | |
print(f"running at {origin}:{port}") | |
while True: | |
# Wait for next request from client. | |
data = socket.recv() | |
message = Message.deserialize(data) | |
print(f"received message: {message}") | |
# Update and send back to client. | |
message = process_message.local(message) | |
socket.send(message.serialize()) | |
@app.function() | |
def process_message(message: Message) -> Message: | |
message.processed = True | |
return message | |
@app.local_entrypoint() | |
def main(): | |
start_server.remote() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment