-
-
Save vgel/8a2497dc45b1ded33287fa7bb6cc1adc to your computer and use it in GitHub Desktop.
import argparse | |
import random | |
import sys | |
from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache | |
import torch | |
parser = argparse.ArgumentParser() | |
parser.add_argument("question", type=str) | |
parser.add_argument( | |
"-m", "--model-name", default="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B" | |
) | |
parser.add_argument("-d", "--device", default="auto") | |
parser.add_argument( | |
"-r", "--replacements", nargs="+", default=["\nWait, but", "\nHmm", "\nSo"] | |
) | |
parser.add_argument("-t", "--min-thinking-tokens", type=int, default=128) | |
parser.add_argument("-p", "--prefill", default="") | |
args = parser.parse_args() | |
tokenizer = AutoTokenizer.from_pretrained(args.model_name) | |
model = AutoModelForCausalLM.from_pretrained( | |
args.model_name, torch_dtype=torch.bfloat16, device_map=args.device | |
) | |
_, _start_think_token, end_think_token = tokenizer.encode("<think></think>") | |
@torch.inference_mode | |
def reasoning_effort(question: str, min_thinking_tokens: int): | |
tokens = tokenizer.apply_chat_template( | |
[ | |
{"role": "user", "content": question}, | |
{"role": "assistant", "content": "<think>\n" + args.prefill}, | |
], | |
continue_final_message=True, | |
return_tensors="pt", | |
) | |
tokens = tokens.to(model.device) | |
kv = DynamicCache() | |
n_thinking_tokens = 0 | |
yield tokenizer.decode(list(tokens[0])) | |
while True: | |
out = model(input_ids=tokens, past_key_values=kv, use_cache=True) | |
next_token = torch.multinomial( | |
torch.softmax(out.logits[0, -1, :], dim=-1), 1 | |
).item() | |
kv = out.past_key_values | |
if ( | |
next_token in (end_think_token, model.config.eos_token_id) | |
and n_thinking_tokens < min_thinking_tokens | |
): | |
replacement = random.choice(args.replacements) | |
yield replacement | |
replacement_tokens = tokenizer.encode(replacement) | |
n_thinking_tokens += len(replacement_tokens) | |
tokens = torch.tensor([replacement_tokens]).to(tokens.device) | |
elif next_token == model.config.eos_token_id: | |
break | |
else: | |
yield tokenizer.decode([next_token]) | |
n_thinking_tokens += 1 | |
tokens = torch.tensor([[next_token]]).to(tokens.device) | |
for chunk in reasoning_effort(args.question, args.min_thinking_tokens): | |
print(chunk, end="", flush=True) |
That’s almost what I thought yesterday – thank you for sharing!
I thought the approach could be a bit different:
If the tag appears, we inject a new prompt like:
The first level of problem-solving has been achieved. Now find a new logical solution through intensive contemplation, but do not copy the previous <think> thoughts, make new ones.
<think>
and perform this process ~five times.
I’m not a skilled developer with r1 level hardware – could you please try this approach too?
Hi, I tried this on my CPU-only machine and it was very slow, even with the Distill 1.5B model. So I asked Claude to generate a 'CPU-friendly' version: https://gist.github.com/sebington/ece931a90048109a38b1df1fa4dc4a03
Is there a way to format the output, I want main output without getting the think texts, after the deepseek r1 has already done with thinking and producing output.
Is there a way to format the output, I want main output without getting the think texts, after the deepseek r1 has already done with thinking and producing output.
Yes the loop at the bottom is over token strings. So just don't print until you see </think>
:
has_stopped_thinking = False
for chunk in reasoning_effort(args.question, args.min_thinking_tokens):
if not has_stopped_thinking:
if "</think>" in chunk:
has_stopped_thinking = True
else:
print(chunk, end="", flush=True)
(wrote this on my phone, untested)
We need someone to straight up eval on the ARC-AGI by setting the min-thinking-tokens to 50k per task
ah, just had that exact error and thought I was going crazy...good thing I checked back here