Skip to content

Instantly share code, notes, and snippets.

@ameya98
Last active April 17, 2025 18:34
Show Gist options
  • Select an option

  • Save ameya98/2c21f50ccd6a20eb5e556d18e989a97f to your computer and use it in GitHub Desktop.

Select an option

Save ameya98/2c21f50ccd6a20eb5e556d18e989a97f to your computer and use it in GitHub Desktop.
Fast Bispectrum in PyTorch 2.0
import e3nn
import e3nn.o3
import e3nn_jax
import torch
import torch.nn as nn
class Bispectrum(nn.Module):
"""Computes the bispectrum of a signal."""
def __init__(self, irreps_in: e3nn.o3.Irreps):
super().__init__()
rtp = e3nn_jax.reduced_symmetric_tensor_product_basis(
str(irreps_in), degree=3, keep_ir=["0e", "0o"],
)
self.irreps_in = e3nn.o3.Irreps(irreps_in)
self.irreps_out = e3nn.o3.Irreps(str(rtp.irreps))
rtp_cob = torch.as_tensor(rtp.array, dtype=torch.float32)
self.register_buffer("rtp_cob", rtp_cob)
def forward(self, sig: torch.Tensor) -> torch.Tensor:
"""Computes the bispectrum of a signal."""
return torch.einsum("ijkz,...i,...j,...k->...z", self.rtp_cob, sig, sig, sig)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment