Skip to content

Instantly share code, notes, and snippets.

@pppppyamaP
Forked from kohya-ss/gradio_llm.py
Created May 2, 2024 05:57
Show Gist options
  • Save pppppyamaP/cc8183dabfa2470bb027ae75e3dd8bb6 to your computer and use it in GitHub Desktop.
Save pppppyamaP/cc8183dabfa2470bb027ae75e3dd8bb6 to your computer and use it in GitHub Desktop.
gradioでLLMを利用する簡易クライアント
# Apache License 2.0
# 使用法は gist のコメントを見てください
import argparse
from typing import List, Optional, Union, Iterator
from llama_cpp.llama_chat_format import _convert_completion_to_chat, register_chat_completion_handler
import llama_cpp.llama_types as llama_types
from llama_cpp.llama import LogitsProcessorList, LlamaGrammar
from llama_cpp import Llama, llama_chat_format
import gradio as gr
debug_flag = False
@register_chat_completion_handler("command-r")
def command_r_chat_handler(
llama: Llama,
messages: List[llama_types.ChatCompletionRequestMessage],
functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None,
tools: Optional[List[llama_types.ChatCompletionTool]] = None,
tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None,
temperature: float = 0.2,
top_p: float = 0.95,
top_k: int = 40,
min_p: float = 0.05,
typical_p: float = 1.0,
stream: bool = False,
stop: Optional[Union[str, List[str]]] = [],
response_format: Optional[llama_types.ChatCompletionRequestResponseFormat] = None,
max_tokens: Optional[int] = None,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0,
repeat_penalty: float = 1.1,
tfs_z: float = 1.0,
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
model: Optional[str] = None,
logits_processor: Optional[LogitsProcessorList] = None,
grammar: Optional[LlamaGrammar] = None,
**kwargs, # type: ignore
) -> Union[llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk]]:
bos_token = "<BOS_TOKEN>"
start_turn_token = "<|START_OF_TURN_TOKEN|>"
end_turn_token = "<|END_OF_TURN_TOKEN|>"
user_token = "<|USER_TOKEN|>"
chatbot_token = "<|CHATBOT_TOKEN|>"
system_token = "<|SYSTEM_TOKEN|>"
prompt = bos_token
if len(messages) > 0 and messages[0]["role"] == "system":
prompt += start_turn_token + system_token + messages[0]["content"] + end_turn_token
messages = messages[1:]
for message in messages:
if message["role"] == "user":
prompt += start_turn_token + user_token + message["content"] + end_turn_token
elif message["role"] == "assistant":
prompt += start_turn_token + chatbot_token + message["content"] + end_turn_token
prompt += start_turn_token + chatbot_token
if debug_flag:
print(f"Prompt: {prompt}")
stop_tokens = [end_turn_token] # , bos_token]
return _convert_completion_to_chat(
llama.create_completion(
prompt=prompt,
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
typical_p=typical_p,
stream=stream,
stop=stop_tokens,
max_tokens=max_tokens,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
repeat_penalty=repeat_penalty,
tfs_z=tfs_z,
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
model=model,
logits_processor=logits_processor,
grammar=grammar,
),
stream=stream,
)
def generate_completion(llama, prompt, max_tokens, temperature, top_p, repeat_penalty, top_k):
global stop_generating
stop_generating = False
output = prompt
if debug_flag:
print(
f"temperature: {temperature}, top_p: {top_p}, top_k: {top_k}, repeat_penalty: {repeat_penalty}, max_tokens: {max_tokens}"
)
for chunk in llama(
prompt,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
repeat_penalty=repeat_penalty,
top_k=top_k,
stream=True,
):
if debug_flag:
print(chunk)
if stop_generating:
break
if "choices" in chunk and len(chunk["choices"]) > 0:
if "text" in chunk["choices"][0]:
text = chunk["choices"][0]["text"]
# check EOS_TOKEN
if text.endswith("<EOS_TOKEN>"): # llama.tokenizer.EOS_TOKEN):
output += text[: -len("<EOS_TOKEN>")]
yield output[len(prompt) :]
break
output += text
yield output[len(prompt) :] # remove prompt
def launch_completion(llama, listen=False):
# css = """
# .prompt textarea {font-size:1.0em !important}
# """
# with gr.Blocks(css=css) as demo:
with gr.Blocks() as demo:
with gr.Row():
# change font size
io_textbox = gr.Textbox(
label="Input/Output: Text may not be scrolled automatically. Shift+Enter to newline. テキストは自動スクロールしないことがあります。Shift+Enterで改行。",
placeholder="Enter your prompt here...",
interactive=True,
elem_classes=["prompt"],
autoscroll=True,
)
with gr.Row():
generate_button = gr.Button("Generate")
stop_button = gr.Button("Stop", visible=False)
with gr.Row():
max_tokens = gr.Slider(minimum=1, maximum=2048, value=128, step=1, label="Max Tokens")
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.3, step=0.01, label="Temperature")
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.05, label="Top P")
repeat_penalty = gr.Slider(minimum=0.0, maximum=2.0, value=1.1, step=0.1, label="Repeat Penalty")
top_k = gr.Slider(minimum=1, maximum=200, value=40, step=1, label="Top K")
def generate_and_display(prompt, max_tokens, temperature, top_p, repeat_penalty, top_k):
output_generator = generate_completion(llama, prompt, max_tokens, temperature, top_p, repeat_penalty, top_k)
for output in output_generator:
yield gr.update(value=prompt + output, autoscroll=True), gr.update(visible=False), gr.update(visible=True)
yield gr.update(value=prompt + output, autoscroll=True), gr.update(visible=True), gr.update(visible=False)
def stop_generation():
globals().update(stop_generating=True)
return gr.update(visible=True), gr.update(visible=False)
generate_button.click(
generate_and_display,
inputs=[io_textbox, max_tokens, temperature, top_p, repeat_penalty, top_k],
outputs=[io_textbox, generate_button, stop_button],
show_progress=True,
)
stop_button.click(
stop_generation,
outputs=[generate_button, stop_button],
)
# add event to textbox to add new line on enter
io_textbox.submit(
lambda x: x + "\n",
inputs=[io_textbox],
outputs=[io_textbox],
)
demo.launch(server_name="0.0.0.0" if listen else None)
def launch_chat(llama, handler_name, listen=False):
def chat(message, history, system_prompt, max_tokens, temperature, top_p, repeat_penalty, top_k):
user_input = message
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
for message in history:
messages.append({"role": "user", "content": message[0]})
messages.append({"role": "assistant", "content": message[1]})
messages.append({"role": "user", "content": user_input})
if debug_flag:
print(f"Messages: {messages}")
print(
f"System prompt: {system_prompt}, temperature: {temperature}, top_p: {top_p}, top_k: {top_k}, repeat_penalty: {repeat_penalty}, max_tokens: {max_tokens}"
)
handler = llama_chat_format.get_chat_completion_handler(handler_name)
chat_completion_chunks = handler(
llama=llama,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
repeat_penalty=repeat_penalty,
top_k=int(top_k),
stream=True,
)
response = ""
for chunk in chat_completion_chunks:
if debug_flag:
print(chunk)
if "choices" in chunk and len(chunk["choices"]) > 0:
if "delta" in chunk["choices"][0]:
if "content" in chunk["choices"][0]["delta"]:
response += chunk["choices"][0]["delta"]["content"]
yield response
system_prompt = gr.Textbox(label="System Prompt", placeholder="Enter system prompt here...")
max_tokens = gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max Tokens")
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.3, step=0.01, label="Temperature")
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.05, label="Top P")
repeat_penalty = gr.Slider(minimum=0.0, maximum=2.0, value=1.1, step=0.1, label="Repeat Penalty")
top_k = gr.Slider(minimum=1, maximum=200, value=40, step=1, label="Top K")
additional_inputs = [system_prompt, max_tokens, temperature, top_p, repeat_penalty, top_k]
chatbot = gr.ChatInterface(chat, additional_inputs=additional_inputs)
chatbot.launch(server_name="0.0.0.0" if listen else None)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", type=str, default=None, help="Model file path")
parser.add_argument("-ngl", "--n_gpu_layers", type=int, default=0, help="Number of GPU layers")
parser.add_argument("-c", "--n_ctx", type=int, default=2048, help="Context length")
parser.add_argument(
"-ch",
"--chat_handler",
type=str,
default="command-r",
help="Chat handler, e.g. command-r, mistral-instruct, alpaca, llama-3 etc. default: command-r",
)
parser.add_argument("--chat", action="store_true", help="Chat mode")
parser.add_argument("--listen", action="store_true", help="Listen mode")
parser.add_argument(
"-ts", "--tensor_split", type=str, default=None, help="Tensor split, float values separated by comma for each gpu"
)
parser.add_argument("--debug", action="store_true", help="Debug mode")
args = parser.parse_args()
# tokenizer initialization doesn't seem to be needed
# print("Initializing tokenizer")
# tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
print(f"Initializing Llama. Model ID: {args.model}, N_GPU_LAYERS: {args.n_gpu_layers}, N_CTX: {args.n_ctx}")
# llama_tokenizer = LlamaHFTokenizer(tokenizer)
tensor_split = None if args.tensor_split is None else [float(x) for x in args.tensor_split.split(",")]
llama = Llama(
model_path=args.model,
n_gpu_layers=args.n_gpu_layers,
tensor_split=tensor_split,
n_ctx=args.n_ctx,
# tokenizer=llama_tokenizer,
# n_threads=n_threads,
)
debug_flag = args.debug
if args.chat:
print(f"Launching chat with handler: {args.chat_handler}")
launch_chat(llama, args.chat_handler, args.listen)
else:
print("Launching completion")
launch_completion(llama, args.listen)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment