|
#!/usr/bin/env python3 |
|
""" |
|
Tokenizer Performance Comparison Script |
|
|
|
A comprehensive tool for evaluating tokenizer efficiency across multiple languages. |
|
Supports HuggingFace datasets, local files, and built-in samples. |
|
|
|
Author: https://github.com/ParagEkbote |
|
""" |
|
from typing import Tuple, List, Dict, Optional, Callable, Literal |
|
from dataclasses import dataclass |
|
from pathlib import Path |
|
import argparse |
|
import logging |
|
|
|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
|
|
from datasets import load_dataset |
|
from transformers import AutoTokenizer |
|
from datatrove.utils.word_tokenizers import load_word_tokenizer |
|
|
|
# Configure logging |
|
logging.basicConfig(level=logging.INFO, format="%(message)s") |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
# ============================================================================ |
|
# Data Classes and Constants |
|
# ============================================================================ |
|
|
|
|
|
@dataclass |
|
class Language: |
|
"""Represents a language with its name and code.""" |
|
|
|
name: str |
|
code: str |
|
|
|
|
|
@dataclass |
|
class TokenizerMetrics: |
|
"""Container for tokenizer evaluation metrics.""" |
|
|
|
word_fertility: float |
|
char_fertility: float |
|
pcw: float # Proportion of continued words |
|
vocab_size: int |
|
|
|
|
|
# Language config name mappings for finewiki |
|
FINEWIKI_CONFIGS = { |
|
'ab', 'ace', 'ady', 'af', 'als', 'alt', 'ami', 'am', 'ang', 'anp', 'an', 'arc', 'ar', |
|
'ary', 'arz', 'ast', 'as', 'atj', 'avk', 'av', 'awa', 'ay', 'azb', 'az', 'ban', 'bar', |
|
'bat_smg', 'ba', 'bbc', 'bcl', 'be', 'bg', 'bh', 'bi', 'bjn', 'blk', 'bm', 'bn', 'bo', |
|
'bpy', 'br', 'bs', 'bug', 'bxr', 'ca', 'cbk_zam', 'cdo', 'ceb', 'ce', 'chr', 'ch', |
|
'chy', 'ckb', 'co', 'crh', 'cr', 'csb', 'cs', 'cu', 'cv', 'cy', 'dag', 'da', 'de', |
|
'dga', 'din', 'diq', 'dsb', 'dty', 'dv', 'dz', 'ee', 'el', 'eml', 'en', 'eo', 'es', |
|
'et', 'eu', 'ext', 'fat', 'fa', 'ff', 'fiu_vro', 'fi', 'fj', 'fon', 'fo', 'frp', |
|
'frr', 'fr', 'fur', 'fy', 'gag', 'gan', 'ga', 'gcr', 'gd', 'glk', 'gl', 'gn', 'gom', |
|
'gor', 'got', 'gpe', 'guc', 'gur', 'gu', 'guw', 'gv', 'hak', 'ha', 'haw', 'he', |
|
'hif', 'hi', 'hr', 'hsb', 'ht', 'hu', 'hy', 'hyw', 'ia', 'id', 'ie', 'ig', 'ik', |
|
'ilo', 'inh', 'io', 'is', 'it', 'iu', 'jam', 'ja', 'jbo', 'jv', 'kaa', 'kab', 'ka', |
|
'kbd', 'kbp', 'kcg', 'kg', 'ki', 'kk', 'kl', 'km', 'kn', 'koi', 'ko', 'krc', 'ksh', |
|
'ks', 'ku', 'kv', 'kw', 'ky', 'lad', 'la', 'lbe', 'lb', 'lez', 'lfn', 'lg', 'lij', |
|
'li', 'lld', 'lmo', 'ln', 'lo', 'ltg', 'lt', 'lv', 'mad', 'mai', 'map_bms', 'mdf', |
|
'mg', 'mhr', 'min', 'mi', 'mk', 'ml', 'mni', 'mn', 'mnw', 'mrj', 'mr', 'ms', 'mt', |
|
'mwl', 'myv', 'my', 'mzn', 'nah', 'nap', 'nds_nl', 'nds', 'ne', 'new', 'nia', 'nl', |
|
'nn', 'nov', 'no', 'nqo', 'nrm', 'nso', 'nv', 'ny', 'oc', 'olo', 'om', 'or', 'os', |
|
'pag', 'pam', 'pap', 'pa', 'pcd', 'pcm', 'pdc', 'pfl', 'pih', 'pi', 'pl', 'pms', |
|
'pnb', 'pnt', 'ps', 'pt', 'pwn', 'qu', 'rm', 'rmy', 'rn', 'roa_rup', 'roa_tara', |
|
'ro', 'rue', 'ru', 'rw', 'sah', 'sat', 'sa', 'scn', 'sco', 'sc', 'sd', 'se', 'sg', |
|
'shi', 'shn', 'sh', 'simple', 'si', 'skr', 'sk', 'sl', 'smn', 'sm', 'sn', 'so', |
|
'sq', 'srn', 'sr', 'ss', 'stq', 'st', 'su', 'sv', 'sw', 'szl', 'szy', 'ta', 'tay', |
|
'tcy', 'tet', 'te', 'tg', 'th', 'ti', 'tk', 'tl', 'tly', 'tn', 'to', 'tpi', 'trv', |
|
'tr', 'ts', 'tt', 'tum', 'tw', 'tyv', 'ty', 'udm', 'ug', 'uk', 'ur', 'uz', 'vec', |
|
'vep', 've', 'vi', 'vls', 'vo', 'war', 'wa', 'wo', 'wuu', 'xal', 'xh', 'xmf', 'yi', |
|
'yo', 'za', 'zea', 'zgh', 'zh_classical', 'zh_min_nan', 'zh_yue', 'zh', 'zu' |
|
} |
|
|
|
LANGUAGE_NAMES = { |
|
'en': 'English', 'es': 'Spanish', 'fr': 'French', 'zh': 'Chinese', |
|
'ar': 'Arabic', 'ja': 'Japanese', 'de': 'German', 'it': 'Italian', |
|
'pt': 'Portuguese', 'ru': 'Russian', 'hi': 'Hindi', 'ko': 'Korean', |
|
'nl': 'Dutch', 'sv': 'Swedish', 'pl': 'Polish', 'tr': 'Turkish', |
|
'id': 'Indonesian', 'vi': 'Vietnamese', 'th': 'Thai', 'cs': 'Czech', |
|
'ro': 'Romanian', 'el': 'Greek', 'hu': 'Hungarian', 'da': 'Danish', |
|
'fi': 'Finnish', 'no': 'Norwegian', 'uk': 'Ukrainian', 'he': 'Hebrew', |
|
'fa': 'Persian', 'bn': 'Bengali', 'simple': 'Simple English', |
|
} |
|
|
|
SAMPLE_TEXTS = { |
|
"en": """The history of artificial intelligence began in antiquity with myths, stories and rumors of artificial beings |
|
endowed with intelligence or consciousness by master craftsmen. Modern AI research began in the mid-20th century |
|
when scientists started exploring whether machines could think. The field was founded on the assumption that |
|
human intelligence can be so precisely described that a machine can be made to simulate it.""", |
|
"es": """La inteligencia artificial es la simulación de procesos de inteligencia humana por parte de máquinas, |
|
especialmente sistemas informáticos. Estos procesos incluyen el aprendizaje, el razonamiento y la autocorrección. |
|
Las aplicaciones particulares de la IA incluyen sistemas expertos, reconocimiento de voz y visión artificial.""", |
|
"fr": """L'intelligence artificielle est un ensemble de théories et de techniques mises en œuvre en vue de réaliser |
|
des machines capables de simuler l'intelligence humaine. Elle correspond à un ensemble de concepts et de |
|
technologies plus qu'à une discipline autonome constituée.""", |
|
"zh": """人工智能是计算机科学的一个分支,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器。 |
|
该领域的研究包括机器人、语言识别、图像识别、自然语言处理和专家系统等。""", |
|
"ar": """الذكاء الاصطناعي هو فرع من علوم الحاسوب يهتم بإنشاء أنظمة قادرة على محاكاة الذكاء البشري. يشمل هذا المجال |
|
التعلم الآلي والتعلم العميق ومعالجة اللغة الطبيعية والرؤية الحاسوبية.""", |
|
"ja": """人工知能とは、コンピュータ科学の一分野で、人間の知能を機械で模倣することを目指す技術です。この分野には、機械学習、 |
|
深層学習、自然言語処理、コンピュータビジョンなどが含まれます。""", |
|
} |
|
|
|
|
|
# ============================================================================ |
|
# Data Loading Functions |
|
# ============================================================================ |
|
|
|
|
|
class DataLoader: |
|
"""Handles loading text data from various sources.""" |
|
|
|
def __init__(self, n_chars: int = 50000): |
|
self.n_chars = n_chars |
|
|
|
def load_from_http_parquet( |
|
self, lang_code: str, base_url: str, content_field: str = "text" |
|
) -> Optional[str]: |
|
"""Load text from HTTP parquet file.""" |
|
parquet_url = f"{base_url}/{lang_code}wiki/000_00000.parquet" |
|
logger.info(f" → Attempting HTTP load from: {parquet_url}") |
|
|
|
data_files = {"train": parquet_url} |
|
ds = load_dataset( |
|
"parquet", data_files=data_files, split="train", streaming=True |
|
) |
|
return self._collect_text_samples(ds, content_field) |
|
|
|
def load_from_hf_dataset( |
|
self, dataset_name: str, lang_code: str, content_field: str = "text" |
|
) -> Optional[str]: |
|
"""Load text from HuggingFace dataset with config.""" |
|
ds = load_dataset(dataset_name, name=lang_code, split="train", streaming=True) |
|
return self._collect_text_samples(ds, content_field) |
|
|
|
def load_from_file(self, filepath: Path) -> Optional[str]: |
|
"""Load text from local file.""" |
|
if not filepath.exists(): |
|
return None |
|
|
|
with open(filepath, "r", encoding="utf-8") as f: |
|
return f.read() |
|
|
|
def load_sample_text(self, lang_code: str) -> Optional[str]: |
|
"""Load and repeat built-in sample text.""" |
|
if lang_code not in SAMPLE_TEXTS: |
|
return None |
|
|
|
sample = SAMPLE_TEXTS[lang_code] |
|
repetitions = (self.n_chars // len(sample)) + 1 |
|
return (sample * repetitions)[: self.n_chars] |
|
|
|
def _collect_text_samples(self, dataset, content_field: str) -> Optional[str]: |
|
"""Collect text samples from streaming dataset.""" |
|
text_samples = [] |
|
total_chars = 0 |
|
|
|
for sample in dataset: |
|
content = sample.get(content_field, "") |
|
if content and isinstance(content, str): |
|
text_samples.append(content) |
|
total_chars += len(content) |
|
if total_chars >= self.n_chars: |
|
break |
|
|
|
return "\n\n".join(text_samples)[: self.n_chars] if text_samples else None |
|
|
|
|
|
def load_all_texts( |
|
languages: List[Language], |
|
data_dir: Optional[str] = None, |
|
hf_dataset: Optional[str] = None, |
|
use_http: bool = False, |
|
base_url: Optional[str] = None, |
|
content_field: str = "text", |
|
n_chars: int = 50000, |
|
) -> Dict[str, str]: |
|
""" |
|
Load text samples for all languages from available sources. |
|
|
|
Priority: HTTP/HF dataset → Local files → Built-in samples |
|
""" |
|
loader = DataLoader(n_chars) |
|
texts = {} |
|
|
|
for lang in languages: |
|
logger.info(f"→ Loading {lang.name} ({lang.code})...") |
|
text = None |
|
|
|
# Try HTTP parquet loading |
|
if use_http and base_url: |
|
try: |
|
text = loader.load_from_http_parquet(lang.code, base_url, content_field) |
|
if text: |
|
logger.info(f" ✓ Loaded from HTTP parquet: {len(text)} chars") |
|
except Exception as e: |
|
logger.debug(f" HTTP load failed: {e}") |
|
|
|
# Try HF dataset loading |
|
if not text and hf_dataset and hf_dataset != "HTTP": |
|
try: |
|
text = loader.load_from_hf_dataset(hf_dataset, lang.code, content_field) |
|
if text: |
|
logger.info(f" ✓ Loaded from {hf_dataset}: {len(text)} chars") |
|
except Exception as e: |
|
logger.debug(f" HF dataset load failed: {e}") |
|
|
|
# Try local file |
|
if not text and data_dir: |
|
filepath = Path(data_dir) / f"{lang.code}.txt" |
|
text = loader.load_from_file(filepath) |
|
if text: |
|
logger.info(f" ✓ Loaded from file: {len(text)} chars") |
|
|
|
# Fall back to sample text |
|
if not text: |
|
text = loader.load_sample_text(lang.code) |
|
if text: |
|
logger.info(f" ✓ Using sample text: {len(text)} chars") |
|
else: |
|
logger.warning(f" ✗ No text available for {lang.name}") |
|
text = "" |
|
|
|
texts[lang.code] = text |
|
|
|
return texts |
|
|
|
|
|
# ============================================================================ |
|
# Tokenizer Analysis Functions |
|
# ============================================================================ |
|
|
|
|
|
class TokenizerAnalyzer: |
|
"""Analyzes tokenizer performance on text.""" |
|
|
|
@staticmethod |
|
def detect_word_marker(token_strings: List[str]) -> Optional[str]: |
|
"""Detect token marker indicating word boundaries.""" |
|
markers = {"▁", "Ġ", "##"} |
|
for marker in markers: |
|
if any(t.startswith(marker) for t in token_strings): |
|
return marker |
|
return None |
|
|
|
@staticmethod |
|
def get_vocab_size(tokenizer) -> int: |
|
"""Extract vocabulary size from tokenizer.""" |
|
if hasattr(tokenizer, "vocab_size") and isinstance(tokenizer.vocab_size, int): |
|
return tokenizer.vocab_size |
|
|
|
if hasattr(tokenizer, "get_vocab"): |
|
return len(tokenizer.get_vocab()) |
|
|
|
return len(tokenizer) if hasattr(tokenizer, "__len__") else -1 |
|
|
|
def compute_metrics( |
|
self, tokenizer, text: str, word_tokenizer: Optional[Callable] = None |
|
) -> TokenizerMetrics: |
|
"""Compute all tokenizer metrics for given text.""" |
|
if not text or len(text.strip()) == 0: |
|
return TokenizerMetrics(0.0, 0.0, 0.0, self.get_vocab_size(tokenizer)) |
|
|
|
# Tokenize text |
|
token_ids = tokenizer.encode(text, add_special_tokens=False) |
|
if len(token_ids) == 0: |
|
return TokenizerMetrics(0.0, 0.0, 0.0, self.get_vocab_size(tokenizer)) |
|
|
|
# Get token strings |
|
try: |
|
token_strings = tokenizer.convert_ids_to_tokens(token_ids) |
|
except Exception: |
|
token_strings = [tokenizer.decode([t]) for t in token_ids] |
|
|
|
# Word tokenization |
|
if word_tokenizer: |
|
try: |
|
words = ( |
|
word_tokenizer.tokenize(text) |
|
if hasattr(word_tokenizer, "tokenize") |
|
else word_tokenizer(text) |
|
) |
|
except Exception: |
|
words = text.split() |
|
else: |
|
words = text.split() |
|
|
|
num_words = len(words) |
|
char_count = len(text.replace(" ", "").replace("\n", "").replace("\t", "")) |
|
|
|
# Calculate fertility scores |
|
word_fertility = len(token_ids) / num_words if num_words > 0 else 0.0 |
|
char_fertility = len(token_ids) / char_count if char_count > 0 else 0.0 |
|
|
|
# Calculate PCW (Proportion of Continued Words) |
|
marker = self.detect_word_marker(token_strings) |
|
if marker == "▁" or marker == "Ġ": |
|
continued = sum(1 for t in token_strings if not t.startswith(marker)) |
|
elif marker == "##": |
|
continued = sum(1 for t in token_strings if t.startswith("##")) |
|
else: |
|
decoded = [tokenizer.decode([i]) for i in token_ids] |
|
continued = sum(1 for d in decoded if not d.startswith(" ") and d != "") |
|
|
|
pcw = continued / len(token_ids) if len(token_ids) > 0 else 0.0 |
|
|
|
return TokenizerMetrics( |
|
round(word_fertility, 3), |
|
round(char_fertility, 3), |
|
round(pcw, 3), |
|
self.get_vocab_size(tokenizer), |
|
) |
|
|
|
|
|
def evaluate_all_tokenizers( |
|
tokenizers: List[Tuple[str, str]], languages: List[Language], texts: Dict[str, str] |
|
) -> pd.DataFrame: |
|
"""Evaluate all tokenizers on all languages.""" |
|
analyzer = TokenizerAnalyzer() |
|
results = [] |
|
|
|
for tokenizer_name, tokenizer_path in tokenizers: |
|
logger.info(f"\nEvaluating {tokenizer_name} ({tokenizer_path})...") |
|
|
|
try: |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
tokenizer_path, trust_remote_code=True |
|
) |
|
except Exception as e: |
|
logger.error(f" ✗ Error loading tokenizer: {e}") |
|
continue |
|
|
|
for lang in languages: |
|
text = texts.get(lang.code, "") |
|
if not text: |
|
logger.info(f" ⊘ Skipping {lang.name} (no text)") |
|
continue |
|
|
|
logger.info(f" → Processing {lang.name}...") |
|
|
|
# Load word tokenizer for the language |
|
word_tokenizer = None |
|
try: |
|
word_tokenizer = load_word_tokenizer(lang.code) |
|
except Exception: |
|
logger.debug(f" Using simple tokenization for {lang.name}") |
|
|
|
try: |
|
metrics = analyzer.compute_metrics(tokenizer, text, word_tokenizer) |
|
results.append( |
|
{ |
|
"tokenizer": tokenizer_name, |
|
"tokenizer_path": tokenizer_path, |
|
"language": lang.name, |
|
"language_code": lang.code, |
|
"word_fertility": metrics.word_fertility, |
|
"char_fertility": metrics.char_fertility, |
|
"pcw": metrics.pcw, |
|
"vocab_size": metrics.vocab_size, |
|
} |
|
) |
|
logger.info( |
|
f" ✓ word_fert={metrics.word_fertility:.3f}, " |
|
f"char_fert={metrics.char_fertility:.3f}" |
|
) |
|
except Exception as e: |
|
logger.error(f" ✗ Error computing metrics: {e}") |
|
|
|
return pd.DataFrame(results) |
|
|
|
|
|
# ============================================================================ |
|
# Visualization Functions |
|
# ============================================================================ |
|
|
|
|
|
class Visualizer: |
|
"""Creates visualizations for tokenizer comparison.""" |
|
|
|
def __init__( |
|
self, |
|
style: Literal["white", "dark", "whitegrid", "darkgrid", "ticks"] = "whitegrid", |
|
): |
|
sns.set_theme(style=style) |
|
self.colors = sns.color_palette("husl", 8) |
|
|
|
def create_heatmaps( |
|
self, df: pd.DataFrame, output_prefix: str = "tokenizer_heatmap" |
|
): |
|
"""Create heatmap visualizations for word and character fertility.""" |
|
if df.empty: |
|
logger.warning("Cannot create heatmaps: no data") |
|
return |
|
|
|
word_pivot = df.pivot( |
|
index="tokenizer", columns="language", values="word_fertility" |
|
) |
|
char_pivot = df.pivot( |
|
index="tokenizer", columns="language", values="char_fertility" |
|
) |
|
|
|
fig, axes = plt.subplots(1, 2, figsize=(16, 6)) |
|
|
|
# Word-level fertility heatmap |
|
sns.heatmap( |
|
word_pivot, |
|
annot=True, |
|
fmt=".3f", |
|
cmap="RdYlGn_r", |
|
cbar_kws={"label": "Tokens per Word"}, |
|
ax=axes[0], |
|
linewidths=0.5, |
|
linecolor="gray", |
|
) |
|
axes[0].set_title( |
|
"Word-Level Fertility\n(Lower is Better)", fontsize=14, fontweight="bold" |
|
) |
|
axes[0].set_xlabel("Language", fontsize=11) |
|
axes[0].set_ylabel("Tokenizer", fontsize=11) |
|
|
|
# Character-level fertility heatmap |
|
sns.heatmap( |
|
char_pivot, |
|
annot=True, |
|
fmt=".3f", |
|
cmap="RdYlGn_r", |
|
cbar_kws={"label": "Tokens per Character"}, |
|
ax=axes[1], |
|
linewidths=0.5, |
|
linecolor="gray", |
|
) |
|
axes[1].set_title( |
|
"Character-Level Fertility\n(Lower is Better)", |
|
fontsize=14, |
|
fontweight="bold", |
|
) |
|
axes[1].set_xlabel("Language", fontsize=11) |
|
axes[1].set_ylabel("Tokenizer", fontsize=11) |
|
|
|
plt.tight_layout() |
|
filepath = f"{output_prefix}_heatmaps.png" |
|
fig.savefig(filepath, dpi=300, bbox_inches="tight") |
|
plt.close(fig) |
|
logger.info(f"\n✓ Heatmaps saved to {filepath}") |
|
|
|
def create_comparison_charts( |
|
self, df: pd.DataFrame, output_prefix: str = "tokenizer_comparison" |
|
): |
|
"""Create comprehensive comparison visualizations.""" |
|
if df.empty: |
|
logger.warning("Cannot create comparison charts: no data") |
|
return |
|
|
|
fig = plt.figure(figsize=(18, 10)) |
|
gs = fig.add_gridspec(3, 2, hspace=0.3, wspace=0.25) |
|
|
|
# 1. Average fertility by tokenizer (bar chart) |
|
ax1 = fig.add_subplot(gs[0, :]) |
|
avg_metrics = df.groupby("tokenizer")[ |
|
["word_fertility", "char_fertility"] |
|
].mean() |
|
avg_metrics.plot(kind="bar", ax=ax1, color=["#3498db", "#e74c3c"], width=0.7) |
|
ax1.set_title( |
|
"Average Fertility Scores by Tokenizer", fontsize=14, fontweight="bold" |
|
) |
|
ax1.set_ylabel("Fertility Score", fontsize=11) |
|
ax1.set_xlabel("Tokenizer", fontsize=11) |
|
ax1.legend(["Word Fertility", "Character Fertility"], loc="upper right") |
|
ax1.grid(axis="y", alpha=0.3) |
|
plt.setp(ax1.xaxis.get_majorticklabels(), rotation=0) |
|
|
|
# 2. Language-specific performance (grouped bar) |
|
ax2 = fig.add_subplot(gs[1, 0]) |
|
word_pivot = df.pivot( |
|
index="language", columns="tokenizer", values="word_fertility" |
|
) |
|
word_pivot.plot(kind="bar", ax=ax2, width=0.8) |
|
ax2.set_title("Word Fertility by Language", fontsize=12, fontweight="bold") |
|
ax2.set_ylabel("Tokens per Word", fontsize=10) |
|
ax2.set_xlabel("Language", fontsize=10) |
|
ax2.legend(title="Tokenizer", fontsize=8) |
|
ax2.grid(axis="y", alpha=0.3) |
|
plt.setp(ax2.xaxis.get_majorticklabels(), rotation=45, ha="right") |
|
|
|
# 3. PCW comparison (grouped bar) |
|
ax3 = fig.add_subplot(gs[1, 1]) |
|
pcw_pivot = df.pivot(index="language", columns="tokenizer", values="pcw") |
|
pcw_pivot.plot(kind="bar", ax=ax3, width=0.8) |
|
ax3.set_title( |
|
"Proportion of Continued Words (PCW)", fontsize=12, fontweight="bold" |
|
) |
|
ax3.set_ylabel("PCW Score", fontsize=10) |
|
ax3.set_xlabel("Language", fontsize=10) |
|
ax3.legend(title="Tokenizer", fontsize=8) |
|
ax3.grid(axis="y", alpha=0.3) |
|
plt.setp(ax3.xaxis.get_majorticklabels(), rotation=45, ha="right") |
|
|
|
# 4. Vocabulary size vs average fertility (scatter) |
|
ax4 = fig.add_subplot(gs[2, 0]) |
|
vocab_fertility = ( |
|
df.groupby("tokenizer") |
|
.agg({"vocab_size": "first", "word_fertility": "mean"}) |
|
.reset_index() |
|
) |
|
for _, row in vocab_fertility.iterrows(): |
|
x = float(row["vocab_size"]) |
|
y = float(row["word_fertility"]) |
|
ax4.scatter(x, y, s=300, alpha=0.6, label=row["tokenizer"]) |
|
ax4.text(x, y, str(row["tokenizer"]), fontsize=9, ha="center", va="center") |
|
ax4.set_title( |
|
"Vocabulary Size vs Average Fertility", fontsize=12, fontweight="bold" |
|
) |
|
ax4.set_xlabel("Vocabulary Size", fontsize=10) |
|
ax4.set_ylabel("Average Word Fertility", fontsize=10) |
|
ax4.grid(alpha=0.3) |
|
|
|
# 5. Distribution of metrics (violin plot) |
|
ax5 = fig.add_subplot(gs[2, 1]) |
|
sns.violinplot( |
|
data=df, x="tokenizer", y="char_fertility", ax=ax5, palette="Set2" |
|
) |
|
ax5.set_title( |
|
"Character Fertility Distribution", fontsize=12, fontweight="bold" |
|
) |
|
ax5.set_ylabel("Tokens per Character", fontsize=10) |
|
ax5.set_xlabel("Tokenizer", fontsize=10) |
|
ax5.grid(axis="y", alpha=0.3) |
|
plt.setp(ax5.xaxis.get_majorticklabels(), rotation=45, ha="right") |
|
|
|
filepath = f"{output_prefix}_analysis.png" |
|
fig.savefig(filepath, dpi=300, bbox_inches="tight") |
|
plt.close(fig) |
|
logger.info(f"✓ Comparison charts saved to {filepath}") |
|
|
|
|
|
# ============================================================================ |
|
# Output Functions |
|
# ============================================================================ |
|
|
|
|
|
def save_results(df: pd.DataFrame, output_csv: str = "tokenizer_comparison.csv"): |
|
"""Save results to CSV file.""" |
|
df.to_csv(output_csv, index=False) |
|
logger.info(f"\n✓ Results saved to {output_csv}") |
|
|
|
|
|
def print_summaries(df: pd.DataFrame): |
|
"""Print formatted summary tables.""" |
|
if df.empty: |
|
logger.warning("\nNo results to display.") |
|
return |
|
|
|
logger.info("\n" + "=" * 80) |
|
logger.info("WORD-LEVEL FERTILITY (tokens per word - lower is better)") |
|
logger.info("=" * 80) |
|
word_pivot = df.pivot( |
|
index="language", columns="tokenizer", values="word_fertility" |
|
) |
|
logger.info(word_pivot.to_string()) |
|
|
|
logger.info("\n" + "=" * 80) |
|
logger.info("CHARACTER-LEVEL FERTILITY (tokens per character - lower is better)") |
|
logger.info("=" * 80) |
|
char_pivot = df.pivot( |
|
index="language", columns="tokenizer", values="char_fertility" |
|
) |
|
logger.info(char_pivot.to_string()) |
|
|
|
logger.info("\n" + "=" * 80) |
|
logger.info("PCW (Proportion Continued Words)") |
|
logger.info("=" * 80) |
|
pcw_pivot = df.pivot(index="language", columns="tokenizer", values="pcw") |
|
logger.info(pcw_pivot.to_string()) |
|
|
|
logger.info("\n" + "=" * 80) |
|
logger.info("VOCABULARY SIZES") |
|
logger.info("=" * 80) |
|
vocab_df = df[["tokenizer", "vocab_size"]].drop_duplicates() |
|
logger.info(vocab_df.to_string(index=False)) |
|
|
|
logger.info("\n" + "=" * 80) |
|
logger.info("SUMMARY STATISTICS") |
|
logger.info("=" * 80) |
|
|
|
avg_word = df.groupby("tokenizer").loc[:, "word_fertility"].mean() |
|
logger.info("\nAverage Word-Level Fertility by Tokenizer:") |
|
logger.info(avg_word.to_string()) |
|
|
|
avg_char: pd.Series = df.groupby("tokenizer").loc[:, "char_fertility"].mean() |
|
logger.info("\nAverage Character-Level Fertility by Tokenizer:") |
|
logger.info(avg_char.to_string()) |
|
|
|
|
|
# ============================================================================ |
|
# Main CLI |
|
# ============================================================================ |
|
|
|
|
|
def parse_arguments(): |
|
"""Parse command line arguments.""" |
|
parser = argparse.ArgumentParser( |
|
description="Compare tokenizers across languages.", |
|
formatter_class=argparse.RawDescriptionHelpFormatter, |
|
epilog=""" |
|
Examples: |
|
# Use built-in sample texts |
|
python %(prog)s |
|
|
|
# Load from HTTP parquet files (fastest) |
|
python %(prog)s --use-http --langs en,es,fr,zh,ar,ja |
|
|
|
# Load from HuggingFace dataset with configs |
|
python %(prog)s --hf-dataset "HuggingFaceFW/finewiki" --langs en,de,it |
|
|
|
# Use custom base URL for HTTP loading |
|
python %(prog)s --use-http --base-url "https://..." --langs en,es,fr |
|
|
|
# Use local text files |
|
python %(prog)s --data-dir ./data --langs en,es,fr |
|
|
|
# Customize tokenizers |
|
python %(prog)s --tokenizers "GPT4:gpt-4,Claude:claude-3" --langs en,fr |
|
|
|
# Advanced: Multiple languages with custom output |
|
python %(prog)s --use-http --langs en,es,fr,de,it,pt,ru,ja,zh,ar,hi,ko \\ |
|
--n-chars 100000 --output-prefix my_comparison |
|
""", |
|
) |
|
|
|
# Data source arguments |
|
parser.add_argument( |
|
"--data-dir", |
|
type=str, |
|
default=None, |
|
help="Directory containing {lang_code}.txt files", |
|
) |
|
parser.add_argument( |
|
"--hf-dataset", |
|
type=str, |
|
default=None, |
|
help="HuggingFace dataset name (e.g., 'HuggingFaceFW/finewiki')", |
|
) |
|
parser.add_argument( |
|
"--use-http", action="store_true", help="Load parquet files directly via HTTP" |
|
) |
|
parser.add_argument( |
|
"--base-url", |
|
type=str, |
|
default="https://huggingface.co/datasets/HuggingFaceFW/finewiki/resolve/main/data", |
|
help="Base URL for HTTP parquet loading", |
|
) |
|
parser.add_argument( |
|
"--content-field", |
|
type=str, |
|
default="text", |
|
help="Field name for text content in dataset", |
|
) |
|
|
|
# Language and tokenizer arguments |
|
parser.add_argument( |
|
"--langs", |
|
type=str, |
|
default="en,es,fr,zh,ar,ja", |
|
help="Comma-separated language codes", |
|
) |
|
parser.add_argument( |
|
"--tokenizers", |
|
type=str, |
|
default=None, |
|
help="Custom tokenizers as 'name:path,name:path'", |
|
) |
|
|
|
# Processing arguments |
|
parser.add_argument( |
|
"--n-chars", type=int, default=50000, help="Target characters per language" |
|
) |
|
parser.add_argument( |
|
"--output-prefix", type=str, default="tokenizer", help="Prefix for output files" |
|
) |
|
|
|
# Output control |
|
parser.add_argument( |
|
"--no-viz", action="store_true", help="Skip visualization generation" |
|
) |
|
parser.add_argument("--verbose", action="store_true", help="Enable verbose logging") |
|
|
|
return parser.parse_args() |
|
|
|
|
|
def create_default_tokenizers() -> List[Tuple[str, str]]: |
|
"""Return default tokenizers to evaluate.""" |
|
return [ |
|
("Llama3", "meta-llama/Llama-3.2-1B"), |
|
("Gemma2", "google/gemma-2-2b"), |
|
("Qwen2.5", "Qwen/Qwen2.5-0.5B"), |
|
("Phi-3", "microsoft/Phi-3-mini-4k-instruct"), |
|
] |
|
|
|
|
|
def parse_custom_tokenizers(tokenizer_spec: str) -> List[Tuple[str, str]]: |
|
"""Parse custom tokenizer specification.""" |
|
tokenizers = [] |
|
for tok_spec in tokenizer_spec.split(","): |
|
if ":" not in tok_spec: |
|
logger.warning(f"Invalid tokenizer spec: {tok_spec} (expected 'name:path')") |
|
continue |
|
name, path = tok_spec.split(":", 1) |
|
tokenizers.append((name.strip(), path.strip())) |
|
return tokenizers |
|
|
|
|
|
def create_languages(lang_codes: List[str]) -> List[Language]: |
|
"""Create Language objects from codes.""" |
|
languages = [] |
|
for code in lang_codes: |
|
name = LANGUAGE_NAMES.get(code, code.upper()) |
|
languages.append(Language(name, code)) |
|
return languages |
|
|
|
|
|
def validate_configuration(args): |
|
"""Validate and report configuration.""" |
|
lang_codes = [code.strip() for code in args.langs.split(",")] |
|
|
|
# Validate language codes |
|
if args.hf_dataset or args.use_http: |
|
invalid_codes = [code for code in lang_codes if code not in FINEWIKI_CONFIGS] |
|
if invalid_codes: |
|
logger.warning(f"⚠ Some language codes not in finewiki: {invalid_codes}") |
|
|
|
# Report configuration |
|
logger.info("=" * 80) |
|
logger.info("TOKENIZER COMPARISON TOOL") |
|
logger.info("=" * 80) |
|
logger.info(f"\nLanguages: {', '.join(lang_codes)}") |
|
logger.info(f"Characters per language: {args.n_chars:,}") |
|
|
|
# Data source |
|
if args.use_http: |
|
logger.info(f"\nPrimary source: HTTP parquet from '{args.base_url}'") |
|
logger.info(f"Content field: '{args.content_field}'") |
|
elif args.hf_dataset: |
|
logger.info(f"\nPrimary source: HuggingFace dataset '{args.hf_dataset}'") |
|
logger.info(f"Content field: '{args.content_field}'") |
|
|
|
if args.data_dir: |
|
logger.info(f"Secondary source: Local files in '{args.data_dir}'") |
|
|
|
if not args.use_http and not args.hf_dataset and not args.data_dir: |
|
logger.info("\nUsing built-in sample texts") |
|
logger.info( |
|
f"⚠ Only {len(SAMPLE_TEXTS)} languages have samples: {', '.join(SAMPLE_TEXTS.keys())}" |
|
) |
|
|
|
return lang_codes |
|
|
|
|
|
def main(): |
|
"""Main execution function.""" |
|
args = parse_arguments() |
|
|
|
# Configure logging level |
|
if args.verbose: |
|
logging.getLogger().setLevel(logging.DEBUG) |
|
|
|
# Validate and report configuration |
|
lang_codes = validate_configuration(args) |
|
|
|
# Prepare tokenizers |
|
if args.tokenizers: |
|
tokenizers = parse_custom_tokenizers(args.tokenizers) |
|
else: |
|
tokenizers = create_default_tokenizers() |
|
|
|
logger.info(f"\nTokenizers: {', '.join(t[0] for t in tokenizers)}") |
|
|
|
# Create language objects |
|
languages = create_languages(lang_codes) |
|
|
|
# Determine effective dataset parameter |
|
effective_hf_dataset = "HTTP" if args.use_http else args.hf_dataset |
|
|
|
# Load text data |
|
logger.info("\n" + "=" * 80) |
|
logger.info("LOADING TEXT DATA") |
|
logger.info("=" * 80) |
|
texts = load_all_texts( |
|
languages=languages, |
|
data_dir=args.data_dir, |
|
hf_dataset=effective_hf_dataset, |
|
use_http=args.use_http, |
|
base_url=args.base_url if args.use_http else None, |
|
content_field=args.content_field, |
|
n_chars=args.n_chars, |
|
) |
|
|
|
# Check if we have any texts |
|
if not any(texts.values()): |
|
logger.error("\n✗ Error: No text data available. Exiting.") |
|
return 1 |
|
|
|
# Evaluate tokenizers |
|
logger.info("\n" + "=" * 80) |
|
logger.info("EVALUATING TOKENIZERS") |
|
logger.info("=" * 80) |
|
df = evaluate_all_tokenizers(tokenizers, languages, texts) |
|
|
|
if df.empty: |
|
logger.error("\n✗ No results generated. Check errors above.") |
|
return 1 |
|
|
|
# Save results |
|
output_csv = f"{args.output_prefix}_comparison.csv" |
|
save_results(df, output_csv=output_csv) |
|
|
|
# Print summaries |
|
print_summaries(df) |
|
|
|
# Generate visualizations |
|
if not args.no_viz: |
|
logger.info("\n" + "=" * 80) |
|
logger.info("GENERATING VISUALIZATIONS") |
|
logger.info("=" * 80) |
|
|
|
visualizer = Visualizer() |
|
visualizer.create_heatmaps(df, output_prefix=args.output_prefix) |
|
visualizer.create_comparison_charts(df, output_prefix=args.output_prefix) |
|
|
|
logger.info("\n" + "=" * 80) |
|
logger.info("✓ COMPARISON COMPLETE") |
|
logger.info("=" * 80) |
|
|
|
return 0 |
|
|
|
|
|
if __name__ == "__main__": |
|
exit(main()) |
Usage: