Skip to content

Instantly share code, notes, and snippets.

@jb2170
jb2170 / tensor_change_basis.py
Created March 1, 2024 02:13
Change the basis of an index of a tensor in numpy, by contracting with a matrix and reordering the resulting indices
import numpy as np
from numpy.typing import NDArray
def tensor_change_basis(tensor: NDArray, index: int, basis_matrix: NDArray) -> NDArray:
"""
Contract the arbitrary rank tensor `tensor` T_ijk...x...lmn
and the rank 2 tensor `basis_matrix` M_wx
to return T'_ijk...w...lmn, where x is the `index`-th index of `tensor`
"""
rank = len(tensor.shape)