Last active
June 19, 2025 23:35
-
-
Save wolfecameron/306aa72a0c5095db460e2ccea9b06777 to your computer and use it in GitHub Desktop.
Tracing text using the algorithm proposed by OLMoTrace: https://arxiv.org/abs/2504.07096
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ast | |
import math | |
import random | |
from infini_gram.engine import InfiniGramEngine | |
from transformers import AutoTokenizer | |
def compute_longest_prefix(query, doc): | |
"""helper function for computing longest prefix of query that exists | |
within a document""" | |
def shared_prefix_length(list1, list2): | |
prefix_length = 0 | |
for elem1, elem2 in zip(list1, list2): | |
if elem1 == elem2: | |
prefix_length += 1 | |
else: | |
break | |
return prefix_length | |
first_id = query[0] | |
start_idx = [index for index, value in enumerate(doc) if value == first_id] | |
longest_prefix = 0 | |
for si in start_idx: | |
longest_prefix = max( | |
longest_prefix, | |
shared_prefix_length(query, doc[si:]), | |
) | |
return longest_prefix | |
# setup | |
enc = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", add_bos_token=False, add_eos_token=False) | |
engine = InfiniGramEngine(index_dir=<path to index>, eos_token_id=enc.eos_token_id) | |
unigram_probs = {1: 0.5, 2: 0.5} # load pre-computed probabilities | |
# LLM output / query to search | |
generation = 'Here is the output of the LLM that we want to search for in our data.' | |
gen_ids = enc.encode(generation) | |
""" | |
Step One: find maximal matching spans | |
""" | |
L = len(gen_ids) | |
max_doc_toks = len(gen_ids) * 2 # size of spans to retrieve in documents | |
# find longest prefix match for every suffix in the query | |
spans = [] | |
for start in range(len(gen_ids) - 1): | |
_suffix = gen_ids[start:] | |
_suff_res = engine.find(input_ids=_suffix) | |
# if no match, get the longest matching prefix using find result | |
if _suff_res['cnt'] == 0: | |
_shards = _suff_res['segment_by_shard'] | |
assert len(_shards) == 1 # assume only one shard | |
_doc_ids = engine.get_doc_by_rank( | |
s=0, # assume only one shard | |
rank=_shards[0][0], | |
max_disp_len=max_doc_toks, | |
)['token_ids'] | |
matched_toks = compute_longest_prefix(_suffix, _doc_ids) # get longest matching prefix | |
elif _suff_res['cnt'] > 0: | |
matched_toks = len(_suffix) | |
spans.append((start, start + matched_toks)) | |
# remove partial and non-self-contained spans | |
full_spans = [] | |
for start, end in spans: | |
span_ids = gen_ids[start: end] | |
span_text = enc.decode(span_ids) | |
# check for internal punctuation | |
has_internal_punc = False | |
punc_chars = "!.?\n" | |
for ch in span_text[:-1]: | |
if ch in punc_chars: | |
has_internal_punc = True | |
break | |
if has_internal_punc: | |
continue | |
# check if first token is a continuation of a word | |
first_tok_id = span_ids[0] | |
first_tok = enc.convert_ids_to_tokens(first_tok_id) | |
if first_tok[0] != '▁': # assumes Llama 2 token format | |
continue | |
# no sub-token follows the last token | |
if end < len(gen_ids) and tokenizer.convert_ids_to_tokens(gen_ids[end])[0] != "▁": | |
continue | |
full_spans.append((start, end, span_ids, span_text)) | |
# remove non-maximal spans | |
maximal_spans = [] | |
max_end_pos = -1 | |
full_spans = sorted(full_spans) | |
for start, end, ids, text in full_spans: | |
if end > max_end_pos: | |
maximal_spans.append((start, end, ids, text)) | |
max_end_pos = end | |
""" | |
Step Two: filter to keep long / unique spans | |
""" | |
K = math.ceil(0.05 * L) | |
assert K > 0 | |
filt_spans = [] | |
for start, end, ids, text in maximal_spans: | |
span_uni_prob = [unigram_probs.get(_id) for _id in ids] | |
span_uni_prob = math.prod(span_uni_prob) | |
filt_spans.append((start, end, ids, text, span_uni_prob)) | |
filt_spans = sorted(filt_spans, key=lambda x: x[-1]) | |
filt_spans = filt_spans[:K] | |
filt_spans = sorted(filt_spans) # sort based on start position again | |
""" | |
Step Three: retrieve Enclosing Docs | |
""" | |
docs_per_span = 10 | |
span_to_docs = defaultdict(list) | |
for i, (start, end, ids, text, uni_prob) in enumerate(filt_spans): | |
# run retrieval in infinigram index to get documents | |
span_res = engine.find(input_ids=ids) | |
assert span_res['cnt'] > 0 | |
assert len(span_res['segment_by_shard']) == 1 # assume only one shard | |
rank_start, rank_end = span_res['segment_by_shard'][0] | |
ranks = [r for r in range(rank_start, rank_end)] | |
if len(ranks) > docs_per_span: | |
# retrieve fixed number of documents for each span | |
ranks = sorted(random.sample(ranks, docs_per_span)) | |
# NOTE: we can instead rank documents by BM25 score here! | |
for r in ranks: | |
_doc = engine.get_doc_by_rank( | |
s=0, | |
rank=r, | |
max_disp_len=max_doc_toks, | |
) | |
_doc_meta = ast.literal_eval(_doc['metadata'])['metadata'] | |
_doc_text = enc.decode(_doc['token_ids']) | |
_doc_data = { | |
"text": _doc_text, | |
**_doc_meta | |
} | |
span_to_docs[i].append(_doc_data) | |
""" | |
Step Four: merge overlapping spans | |
""" | |
# get indices of spans to merge together | |
merged_spans = [[0]] | |
curr_idx = 0 | |
curr_start = filt_spans[0][0] | |
curr_end = filt_spans[0][1] | |
for i, next_span in enumerate(filt_spans[1:]): | |
start = next_span[0] | |
end = next_span[1] | |
if start < curr_end: | |
curr_end = max(curr_end, end) | |
merged_spans[curr_idx].append(i + 1) | |
else: | |
curr_start, curr_end = start, end | |
curr_idx += 1 | |
merged_spans.append([i + 1]) | |
assert len(merged_spans) == curr_idx + 1 | |
# merge spans into a final set | |
final_spans = [] | |
for ms in merged_spans: | |
all_docs = [] | |
docs_per_merged_span = math.ceil(docs_per_span / float(len(ms))) # subsample docs for spans being merged | |
for i in ms: | |
# take top docs from each span being merged | |
all_docs.extend(span_to_docs[i][:docs_per_merged_span]) | |
_spans = [filt_spans[i] for i in ms] | |
start = min([x[0] for x in _spans]) | |
end = max([x[1] for x in _spans]) | |
text = enc.decode(gen_ids[start: end]) | |
final_spans.append({ | |
"start": start, | |
"end": end, | |
"text": text, | |
"docs": all_docs, | |
}) | |
""" | |
Step Five: observe tracing results | |
""" | |
docs_to_print = 5 | |
print(f'Query Text: {enc.decode(gen_ids)}') | |
for i, sp in enumerate(final_spans): | |
print("\n" + "="*20 + f" SPAN {i + 1} / {len(final_spans)} " + "="*20) | |
print(f"Span Text: {sp['text']}\n") | |
for j, doc in enumerate(sp['docs']): | |
print("-"*10 + f" Document {j + 1} / {len(sp['docs'])} " + "-"*10) | |
for k in ['text', 'movie_id', 'src_lang', 'start_frame', 'end_frame']: | |
if k == 'text': | |
v = doc[k].replace('\n', ' ') | |
else: | |
v = doc[k] | |
print(f"- {k} --> {v}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment