|
from __future__ import annotations |
|
import aiofiles |
|
import json |
|
import os |
|
import argparse |
|
import yaml |
|
import sys |
|
import aiohttp |
|
import asyncio |
|
|
|
url_base = "http://127.0.0.1:1234" # llama.cpp serverのデフォルトURL |
|
min_context_count = 3 # コンテキストに保持する最小ターン数のデフォルト値 |
|
|
|
|
|
class LLMCompleter: |
|
"""LLMを用いたテキスト補完を行うクラス。""" |
|
_input_label_config_name = "" |
|
_output_label_config_name = "" |
|
|
|
def __init__(self, system_prompt="", max_length=200, grammar: str = None, url_base: str = url_base, debug=False): |
|
self.system_prompt = system_prompt |
|
self.url_base = url_base |
|
self.grammar = grammar |
|
self.generated_text = "" |
|
self.max_length = max_length |
|
self._stop = False |
|
self._debug = debug |
|
|
|
def __call__(self, text: str, verbose=False, grammar=""): |
|
return self._generator(text=text, verbose=verbose, grammar=grammar) |
|
|
|
@classmethod |
|
def _load_config_file(cls, prompt_path: str, dictionary_path: str = None): |
|
# prompt_pathが".yaml"で終わっていない場合は付与する |
|
prompt_path = prompt_path if prompt_path.endswith( |
|
".yaml") else prompt_path + ".yaml" |
|
|
|
# プロンプトファイルが存在しない場合は、スクリプトと同フォルダを見に行く。 |
|
if not os.path.exists(prompt_path): |
|
prompt_path = os.path.join(os.path.dirname(os.path.abspath( |
|
__file__)), prompt_path) if not os.path.isabs(prompt_path) else prompt_path |
|
|
|
# プロンプトファイルの読み込み |
|
if os.path.exists(prompt_path): |
|
with open(prompt_path, 'r', encoding='utf-8') as prompt_file: |
|
prompt_data = yaml.safe_load(prompt_file) |
|
system_prompt = prompt_data.get('system_prompt', "") |
|
|
|
if cls._input_label_config_name: |
|
input_label = prompt_data.get( |
|
cls._input_label_config_name, "入力") |
|
if cls._output_label_config_name: |
|
output_label = prompt_data.get( |
|
cls._output_label_config_name, "出力") |
|
|
|
dictionary = prompt_data.get('dictionary', {}) |
|
grammar = prompt_data.get('grammar', None) |
|
examples = prompt_data.get('examples', []) |
|
else: |
|
system_prompt = "" |
|
input_label = "" |
|
output_label = "" |
|
dictionary = {} |
|
grammar = None |
|
examples = [] |
|
|
|
if isinstance(dictionary, dict): |
|
dictionary = {str(key): str(value) |
|
for key, value in dictionary.items()} |
|
|
|
# 辞書ファイルの読み込みとマージ |
|
if dictionary_path: |
|
with open(dictionary_path, 'r', encoding='utf-8') as dict_file: |
|
for line in dict_file: |
|
if '\t' in line: |
|
source_term, target_term = line.strip().split('\t') |
|
dictionary[source_term] = target_term |
|
|
|
return system_prompt, input_label, output_label, dictionary, grammar, examples |
|
|
|
@classmethod |
|
def from_config(cls, prompt_path: str, dictionary_path: str = None, min_context_count=min_context_count, url_base: str = url_base) -> LLMCompleter: |
|
"""設定ファイルからロードしインスタンス化""" |
|
system_prompt, _, _, _, _, grammar = cls._load_config_file( |
|
prompt_path, dictionary_path) |
|
|
|
return cls( |
|
system_prompt=system_prompt, |
|
grammar=grammar, min_context_count=min_context_count, url_base=url_base |
|
) |
|
|
|
def _check_wait_word(self, new_text: str, wait_words: list[str]): |
|
for word in wait_words: |
|
if word in new_text: |
|
return "break" |
|
|
|
for word in wait_words: |
|
min_len = min(len(new_text), len(word)) |
|
|
|
for i in range(1, min_len + 1): |
|
if new_text[-i:] == word[:i]: |
|
return "continue" |
|
|
|
return "pass" |
|
|
|
async def _get_response(self, prompt, quiet=False, wait_words=[], stop_words=[], logit_bias=[], grammar="", max_length=0): |
|
url = f"{self.url_base}/completion" |
|
data = { |
|
"prompt": prompt, |
|
"stream": True, |
|
"cache_prompt": True, |
|
"stop": stop_words, |
|
"n_predict": max_length if max_length > 0 else self.max_length |
|
} |
|
|
|
grammar = grammar if grammar else self.grammar |
|
|
|
if grammar: |
|
if not grammar.replace(" ", "").replace("\n", "").startswith('root::='): |
|
grammar = f'root::={grammar}' |
|
data["grammar"] = grammar |
|
|
|
if logit_bias: |
|
data["logit_bias"] = logit_bias |
|
|
|
if self._debug: |
|
print(data) |
|
|
|
async with aiohttp.ClientSession() as session: |
|
async with session.post(url, json=data) as response: |
|
if response.status == 200: |
|
content = '' |
|
first_line_printed = False |
|
async for line in response.content: |
|
line_str = line.decode() |
|
if line_str.startswith('data:'): |
|
json_data: dict[str, str] = json.loads( |
|
line_str[len('data:'):]) |
|
if 'content' in json_data: |
|
content += json_data['content'] |
|
if not first_line_printed: |
|
content = content.lstrip() |
|
first_line_printed = True |
|
|
|
loop_condition = self._check_wait_word( |
|
content, wait_words) |
|
|
|
if loop_condition == "break": |
|
break |
|
elif loop_condition == "continue": |
|
continue |
|
|
|
yield content |
|
|
|
if self._stop: |
|
self._stop = False |
|
break |
|
|
|
if not quiet or self._debug: |
|
print(content, end='', flush=True) |
|
|
|
content = "" |
|
else: |
|
yield f"Error: {response.status}" |
|
|
|
async def complete(self, text="", verbose=False, grammar="", stop_words: list[str] = None, logit_bias: list = None, max_length=0): |
|
"""LLMによるテキスト補完を実行する。""" |
|
generated = "" |
|
async for response in self._get_response(self.system_prompt+self.generated_text+text, quiet=not verbose, grammar=grammar, stop_words=stop_words, logit_bias=logit_bias, max_length=max_length): |
|
generated += response |
|
self.generated_text += generated |
|
return generated |
|
|
|
async def _generator(self, text="", verbose=False, grammar="", stop_words: list[str] = None, logit_bias: list = None, max_length=0): |
|
generated = "" |
|
async for response in self._get_response(self.system_prompt+self.generated_text+text, quiet=not verbose, grammar=grammar, stop_words=stop_words, logit_bias=logit_bias, max_length=max_length): |
|
generated += response |
|
yield response |
|
self.generated_text += generated |
|
|
|
async def reset(self): |
|
self.generated_text = "" |
|
|
|
def stop(self): |
|
"""テキスト生成を中断する。""" |
|
self._stop = True |
|
|
|
|
|
class LLMMultiTurnCompleter(LLMCompleter): |
|
"""LLMを用いたマルチターン応答を行うクラス。""" |
|
_input_label_config_name = "input" |
|
_output_label_config_name = "output" |
|
|
|
def __init__(self, system_prompt: str = "", input_label: str = "入力", output_label: str = "出力", dictionary: dict = None, grammar: str = None, examples: list[str] = [], max_length=500, min_context_count=min_context_count, url_base: str = url_base, debug=False): |
|
super().__init__(system_prompt, |
|
grammar=grammar, max_length=max_length, url_base=url_base, debug=debug) |
|
self.input_label = input_label |
|
self.output_label = output_label |
|
self.context = [] |
|
self.dictionary = dictionary if dictionary else {} |
|
self.min_context_count = min_context_count |
|
self.wait_words = [f"\n{self.input_label}:", f"\n{self.output_label}:"] |
|
self.examples = examples |
|
|
|
def __call__(self, text: str, initial_text="", multiline=False, verbose=False, grammar="", stop_words: list[str] = None, logit_bias: list = None): |
|
if not multiline: |
|
stop_words = ["\n"] |
|
return self._generator(text=text, initial_text=initial_text, multiline=multiline, verbose=verbose, grammar=grammar, stop_words=stop_words, logit_bias=logit_bias) |
|
|
|
@classmethod |
|
def from_config(cls, prompt_path: str, dictionary_path: str = None, min_context_count=min_context_count, url_base: str = url_base, debug=False) -> LLMMultiTurnCompleter: |
|
"""設定ファイルからロードしインスタンス化""" |
|
system_prompt, input_label, output_label, dictionary, grammar, examples = cls._load_config_file( |
|
prompt_path, dictionary_path) |
|
|
|
return cls( |
|
system_prompt, input_label=input_label, output_label=output_label, dictionary=dictionary, grammar=grammar, examples=examples, min_context_count=min_context_count, url_base=url_base, debug=debug |
|
) |
|
|
|
def reset(self): |
|
"""コンテキストを消去する。""" |
|
self.context.clear() |
|
|
|
def delete_last_context(self, turn=1): |
|
"""直近ターンのコンテキストを消去する。""" |
|
self.context = self.context[:-turn*2] |
|
|
|
def append_text(self, text: str): |
|
"""前回ターンの末尾にテキストを追記する。""" |
|
if self.context: |
|
self.context[-1] += text |
|
|
|
def _replace_terms(self, text): |
|
if not self.dictionary: |
|
return text |
|
for source_term, target_term in self.dictionary.items(): |
|
text = text.replace(source_term, target_term) |
|
return text |
|
|
|
def _update_context(self, user_input, assistant_response): |
|
self.context.append(f"{self.input_label}: {user_input}") |
|
self.context.append( |
|
f"{self.output_label}: {assistant_response}") |
|
if self.min_context_count == 0: |
|
self.reset() |
|
elif len(self.context) >= self.min_context_count*4: |
|
self.context = self.context[-self.min_context_count*2:] |
|
|
|
def _create_prompt(self, user_input: str, initial_text=""): |
|
context_text = self.system_prompt |
|
|
|
if self.examples: |
|
for i in range(0, len(self.examples), 2): |
|
context_text += f"\n{self.input_label}: {self.examples[i]}\n{self.output_label}: {self.examples[i+1]}" |
|
|
|
if self.context: |
|
context_text += "\n" + '\n'.join(self.context) |
|
|
|
prompt = f"{context_text}\n{self.input_label}: {user_input}\n{self.output_label}:" |
|
if initial_text: |
|
prompt = f"{prompt} {initial_text}" |
|
return prompt |
|
|
|
async def _process_line(self, user_input, initial_text, verbose, grammar, stop_words, logit_bias, max_length): |
|
if not user_input: |
|
return |
|
|
|
replaced_input = self._replace_terms(user_input) |
|
prompt = self._create_prompt( |
|
replaced_input, initial_text=initial_text) |
|
if (verbose or self._debug) and initial_text: |
|
print(initial_text, end="", flush=True) |
|
|
|
response_text = "" |
|
async for response in self._get_response(prompt, quiet=not verbose, wait_words=self.wait_words, stop_words=stop_words, logit_bias=logit_bias, grammar=grammar, max_length=max_length): |
|
response_text += response |
|
yield response |
|
self._update_context( |
|
user_input, initial_text + response_text) |
|
|
|
async def complete(self, text: str, initial_text="", multiline=False, include_input=False, verbose=False, grammar="", stop_words: list[str] = None, logit_bias: list = None, max_length=0) -> (str | tuple[str]): |
|
"""LLMによるテキスト補完を実行し、入力文に対する出力文を得る。""" |
|
|
|
lines = text.split("\n") if not multiline else [text] |
|
output = "" |
|
output_list = [] |
|
|
|
if not multiline: |
|
stop_words = ["\n"] |
|
|
|
for index, user_input in enumerate(lines): |
|
assistant_response = "" |
|
|
|
if (verbose or self._debug) and include_input: |
|
print( |
|
f"{self.input_label}: {user_input}\n{self.output_label}: ", end='', flush=True) |
|
|
|
async for response in self._process_line(user_input, initial_text, verbose, grammar, stop_words, logit_bias, max_length): |
|
assistant_response += response |
|
|
|
output += assistant_response |
|
output_list.append((user_input, assistant_response)) |
|
|
|
if index != len(lines) - 1: |
|
output += "\n" |
|
|
|
if verbose or self._debug: |
|
print() |
|
|
|
if (verbose or self._debug) and include_input: |
|
print() |
|
|
|
if include_input: |
|
return output_list |
|
|
|
return output |
|
|
|
async def _generator(self, text: str, initial_text="", multiline=False, verbose=False, grammar="", stop_words: list[str] = None, logit_bias: list = None, max_length=0): |
|
lines = text.split("\n") if not multiline else [text] |
|
if not multiline: |
|
stop_words = ["\n"] |
|
|
|
for index, user_input in enumerate(lines): |
|
async for response in self._process_line(user_input, initial_text, verbose, grammar, stop_words, logit_bias, max_length): |
|
yield response |
|
|
|
if index != len(lines) - 1: |
|
yield "\n" |
|
|
|
if verbose or self._debug: |
|
print() |
|
|
|
async def complete_all(self, input_path="", output_path="", include_input=False, quiet=False): |
|
"""ファイルまたは標準入力から入力したテキストに対し、テキスト補完を実行する。""" |
|
if not input_path: |
|
lines = sys.stdin |
|
else: |
|
async with aiofiles.open(input_path, 'r', encoding='utf-8') as input_file: |
|
lines = await input_file.readlines() |
|
|
|
first_line_output = False |
|
|
|
for line in lines: |
|
user_input = line.strip() |
|
if user_input: |
|
replaced_input = self._replace_terms(user_input) |
|
prompt = self._create_prompt(replaced_input) |
|
|
|
if include_input and (not quiet or self._debug): |
|
print(f"{self.input_label}: {user_input}\n") |
|
print(f"{self.output_label}: ", end='', flush=True) |
|
|
|
assistant_response = "" |
|
async for response in self._get_response(prompt, quiet=quiet, wait_words=self.wait_words, stop_words=["\n"]): |
|
assistant_response += response |
|
|
|
self._update_context( |
|
user_input, assistant_response) |
|
else: |
|
assistant_response = "" |
|
|
|
if output_path: |
|
mode = "a" if first_line_output else "w" |
|
first_line_output = True |
|
|
|
async with aiofiles.open(output_path, mode, encoding='utf-8') as output_file: |
|
if include_input and user_input and assistant_response: |
|
await output_file.write(f"{self.input_label}: {user_input}\n") |
|
await output_file.write("\n") |
|
await output_file.write(f"{self.output_label}: {assistant_response}\n") |
|
await output_file.write("\n") |
|
else: |
|
await output_file.write(f"{assistant_response}\n") |
|
|
|
if not quiet or self._debug: |
|
print() |
|
|
|
|
|
class LLMTranslator(LLMMultiTurnCompleter): |
|
"""LLMを用いた翻訳を行うクラス。""" |
|
_input_label_config_name = "from_lang" |
|
_output_label_config_name = "to_lang" |
|
|
|
def __init__(self, system_prompt: str = "", from_lang: str = "", to_lang: str = "", dictionary: dict = None, grammar: str = None, examples: list[str] = None, min_context_count=min_context_count, url_base: str = url_base, debug=False): |
|
if not system_prompt: |
|
system_prompt = """英文を日本語に和訳してください。 |
|
English: Hello! |
|
日本語: おはようございます。 |
|
English: How are you? |
|
日本語: お元気ですか? |
|
""" |
|
from_lang = "English" |
|
to_lang = "日本語" |
|
|
|
if not from_lang: |
|
from_lang = "入力" |
|
if not to_lang: |
|
to_lang = "出力" |
|
|
|
system_prompt = system_prompt.strip() |
|
super().__init__(system_prompt=system_prompt, input_label=from_lang, output_label=to_lang, |
|
dictionary=dictionary, grammar=grammar, examples=examples, max_length=100000, min_context_count=min_context_count, url_base=url_base, debug=debug) |
|
|
|
@classmethod |
|
def from_config(cls, prompt_path: str, dictionary_path: str = None, min_context_count=min_context_count, url_base: str = url_base, debug=False) -> LLMTranslator: |
|
"""設定ファイルからロードしインスタンス化""" |
|
system_prompt, input_label, output_label, dictionary, grammar, examples = cls._load_config_file( |
|
prompt_path, dictionary_path) |
|
|
|
return cls( |
|
system_prompt=system_prompt, from_lang=input_label, to_lang=output_label, dictionary=dictionary, grammar=grammar, examples=examples, min_context_count=min_context_count, url_base=url_base, debug=debug |
|
) |
|
|
|
def __call__(self, text: str, verbose=False): |
|
return self._generator(text=text, multiline=False, verbose=verbose) |
|
|
|
async def translate(self, text: str, include_input=False, verbose=False): |
|
"""入力文を翻訳する。""" |
|
return await self.complete(text=text, multiline=False, include_input=include_input, verbose=verbose) |
|
|
|
async def translate_all(self, input_path="", output_path="", include_input=False, quiet=False): |
|
"""ファイルまたは標準入力から入力したテキストを翻訳する。""" |
|
await self.complete_all(input_path, output_path, include_input, quiet) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser( |
|
description='LLMを使用して、テキストファイルまたは標準入力からの文字列を翻訳して出力します。', |
|
usage='python LLM_Translator.py [-h] [-i 入力ファイルパス] [-o 出力ファイルパス] [-p プロンプトファイルパス] [-d 辞書ファイルパス] [-u llama.cpp serverのURL] [--quiet] [--no-outfile] [--include-input] [--min-context-count 最小保持ターン数]' |
|
) |
|
parser.add_argument('-i', '--input-path', type=str, |
|
help='翻訳元テキストが含まれる入力テキストファイルのパス。指定しない場合は標準入力を使用します。', default=None) |
|
parser.add_argument('-o', '--output-path', type=str, |
|
help='翻訳先の出力テキストファイルのパス。指定しない場合は <input_path>_<prompt_name>.<ext> になります。', default=None) |
|
parser.add_argument('-p', '--prompt-path', type=str, |
|
help='翻訳用のプロンプトファイルのパス。指定しない場合はen_ja.yamlを使用します。拡張子省略可。', default="en_ja.yaml") |
|
parser.add_argument('-d', '-dict', '--dictionary-path', type=str, |
|
help='固有名詞置換用の辞書ファイルのパス。翻訳元<tab>翻訳先<改行>形式のテキストファイルを用意します。', default=None) |
|
parser.add_argument('-u', '--url', type=str, |
|
help='llama.cpp serverのURL', default=url_base) |
|
parser.add_argument('-q', '--quiet', action='store_true', |
|
help='翻訳結果を標準出力に表示しません。') |
|
parser.add_argument('-nof', '--no-outfile', action='store_true', |
|
help='翻訳結果をファイルに書き込みません。') |
|
parser.add_argument('-ii', '--include-input', action='store_true', |
|
help='入力テキストを出力に含めます。') |
|
parser.add_argument('--min-context-count', type=int, |
|
help=f'コンテキストに保持する最小ターン数。(デフォルト: {min_context_count})' |
|
'小さくすると、よりシステムプロンプトでの翻訳例に忠実な翻訳を行うが、文脈把握力が低下し、出力速度も落ちる。' |
|
'大きくすると、文脈把握力が向上し、出力速度は落ちにくいが、システムプロンプトの翻訳例と違った雰囲気になる可能性が高くなる。', default=min_context_count) |
|
|
|
args = parser.parse_args() |
|
|
|
input_path = args.input_path |
|
output_path = args.output_path |
|
dictionary_path = args.dictionary_path |
|
prompt_path = args.prompt_path |
|
url_base = args.url |
|
quiet = args.quiet |
|
no_outfile = args.no_outfile |
|
include_input = args.include_input |
|
min_context_count = args.min_context_count |
|
|
|
# 出力ファイルパスが指定されていないときは、入力ファイルパスから出力ファイルパスを生成しておく |
|
if not output_path and not no_outfile and input_path: |
|
suffix = os.path.splitext(os.path.basename(prompt_path))[0] |
|
base, ext = os.path.splitext(input_path) |
|
output_path = f"{base}_{suffix}{ext}" |
|
|
|
translator = LLMTranslator.from_config( |
|
prompt_path=prompt_path, dictionary_path=dictionary_path, min_context_count=min_context_count, url_base=url_base) |
|
|
|
asyncio.run(translator.translate_all(input_path=input_path, |
|
output_path=output_path, include_input=include_input, quiet=quiet)) |