Skip to content

Instantly share code, notes, and snippets.

@AshtonIzmev
Created July 16, 2025 11:49
Show Gist options
  • Select an option

  • Save AshtonIzmev/c84e3135a5229873c561e2783fba2b33 to your computer and use it in GitHub Desktop.

Select an option

Save AshtonIzmev/c84e3135a5229873c561e2783fba2b33 to your computer and use it in GitHub Desktop.
RAG workflow
from typing import List, Dict, Optional
from utils import count_tokens, pdf_to_text_with_fonts_pdfplumber
import logging
# Configure logging for the chunker
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class FontAwareRecursiveChunker:
"""A recursive text chunker that uses font sizes to determine optimal split points.
This chunker prioritizes splitting at larger fonts (headings) to maintain
semantic coherence while respecting token limits.
"""
def __init__(
self,
min_tokens: int = 500,
max_tokens: int = 2500,
model: str = "gpt-4o",
min_heading_size_percentile: float = 70.0
):
"""Initialize the font-aware chunker.
Args:
min_tokens: Minimum tokens per chunk
max_tokens: Maximum tokens per chunk
model: Model to use for tokenization
min_heading_size_percentile: Percentile threshold for considering text as heading
"""
self.min_tokens = min_tokens
self.max_tokens = max_tokens
self.model = model
self.min_heading_size_percentile = min_heading_size_percentile
logger.info(f"FontAwareRecursiveChunker initialized with min_tokens={min_tokens}, max_tokens={max_tokens}")
def chunk_by_fonts(self, text_elements: List[Dict]) -> List[Dict]:
"""Chunk text elements using font-aware recursive strategy.
Args:
text_elements: List of text elements with font information
Returns:
List of chunks with metadata
"""
logger.info(f"Starting font-aware chunking for {len(text_elements)} text elements")
# Analyze font distribution to determine heading thresholds
font_analysis = self._analyze_font_sizes(text_elements)
heading_threshold = font_analysis['heading_threshold']
logger.info(f"Font analysis complete: heading_threshold={heading_threshold:.1f}pt, median={font_analysis.get('median_size', 0):.1f}pt, range=[{font_analysis.get('min_size', 0):.1f}, {font_analysis.get('max_size', 0):.1f}]pt")
# Identify potential split points based on font sizes
split_points = self._identify_split_points(text_elements, heading_threshold)
logger.info(f"Identified {len(split_points)} potential split points based on font sizes")
if split_points:
logger.debug("Split points preview:")
for i, sp in enumerate(split_points[:5]): # Show first 5
logger.debug(f" {i+1}. Index {sp['index']}: {sp['font_size']:.1f}pt - '{sp['text'][:50]}{'...' if len(sp['text']) > 50 else ''}'")
# Perform recursive chunking
logger.info("Starting recursive chunking process")
chunks = self._recursive_chunk(text_elements, split_points, 0, len(text_elements))
logger.info(f"Recursive chunking complete: generated {len(chunks)} raw chunks")
return chunks
def _analyze_font_sizes(self, text_elements: List[Dict]) -> Dict:
"""Analyze font size distribution to determine heading thresholds."""
font_sizes = []
for element in text_elements:
size = element.get('font_size', 0)
text = element.get('text', '').strip()
if size > 0 and text: # Only consider elements with actual text
font_sizes.append(size)
if not font_sizes:
logger.warning("No valid font sizes found, using default threshold")
return {'heading_threshold': 12.0, 'size_distribution': []}
font_sizes.sort()
# Calculate percentile-based threshold for headings
import statistics
threshold_index = int(len(font_sizes) * self.min_heading_size_percentile / 100)
heading_threshold = font_sizes[threshold_index] if threshold_index < len(font_sizes) else font_sizes[-1]
logger.debug(f"Font size analysis: {len(font_sizes)} valid sizes, using {self.min_heading_size_percentile}th percentile")
return {
'heading_threshold': heading_threshold,
'size_distribution': font_sizes,
'median_size': statistics.median(font_sizes),
'max_size': max(font_sizes),
'min_size': min(font_sizes)
}
def _identify_split_points(self, text_elements: List[Dict], heading_threshold: float) -> List[Dict]:
"""Identify potential split points based on font sizes."""
split_points = []
for i, element in enumerate(text_elements):
font_size = element.get('font_size', 0)
text = element.get('text', '').strip()
# Consider as split point if font size is above threshold and has meaningful text
if font_size >= heading_threshold and text and len(text) > 3:
split_points.append({
'index': i,
'font_size': font_size,
'text': text,
'element': element
})
# Sort by font size (largest first) for priority-based splitting
split_points.sort(key=lambda x: x['font_size'], reverse=True)
logger.debug(f"Found {len(split_points)} split points with font sizes >= {heading_threshold:.1f}pt")
return split_points
def _recursive_chunk(
self,
text_elements: List[Dict],
split_points: List[Dict],
start_idx: int,
end_idx: int,
depth: int = 0
) -> List[Dict]:
"""Recursively chunk text elements using font-based split points."""
indent = " " * depth # For visual hierarchy in logs
# Extract text segment
segment_elements = text_elements[start_idx:end_idx]
segment_text = " ".join([elem.get('text', '') for elem in segment_elements])
segment_tokens = count_tokens(segment_text, self.model)
logger.debug(f"{indent}πŸ“Š Analyzing segment [{start_idx}:{end_idx}] (depth {depth}): {segment_tokens} tokens, {len(segment_elements)} elements")
if segment_elements:
first_text = segment_elements[0].get('text', '')[:30]
logger.debug(f"{indent} Starts with: '{first_text}{'...' if len(first_text) >= 30 else ''}'")
# Base case: segment fits within token limits
if segment_tokens <= self.max_tokens:
if segment_tokens >= self.min_tokens or depth == 0: # Accept if above min or root level
logger.info(f"{indent}βœ… Creating chunk [{start_idx}:{end_idx}]: {segment_tokens} tokens (depth {depth})")
return [{
'text': segment_text,
'elements': segment_elements,
'start_idx': start_idx,
'end_idx': end_idx,
'token_count': segment_tokens,
'font_sizes': [elem.get('font_size', 0) for elem in segment_elements],
'leading_font_size': segment_elements[0].get('font_size', 0) if segment_elements else 0
}]
else:
logger.debug(f"{indent}❌ Segment too small ({segment_tokens} < {self.min_tokens} tokens), will merge later")
return [] # Too small, will be merged with adjacent chunk
# Find best split point within this segment
best_split = self._find_best_split_point(split_points, start_idx, end_idx)
if best_split is None:
logger.warning(f"{indent}⚠️ No font-based split found for segment [{start_idx}:{end_idx}], forcing token-based split")
# No good split point found, force split at token boundary
return self._force_split_at_token_boundary(text_elements, start_idx, end_idx, depth)
split_idx = best_split['index']
logger.info(f"{indent}πŸ”€ Font-based split at index {split_idx} (font_size={best_split['font_size']:.1f}pt): '{best_split['text'][:40]}{'...' if len(best_split['text']) > 40 else ''}'")
# Recursively chunk left and right segments
logger.debug(f"{indent}⬅️ Processing left segment [{start_idx}:{split_idx}]")
left_chunks = self._recursive_chunk(text_elements, split_points, start_idx, split_idx, depth + 1)
logger.debug(f"{indent}➑️ Processing right segment [{split_idx}:{end_idx}]")
right_chunks = self._recursive_chunk(text_elements, split_points, split_idx, end_idx, depth + 1)
logger.debug(f"{indent}πŸ”— Combining {len(left_chunks)} left + {len(right_chunks)} right chunks")
return left_chunks + right_chunks
def _find_best_split_point(self, split_points: List[Dict], start_idx: int, end_idx: int) -> Optional[Dict]:
"""Find the best split point within the given range."""
valid_splits = [
sp for sp in split_points
if start_idx < sp['index'] < end_idx - 1 # Ensure split is within range and not at edges
]
if not valid_splits:
return None
# Return the split with largest font size (already sorted)
best = valid_splits[0]
logger.debug(f" Found {len(valid_splits)} valid splits in range, choosing best: index {best['index']}, {best['font_size']:.1f}pt")
return best
def _force_split_at_token_boundary(
self,
text_elements: List[Dict],
start_idx: int,
end_idx: int,
depth: int
) -> List[Dict]:
"""Force split at token boundary when no good font-based split is found."""
indent = " " * depth
logger.debug(f"{indent}πŸ”§ Force-splitting segment [{start_idx}:{end_idx}] at token boundary")
segment_elements = text_elements[start_idx:end_idx]
# Find approximate middle point based on tokens
cumulative_tokens = 0
target_tokens = self.max_tokens // 2
split_idx = start_idx
for i, element in enumerate(segment_elements):
element_text = element.get('text', '')
element_tokens = count_tokens(element_text, self.model)
cumulative_tokens += element_tokens
if cumulative_tokens >= target_tokens:
split_idx = start_idx + i + 1
break
# Ensure we don't split at the very beginning or end
split_idx = max(start_idx + 1, min(split_idx, end_idx - 1))
logger.debug(f"{indent} Token-based split at index {split_idx} (target: {target_tokens} tokens)")
# Recursively chunk left and right segments
left_chunks = self._recursive_chunk(text_elements, [], start_idx, split_idx, depth + 1)
right_chunks = self._recursive_chunk(text_elements, [], split_idx, end_idx, depth + 1)
return left_chunks + right_chunks
def post_process_chunks(self, chunks: List[Dict]) -> List[Dict]:
"""Post-process chunks to merge small chunks."""
if not chunks:
logger.warning("No chunks to post-process")
return chunks
logger.info(f"Post-processing {len(chunks)} chunks (merging small chunks)")
processed_chunks = []
merged_count = 0
for i, chunk in enumerate(chunks):
# Merge very small chunks with adjacent ones
if chunk['token_count'] < self.min_tokens and processed_chunks:
# Merge with previous chunk
prev_chunk = processed_chunks[-1]
prev_tokens = prev_chunk['token_count']
prev_chunk['text'] += " " + chunk['text']
prev_chunk['elements'].extend(chunk['elements'])
prev_chunk['end_idx'] = chunk['end_idx']
prev_chunk['token_count'] = count_tokens(prev_chunk['text'], self.model)
prev_chunk['font_sizes'].extend(chunk['font_sizes'])
merged_count += 1
logger.info(f"πŸ”— Merged small chunk {i} ({chunk['token_count']} tokens) with previous: {prev_tokens} -> {prev_chunk['token_count']} tokens")
else:
processed_chunks.append(chunk)
logger.debug(f"βœ… Kept chunk {i}: {chunk['token_count']} tokens")
logger.info(f"Post-processing complete: {len(processed_chunks)} final chunks ({merged_count} merges)")
# Log summary of final chunks
for i, chunk in enumerate(processed_chunks):
leading_font = chunk.get('leading_font_size', 0)
logger.info(f"Final chunk {i+1}: {chunk['token_count']} tokens, leading font {leading_font:.1f}pt, range [{chunk['start_idx']}:{chunk['end_idx']}]")
return processed_chunks
def chunk_pdf_by_fonts(
filename: str,
min_tokens: int = 500,
max_tokens: int = 2500,
model: str = "gpt-4o"
) -> List[Dict]:
"""Complete pipeline to chunk a PDF using font-aware strategy.
Args:
filename: Path to PDF file
min_tokens: Minimum tokens per chunk
max_tokens: Maximum tokens per chunk
model: Model to use for tokenization
Returns:
List of chunks with text and metadata
"""
logger.info(f"πŸš€ Starting PDF chunking pipeline for: {filename}")
# Extract text with font information
logger.info("πŸ“„ Extracting text with font information...")
text_elements = pdf_to_text_with_fonts_pdfplumber(filename)
if not text_elements:
error_msg = f"No text elements extracted from {filename}"
logger.error(error_msg)
raise ValueError(error_msg)
logger.info(f"πŸ“ Extracted {len(text_elements)} text elements from PDF")
# Initialize chunker
chunker = FontAwareRecursiveChunker(
min_tokens=min_tokens,
max_tokens=max_tokens,
model=model
)
# Perform chunking
logger.info("βš™οΈ Performing font-aware chunking...")
raw_chunks = chunker.chunk_by_fonts(text_elements)
# Post-process chunks
logger.info("πŸ”§ Post-processing chunks...")
final_chunks = chunker.post_process_chunks(raw_chunks)
logger.info(f"πŸŽ‰ PDF chunking complete! Generated {len(final_chunks)} final chunks")
return final_chunks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment