Created
November 16, 2023 12:49
-
-
Save Bollegala/376bab625acd4d44ebc93d1334a31bef to your computer and use it in GitHub Desktop.
Einsum Examples
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"In this notebook, we will study the einsum notation.\n", | |
"## References\n", | |
"1. https://www.youtube.com/watch?v=pkVwUVEHmfI&ab_channel=AladdinPersson\n", | |
"2. https://rockt.github.io/2018/04/30/einsum\n", | |
"3. https://ajcr.net/Basic-guide-to-einsum/" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"import torch" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"#A = np.random.rand(3,5)\n", | |
"#B = np.random.rand(5,2)\n", | |
"\n", | |
"A = torch.rand(3,5)\n", | |
"B = torch.rand(5,2)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Matrix Multiplication \n", | |
"# i and j are called free indices and k the summation index\n", | |
"print(torch.einsum('ik,kj->ij', A, B))\n", | |
"print(A @ B)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Outer product\n", | |
"a = torch.rand(5)\n", | |
"b = torch.rand(3)\n", | |
"C1 = torch.outer(a, b)\n", | |
"C2 = torch.einsum('i,j->ij', a, b)\n", | |
"assert((C1-C2).all() == 0)\n", | |
"\n", | |
"#D1 = torch.outer(A,B)\n", | |
"#D2 = torch.einsum('in,jm...->ij', A, B)\n", | |
"#assert((D1-D2).all() == 0)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Transpose\n", | |
"C1 = torch.einsum('ij->ji', A)\n", | |
"assert((A.T - C1).all() == 0)\n", | |
"#print(torch.einsum('ij->ji', a))\n", | |
"a.reshape(5,1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Get the diagonal elements of a matrix\n", | |
"# Even if the tensor is not square, torch.diag returns the diagonal, whereas einsum returns an error (as rightly so!)\n", | |
"\n", | |
"E = torch.rand(4,4)\n", | |
"print(E)\n", | |
"E.diag()\n", | |
"torch.einsum('ii->i', E)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Dot product\n", | |
"x = torch.rand(5)\n", | |
"y = torch.rand(5)\n", | |
"print(torch.dot(x, y))\n", | |
"\n", | |
"print(torch.einsum('i,i->', x, y))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# batch matrix multiplication\n", | |
"# We have a batch of n matrices, each having dimensions p and q.\n", | |
"# We have a separate batch of n matrices, each having dimensions q and l.\n", | |
"# We multiply these two batches two get a batch of n matrices each having dimensions q x l.\n", | |
"\n", | |
"B1 = torch.rand(10, 3, 5)\n", | |
"B2 = torch.rand(10, 5, 4)\n", | |
"\n", | |
"B3 = torch.einsum('ijk, ikl -> ijl', B1, B2)\n", | |
"B3.shape\n", | |
"B3" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 42, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"tensor(6.7051)\n", | |
"tensor([[6.7051]])\n" | |
] | |
} | |
], | |
"source": [ | |
"# Bilinear Transformation\n", | |
"x = torch.rand(10, 1)\n", | |
"y = torch.rand(5, 1)\n", | |
"A = torch.rand(10,5)\n", | |
"\n", | |
"R = torch.einsum('ji,jk,kn ->', x, A, y)\n", | |
"S = x.T @ A @ y\n", | |
"print(R)\n", | |
"print(S)\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.13" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Initial commit