Created
September 20, 2023 20:52
-
-
Save evinism/73ae08efa63396495b1f9d6030cb4b96 to your computer and use it in GitHub Desktop.
basic ray example
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from starlette.requests import Request | |
from typing import Dict | |
import ray | |
from ray import serve | |
from ray.serve.handle import RayServeHandle | |
from transformers import pipeline | |
@serve.deployment | |
class Translator: | |
def __init__(self): | |
self.language = "french" | |
self.model = pipeline("translation_en_to_fr", model="t5-small") | |
def translate(self, text: str) -> str: | |
model_output = self.model(text) | |
translation = model_output[0]["translation_text"] | |
return translation | |
def reconfigure(self, config: Dict): | |
self.language = config.get("language", "french") | |
if self.language.lower() == "french": | |
self.model = pipeline("translation_en_to_fr", model="t5-small") | |
elif self.language.lower() == "german": | |
self.model = pipeline("translation_en_to_de", model="t5-small") | |
elif self.language.lower() == "romanian": | |
self.model = pipeline("translation_en_to_ro", model="t5-small") | |
else: | |
pass | |
@serve.deployment | |
class Summarizer: | |
def __init__(self, translator: RayServeHandle): | |
# Load model | |
self.model = pipeline("summarization", model="t5-small") | |
self.translator = translator | |
self.min_length = 5 | |
self.max_length = 15 | |
def summarize(self, text: str) -> str: | |
# Run inference | |
model_output = self.model( | |
text, min_length=self.min_length, max_length=self.max_length | |
) | |
# Post-process output to return only the summary text | |
summary = model_output[0]["summary_text"] | |
return summary | |
async def __call__(self, http_request: Request) -> str: | |
english_text: str = await http_request.json() | |
summary = self.summarize(english_text) | |
translation_ref = await self.translator.translate.remote(summary) | |
translation = await translation_ref | |
return translation | |
def reconfigure(self, config: Dict): | |
self.min_length = config.get("min_length", 5) | |
self.max_length = config.get("max_length", 15) | |
app = Summarizer.bind(Translator.bind()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment