Skip to content

Instantly share code, notes, and snippets.

@wolfecameron
Last active June 19, 2025 23:35
Show Gist options
  • Save wolfecameron/306aa72a0c5095db460e2ccea9b06777 to your computer and use it in GitHub Desktop.
Save wolfecameron/306aa72a0c5095db460e2ccea9b06777 to your computer and use it in GitHub Desktop.
Tracing text using the algorithm proposed by OLMoTrace: https://arxiv.org/abs/2504.07096
import ast
import math
import random
from infini_gram.engine import InfiniGramEngine
from transformers import AutoTokenizer
def compute_longest_prefix(query, doc):
"""helper function for computing longest prefix of query that exists
within a document"""
def shared_prefix_length(list1, list2):
prefix_length = 0
for elem1, elem2 in zip(list1, list2):
if elem1 == elem2:
prefix_length += 1
else:
break
return prefix_length
first_id = query[0]
start_idx = [index for index, value in enumerate(doc) if value == first_id]
longest_prefix = 0
for si in start_idx:
longest_prefix = max(
longest_prefix,
shared_prefix_length(query, doc[si:]),
)
return longest_prefix
# setup
enc = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", add_bos_token=False, add_eos_token=False)
engine = InfiniGramEngine(index_dir=<path to index>, eos_token_id=enc.eos_token_id)
unigram_probs = {1: 0.5, 2: 0.5} # load pre-computed probabilities
# LLM output / query to search
generation = 'Here is the output of the LLM that we want to search for in our data.'
gen_ids = enc.encode(generation)
"""
Step One: find maximal matching spans
"""
L = len(gen_ids)
max_doc_toks = len(gen_ids) * 2 # size of spans to retrieve in documents
# find longest prefix match for every suffix in the query
spans = []
for start in range(len(gen_ids) - 1):
_suffix = gen_ids[start:]
_suff_res = engine.find(input_ids=_suffix)
# if no match, get the longest matching prefix using find result
if _suff_res['cnt'] == 0:
_shards = _suff_res['segment_by_shard']
assert len(_shards) == 1 # assume only one shard
_doc_ids = engine.get_doc_by_rank(
s=0, # assume only one shard
rank=_shards[0][0],
max_disp_len=max_doc_toks,
)['token_ids']
matched_toks = compute_longest_prefix(_suffix, _doc_ids) # get longest matching prefix
elif _suff_res['cnt'] > 0:
matched_toks = len(_suffix)
spans.append((start, start + matched_toks))
# remove partial and non-self-contained spans
full_spans = []
for start, end in spans:
span_ids = gen_ids[start: end]
span_text = enc.decode(span_ids)
# check for internal punctuation
has_internal_punc = False
punc_chars = "!.?\n"
for ch in span_text[:-1]:
if ch in punc_chars:
has_internal_punc = True
break
if has_internal_punc:
continue
# check if first token is a continuation of a word
first_tok_id = span_ids[0]
first_tok = enc.convert_ids_to_tokens(first_tok_id)
if first_tok[0] != '▁': # assumes Llama 2 token format
continue
# no sub-token follows the last token
if end < len(gen_ids) and tokenizer.convert_ids_to_tokens(gen_ids[end])[0] != "▁":
continue
full_spans.append((start, end, span_ids, span_text))
# remove non-maximal spans
maximal_spans = []
max_end_pos = -1
full_spans = sorted(full_spans)
for start, end, ids, text in full_spans:
if end > max_end_pos:
maximal_spans.append((start, end, ids, text))
max_end_pos = end
"""
Step Two: filter to keep long / unique spans
"""
K = math.ceil(0.05 * L)
assert K > 0
filt_spans = []
for start, end, ids, text in maximal_spans:
span_uni_prob = [unigram_probs.get(_id) for _id in ids]
span_uni_prob = math.prod(span_uni_prob)
filt_spans.append((start, end, ids, text, span_uni_prob))
filt_spans = sorted(filt_spans, key=lambda x: x[-1])
filt_spans = filt_spans[:K]
filt_spans = sorted(filt_spans) # sort based on start position again
"""
Step Three: retrieve Enclosing Docs
"""
docs_per_span = 10
span_to_docs = defaultdict(list)
for i, (start, end, ids, text, uni_prob) in enumerate(filt_spans):
# run retrieval in infinigram index to get documents
span_res = engine.find(input_ids=ids)
assert span_res['cnt'] > 0
assert len(span_res['segment_by_shard']) == 1 # assume only one shard
rank_start, rank_end = span_res['segment_by_shard'][0]
ranks = [r for r in range(rank_start, rank_end)]
if len(ranks) > docs_per_span:
# retrieve fixed number of documents for each span
ranks = sorted(random.sample(ranks, docs_per_span))
# NOTE: we can instead rank documents by BM25 score here!
for r in ranks:
_doc = engine.get_doc_by_rank(
s=0,
rank=r,
max_disp_len=max_doc_toks,
)
_doc_meta = ast.literal_eval(_doc['metadata'])['metadata']
_doc_text = enc.decode(_doc['token_ids'])
_doc_data = {
"text": _doc_text,
**_doc_meta
}
span_to_docs[i].append(_doc_data)
"""
Step Four: merge overlapping spans
"""
# get indices of spans to merge together
merged_spans = [[0]]
curr_idx = 0
curr_start = filt_spans[0][0]
curr_end = filt_spans[0][1]
for i, next_span in enumerate(filt_spans[1:]):
start = next_span[0]
end = next_span[1]
if start < curr_end:
curr_end = max(curr_end, end)
merged_spans[curr_idx].append(i + 1)
else:
curr_start, curr_end = start, end
curr_idx += 1
merged_spans.append([i + 1])
assert len(merged_spans) == curr_idx + 1
# merge spans into a final set
final_spans = []
for ms in merged_spans:
all_docs = []
docs_per_merged_span = math.ceil(docs_per_span / float(len(ms))) # subsample docs for spans being merged
for i in ms:
# take top docs from each span being merged
all_docs.extend(span_to_docs[i][:docs_per_merged_span])
_spans = [filt_spans[i] for i in ms]
start = min([x[0] for x in _spans])
end = max([x[1] for x in _spans])
text = enc.decode(gen_ids[start: end])
final_spans.append({
"start": start,
"end": end,
"text": text,
"docs": all_docs,
})
"""
Step Five: observe tracing results
"""
docs_to_print = 5
print(f'Query Text: {enc.decode(gen_ids)}')
for i, sp in enumerate(final_spans):
print("\n" + "="*20 + f" SPAN {i + 1} / {len(final_spans)} " + "="*20)
print(f"Span Text: {sp['text']}\n")
for j, doc in enumerate(sp['docs']):
print("-"*10 + f" Document {j + 1} / {len(sp['docs'])} " + "-"*10)
for k in ['text', 'movie_id', 'src_lang', 'start_frame', 'end_frame']:
if k == 'text':
v = doc[k].replace('\n', ' ')
else:
v = doc[k]
print(f"- {k} --> {v}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment