Created
September 21, 2021 20:22
-
-
Save edoakes/af18a782cd607e3f5b4caa4a524a15ea to your computer and use it in GitHub Desktop.
Ray Serve plotly wrapper (working but hacky)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import dash | |
from dash import dcc, html | |
from dash.dependencies import Input, Output | |
import pandas as pd | |
import plotly.graph_objs as obj | |
import uvicorn as uvicorn | |
from fastapi import FastAPI | |
from starlette.middleware.wsgi import WSGIMiddleware | |
import ray | |
from ray import serve | |
def build_plotly_app(): | |
app = dash.Dash(__name__, requests_pathname_prefix="/dash/") | |
years = list(range(1940, 2021, 1)) | |
temp_high = [x / 20 for x in years] | |
temp_low = [x - 20 for x in temp_high] | |
df = pd.DataFrame({"Year": years, "TempHigh": temp_high, "TempLow": temp_low}) | |
slider = dcc.RangeSlider( | |
id="slider", | |
value=[df["Year"].min(), df["Year"].max()], | |
min=df["Year"].min(), | |
max=df["Year"].max(), | |
step=5, | |
marks={ | |
1940: "1940", | |
1945: "1945", | |
1950: "1950", | |
1955: "1955", | |
1960: "1960", | |
1965: "1965", | |
1970: "1970", | |
1975: "1975", | |
1980: "1980", | |
1985: "1985", | |
1990: "1990", | |
1995: "1995", | |
2000: "2000", | |
2005: "2005", | |
2010: "2010", | |
2015: "2015", | |
2020: "2020", | |
}, | |
) | |
app.layout = html.Div( | |
children=[ | |
html.H1(children="Data Visualization with Dash"), | |
html.Div(children="High/Low Temperatures Over Time"), | |
dcc.Graph(id="temp-plot"), | |
slider, | |
] | |
) | |
@app.callback(Output("temp-plot", "figure"), [Input("slider", "value")]) | |
def add_graph(slider): | |
print(type(slider)) | |
trace_high = obj.Scatter(x=df["Year"], y=df["TempHigh"], mode="markers", name="High Temperatures") | |
trace_low = obj.Scatter(x=df["Year"], y=df["TempLow"], mode="markers", name="Low Temperatures") | |
layout = obj.Layout(xaxis=dict(range=[slider[0], slider[1]]), yaxis={"title": "Temperature"}) | |
figure = obj.Figure(data=[trace_high, trace_low], layout=layout) | |
return figure | |
return app | |
def lazy_middleware(*args, **kwargs): | |
"""This is some extreme hackery, beware... | |
This call lets us get a reference to the main Serve deployment class | |
from within the deployment. We hard-code the wrapped WSGI to be a field of | |
the class called `cls.app.server`, so we can transparently proxy the HTTP | |
requests through to that field here. | |
""" | |
return serve.get_replica_context().servable_object.app.server(*args, **kwargs) | |
if __name__ == "__main__": | |
server = FastAPI() | |
server.mount("/dash", WSGIMiddleware(lazy_middleware)) | |
ray.init(address="auto", namespace="serve") | |
serve.start(detached=True) | |
@serve.deployment(route_prefix="/") | |
@serve.ingress(server) | |
class MyServeWrapper: | |
def __init__(self): | |
# We need to construct the plotly app within the constructor | |
# because it is unfortunately not serializable using cloudpickle | |
# (contains some weakrefs). | |
self.app = build_plotly_app() | |
MyServeWrapper.deploy() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment