Last active
January 3, 2024 06:40
-
-
Save aleksandr-smechov/437f17a055d146229a1eb3f64c5c4b4b to your computer and use it in GitHub Desktop.
vLLM gradio server for skypilot
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import requests | |
import gradio as gr | |
def http_bot(prompt, model_input, api_key, temperature, max_tokens, top_p): | |
headers = {"User-Agent": "vLLM Client"} | |
payload = dict( | |
model=model_input, | |
api_key=api_key, | |
prompt=prompt, | |
temperature=temperature, | |
max_tokens=max_tokens, | |
top_p=top_p, | |
stream=False | |
) | |
response = requests.post(args.model_url, headers=headers, json=payload) | |
return response.json()["choices"][0]["text"] | |
def build_demo(): | |
with gr.Blocks() as demo: | |
gr.Markdown("# vLLM text completion demo\n") | |
inputbox = gr.Textbox(label="Input", placeholder="Enter text and press ENTER") | |
model_input = gr.Textbox(label="Model") | |
api_key = gr.Textbox(label="API key") | |
temperature_slider = gr.Slider(label="Temperature", minimum=0, maximum=1, step=0.1, value=0.7) | |
max_tokens_input = gr.Number(label="Max Tokens", value=128) | |
top_p_slider = gr.Slider(label="Top P", minimum=0, maximum=1, step=0.1, value=1.0) | |
outputbox = gr.Textbox(label="Output", placeholder="Generated result from the model") | |
inputbox.submit( | |
http_bot, | |
inputs=[inputbox, model_input, api_key, temperature_slider, max_tokens_input, top_p_slider], | |
outputs=outputbox | |
) | |
return demo | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--host", type=str, default=None) | |
parser.add_argument("--port", type=int, default=8001) | |
parser.add_argument("--model-url", type=str, default="http://localhost:8000/completions") | |
args = parser.parse_args() | |
demo = build_demo() | |
demo.queue(concurrency_count=100).launch(server_name=args.host, server_port=args.port, share=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment