Created
July 10, 2025 14:34
-
-
Save googlefan256/8c747731d907f407e783614335d78950 to your computer and use it in GitHub Desktop.
Mem-Agent
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env uv run --with openai,transformers,tqdm--script | |
from openai import AsyncOpenAI | |
from transformers import AutoTokenizer | |
from tqdm import tqdm | |
import argparse | |
parser = argparse.ArgumentParser(description="RL Memory Agent") | |
parser.add_argument( | |
"--api_base", | |
type=str, | |
default="http://localhost:30000/v1", | |
help="Base URL for the API", | |
) | |
parser.add_argument("--model", type=str, default="model", help="Model name to use") | |
parser.add_argument( | |
"--api_key", | |
type=str, | |
default="your_api_key_here", | |
help="API key for authentication", | |
) | |
parser.add_argument( | |
"--context_file", type=str, default="./context.txt", help="Path to the context file" | |
) | |
parser.add_argument( | |
"--tokenizer_name", | |
type=str, | |
default="BytedTsinghua-SIA/RL-MemoryAgent-14B", | |
help="Tokenizer name to use", | |
) | |
args = parser.parse_args() | |
API_BASE_URL = args.api_base | |
API_MODEL_NAME = args.model | |
API_KEY = args.api_key | |
CONTEXT_FILE = args.context_file | |
TOKENIZER_NAME = args.tokenizer_name | |
ai = AsyncOpenAI( | |
base_url=API_BASE_URL, | |
api_key=API_KEY, | |
) | |
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME) | |
TEMPLATE_UPDATE_MEMORY = """You are presented with a problem, a section of an article that may contain the answer to the problem, and a previous memory. Please read the provided section carefully and update the memory with the new information that helps to answer the problem. Be sure to retain all relevant details from the previous memory while adding any new, useful information. | |
<problem> | |
{prompt} | |
</problem> | |
<memory> | |
{memory} | |
</memory> | |
<section> | |
{chunk} | |
</section> | |
Updated memory: | |
""" | |
TEMPLATE_ANSWER = """You are presented with a problem and a previous memory. Please answer the problem based on the previous memory and put the answer in \\boxed{{}}. | |
<problem> | |
{prompt} | |
</problem> | |
<memory> | |
{memory} | |
</memory> | |
Your answer: | |
""" | |
async def request_llm(msg: str) -> str: | |
resp = await ai.chat.completions.create( | |
model=API_MODEL_NAME, | |
messages=[{"role": "user", "content": msg}], | |
temperature=0.7, | |
top_p=0.95, | |
max_tokens=1024, | |
) | |
return resp.choices[0].message.content | |
def clip_long_string(string: str, max_length=2000) -> str: | |
"""Clip long string to a maximum length.""" | |
# assert max_length > 50, "max_length must be greater than 50" | |
if not len(string) > max_length: | |
return string | |
target_len = max_length - len("\n\n...(truncated)\n\n") | |
return ( | |
string[: target_len // 2] | |
+ "\n\n...(truncated)\n\n" | |
+ string[-target_len // 2 :] | |
) | |
def init_memory() -> str: | |
return "No previous memory" | |
async def query( | |
prompt: str, | |
context: str, | |
memory=init_memory(), | |
recurrent_max_content_len=120000, | |
recurrent_chunk_size=5000, | |
) -> str: | |
input_ids = tokenizer.encode(context) | |
if len(input_ids) > recurrent_max_content_len: | |
input_ids = ( | |
input_ids[: recurrent_max_content_len // 2] | |
+ input_ids[-recurrent_max_content_len // 2 :] | |
) | |
for i in tqdm( | |
range(0, len(input_ids), recurrent_chunk_size), desc="Processing chunks" | |
): | |
chunk = input_ids[i : i + recurrent_chunk_size] | |
memory = await request_llm( | |
TEMPLATE_UPDATE_MEMORY.format( | |
prompt=prompt, chunk=tokenizer.decode(chunk), memory=memory | |
) | |
) | |
return await request_llm( | |
TEMPLATE_ANSWER.format(prompt=prompt, memory=memory) | |
), memory | |
async def main(): | |
memory = init_memory() | |
prompt = input("Question: ") | |
answer, memory = await query(prompt, open(CONTEXT_FILE).read(), memory=memory) | |
print(answer, memory) | |
if __name__ == "__main__": | |
import asyncio | |
asyncio.run(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Requirements
uv and a good GPU.
Usage
context.txt
with your context, e.g.: