Skip to content

Instantly share code, notes, and snippets.

@N8python
Created January 19, 2025 22:18
Show Gist options
  • Save N8python/b3e24a4f88efa52bdd81a8762b7a7238 to your computer and use it in GitHub Desktop.
Save N8python/b3e24a4f88efa52bdd81a8762b7a7238 to your computer and use it in GitHub Desktop.
MLX matrix_exp.
import mlx.core as mx
@mx.compile
def _compute_T1(A):
"""I + A"""
return mx.eye(A.shape[-1]) + A
@mx.compile
def _compute_T2(A):
"""I + A + A^2/2"""
A2 = A @ A
return mx.eye(A.shape[-1]) + A + A2/2
@mx.compile
def _compute_T4(A):
"""I + A + A^2 * (I/2 + A/6 + A^2/24)"""
A2 = A @ A
inner_term = (mx.eye(A.shape[-1])/2 + A/6 + A2/24)
return mx.eye(A.shape[-1]) + A + (A2 @ inner_term)
@mx.compile
def _compute_T8(A):
sqrt_177 = 0.1330413469565007072504e+2
x3 = 2/3
x1 = x3 * ((1 + sqrt_177) / 88)
x2 = x3 * ((1 + sqrt_177) / 352)
x4 = (-271 + 29 * sqrt_177) / (315 * x3)
x5 = (-11 + 11 * sqrt_177) / (1260 * x3)
x6 = (-99 + 11 * sqrt_177) / (5040 * x3)
x7 = (89 - sqrt_177) / (5040 * x3)
y2 = (857 - 58 * sqrt_177) / 630
A2 = A @ A
A4 = A2 @ (x1*A + x2*A2)
A8 = (x3*A2 + A4) @ (x4*mx.eye(A.shape[-1]) + x5*A + x6*A2 + x7*A4)
return mx.eye(A.shape[-1]) + A + y2*A2 + A8
@mx.compile
def compute_scale_factor(matrix_norm, threshold):
"""
Computes the scale factor for the matrix norm.
"""
return mx.maximum(
mx.zeros_like(matrix_norm),
mx.ceil(mx.log2(matrix_norm / threshold))
).astype(mx.int32)
def _matrix_exp_forward(A):
"""
Computes matrix exponential using optimized Taylor series.
Based on PyTorch's implementation from the paper:
Bader, P.; Blanes, S.; Casas, F.
Computing the Matrix Exponential with an Optimized Taylor Polynomial Approximation.
"""
if A.shape[-2:] == (0, 0):
return A.clone()
elif A.shape[-2:] == (1, 1):
return mx.exp(A)
# Compute the matrix norm to choose degree
matrix_norm = mx.linalg.norm(A)
# These thresholds are from PyTorch's implementation
# They're carefully chosen based on the paper
thresholds = [
1.192092800768788e-07, # deg 1
5.978858893805233e-04, # deg 2
5.116619363445086e-02, # deg 4
5.800524627688768e-01, # deg 8
1.461661507209034e+00, # deg 12
3.010066362817634e+00 # deg 18
]
# For small norms use lower degree approximations
if matrix_norm <= thresholds[0]:
return _compute_T1(A)
elif matrix_norm <= thresholds[1]:
return _compute_T2(A)
elif matrix_norm <= thresholds[2]:
return _compute_T4(A)
elif matrix_norm <= thresholds[3]:
return _compute_T8(A)
# For larger norms use scaling and squaring with T8
#s = s.astype(mx.int32)
s = compute_scale_factor(matrix_norm, thresholds[3])
A_scaled = A / mx.expand_dims(mx.expand_dims(2.0**s, -1), -1)
# Compute exponential of scaled matrix
X = _compute_T8(A_scaled)
# Square back up
max_s = int(mx.max(s).item())
for _ in range(max_s):
X = mx.where(s > 0, X @ X, X)
s = s - 1
return X
@mx.custom_function
def _matrix_exp_frechet(A):
"""Matrix exponential using Fréchet derivative for gradients"""
return _matrix_exp_forward(A)
@_matrix_exp_frechet.vjp
@mx.compile
def _matrix_exp_frechet_vjp(A, cotangent, output):
"""Custom VJP using the Fréchet derivative trick"""
# Get matrix size
n = A.shape[-1]
# Create 2nx2n block matrix M = [[A^T, G], [0, A^T]]
A_H = mx.swapaxes(A, -2, -1)
M = mx.zeros((*A.shape[:-2], 2*n, 2*n), dtype=A.dtype)
# Fill the blocks using index assignment
M[..., :n, :n] = A_H
M[..., n:, n:] = A_H
M[..., :n, n:] = cotangent
# Compute exp(M) and extract top-right block
exp_M = _matrix_exp_forward(M)
grad_A = exp_M[..., :n, n:]
return grad_A
@mx.custom_function
def matrix_exp(A):
"""Matrix exponential with Fréchet derivative for gradients"""
# Use Fréchet derivative for small matrices - it is faster
return _matrix_exp_frechet(A) if A.shape[-1] <= 384 else _matrix_exp_forward(A)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment