Created
September 8, 2022 19:26
-
-
Save cat-state/6308e46f323b909825d5146afa2945a0 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"import torch.nn as nn\n", | |
"import torch.nn.functional as F\n", | |
"from math import pi\n", | |
"from typing import Tuple" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Routing" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 212, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(tensor([[1.6784, 2.4763, 1.1539]], grad_fn=<RepeatBackward>),\n", | |
" tensor([[0.3223, 2.3602, 1.2148]], grad_fn=<NormBackward3>))" | |
] | |
}, | |
"execution_count": 212, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"Input = Tuple[torch.Tensor, torch.Tensor]\n", | |
"\n", | |
"class SLMRouter(nn.Module):\n", | |
" def __init__(self, n_filters: int):\n", | |
" super().__init__()\n", | |
" self._angles = nn.Parameter(torch.zeros(n_filters).uniform_(0, 3.14159))\n", | |
" #self._angles = nn.Parameter(torch.tensor([0.0, pi / 2, pi / 2]))\n", | |
" \n", | |
" def forward(self, x: Input):\n", | |
" angle, intensity = x\n", | |
"\n", | |
" self_vecs = torch.stack([self._angles.cos(), self._angles.sin()]).unsqueeze(0)\n", | |
" in_vecs = torch.stack([angle.cos(), angle.sin()], dim=1)\n", | |
" \n", | |
" infall = self_vecs[:, :, :, None] @ in_vecs[:, :, None, :]\n", | |
" transmitted = (infall * intensity[:, None, None, :]).sum(dim=3).norm(dim=1)\n", | |
"\n", | |
" return (self._angles.repeat(intensity.shape[0], 1), transmitted)\n", | |
" \n", | |
"s = SLMRouter(3)\n", | |
"\n", | |
"x = (torch.tensor([[0.0, 0.0, 0.0]]), torch.ones(3)[None, :])\n", | |
"s(x)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Non-linearity" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 213, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[<matplotlib.lines.Line2D at 0x1e5622f1400>]" | |
] | |
}, | |
"execution_count": 213, | |
"metadata": {}, | |
"output_type": "execute_result" | |
}, | |
{ | |
"data": { | |
"image/png": "\n", | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
] | |
}, | |
"metadata": { | |
"needs_background": "light" | |
}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"import matplotlib.pyplot as plt\n", | |
"\n", | |
"def optical_limiting(x):\n", | |
" return ((x * 5).sigmoid() - 0.5) * 1.8\n", | |
"\n", | |
"plt.plot(torch.linspace(0, 1).numpy(), optical_limiting(torch.linspace(0, 1)).numpy())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 214, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Act(nn.Module):\n", | |
" def __init__(self, activation):\n", | |
" super().__init__()\n", | |
" self._activation = activation\n", | |
" \n", | |
" def forward(self, x):\n", | |
" angle, intensity = x\n", | |
" return (angle, self._activation(intensity))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 225, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"tensor(0.6629, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)\n", | |
"tensor([0, 1, 1, 0])\n", | |
"tensor([[0.0000],\n", | |
" [1.1859],\n", | |
" [1.2430],\n", | |
" [1.1676]], grad_fn=<NormBackward3>)\n", | |
"tensor(0.4733, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)\n", | |
"tensor([0, 1, 1, 0])\n", | |
"tensor([[0.0000],\n", | |
" [1.2728],\n", | |
" [1.2726],\n", | |
" [0.0262]], grad_fn=<NormBackward3>)\n", | |
"tensor(0.4276, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)\n", | |
"tensor([0, 1, 1, 0])\n", | |
"tensor([[0.0000e+00],\n", | |
" [1.7377e+00],\n", | |
" [1.7377e+00],\n", | |
" [3.2093e-05]], grad_fn=<NormBackward3>)\n", | |
"tensor(0.4261, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)\n", | |
"tensor([0, 1, 1, 0])\n", | |
"tensor([[0.0000e+00],\n", | |
" [1.7579e+00],\n", | |
" [1.7579e+00],\n", | |
" [6.3294e-06]], grad_fn=<NormBackward3>)\n", | |
"tensor(0.4261, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)\n", | |
"tensor([0, 1, 1, 0])\n", | |
"tensor([[0.0000e+00],\n", | |
" [1.7580e+00],\n", | |
" [1.7580e+00],\n", | |
" [1.0729e-07]], grad_fn=<NormBackward3>)\n", | |
"tensor(0.4261, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)\n", | |
"tensor([0, 1, 1, 0])\n", | |
"tensor([[0.0000],\n", | |
" [1.7580],\n", | |
" [1.7580],\n", | |
" [0.0000]], grad_fn=<NormBackward3>)\n", | |
"tensor(0.4261, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)\n", | |
"tensor([0, 1, 1, 0])\n", | |
"tensor([[0.0000],\n", | |
" [1.7580],\n", | |
" [1.7580],\n", | |
" [0.0000]], grad_fn=<NormBackward3>)\n", | |
"tensor(0.4261, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)\n", | |
"tensor([0, 1, 1, 0])\n", | |
"tensor([[0.0000],\n", | |
" [1.7580],\n", | |
" [1.7580],\n", | |
" [0.0000]], grad_fn=<NormBackward3>)\n", | |
"tensor(0.4261, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)\n", | |
"tensor([0, 1, 1, 0])\n", | |
"tensor([[0.0000],\n", | |
" [1.7580],\n", | |
" [1.7580],\n", | |
" [0.0000]], grad_fn=<NormBackward3>)\n", | |
"tensor(0.4261, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)\n", | |
"tensor([0, 1, 1, 0])\n", | |
"tensor([[0.0000],\n", | |
" [1.7580],\n", | |
" [1.7580],\n", | |
" [0.0000]], grad_fn=<NormBackward3>)\n" | |
] | |
} | |
], | |
"source": [ | |
"xor = {\n", | |
" (0, 0): 0,\n", | |
" (0, 1): 1,\n", | |
" (1, 0): 1,\n", | |
" (1, 1): 0\n", | |
"}\n", | |
"\n", | |
"batch_x = torch.tensor([[0, 0],\n", | |
" [0, 1],\n", | |
" [1, 0],\n", | |
" [1, 1]]).float()\n", | |
"\n", | |
"batch_y = torch.tensor([0, 1, 1, 0])\n", | |
"\n", | |
"net = nn.Sequential(\n", | |
" SLMRouter(2),\n", | |
" Act(optical_limiting),\n", | |
" SLMRouter(2),\n", | |
" Act(optical_limiting),\n", | |
" SLMRouter(1)\n", | |
")\n", | |
"\n", | |
"start_angles = nn.Parameter(torch.zeros(1, 2).uniform_(0, 3.14159).expand(4, 2))\n", | |
"\n", | |
"optim = torch.optim.Adam([start_angles, *net.parameters()], lr=0.01)\n", | |
"\n", | |
"for i in range(1000):\n", | |
" optim.zero_grad()\n", | |
" \n", | |
" x = (start_angles, batch_x)\n", | |
" out_angle, out_intensity = net(x)\n", | |
" loss = F.binary_cross_entropy_with_logits(out_intensity.squeeze(), batch_y.float())\n", | |
" loss.backward()\n", | |
" \n", | |
" optim.step()\n", | |
" if i % 100 == 0:\n", | |
" print(loss)\n", | |
" print(batch_y)\n", | |
" print(out_intensity)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.10" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment