|
#!/usr/bin/env python3 |
|
# /// script |
|
# requires-python = ">=3.10" |
|
# dependencies = [ |
|
# "numpy>=1.24", |
|
# "matplotlib>=3.7", |
|
# "scipy>=1.10", |
|
# ] |
|
# /// |
|
""" |
|
TurboQuant Core Concepts Demo |
|
============================= |
|
|
|
This script demonstrates the key mathematical insights behind TurboQuant: |
|
|
|
1. Random Rotation: Any vector on a unit sphere, when multiplied by a random |
|
rotation matrix, becomes uniformly distributed on the sphere. |
|
|
|
2. Coordinate Distribution: Each coordinate of a uniformly random point on |
|
the d-dimensional unit sphere follows a Beta distribution: |
|
|
|
f_X(x) = Γ(d/2) / (√π · Γ((d-1)/2)) · (1 - x²)^((d-3)/2) |
|
|
|
For large d, this converges to a Gaussian N(0, 1/d). |
|
|
|
3. Near-Independence: In high dimensions, distinct coordinates become nearly |
|
independent, allowing us to apply optimal scalar quantizers to each |
|
coordinate separately. |
|
|
|
These properties are what make TurboQuant "data-oblivious" - the quantizer |
|
doesn't need to see the data distribution ahead of time because random |
|
rotation induces a known, predictable distribution on the coordinates. |
|
|
|
Reference: "TurboQuant: Online Vector Quantization with Near-optimal Distortion Rate" |
|
arXiv:2504.19874 |
|
""" |
|
|
|
import numpy as np |
|
from scipy import special |
|
from scipy import stats |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
def generate_random_rotation_matrix(d: int, rng: np.random.Generator) -> np.ndarray: |
|
""" |
|
Generate a random orthogonal (rotation) matrix in d dimensions. |
|
|
|
Method: QR decomposition of a random Gaussian matrix. |
|
This produces a matrix uniformly distributed over the orthogonal group O(d). |
|
|
|
This is the matrix Π in TurboQuant that transforms any input vector |
|
into a uniformly random point on the sphere. |
|
|
|
Args: |
|
d: Dimension of the space |
|
rng: NumPy random generator for reproducibility |
|
|
|
Returns: |
|
Π: A d×d orthogonal matrix (Π^T Π = I) |
|
""" |
|
# Generate a matrix with i.i.d. standard normal entries |
|
gaussian_matrix = rng.standard_normal((d, d)) |
|
|
|
# QR decomposition gives us an orthogonal matrix Q |
|
Q, R = np.linalg.qr(gaussian_matrix) |
|
|
|
# Ensure proper rotation (det = +1) by adjusting signs |
|
# This makes Q uniformly distributed over O(d) |
|
Q = Q @ np.diag(np.sign(np.diag(R))) |
|
|
|
return Q |
|
|
|
|
|
def theoretical_coordinate_pdf(x: np.ndarray, d: int) -> np.ndarray: |
|
""" |
|
Compute the theoretical PDF of a single coordinate for a uniform |
|
random point on the (d-1)-dimensional unit sphere in R^d. |
|
|
|
From Lemma 1 in the paper: |
|
f_X(x) = Γ(d/2) / (√π · Γ((d-1)/2)) · (1 - x²)^((d-3)/2) |
|
|
|
Intuition: This is a "projected" beta distribution. The coordinate x_j |
|
of a point on the sphere can range from -1 to +1, and most of the |
|
"surface area" of a high-dimensional sphere is concentrated near the |
|
equator (where x_j ≈ 0). |
|
|
|
Args: |
|
x: Points at which to evaluate the PDF (must be in [-1, 1]) |
|
d: Dimension of the ambient space |
|
|
|
Returns: |
|
PDF values at each point x |
|
""" |
|
# Use log-gamma for numerical stability with large d |
|
# log(normalizer) = loggamma(d/2) - 0.5*log(π) - loggamma((d-1)/2) |
|
log_normalizer = (special.gammaln(d / 2) |
|
- 0.5 * np.log(np.pi) |
|
- special.gammaln((d - 1) / 2)) |
|
|
|
# The PDF: proportional to (1 - x²)^((d-3)/2) |
|
# Use exp(log(...)) for numerical stability |
|
exponent = (d - 3) / 2 |
|
log_pdf = log_normalizer + exponent * np.log(np.maximum(1 - x**2, 1e-300)) |
|
pdf = np.exp(log_pdf) |
|
|
|
return pdf |
|
|
|
|
|
def demonstrate_rotation_uniformity(d: int = 128, n_samples: int = 10000): |
|
""" |
|
Demonstrate that random rotation makes ANY fixed vector uniformly |
|
distributed on the sphere. |
|
|
|
Key insight: If x is a fixed unit vector and Π is a random rotation, |
|
then Πx is uniformly distributed on the sphere S^(d-1). |
|
|
|
This is why TurboQuant is "data-oblivious" - no matter what vectors |
|
you give it, after rotation they all look the same statistically. |
|
""" |
|
print("=" * 70) |
|
print("DEMO 1: Random Rotation → Uniform Distribution on Sphere") |
|
print("=" * 70) |
|
|
|
rng = np.random.default_rng(42) |
|
|
|
# Start with a very "biased" vector: all weight on first coordinate |
|
# x = [1, 0, 0, ..., 0] |
|
x_biased = np.zeros(d) |
|
x_biased[0] = 1.0 |
|
|
|
print(f"\nInput vector: x = [1, 0, 0, ..., 0] in R^{d}") |
|
print(f"This is maximally 'non-uniform' - all mass on one coordinate.\n") |
|
|
|
# Apply many random rotations and collect the first coordinate |
|
rotated_first_coords = [] |
|
|
|
for _ in range(n_samples): |
|
Pi = generate_random_rotation_matrix(d, rng) |
|
y = Pi @ x_biased # Rotated vector |
|
rotated_first_coords.append(y[0]) |
|
|
|
rotated_first_coords = np.array(rotated_first_coords) |
|
|
|
# Compare empirical distribution to theoretical Beta distribution |
|
fig, axes = plt.subplots(1, 2, figsize=(14, 5)) |
|
|
|
# Left plot: Histogram of rotated coordinates |
|
ax1 = axes[0] |
|
ax1.hist(rotated_first_coords, bins=50, density=True, alpha=0.7, |
|
label='Empirical (rotated x₁)', color='steelblue', edgecolor='white') |
|
|
|
# Overlay theoretical PDF |
|
x_range = np.linspace(-0.99, 0.99, 200) |
|
theoretical_pdf = theoretical_coordinate_pdf(x_range, d) |
|
ax1.plot(x_range, theoretical_pdf, 'r-', linewidth=2.5, |
|
label=f'Theoretical Beta PDF (d={d})') |
|
|
|
# Also show the Gaussian approximation for large d |
|
gaussian_std = 1 / np.sqrt(d) |
|
gaussian_pdf = stats.norm.pdf(x_range, 0, gaussian_std) |
|
ax1.plot(x_range, gaussian_pdf, 'g--', linewidth=2, |
|
label=f'Gaussian N(0, 1/{d}) approx') |
|
|
|
ax1.set_xlabel('Coordinate value', fontsize=12) |
|
ax1.set_ylabel('Density', fontsize=12) |
|
ax1.set_title(f'Distribution of First Coordinate After Random Rotation\n' |
|
f'(d={d}, n={n_samples:,} samples)', fontsize=12) |
|
ax1.legend(fontsize=10) |
|
ax1.grid(True, alpha=0.3) |
|
|
|
# Right plot: Q-Q plot comparing to theoretical distribution |
|
ax2 = axes[1] |
|
|
|
# Sort empirical values |
|
sorted_coords = np.sort(rotated_first_coords) |
|
n = len(sorted_coords) |
|
|
|
# Generate theoretical quantiles using inverse CDF sampling |
|
# For the Beta distribution on [-1, 1], we use scipy's beta distribution |
|
# The coordinate distribution is Beta((d-1)/2, (d-1)/2) scaled to [-1, 1] |
|
alpha = (d - 1) / 2 |
|
theoretical_quantiles = stats.beta.ppf(np.linspace(0.001, 0.999, n), alpha, alpha) |
|
theoretical_quantiles = 2 * theoretical_quantiles - 1 # Scale from [0,1] to [-1,1] |
|
|
|
ax2.scatter(theoretical_quantiles, sorted_coords, alpha=0.3, s=1, color='steelblue') |
|
ax2.plot([-0.5, 0.5], [-0.5, 0.5], 'r-', linewidth=2, label='Perfect fit line') |
|
ax2.set_xlabel('Theoretical Quantiles', fontsize=12) |
|
ax2.set_ylabel('Empirical Quantiles', fontsize=12) |
|
ax2.set_title('Q-Q Plot: Empirical vs Theoretical Distribution', fontsize=12) |
|
ax2.legend(fontsize=10) |
|
ax2.grid(True, alpha=0.3) |
|
ax2.set_aspect('equal') |
|
|
|
plt.tight_layout() |
|
plt.savefig('demo1_rotation_uniformity.png', dpi=150, bbox_inches='tight') |
|
plt.close() |
|
|
|
print(f"Empirical mean: {np.mean(rotated_first_coords):.6f} (theory: 0)") |
|
print(f"Empirical variance: {np.var(rotated_first_coords):.6f} (theory: 1/{d} = {1/d:.6f})") |
|
print(f"\n✓ Saved plot to: demo1_rotation_uniformity.png") |
|
|
|
|
|
def demonstrate_dimension_effect(): |
|
""" |
|
Show how the coordinate distribution changes with dimension d. |
|
|
|
Key insight: As d increases, the Beta distribution concentrates |
|
more and more around 0, converging to N(0, 1/d). |
|
|
|
This "concentration of measure" is why high-dimensional geometry |
|
is so different from our 3D intuition! |
|
""" |
|
print("\n" + "=" * 70) |
|
print("DEMO 2: Effect of Dimension on Coordinate Distribution") |
|
print("=" * 70) |
|
|
|
dimensions = [3, 10, 50, 128, 512] |
|
|
|
fig, ax = plt.subplots(figsize=(12, 6)) |
|
|
|
x_range = np.linspace(-0.99, 0.99, 500) |
|
colors = plt.cm.viridis(np.linspace(0, 0.9, len(dimensions))) |
|
|
|
for d, color in zip(dimensions, colors): |
|
pdf = theoretical_coordinate_pdf(x_range, d) |
|
ax.plot(x_range, pdf, linewidth=2.5, color=color, label=f'd = {d}') |
|
|
|
ax.set_xlabel('Coordinate value x', fontsize=12) |
|
ax.set_ylabel('Probability density f(x)', fontsize=12) |
|
ax.set_title('Coordinate Distribution on Unit Sphere for Various Dimensions\n' |
|
'Higher d → More concentrated around 0 (approaches Gaussian)', fontsize=12) |
|
ax.legend(fontsize=11) |
|
ax.grid(True, alpha=0.3) |
|
ax.set_xlim(-1, 1) |
|
|
|
plt.tight_layout() |
|
plt.savefig('demo2_dimension_effect.png', dpi=150, bbox_inches='tight') |
|
plt.close() |
|
|
|
print(f"\nStandard deviation of coordinate for various d:") |
|
print(f"{'Dimension d':<15} {'Std Dev (theory)':<20} {'1/√d':<15}") |
|
print("-" * 50) |
|
for d in dimensions: |
|
# Variance of coordinate on unit sphere is 1/d |
|
std_theory = 1 / np.sqrt(d) |
|
print(f"{d:<15} {std_theory:<20.6f} {std_theory:<15.6f}") |
|
|
|
print(f"\n✓ Saved plot to: demo2_dimension_effect.png") |
|
|
|
|
|
def demonstrate_coordinate_independence(d: int = 128, n_samples: int = 5000): |
|
""" |
|
Demonstrate that distinct coordinates become nearly independent |
|
in high dimensions. |
|
|
|
This is crucial for TurboQuant: it means we can quantize each |
|
coordinate separately using optimal scalar quantizers, without |
|
worrying about correlations between coordinates. |
|
|
|
The correlation between any two coordinates x_i and x_j (i ≠ j) |
|
for a uniform point on the sphere is exactly -1/(d-1), which |
|
vanishes as d → ∞. |
|
""" |
|
print("\n" + "=" * 70) |
|
print("DEMO 3: Near-Independence of Coordinates in High Dimensions") |
|
print("=" * 70) |
|
|
|
rng = np.random.default_rng(123) |
|
|
|
# Generate uniform random points on the sphere |
|
# Method: normalize Gaussian vectors |
|
gaussian_samples = rng.standard_normal((n_samples, d)) |
|
norms = np.linalg.norm(gaussian_samples, axis=1, keepdims=True) |
|
sphere_samples = gaussian_samples / norms |
|
|
|
# Extract first two coordinates |
|
x1 = sphere_samples[:, 0] |
|
x2 = sphere_samples[:, 1] |
|
|
|
# Compute empirical correlation |
|
empirical_corr = np.corrcoef(x1, x2)[0, 1] |
|
theoretical_corr = -1 / (d - 1) |
|
|
|
print(f"\nDimension d = {d}") |
|
print(f"Theoretical correlation between x₁ and x₂: {theoretical_corr:.6f}") |
|
print(f"Empirical correlation (n={n_samples:,}): {empirical_corr:.6f}") |
|
print(f"\nAs d → ∞, correlation → 0 (independence)") |
|
|
|
# Create scatter plot |
|
fig, axes = plt.subplots(1, 2, figsize=(14, 5)) |
|
|
|
# Left: Joint distribution of (x1, x2) |
|
ax1 = axes[0] |
|
ax1.scatter(x1, x2, alpha=0.2, s=5, color='steelblue') |
|
ax1.set_xlabel('$x_1$ (first coordinate)', fontsize=12) |
|
ax1.set_ylabel('$x_2$ (second coordinate)', fontsize=12) |
|
ax1.set_title(f'Joint Distribution of Two Coordinates (d={d})\n' |
|
f'Correlation = {empirical_corr:.4f} ≈ 0', fontsize=12) |
|
ax1.set_aspect('equal') |
|
ax1.grid(True, alpha=0.3) |
|
|
|
# Add a circle showing the constraint x1² + x2² ≤ 1 |
|
theta = np.linspace(0, 2*np.pi, 100) |
|
ax1.plot(np.cos(theta), np.sin(theta), 'r--', alpha=0.5, linewidth=1, |
|
label='Unit circle (2D projection)') |
|
ax1.legend(fontsize=10) |
|
|
|
# Right: How correlation decreases with dimension |
|
ax2 = axes[1] |
|
dims = np.array([3, 5, 10, 20, 50, 100, 200, 500, 1000]) |
|
correlations = -1 / (dims - 1) |
|
|
|
ax2.plot(dims, np.abs(correlations), 'o-', linewidth=2, markersize=8, color='steelblue') |
|
ax2.set_xlabel('Dimension d', fontsize=12) |
|
ax2.set_ylabel('|Correlation between coordinates|', fontsize=12) |
|
ax2.set_title('Correlation Vanishes with Increasing Dimension\n' |
|
'|ρ(x₁, x₂)| = 1/(d-1) → 0', fontsize=12) |
|
ax2.set_xscale('log') |
|
ax2.set_yscale('log') |
|
ax2.grid(True, alpha=0.3, which='both') |
|
|
|
plt.tight_layout() |
|
plt.savefig('demo3_coordinate_independence.png', dpi=150, bbox_inches='tight') |
|
plt.close() |
|
|
|
print(f"\n✓ Saved plot to: demo3_coordinate_independence.png") |
|
|
|
|
|
def demonstrate_simple_quantization(d: int = 128, b: int = 2, n_vectors: int = 100): |
|
""" |
|
Demonstrate the basic TurboQuant quantization process. |
|
|
|
For a b-bit quantizer: |
|
1. Rotate vector with random Π |
|
2. For each coordinate, find which of 2^b buckets it falls into |
|
3. Store bucket indices (b bits each) |
|
4. To reconstruct: look up centroids, rotate back with Π^T |
|
""" |
|
print("\n" + "=" * 70) |
|
print(f"DEMO 4: Simple {b}-bit Quantization Example") |
|
print("=" * 70) |
|
|
|
rng = np.random.default_rng(456) |
|
|
|
# Compute optimal centroids for Gaussian N(0, 1/d) |
|
# For 2-bit (4 levels), optimal centroids for standard normal are approximately: |
|
# ±0.4528, ±1.5104 (scaled by 1/√d for our distribution) |
|
if b == 1: |
|
# 1-bit: just positive/negative |
|
centroids = np.array([-1, 1]) * np.sqrt(2 / (np.pi * d)) |
|
elif b == 2: |
|
# 2-bit: 4 levels (from Lloyd-Max for Gaussian) |
|
centroids = np.array([-1.510, -0.4528, 0.4528, 1.510]) / np.sqrt(d) |
|
else: |
|
# For higher bits, use uniform spacing as approximation |
|
centroids = np.linspace(-3, 3, 2**b) / np.sqrt(d) |
|
|
|
print(f"\nQuantization setup:") |
|
print(f" Dimension d = {d}") |
|
print(f" Bit-width b = {b} ({2**b} levels)") |
|
print(f" Centroids: {np.array2string(centroids * np.sqrt(d), precision=4)} × 1/√d") |
|
|
|
# Generate random rotation matrix (shared across all vectors) |
|
Pi = generate_random_rotation_matrix(d, rng) |
|
|
|
# Generate some random test vectors on the unit sphere |
|
test_vectors = rng.standard_normal((n_vectors, d)) |
|
test_vectors /= np.linalg.norm(test_vectors, axis=1, keepdims=True) |
|
|
|
mse_errors = [] |
|
|
|
for x in test_vectors: |
|
# Step 1: Rotate |
|
y = Pi @ x |
|
|
|
# Step 2: Quantize each coordinate to nearest centroid |
|
# Find index of nearest centroid for each coordinate |
|
distances = np.abs(y[:, np.newaxis] - centroids[np.newaxis, :]) |
|
indices = np.argmin(distances, axis=1) |
|
|
|
# Step 3: Dequantize (look up centroids) |
|
y_quantized = centroids[indices] |
|
|
|
# Step 4: Rotate back |
|
x_reconstructed = Pi.T @ y_quantized |
|
|
|
# Compute MSE (total squared error, as defined in the paper) |
|
# D_mse = E[||x - x̃||²] = sum of squared errors across all d coordinates |
|
mse = np.sum((x - x_reconstructed) ** 2) |
|
mse_errors.append(mse) |
|
|
|
mse_errors = np.array(mse_errors) |
|
|
|
# Theoretical MSE from the paper (for unit norm vectors) |
|
theoretical_mse = {1: 0.36, 2: 0.117, 3: 0.03, 4: 0.009} |
|
|
|
print(f"\nResults over {n_vectors} random unit vectors:") |
|
print(f" Empirical MSE (||x - x̃||²): {np.mean(mse_errors):.4f} ± {np.std(mse_errors):.4f}") |
|
if b in theoretical_mse: |
|
print(f" Theoretical upper bound: {theoretical_mse[b]:.4f}") |
|
print(f" Information-theoretic lower: {1/4**b:.4f}") |
|
|
|
# Visualize one example |
|
fig, axes = plt.subplots(1, 3, figsize=(15, 4)) |
|
|
|
# Pick one vector for visualization |
|
x = test_vectors[0] |
|
y = Pi @ x |
|
distances = np.abs(y[:, np.newaxis] - centroids[np.newaxis, :]) |
|
indices = np.argmin(distances, axis=1) |
|
y_quantized = centroids[indices] |
|
x_reconstructed = Pi.T @ y_quantized |
|
|
|
# Plot 1: Original rotated coordinates |
|
ax1 = axes[0] |
|
ax1.hist(y, bins=30, density=True, alpha=0.7, color='steelblue', edgecolor='white') |
|
for c in centroids: |
|
ax1.axvline(c, color='red', linestyle='--', linewidth=2, alpha=0.7) |
|
ax1.set_xlabel('Rotated coordinate value', fontsize=11) |
|
ax1.set_ylabel('Density', fontsize=11) |
|
ax1.set_title('Step 1-2: Rotated Coordinates\n(red lines = quantization centroids)', fontsize=11) |
|
ax1.grid(True, alpha=0.3) |
|
|
|
# Plot 2: Quantization error per coordinate |
|
ax2 = axes[1] |
|
coord_errors = (y - y_quantized) ** 2 |
|
ax2.hist(coord_errors, bins=30, density=True, alpha=0.7, color='orange', edgecolor='white') |
|
ax2.set_xlabel('Squared error per coordinate', fontsize=11) |
|
ax2.set_ylabel('Density', fontsize=11) |
|
ax2.set_title(f'Per-coordinate Quantization Error\nMean = {np.mean(coord_errors):.4f}', fontsize=11) |
|
ax2.grid(True, alpha=0.3) |
|
|
|
# Plot 3: Original vs reconstructed (first 50 coords) |
|
ax3 = axes[2] |
|
n_show = 50 |
|
ax3.plot(range(n_show), x[:n_show], 'b-', linewidth=1.5, label='Original x', alpha=0.8) |
|
ax3.plot(range(n_show), x_reconstructed[:n_show], 'r--', linewidth=1.5, |
|
label='Reconstructed x̃', alpha=0.8) |
|
ax3.set_xlabel('Coordinate index', fontsize=11) |
|
ax3.set_ylabel('Value', fontsize=11) |
|
ax3.set_title(f'Original vs Reconstructed (first {n_show} coords)\n' |
|
f'||x - x̃||² = {np.sum((x - x_reconstructed)**2):.4f}', fontsize=11) |
|
ax3.legend(fontsize=10) |
|
ax3.grid(True, alpha=0.3) |
|
|
|
plt.tight_layout() |
|
plt.savefig('demo4_quantization_example.png', dpi=150, bbox_inches='tight') |
|
plt.close() |
|
|
|
print(f"\n✓ Saved plot to: demo4_quantization_example.png") |
|
|
|
|
|
def main(): |
|
"""Run all demonstrations.""" |
|
print("\n" + "=" * 70) |
|
print("TurboQuant: Core Mathematical Concepts Demonstration") |
|
print("=" * 70) |
|
print("\nThis script demonstrates the key insights that make TurboQuant work:") |
|
print(" 1. Random rotation → uniform distribution on sphere") |
|
print(" 2. Known Beta distribution for each coordinate") |
|
print(" 3. Near-independence of coordinates in high dimensions") |
|
print(" 4. Simple scalar quantization achieves near-optimal MSE") |
|
|
|
# Run demonstrations |
|
demonstrate_rotation_uniformity(d=128, n_samples=10000) |
|
demonstrate_dimension_effect() |
|
demonstrate_coordinate_independence(d=128, n_samples=5000) |
|
demonstrate_simple_quantization(d=128, b=2, n_vectors=1000) |
|
|
|
print("\n" + "=" * 70) |
|
print("All demonstrations complete!") |
|
print("=" * 70) |
|
print("\nGenerated plots:") |
|
print(" • demo1_rotation_uniformity.png") |
|
print(" • demo2_dimension_effect.png") |
|
print(" • demo3_coordinate_independence.png") |
|
print(" • demo4_quantization_example.png") |
|
print("\nThese visualizations show why TurboQuant can achieve near-optimal") |
|
print("compression without needing to see the data distribution in advance.") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
Figures from local run: