Skip to content

Instantly share code, notes, and snippets.

@cccntu
Last active November 28, 2024 13:41
Show Gist options
  • Save cccntu/f3be8fd827bac7c4bee7f62705cb457e to your computer and use it in GitHub Desktop.
Save cccntu/f3be8fd827bac7c4bee7f62705cb457e to your computer and use it in GitHub Desktop.
import torch
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, dtype=torch.float32, mod=True):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
angles = torch.outer(t, freqs) # type: ignore
if mod:
angles = angles % (2 * torch.pi)
# Now convert angles to the desired lower precision
angles = angles.to(dtype)
return angles.cos()
print('-------------- MOD=False -------------- ')
x = precompute_freqs_cis(64, 2048, mod=False)
print(x)
print(x.dtype)
x = precompute_freqs_cis(64, 2048, dtype=torch.float16, mod=False)
print(x)
print(x.dtype)
x = precompute_freqs_cis(64, 2048, dtype=torch.bfloat16, mod=False)
print(x)
print(x.dtype)
print('-------------- MOD=True -------------- ')
x = precompute_freqs_cis(64, 2048, mod=False)
print(x)
print(x.dtype)
x = precompute_freqs_cis(64, 2048, dtype=torch.float16)
print(x)
print(x.dtype)
x = precompute_freqs_cis(64, 2048, dtype=torch.bfloat16)
print(x)
print(x.dtype)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment