Skip to content

Instantly share code, notes, and snippets.

@jb2170
Created March 1, 2024 02:13
Show Gist options
  • Save jb2170/b9b55de049154b45bdbc1cf8a442f6a4 to your computer and use it in GitHub Desktop.
Save jb2170/b9b55de049154b45bdbc1cf8a442f6a4 to your computer and use it in GitHub Desktop.
Change the basis of an index of a tensor in numpy, by contracting with a matrix and reordering the resulting indices
import numpy as np
from numpy.typing import NDArray
def tensor_change_basis(tensor: NDArray, index: int, basis_matrix: NDArray) -> NDArray:
"""
Contract the arbitrary rank tensor `tensor` T_ijk...x...lmn
and the rank 2 tensor `basis_matrix` M_wx
to return T'_ijk...w...lmn, where x is the `index`-th index of `tensor`
"""
rank = len(tensor.shape)
permutation = tuple(range(1, index + 1)) + (0,) + tuple(range(index + 1, rank))
# `np.tensordot` returns T'_wijk...lmn; we need to rearrange the indices with `permutation`
return np.tensordot(basis_matrix, tensor, (1, index)).transpose(permutation)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment