Created
January 19, 2025 22:18
-
-
Save N8python/b3e24a4f88efa52bdd81a8762b7a7238 to your computer and use it in GitHub Desktop.
MLX matrix_exp.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import mlx.core as mx | |
@mx.compile | |
def _compute_T1(A): | |
"""I + A""" | |
return mx.eye(A.shape[-1]) + A | |
@mx.compile | |
def _compute_T2(A): | |
"""I + A + A^2/2""" | |
A2 = A @ A | |
return mx.eye(A.shape[-1]) + A + A2/2 | |
@mx.compile | |
def _compute_T4(A): | |
"""I + A + A^2 * (I/2 + A/6 + A^2/24)""" | |
A2 = A @ A | |
inner_term = (mx.eye(A.shape[-1])/2 + A/6 + A2/24) | |
return mx.eye(A.shape[-1]) + A + (A2 @ inner_term) | |
@mx.compile | |
def _compute_T8(A): | |
sqrt_177 = 0.1330413469565007072504e+2 | |
x3 = 2/3 | |
x1 = x3 * ((1 + sqrt_177) / 88) | |
x2 = x3 * ((1 + sqrt_177) / 352) | |
x4 = (-271 + 29 * sqrt_177) / (315 * x3) | |
x5 = (-11 + 11 * sqrt_177) / (1260 * x3) | |
x6 = (-99 + 11 * sqrt_177) / (5040 * x3) | |
x7 = (89 - sqrt_177) / (5040 * x3) | |
y2 = (857 - 58 * sqrt_177) / 630 | |
A2 = A @ A | |
A4 = A2 @ (x1*A + x2*A2) | |
A8 = (x3*A2 + A4) @ (x4*mx.eye(A.shape[-1]) + x5*A + x6*A2 + x7*A4) | |
return mx.eye(A.shape[-1]) + A + y2*A2 + A8 | |
@mx.compile | |
def compute_scale_factor(matrix_norm, threshold): | |
""" | |
Computes the scale factor for the matrix norm. | |
""" | |
return mx.maximum( | |
mx.zeros_like(matrix_norm), | |
mx.ceil(mx.log2(matrix_norm / threshold)) | |
).astype(mx.int32) | |
def _matrix_exp_forward(A): | |
""" | |
Computes matrix exponential using optimized Taylor series. | |
Based on PyTorch's implementation from the paper: | |
Bader, P.; Blanes, S.; Casas, F. | |
Computing the Matrix Exponential with an Optimized Taylor Polynomial Approximation. | |
""" | |
if A.shape[-2:] == (0, 0): | |
return A.clone() | |
elif A.shape[-2:] == (1, 1): | |
return mx.exp(A) | |
# Compute the matrix norm to choose degree | |
matrix_norm = mx.linalg.norm(A) | |
# These thresholds are from PyTorch's implementation | |
# They're carefully chosen based on the paper | |
thresholds = [ | |
1.192092800768788e-07, # deg 1 | |
5.978858893805233e-04, # deg 2 | |
5.116619363445086e-02, # deg 4 | |
5.800524627688768e-01, # deg 8 | |
1.461661507209034e+00, # deg 12 | |
3.010066362817634e+00 # deg 18 | |
] | |
# For small norms use lower degree approximations | |
if matrix_norm <= thresholds[0]: | |
return _compute_T1(A) | |
elif matrix_norm <= thresholds[1]: | |
return _compute_T2(A) | |
elif matrix_norm <= thresholds[2]: | |
return _compute_T4(A) | |
elif matrix_norm <= thresholds[3]: | |
return _compute_T8(A) | |
# For larger norms use scaling and squaring with T8 | |
#s = s.astype(mx.int32) | |
s = compute_scale_factor(matrix_norm, thresholds[3]) | |
A_scaled = A / mx.expand_dims(mx.expand_dims(2.0**s, -1), -1) | |
# Compute exponential of scaled matrix | |
X = _compute_T8(A_scaled) | |
# Square back up | |
max_s = int(mx.max(s).item()) | |
for _ in range(max_s): | |
X = mx.where(s > 0, X @ X, X) | |
s = s - 1 | |
return X | |
@mx.custom_function | |
def _matrix_exp_frechet(A): | |
"""Matrix exponential using Fréchet derivative for gradients""" | |
return _matrix_exp_forward(A) | |
@_matrix_exp_frechet.vjp | |
@mx.compile | |
def _matrix_exp_frechet_vjp(A, cotangent, output): | |
"""Custom VJP using the Fréchet derivative trick""" | |
# Get matrix size | |
n = A.shape[-1] | |
# Create 2nx2n block matrix M = [[A^T, G], [0, A^T]] | |
A_H = mx.swapaxes(A, -2, -1) | |
M = mx.zeros((*A.shape[:-2], 2*n, 2*n), dtype=A.dtype) | |
# Fill the blocks using index assignment | |
M[..., :n, :n] = A_H | |
M[..., n:, n:] = A_H | |
M[..., :n, n:] = cotangent | |
# Compute exp(M) and extract top-right block | |
exp_M = _matrix_exp_forward(M) | |
grad_A = exp_M[..., :n, n:] | |
return grad_A | |
@mx.custom_function | |
def matrix_exp(A): | |
"""Matrix exponential with Fréchet derivative for gradients""" | |
# Use Fréchet derivative for small matrices - it is faster | |
return _matrix_exp_frechet(A) if A.shape[-1] <= 384 else _matrix_exp_forward(A) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Algorithm from: https://raw.githubusercontent.com/pytorch/pytorch/7f18ef14c1fed4e4376a75d626d98ba3c074809c/aten/src/ATen/native/LinearAlgebra.cpp