-
-
Save pppppyamaP/cc8183dabfa2470bb027ae75e3dd8bb6 to your computer and use it in GitHub Desktop.
gradioでLLMを利用する簡易クライアント
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Apache License 2.0 | |
# 使用法は gist のコメントを見てください | |
import argparse | |
from typing import List, Optional, Union, Iterator | |
from llama_cpp.llama_chat_format import _convert_completion_to_chat, register_chat_completion_handler | |
import llama_cpp.llama_types as llama_types | |
from llama_cpp.llama import LogitsProcessorList, LlamaGrammar | |
from llama_cpp import Llama, llama_chat_format | |
import gradio as gr | |
debug_flag = False | |
@register_chat_completion_handler("command-r") | |
def command_r_chat_handler( | |
llama: Llama, | |
messages: List[llama_types.ChatCompletionRequestMessage], | |
functions: Optional[List[llama_types.ChatCompletionFunction]] = None, | |
function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None, | |
tools: Optional[List[llama_types.ChatCompletionTool]] = None, | |
tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None, | |
temperature: float = 0.2, | |
top_p: float = 0.95, | |
top_k: int = 40, | |
min_p: float = 0.05, | |
typical_p: float = 1.0, | |
stream: bool = False, | |
stop: Optional[Union[str, List[str]]] = [], | |
response_format: Optional[llama_types.ChatCompletionRequestResponseFormat] = None, | |
max_tokens: Optional[int] = None, | |
presence_penalty: float = 0.0, | |
frequency_penalty: float = 0.0, | |
repeat_penalty: float = 1.1, | |
tfs_z: float = 1.0, | |
mirostat_mode: int = 0, | |
mirostat_tau: float = 5.0, | |
mirostat_eta: float = 0.1, | |
model: Optional[str] = None, | |
logits_processor: Optional[LogitsProcessorList] = None, | |
grammar: Optional[LlamaGrammar] = None, | |
**kwargs, # type: ignore | |
) -> Union[llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk]]: | |
bos_token = "<BOS_TOKEN>" | |
start_turn_token = "<|START_OF_TURN_TOKEN|>" | |
end_turn_token = "<|END_OF_TURN_TOKEN|>" | |
user_token = "<|USER_TOKEN|>" | |
chatbot_token = "<|CHATBOT_TOKEN|>" | |
system_token = "<|SYSTEM_TOKEN|>" | |
prompt = bos_token | |
if len(messages) > 0 and messages[0]["role"] == "system": | |
prompt += start_turn_token + system_token + messages[0]["content"] + end_turn_token | |
messages = messages[1:] | |
for message in messages: | |
if message["role"] == "user": | |
prompt += start_turn_token + user_token + message["content"] + end_turn_token | |
elif message["role"] == "assistant": | |
prompt += start_turn_token + chatbot_token + message["content"] + end_turn_token | |
prompt += start_turn_token + chatbot_token | |
if debug_flag: | |
print(f"Prompt: {prompt}") | |
stop_tokens = [end_turn_token] # , bos_token] | |
return _convert_completion_to_chat( | |
llama.create_completion( | |
prompt=prompt, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
min_p=min_p, | |
typical_p=typical_p, | |
stream=stream, | |
stop=stop_tokens, | |
max_tokens=max_tokens, | |
presence_penalty=presence_penalty, | |
frequency_penalty=frequency_penalty, | |
repeat_penalty=repeat_penalty, | |
tfs_z=tfs_z, | |
mirostat_mode=mirostat_mode, | |
mirostat_tau=mirostat_tau, | |
mirostat_eta=mirostat_eta, | |
model=model, | |
logits_processor=logits_processor, | |
grammar=grammar, | |
), | |
stream=stream, | |
) | |
def generate_completion(llama, prompt, max_tokens, temperature, top_p, repeat_penalty, top_k): | |
global stop_generating | |
stop_generating = False | |
output = prompt | |
if debug_flag: | |
print( | |
f"temperature: {temperature}, top_p: {top_p}, top_k: {top_k}, repeat_penalty: {repeat_penalty}, max_tokens: {max_tokens}" | |
) | |
for chunk in llama( | |
prompt, | |
max_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
repeat_penalty=repeat_penalty, | |
top_k=top_k, | |
stream=True, | |
): | |
if debug_flag: | |
print(chunk) | |
if stop_generating: | |
break | |
if "choices" in chunk and len(chunk["choices"]) > 0: | |
if "text" in chunk["choices"][0]: | |
text = chunk["choices"][0]["text"] | |
# check EOS_TOKEN | |
if text.endswith("<EOS_TOKEN>"): # llama.tokenizer.EOS_TOKEN): | |
output += text[: -len("<EOS_TOKEN>")] | |
yield output[len(prompt) :] | |
break | |
output += text | |
yield output[len(prompt) :] # remove prompt | |
def launch_completion(llama, listen=False): | |
# css = """ | |
# .prompt textarea {font-size:1.0em !important} | |
# """ | |
# with gr.Blocks(css=css) as demo: | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
# change font size | |
io_textbox = gr.Textbox( | |
label="Input/Output: Text may not be scrolled automatically. Shift+Enter to newline. テキストは自動スクロールしないことがあります。Shift+Enterで改行。", | |
placeholder="Enter your prompt here...", | |
interactive=True, | |
elem_classes=["prompt"], | |
autoscroll=True, | |
) | |
with gr.Row(): | |
generate_button = gr.Button("Generate") | |
stop_button = gr.Button("Stop", visible=False) | |
with gr.Row(): | |
max_tokens = gr.Slider(minimum=1, maximum=2048, value=128, step=1, label="Max Tokens") | |
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.3, step=0.01, label="Temperature") | |
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.05, label="Top P") | |
repeat_penalty = gr.Slider(minimum=0.0, maximum=2.0, value=1.1, step=0.1, label="Repeat Penalty") | |
top_k = gr.Slider(minimum=1, maximum=200, value=40, step=1, label="Top K") | |
def generate_and_display(prompt, max_tokens, temperature, top_p, repeat_penalty, top_k): | |
output_generator = generate_completion(llama, prompt, max_tokens, temperature, top_p, repeat_penalty, top_k) | |
for output in output_generator: | |
yield gr.update(value=prompt + output, autoscroll=True), gr.update(visible=False), gr.update(visible=True) | |
yield gr.update(value=prompt + output, autoscroll=True), gr.update(visible=True), gr.update(visible=False) | |
def stop_generation(): | |
globals().update(stop_generating=True) | |
return gr.update(visible=True), gr.update(visible=False) | |
generate_button.click( | |
generate_and_display, | |
inputs=[io_textbox, max_tokens, temperature, top_p, repeat_penalty, top_k], | |
outputs=[io_textbox, generate_button, stop_button], | |
show_progress=True, | |
) | |
stop_button.click( | |
stop_generation, | |
outputs=[generate_button, stop_button], | |
) | |
# add event to textbox to add new line on enter | |
io_textbox.submit( | |
lambda x: x + "\n", | |
inputs=[io_textbox], | |
outputs=[io_textbox], | |
) | |
demo.launch(server_name="0.0.0.0" if listen else None) | |
def launch_chat(llama, handler_name, listen=False): | |
def chat(message, history, system_prompt, max_tokens, temperature, top_p, repeat_penalty, top_k): | |
user_input = message | |
messages = [] | |
if system_prompt: | |
messages.append({"role": "system", "content": system_prompt}) | |
for message in history: | |
messages.append({"role": "user", "content": message[0]}) | |
messages.append({"role": "assistant", "content": message[1]}) | |
messages.append({"role": "user", "content": user_input}) | |
if debug_flag: | |
print(f"Messages: {messages}") | |
print( | |
f"System prompt: {system_prompt}, temperature: {temperature}, top_p: {top_p}, top_k: {top_k}, repeat_penalty: {repeat_penalty}, max_tokens: {max_tokens}" | |
) | |
handler = llama_chat_format.get_chat_completion_handler(handler_name) | |
chat_completion_chunks = handler( | |
llama=llama, | |
messages=messages, | |
max_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
repeat_penalty=repeat_penalty, | |
top_k=int(top_k), | |
stream=True, | |
) | |
response = "" | |
for chunk in chat_completion_chunks: | |
if debug_flag: | |
print(chunk) | |
if "choices" in chunk and len(chunk["choices"]) > 0: | |
if "delta" in chunk["choices"][0]: | |
if "content" in chunk["choices"][0]["delta"]: | |
response += chunk["choices"][0]["delta"]["content"] | |
yield response | |
system_prompt = gr.Textbox(label="System Prompt", placeholder="Enter system prompt here...") | |
max_tokens = gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max Tokens") | |
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.3, step=0.01, label="Temperature") | |
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.05, label="Top P") | |
repeat_penalty = gr.Slider(minimum=0.0, maximum=2.0, value=1.1, step=0.1, label="Repeat Penalty") | |
top_k = gr.Slider(minimum=1, maximum=200, value=40, step=1, label="Top K") | |
additional_inputs = [system_prompt, max_tokens, temperature, top_p, repeat_penalty, top_k] | |
chatbot = gr.ChatInterface(chat, additional_inputs=additional_inputs) | |
chatbot.launch(server_name="0.0.0.0" if listen else None) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("-m", "--model", type=str, default=None, help="Model file path") | |
parser.add_argument("-ngl", "--n_gpu_layers", type=int, default=0, help="Number of GPU layers") | |
parser.add_argument("-c", "--n_ctx", type=int, default=2048, help="Context length") | |
parser.add_argument( | |
"-ch", | |
"--chat_handler", | |
type=str, | |
default="command-r", | |
help="Chat handler, e.g. command-r, mistral-instruct, alpaca, llama-3 etc. default: command-r", | |
) | |
parser.add_argument("--chat", action="store_true", help="Chat mode") | |
parser.add_argument("--listen", action="store_true", help="Listen mode") | |
parser.add_argument( | |
"-ts", "--tensor_split", type=str, default=None, help="Tensor split, float values separated by comma for each gpu" | |
) | |
parser.add_argument("--debug", action="store_true", help="Debug mode") | |
args = parser.parse_args() | |
# tokenizer initialization doesn't seem to be needed | |
# print("Initializing tokenizer") | |
# tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
print(f"Initializing Llama. Model ID: {args.model}, N_GPU_LAYERS: {args.n_gpu_layers}, N_CTX: {args.n_ctx}") | |
# llama_tokenizer = LlamaHFTokenizer(tokenizer) | |
tensor_split = None if args.tensor_split is None else [float(x) for x in args.tensor_split.split(",")] | |
llama = Llama( | |
model_path=args.model, | |
n_gpu_layers=args.n_gpu_layers, | |
tensor_split=tensor_split, | |
n_ctx=args.n_ctx, | |
# tokenizer=llama_tokenizer, | |
# n_threads=n_threads, | |
) | |
debug_flag = args.debug | |
if args.chat: | |
print(f"Launching chat with handler: {args.chat_handler}") | |
launch_chat(llama, args.chat_handler, args.listen) | |
else: | |
print("Launching completion") | |
launch_completion(llama, args.listen) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment