Created
February 9, 2025 13:13
-
-
Save kohya-ss/bdc5cdaf6ff4566e42a50391b3cf61e5 to your computer and use it in GitHub Desktop.
ノベルゲーム風簡易CUI:会話とプロンプト生成を別のchatにして一貫性を向上させる
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# ライセンスは Apache License 2.0 です。 | |
# お勧めのオプションは: | |
# --flash_attn --q8_kv_cache --n_gpu_layers <GPUのVRAMに応じて> --n_ctx <VRAMとモデルのcontext lengthに応じて> | |
import argparse | |
import random | |
import msvcrt # Windowsの場合 | |
import toml | |
from typing import List, Optional, Union, Iterator | |
from datetime import datetime | |
import importlib | |
import os | |
import llama_cpp | |
from llama_cpp.llama_chat_format import _convert_completion_to_chat, register_chat_completion_handler | |
import llama_cpp.llama_types as llama_types | |
from llama_cpp.llama import LogitsProcessorList, LlamaGrammar | |
from llama_cpp import Llama, llama_chat_format | |
MESSAGE_HISTORY_SIZE = 20 # コンテキスト長に余裕があれば長くするとよい | |
# デフォルト設定 | |
DEFAULT_CHARACTER_NAME = "美咲" | |
DEFAULT_USER_NAME = "彰" | |
DEFAULT_CHARACTER_SYSTEM_PROMPT = """あなたは高校二年生の女子生徒、美咲です。 | |
ある日、学校の帰り道、突然の雨に見舞われ、ひとつの古びた洋館に避難することになりました。 | |
その洋館には、あなたのクラスメイトである男子生徒、彰(ユーザー)がいました。 | |
彼は、あなたに「ここには、幽霊が出ると言われている」と教えてくれました。 | |
あなたは、彰と一緒に、その洋館の中を探索することになります。 | |
ユーザーの言葉に対応して、美咲のセリフだけを出力してください。""" | |
DEFAULT_PROMPTER_SYSTEM_PROMPT = """あなたは有能な脚本家です。 | |
女子高生・美咲と男子高生・彰(ユーザー)が、古びた洋館で出会い、幽霊の噂を聞いて探索するストーリーを考えてください。 | |
会話の履歴から、適切な次のセリフを考えてください。""" | |
DEFAULT_PROMPTER_USER_PROMPT = """以下が直近の会話です。 | |
--- | |
$conversation | |
--- | |
行動「$policy」を表す適切なユーザーの次のセリフを考えてください。セリフだけを出力してください。""" | |
DEFAULT_POLICIES = [ | |
"美咲と会話する", | |
"消極的に行動する", | |
"積極的に行動する", | |
"あたりを調べる", | |
"核心に迫る", | |
"もしかしてあれは……?", | |
"逃げ出す", | |
"なんかいい感じに振る舞う", | |
] | |
llama_state_chat = None | |
llama_state_prompter = None | |
debug_mode = False | |
def log(log_file, message): | |
if log_file: | |
with open(log_file, "a", encoding="utf-8") as f: | |
f.write(message + "\n") | |
def debug(log_msg): | |
# TODO 標準のloggingモジュールを使う | |
if debug_mode: | |
print(log_msg) | |
# the latest llama.cpp seems to have "command-r" handler, but we keep this until llama-cpp-python is updated | |
# we can also use the chat template from GGUF | |
@register_chat_completion_handler("command-r") | |
def command_r_chat_handler( | |
llama: Llama, | |
messages: List[llama_types.ChatCompletionRequestMessage], | |
functions: Optional[List[llama_types.ChatCompletionFunction]] = None, | |
function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None, | |
tools: Optional[List[llama_types.ChatCompletionTool]] = None, | |
tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None, | |
temperature: float = 0.2, | |
top_p: float = 0.95, | |
top_k: int = 40, | |
min_p: float = 0.05, | |
typical_p: float = 1.0, | |
stream: bool = False, | |
stop: Optional[Union[str, List[str]]] = [], | |
response_format: Optional[llama_types.ChatCompletionRequestResponseFormat] = None, | |
max_tokens: Optional[int] = None, | |
presence_penalty: float = 0.0, | |
frequency_penalty: float = 0.0, | |
repeat_penalty: float = 1.1, | |
tfs_z: float = 1.0, | |
mirostat_mode: int = 0, | |
mirostat_tau: float = 5.0, | |
mirostat_eta: float = 0.1, | |
model: Optional[str] = None, | |
logits_processor: Optional[LogitsProcessorList] = None, | |
grammar: Optional[LlamaGrammar] = None, | |
seed: Optional[int] = None, | |
**kwargs, # type: ignore | |
) -> Union[llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk]]: | |
# bos_token = "<BOS_TOKEN>" | |
start_turn_token = "<|START_OF_TURN_TOKEN|>" | |
end_turn_token = "<|END_OF_TURN_TOKEN|>" | |
user_token = "<|USER_TOKEN|>" | |
chatbot_token = "<|CHATBOT_TOKEN|>" | |
system_token = "<|SYSTEM_TOKEN|>" | |
prompt = "" # bos_token # suppress warning | |
if len(messages) > 0 and messages[0]["role"] == "system": | |
prompt += start_turn_token + system_token + messages[0]["content"] + end_turn_token | |
messages = messages[1:] | |
for message in messages: | |
if message["role"] == "user": | |
prompt += start_turn_token + user_token + message["content"] + end_turn_token | |
elif message["role"] == "assistant": | |
prompt += start_turn_token + chatbot_token + message["content"] + end_turn_token | |
prompt += start_turn_token + chatbot_token | |
# prompt += start_turn_token + chatbot_token + "<think>" # temp: add <think> to test thinking mode | |
# if debug_flag: | |
# print(f"Prompt: {prompt}") | |
stop_tokens = [end_turn_token] # , bos_token] | |
return _convert_completion_to_chat( | |
llama.create_completion( | |
prompt=prompt, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
min_p=min_p, | |
typical_p=typical_p, | |
stream=stream, | |
stop=stop_tokens, | |
max_tokens=max_tokens, | |
presence_penalty=presence_penalty, | |
frequency_penalty=frequency_penalty, | |
repeat_penalty=repeat_penalty, | |
tfs_z=tfs_z, | |
mirostat_mode=mirostat_mode, | |
mirostat_tau=mirostat_tau, | |
mirostat_eta=mirostat_eta, | |
model=model, | |
logits_processor=logits_processor, | |
grammar=grammar, | |
seed=seed, | |
), | |
stream=stream, | |
) | |
def generate(llama, handler, generation_params, messages, max_tokens): | |
debug(f"messages: {messages}") | |
seed = random.randint(0, 2**31 - 1) | |
response = handler(llama=llama, messages=messages, stream=False, max_tokens=max_tokens, seed=seed, **generation_params) | |
debug(f"response: {response}") | |
content = response["choices"][0]["message"]["content"] | |
return content | |
def generate_chat( | |
llama: Llama, | |
handler, | |
tts, | |
tts_char_name, | |
tts_call_letters, | |
generation_params, | |
character_system_prompt, | |
conversation, | |
max_tokens, | |
): | |
# モデルの状態を復元 | |
global llama_state_chat | |
if llama_state_chat is not None: | |
llama.load_state(llama_state_chat) | |
if tts_call_letters: | |
tts_call_letters += "\n" | |
# モデルが想定するmessagesの形式に変換 | |
messages = [] | |
if character_system_prompt: | |
messages.append({"role": "system", "content": character_system_prompt}) | |
for role, message in conversation: | |
messages.append({"role": role, "content": message}) | |
# メッセージのトークン数を制限 | |
n_ctx = llama.n_ctx() | |
total_tokens = 0 | |
token_counts = [] | |
for message in messages: | |
message_bytes = message["content"].encode("utf-8") | |
tokens = len(llama.tokenize(message_bytes, add_bos=False)) + 3 # add start/end | |
total_tokens += tokens | |
token_counts.append(tokens) | |
if total_tokens > n_ctx - max_tokens: | |
print(f"Total tokens: {total_tokens}, n_ctx: {n_ctx}, max_tokens: {max_tokens}") | |
index = 1 if character_system_prompt else 0 # システムプロンプトがある場合は除外 | |
tokens_removed = 0 | |
while total_tokens > n_ctx - max_tokens: | |
if len(token_counts) == 0: | |
break | |
total_tokens -= token_counts[index] | |
tokens_removed += token_counts[index] | |
token_counts.pop(index) | |
messages.pop(index) | |
# tokens and messages are removed, so we don't need to update index | |
print(f"Removed {tokens_removed} tokens") | |
# # モデルによる応答生成 | |
# response = generate(llama, handler, generation_params, messages) | |
# # モデルが"「」"で囲まれたセリフを生成することがあるので、"」"以降を削除し、さらに"「"を削除 | |
# response = response.split("」", 1)[0].replace("「", "") | |
# streamで生成し、都度ttsを呼び出す | |
chat_completion_chunks = handler(llama=llama, messages=messages, stream=True, max_tokens=max_tokens, **generation_params) | |
response = "" | |
response_sent_to_tts = "" | |
i = 0 | |
for chunk in chat_completion_chunks: | |
debug(chunk) | |
i += 1 | |
if not debug_mode and i % 20 == 0: | |
print(".", end="", flush=True) | |
if "choices" in chunk and len(chunk["choices"]) > 0: | |
if "delta" in chunk["choices"][0]: | |
if "content" in chunk["choices"][0]["delta"]: | |
content = chunk["choices"][0]["delta"]["content"] | |
response += content | |
# TTSを呼び出す区切り文字が含まれている場合、TTSを呼び出す | |
if tts and any([letter in content for letter in tts_call_letters]): | |
text_to_tts = response[len(response_sent_to_tts) :] # 前回からの差分 | |
print(f"\nTTS: {text_to_tts}") | |
response_sent_to_tts = response | |
tts(tts_char_name, text_to_tts) | |
# TTSへ未送信のテキストがある場合、最後にTTSを呼び出す | |
if tts and response != response_sent_to_tts: | |
text_to_tts = response[len(response_sent_to_tts) :] | |
response_sent_to_tts = response | |
tts(tts_char_name, text_to_tts) | |
# モデルの状態を保存 | |
llama_state_chat = llama.save_state() | |
# モデルが"「」"で囲まれたセリフを生成することがあるので、"」"以降を削除し、さらに"「"を削除 | |
response = response.strip() | |
if response.startswith("「") and response.endswith("」"): | |
response = response[1:] | |
response = response[:-1] | |
return response | |
def generate_user_response( | |
llama, handler, generation_params, prompter_system_prompt, prompter_user_prompt, conversation_str, policy, max_tokens | |
): | |
# モデルの状態を復元 | |
global llama_state_prompter | |
if llama_state_prompter is not None: | |
llama.load_state(llama_state_prompter) | |
# プロンプトの作成 | |
prompt = prompter_user_prompt.replace("$conversation", conversation_str).replace("$policy", policy) | |
# モデルが想定するmessagesの形式に変換 | |
messages = [] | |
messages.append({"role": "system", "content": prompter_system_prompt}) | |
messages.append({"role": "user", "content": prompt}) | |
# モデルによる応答生成 | |
response = generate(llama, handler, generation_params, messages, max_tokens) # , first_turn=first_turn) | |
# モデルが"「」"で囲まれたセリフを生成することがあるので、"」"以降を削除し、さらに"「"を削除 | |
response = response.split("」", 1)[0].replace("「", "") | |
# 同様に `"` も削除 | |
response = response.replace('"', "") | |
# モデルの状態を保存 | |
llama_state_prompter = llama.save_state() | |
return response | |
def get_key(): | |
"""エンターキーを押さずに入力を取得する""" | |
try: | |
return msvcrt.getch().decode() | |
except UnicodeDecodeError: | |
return "" | |
def display_policies(policies): | |
"""ポリシーの選択肢を表示する""" | |
print("\n=== 次の行動を選んでください ===") | |
for i, policy in enumerate(policies, 1): | |
c = str(i) if i < 10 else chr(ord("a") + i - 10) | |
print(f"{c}. {policy}, ", end="") | |
print("\nx. 自由入力, y. 直前の選択をやり直す, z. 終了") | |
def format_conversation(conversation, user_name, character_name): | |
"""会話履歴をフォーマットする""" | |
if not conversation: | |
return "(会話はまだありません)" | |
if len(conversation) > MESSAGE_HISTORY_SIZE: | |
# 残りがN件以上になるように、N/2件単位で先頭から削除:モデルのキャッシュを有効に使うため | |
while len(conversation) - MESSAGE_HISTORY_SIZE // 2 >= MESSAGE_HISTORY_SIZE: | |
conversation = conversation[MESSAGE_HISTORY_SIZE // 2 :] | |
result = "" | |
for role, message in conversation: | |
name = user_name if role == "user" else character_name | |
result += f"{name}: {message}\n" | |
return result | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--config", type=str, default=None, help="toml file path") | |
parser.add_argument("-m", "--model", type=str, default=None, help="Model file path") | |
parser.add_argument("-ngl", "--n_gpu_layers", type=int, default=0, help="Number of GPU layers") | |
parser.add_argument("-c", "--n_ctx", type=int, default=4096, help="Context length") | |
parser.add_argument( | |
"-ch", | |
"--chat_handler", | |
type=str, | |
default=None, | |
help="Chat handler, e.g. command-r, mistral-instruct, alpaca, llama-3 etc. default: None (use template from GGUF)", | |
) | |
parser.add_argument("--max_tokens", type=int, default=512, help="Max tokens for each prompt") | |
parser.add_argument( | |
"-ts", "--tensor_split", type=str, default=None, help="Tensor split, float values separated by comma for each gpu" | |
) | |
parser.add_argument("--disable_mmap", action="store_true", help="Disable mmap") | |
parser.add_argument("--q8_kv_cache", action="store_true", help="Use quantized kv cache (Q8)") | |
parser.add_argument("--q4_kv_cache", action="store_true", help="Use quantized kv cache (Q4)") | |
parser.add_argument("--flash_attn", action="store_true", help="Use flash attention") | |
parser.add_argument("--debug", action="store_true", help="Debug mode") | |
parser.add_argument("--tts_module", type=str, default=None, help="TTS module. must have `tts(tts_char_name, message)` function") | |
parser.add_argument( | |
"--tts_call_letters", type=str, default=None, help="letters to call TTS in middle of the message, like `。?」♪`" | |
) | |
parser.add_argument("--tts_char_name", type=str, default=None, help="Character name for TTS (TTS model selection)") | |
parser.add_argument("--output_dir", type=str, default=None, help="Output directory of logs") | |
args = parser.parse_args() | |
if args.debug: | |
global debug_mode | |
debug_mode = True | |
# load .toml file | |
if args.config is not None: | |
config = toml.load(args.config) | |
character_name = config.get("character_name", DEFAULT_CHARACTER_NAME) | |
user_name = config.get("user_name", DEFAULT_USER_NAME) | |
character_system_prompt = config.get("character_system_prompt", DEFAULT_CHARACTER_SYSTEM_PROMPT) | |
prompter_system_prompt = config.get("prompter_system_prompt", DEFAULT_PROMPTER_SYSTEM_PROMPT) | |
prompter_user_prompt = config.get("prompter_user_prompt", DEFAULT_PROMPTER_USER_PROMPT) | |
policies = config.get("policies", DEFAULT_POLICIES) | |
character_generation_params = config.get("character_generation_params", {}) | |
prompter_generation_params = config.get("prompter_generation_params", character_generation_params) | |
else: | |
character_name = DEFAULT_CHARACTER_NAME | |
user_name = DEFAULT_USER_NAME | |
character_system_prompt = DEFAULT_CHARACTER_SYSTEM_PROMPT | |
prompter_system_prompt = DEFAULT_PROMPTER_SYSTEM_PROMPT | |
prompter_user_prompt = DEFAULT_PROMPTER_USER_PROMPT | |
policies = DEFAULT_POLICIES | |
character_generation_params = {} | |
prompter_generation_params = character_generation_params | |
print(f"character generation params: {character_generation_params}") | |
print(f"prompter generation params: {prompter_generation_params}") | |
# initialize Llama | |
print(f"Initializing Llama. Model ID: {args.model}, N_GPU_LAYERS: {args.n_gpu_layers}, N_CTX: {args.n_ctx}") | |
tensor_split = None if args.tensor_split is None else [float(x) for x in args.tensor_split.split(",")] | |
llama = Llama( | |
model_path=args.model, | |
n_gpu_layers=args.n_gpu_layers, | |
tensor_split=tensor_split, | |
n_ctx=args.n_ctx, | |
use_mmap=not args.disable_mmap, | |
type_k=llama_cpp.GGML_TYPE_Q8_0 if args.q8_kv_cache else (llama_cpp.GGML_TYPE_Q4_0 if args.q4_kv_cache else None), | |
type_v=llama_cpp.GGML_TYPE_Q8_0 if args.q8_kv_cache else (llama_cpp.GGML_TYPE_Q4_0 if args.q4_kv_cache else None), | |
flash_attn=args.flash_attn, | |
) | |
if args.chat_handler is not None: | |
handler = llama_chat_format.get_chat_completion_handler(args.chat_handler) | |
else: | |
handler = llama._chat_handlers[llama.chat_format] | |
# TTS | |
if args.tts_module: | |
print(f"Loading TTS module: {args.tts_module}") | |
tts_module = importlib.import_module(args.tts_module) | |
tts = tts_module.tts | |
tts_call_letters = args.tts_call_letters | |
tts_char_name = args.tts_char_name | |
else: | |
tts = tts_call_letters = tts_char_name = None | |
# ログファイル | |
log_file = None | |
if args.output_dir: | |
os.makedirs(args.output_dir, exist_ok=True) | |
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") | |
log_file = f"{args.output_dir}/chat_log_{timestamp}.txt" | |
print(f"Log file: {log_file}") | |
# 会話履歴 | |
conversation = [] | |
print("\n=== ゲームを開始します ===") | |
print("\n") | |
while True: | |
# 1. ポリシーの選択 | |
display_policies(policies) | |
choice = get_key() | |
print(f"選択: {choice}") | |
if choice == "z": | |
print("\nゲームを終了します。") | |
break | |
# ユーザーの発言を取得 | |
if choice == "y": | |
if len(conversation) >= 2: | |
conversation = conversation[:-2] | |
print( | |
f"\n直前の選択をやり直しました。({character_name}の最後の台詞: {conversation[-1][1] if len(conversation)>0 else ''})" | |
) | |
continue | |
else: | |
print("\n直前の選択がありません。") | |
continue | |
if choice == "x": | |
print("\n自由入力モード(入力後、Enterを押してください):") | |
user_message = input("> ") | |
elif choice.isdigit() or (choice.isalpha() and ord("a") <= ord(choice) <= ord("z")): | |
if choice.isalpha(): | |
choice = ord(choice) - ord("a") + 9 | |
else: | |
choice = int(choice) - 1 | |
policy = policies[choice] | |
# 2. プロンプターによるセリフ生成 | |
conversation_str = format_conversation(conversation, user_name, character_name) | |
user_message = generate_user_response( | |
llama, | |
handler, | |
prompter_generation_params, | |
prompter_system_prompt, | |
prompter_user_prompt, | |
conversation_str, | |
policy, | |
args.max_tokens, | |
) | |
else: | |
print("\n無効な選択です。") | |
continue | |
# 3. ユーザーのセリフを表示 | |
print(f"\nあなた: {user_message}") | |
log(log_file, f"{user_name}: {user_message}") | |
conversation.append(("user", user_message)) | |
# 4. キャラクターの応答を生成 | |
character_response = generate_chat( | |
llama, | |
handler, | |
tts, | |
tts_char_name, | |
tts_call_letters, | |
character_generation_params, | |
character_system_prompt, | |
conversation, | |
args.max_tokens, | |
) | |
print(f"{character_name}: {character_response}") | |
log(log_file, f"{character_name}: {character_response}") | |
conversation.append(("assistant", character_response)) | |
print("\n" + "=" * 50) | |
if __name__ == "__main__": | |
main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
character_name = "美咲" | |
user_name = "彰" | |
character_system_prompt = """あなたは高校二年生の女子生徒、美咲です。 | |
ある日、学校の帰り道、突然の雨に見舞われ、ひとつの古びた洋館に避難することになりました。 | |
その洋館には、あなたのクラスメイトである男子生徒、彰(ユーザー)がいました。 | |
彼は、あなたに「ここには、幽霊が出ると言われている」と教えてくれました。 | |
あなたは、彰と一緒に、その洋館の中を探索することになります。 | |
ユーザーの言葉に対応して、美咲のセリフだけを出力してください。""" | |
prompter_system_prompt = """あなたは有能な脚本家です。 | |
女子高生・美咲と男子高生・彰(ユーザー)が、古びた洋館で出会い、幽霊の噂を聞いて探索するストーリーを考えてください。 | |
会話の履歴から、適切な次のセリフを考えてください。""" | |
prompter_user_prompt = """以下が直近の会話です。 | |
--- | |
$conversation | |
--- | |
行動「$policy」を表す適切なユーザーの次のセリフを考えてください。セリフだけを出力してください。""" | |
policies = [ | |
"美咲と会話する", | |
"消極的に行動する", | |
"積極的に行動する", | |
"あたりを調べる", | |
"核心に迫る", | |
"もしかしてあれは……?", | |
"逃げ出す", | |
"なんかいい感じに振る舞う", | |
] | |
character_generation_params = {"temperature"=0.5} | |
prompter_generation_params = {"temperature"=0.5} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment