Created
September 8, 2022 19:26
-
-
Save cat-state/6308e46f323b909825d5146afa2945a0 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"import torch.nn as nn\n", | |
"import torch.nn.functional as F\n", | |
"from math import pi\n", | |
"from typing import Tuple" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Routing" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 212, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(tensor([[1.6784, 2.4763, 1.1539]], grad_fn=<RepeatBackward>),\n", | |
" tensor([[0.3223, 2.3602, 1.2148]], grad_fn=<NormBackward3>))" | |
] | |
}, | |
"execution_count": 212, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"Input = Tuple[torch.Tensor, torch.Tensor]\n", | |
"\n", | |
"class SLMRouter(nn.Module):\n", | |
" def __init__(self, n_filters: int):\n", | |
" super().__init__()\n", | |
" self._angles = nn.Parameter(torch.zeros(n_filters).uniform_(0, 3.14159))\n", | |
" #self._angles = nn.Parameter(torch.tensor([0.0, pi / 2, pi / 2]))\n", | |
" \n", | |
" def forward(self, x: Input):\n", | |
" angle, intensity = x\n", | |
"\n", | |
" self_vecs = torch.stack([self._angles.cos(), self._angles.sin()]).unsqueeze(0)\n", | |
" in_vecs = torch.stack([angle.cos(), angle.sin()], dim=1)\n", | |
" \n", | |
" infall = self_vecs[:, :, :, None] @ in_vecs[:, :, None, :]\n", | |
" transmitted = (infall * intensity[:, None, None, :]).sum(dim=3).norm(dim=1)\n", | |
"\n", | |
" return (self._angles.repeat(intensity.shape[0], 1), transmitted)\n", | |
" \n", | |
"s = SLMRouter(3)\n", | |
"\n", | |
"x = (torch.tensor([[0.0, 0.0, 0.0]]), torch.ones(3)[None, :])\n", | |
"s(x)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Non-linearity" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 213, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[<matplotlib.lines.Line2D at 0x1e5622f1400>]" | |
] | |
}, | |
"execution_count": 213, | |
"metadata": {}, | |
"output_type": "execute_result" | |
}, | |
{ | |
"data": { | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAH5JJREFUeJzt3Xl81dWd//HXJ/u+h7CEXZBNUAibjrbVuqCtOqNtXQouFepvqu2vHa3a2tqOXbSddrppKQqtrVUq1YdLQS0dNxwFCQoBwhbWhC0rWclyc8/8kbSNiOYCN/ne5f18PO4jdzm593MIvHM43/M9X3POISIikSXG6wJERCT4FO4iIhFI4S4iEoEU7iIiEUjhLiISgRTuIiIRSOEuIhKBFO4iIhFI4S4iEoHivPrgvLw8N2LECK8+XkQkLK1bt67aOZffWzvPwn3EiBEUFxd79fEiImHJzPYG0k7TMiIiEUjhLiISgRTuIiIRSOEuIhKBFO4iIhFI4S4iEoEU7iIiEcizde4iIpGko9NPU6uP5nYfzW2dNLf7aPn71+7nWtp9tLR3Mm14NueO6fU8pFOicBeRqOfr9NPQ6qP+aAf1RztoONpBQ2sHDUd9NLZ23W9s9XXfuu43tXXfWn00tvlo9/kD/rxbPzZa4S4iciJaOzqpa2mnpqmd2uZ26lq6vza3U9fSQV1LO0daOjhytOtrfUsHjW2+j3zP2BgjPSmO9KQ40hLjSU+MY2BGEmlJcaQmxpGeGEdaYtf9tMQ4UhJj/3k/IZbUhK7nUhLiSI6PJTbG+vzPQeEuIiGvzddJVWMblY1tVPW8NbVR3dhGdVMbNc3t1Da1f2hQm0FmcjzZKQlkpcSTn5bI2AHpZKbEk5n8wVtGcjwZSfGkJ3UFtFnfB3IwKdxFxFONrR0cONLKgfqjHK5v5WB9K4fqWznU0Mrh7ltdS8cHvs8MclISyE1LIC8tkSmFWeSkJpCbmkBuWmLX/bQEslMSyElNIDM5vl9GzKFC4S4ifcY5x5GWDsrrWiivPUpFXQsVdUfZf+QoB450fW1sff9I2wzy0hIZmJFEYXYK04ZnU5CRxID0RAZkJDIgPYn89ERyUxOIi9WCvw+jcBeRU+L3Ow43trK7upm9NS3sqWlmb3ULe2tbqKht+cA0SWZyPEOykinMTmHmyBwGZyUzKCuZwZlJDMpKZkB6IvEK7VOmcBeRgBxt72RnVVPXrbKJndXN7KxsYk9NM60d/1wpkhAbQ2FOMsNzUpgxIpuhOSldt+wUCnOSyUiK97AX0UPhLiLv0+brpKyyie2HG9l2qIkdhxvZXtlIRd1RnOtqE2MwNCeFUXmpnD06j5H5qYzKS2V4bgqDMpOjam47VCncRaJYTVMbmw80UHqwgS0HGyg90MCu6mY6/V0pHh9rjMpLY0phFldPHcqYgjROG5DG8NwUEuNiPa5ePorCXSRK1DS1UbK/no0V9ZRU1LP5QD0H61v/8frgzCTGD8rg4okDOX1gOuMGpjMiL1Xz32FK4S4Sgdp8nWza38B7++pYX36E9eVHqKg7CnStRhmZl8qMkTlMGpzJxMEZTBicQVZKgsdVSzAp3EUiQG1zO+v21lG8p5a1e2rZtL+B9s6ug5xDspKZMjSTebOHM7kwi4mDM0jXQc2Ip3AXCUOVja2s2VXLmt01rNlVy47KJqBrpcoZhZnceM4Ipg7LYuqwbAZkJHlcrXhB4S4SBhpaO1i9s4a3dtbwv2XV/wjz1IRYikbkcOVZQ5gxMoczhmSSFK8DnaJwFwlJnX5HScURXt9exRvbq9hQUU+n35EcH8v0kTlcNa2Q2aNymTg4Q2dpynEp3EVCxJGWdl7fXsX/bKlk1Y4q6lo6MIPJhVn8v4+N5l/G5DF1WDYJcQpz6V1A4W5mlwA/B2KBR51zDxzzeibwODCs+z3/yzn32yDXKhJx9lQ3s7L0MCu3HKZ4Ty1+B7mpCXzi9AF87PR8zhuTT3aqVrHIies13M0sFngIuBCoANaa2fPOudIezb4ElDrnPm1m+cA2M/ujc669T6oWCVPOObYcbOSlTQd5afMhth/umjsfPyiDL33iNM4fN4AphVnE6AxPOUWBjNxnAGXOuV0AZrYUuALoGe4OSLeuDY/TgFrgo3e/F4kSzjm2HmrkLyUHWF5ykD01LcQYTB+Rw7c/NYELJxQwNCfF6zIlwgQS7kOA8h6PK4CZx7T5FfA8cABIBz7nnAv8mlMiEWhvTTPPrT/Ac+v3s7OqmRiDs0fnseC80Vw0sYC8tESvS5QIFki4H+//h+6YxxcD64HzgdHASjNb5ZxreN8bmS0AFgAMGzbsxKsVCXFHWtp5oeQgz7xbwXv7jgAwY2QON54zkjmTBirQpd8EEu4VwNAejwvpGqH3dBPwgHPOAWVmthsYB7zTs5FzbhGwCKCoqOjYXxAiYanT73hjexXL1pXzt9JK2jv9nF6Qzt1zxnH5lMEMzkr2ukSJQoGE+1pgjJmNBPYD1wDXHdNmH3ABsMrMCoDTgV3BLFQk1JTXtvBUcTnLiis41NBKTmoC188axlVTC5k4OCPsrrkpkaXXcHfO+czsNuBlupZCLnHObTazW7tfXwjcD/zOzDbSNY1zl3Ouug/rFvFEp9/x6tZK/rhmL69tr8KA88bmc9+nJ3DB+AKtQZeQEdA6d+fcCmDFMc8t7HH/AHBRcEsTCR11ze0sXVvO46v3sv/IUQakJ3L7J07jczOGMUTTLhKCdIaqyEfYdqiRJW/u5tn1+2nz+Zk9Kpd7LxvPJycUaJ9zCWkKd5FjOOd4Y0c1j67axaod1STFx3DVtEJuPHsEYwvSvS5PJCAKd5FuHZ1+lpccZOHrO9l6qJEB6YncefHpXDdjmLYAkLCjcJeo19rRybJ1FSx8bSf7jxzltAFp/PjqyVxx5hAdIJWwpXCXqNXa0ckf1+zjN6/vpLKxjTOHZvGdyydywbgB2ttFwp7CXaLO30N94es7qWpsY9aoHH72uTOZPTpXa9MlYijcJWq0+/w8VVzOL1/ZweGGNmaPyuVX157FzFG5XpcmEnQKd4l4fr/j+Q0H+OnK7eyrbWH6iGx+fs1ZzFKoSwRTuEtEW7Wjih+u2ErpwQYmDMrgtzdN5+Nj8zX9IhFP4S4RaduhRr63vJRVO6opzE7m59ecyacnD9aBUokaCneJKDVNbfx05XaefGcf6Unx3HvZeObOHk5iXKzXpYn0K4W7RARfp58/rN7LT1dup6W9k3mzR/CVC8bo5COJWgp3CXurd9Vw33Ob2Xa4kXPH5PHtT01gjLYJkCincJewVd3Uxg+Wb+GZ9/YzJCuZhZ+fxsUTC3SwVASFu4Qhv9+xdG05D7y4haMdndx+/mn8+8dPIzlB8+oif6dwl7BSVtnEPc+UsHZPHbNG5fC9K8/gtAFpXpclEnIU7hIWOjr9LHxtJ798pYzkhFh+dPVkPjOtUFMwIh9C4S4hr/RAA3cs20DpwQY+NXkQ9316IvnpiV6XJRLSFO4Ssjo6/Tz86k5++coOslIS+M3caVw8caDXZYmEBYW7hKSyyia++qf1bNxfz+VTBvPdyydqzbrICVC4S0jx+x2/f3sPP3xxKykJsfz6+qnMOWOQ12WJhB2Fu4SMqsY27li2gde3V/Hx0/P50VWTGZCR5HVZImFJ4S4h4bVtldyxbAONrT7uv2Iin581XCthRE6Bwl081e7z8+OXt/LIqt2MG5jOE/NnMVZbB4icMoW7eKairoXbnniP9eVHmDd7ON+4dDxJ8TrLVCQYFO7iib+VHuZrT63HOXj4+qlcqoOmIkGlcJd+1el3/OSv23j4tZ1MGpLBQ9dNZXhuqtdliUQchbv0m5qmNr689D3+t6yGa2cM5b5PT9Q0jEgfUbhLv9hYUc8X/1BMdXM7P7p6Mp8tGup1SSIRTeEufe7Z9/Zz19Ml5KYm8PStZ3NGYabXJYlEPIW79JlOv+PBl7ay6I1dzByZw0PXTyUvTRt+ifQHhbv0icbWDr6ydD2vbK1k3uzhfOtTE4iPjfG6LJGooXCXoCuvbeELj61lZ1Uz9185ibmzhntdkkjUUbhLUL27r475jxXT0ennsZtm8C9j8rwuSSQqKdwlaF7ceJD//6f1FGQk8dubpjM6X5e/E/GKwl1OmXOOxW/u5vsrtnDW0CwemVdErg6cingqoCNcZnaJmW0zszIzu/tD2nzczNab2WYzez24ZUqo8vsd9/9lC99bvoVLJw3iifmzFOwiIaDXkbuZxQIPARcCFcBaM3veOVfao00W8DBwiXNun5kN6KuCJXS0+Tr52lMbWF5ykJvPGcm9l40nJkbb9IqEgkCmZWYAZc65XQBmthS4Aijt0eY64Bnn3D4A51xlsAuV0NLY2sGC36/j7V01fOPSccw/d5T2XxcJIYFMywwByns8ruh+rqexQLaZvWZm68xs3vHeyMwWmFmxmRVXVVWdXMXiuZqmNq57ZA1r99Ty35+bwoLzRivYRUJMICP34/2rdcd5n2nABUAy8LaZrXbObX/fNzm3CFgEUFRUdOx7SBjYf+QocxevYX/dURbNm8b54wq8LklEjiOQcK8Aeu7yVAgcOE6baudcM9BsZm8AU4DtSMTYVdXE5x9dQ2Obj8dvmcn0ETlelyQiHyKQaZm1wBgzG2lmCcA1wPPHtHkOONfM4swsBZgJbAluqeKlbYca+exvVtPm87N0wSwFu0iI63Xk7pzzmdltwMtALLDEObfZzG7tfn2hc26Lmb0ElAB+4FHn3Ka+LFz6z8aKeuYuWUNiXAx/vGU2pw3QyUkioc6c82bqu6ioyBUXF3vy2RK4d/fVccPid8hMieeJW2YxLDfF65JEopqZrXPOFfXWTmeoyodat7eWG5asJS8tgSfmz2JwVrLXJYlIgLQHqxzX2j21zFv8DgPSE1m6YLaCXSTMaOQuH7B2Ty03LHmHgZlJLJ0/iwEZSV6XJCInSCN3eZ91e+u4UcEuEvYU7vIP68uPcOOSd8hPT+RJBbtIWFO4CwCbD9Qzb/EaslMTeHLBLAoU7CJhTeEulFU2MnfxO6QnxfPE/JkMytTBU5Fwp3CPcntrmrnukTXExhh/vGUmhdlaxy4SCRTuUexg/VGuf3QN7Z1+Hv/CTEbkpXpdkogEicI9StU1tzN38Tscaeng9zfP4PSB6V6XJCJBpHXuUai5zceNv1vLvtoWHrtpBpMLs7wuSUSCTCP3KNPm6+TWx9exseIIv7r2LGaPzvW6JBHpAxq5RxG/3/EfT21g1Y5qfnz1ZC6aONDrkkSkj2jkHiWcc3xv+Rb+UnKQu+eM4zNFQ3v/JhEJWwr3KPHIql0s+d/d3HTOCL543iivyxGRPqZwjwLPrd/PD1Zs5bLJg/jWZRN0MWuRKKBwj3Bv76zhjmUbmDkyh59+dgoxMQp2kWigcI9gZZWNfPEPxQzPTWXR3CIS42K9LklE+onCPUJVNrZyw5K1JMTF8tsbp5OZEu91SSLSjxTuEehoeyfzHyumtrmdJTcWMTRH+8WIRButc48wfr/jP5atp2R/PQs/P01nn4pEKY3cI8xPVm5jxcZD3DNnHBfrJCWRqKVwjyBPr6vgoVd3cs30ocw/V2vZRaKZwj1CrNtbyz3PbGT2qFzuv3KS1rKLRDmFewTYf+QoX/zDOgZlJfHrz08lPlY/VpFopwOqYa6l3cf8x4pp6/CzdEERWSkJXpckIiFA4R7GnHPcuayELYcaWHLDdE4boAtuiEgX/f89jD382k6WbzzIXZeM4xPjBnhdjoiEEIV7mHp1ayX/9ddtXD5lsHZ5FJEPULiHod3VzXx56XuMH5jBg1dN1soYEfkAhXuYaW7zseD3xcTHxrBo3jSSE7QZmIh8kMI9jDjn+PqfS9hZ1cSvrj2LwmztGSMix6dwDyOPrNrF8o0H+fol4zj7tDyvyxGREKZwDxNvlVXzwItbmTNpoA6gikivFO5h4FB9K7c/+R4j81L58Wem6ACqiPQqoHA3s0vMbJuZlZnZ3R/RbrqZdZrZ1cErMbp1dPq57Yl3OdrRyW/mTiMtUeediUjveg13M4sFHgLmABOAa81swoe0exB4OdhFRrMHX9xK8d46Hrhqss5AFZGABTJynwGUOed2OefagaXAFcdpdzvwNFAZxPqi2kubDvLom7uZN3s4l08Z7HU5IhJGAgn3IUB5j8cV3c/9g5kNAf4VWBi80qLb3ppm7lxWwpShWXzzsvFelyMiYSaQcD/e0Tt3zOOfAXc55zo/8o3MFphZsZkVV1VVBVpj1GnzdfKlJ97FDB667iwS43SikoicmECOzlUAQ3s8LgQOHNOmCFjavYojD7jUzHzOuWd7NnLOLQIWARQVFR37C0K6fX/5Fjbtb+CReUU6UUlETkog4b4WGGNmI4H9wDXAdT0bOOdG/v2+mf0O+MuxwS6BWV5ykN+/vZf5547kwgkFXpcjImGq13B3zvnM7Da6VsHEAkucc5vN7Nbu1zXPHiT7alq4++kSzhqWxdcvGed1OSISxgJaNO2cWwGsOOa544a6c+7GUy8r+rT7/Nz+ZNc8+y+uOUuXyhORU6IzYkLEf/11Gxsq6vn19VMZmqN5dhE5NRoehoBXt1Wy6I1dfH7WMOacMcjrckQkAijcPVbZ0ModT21g3MB07r3sAyf+ioicFE3LeMjvd/zHsg00t/tYeu0skuK1nl1EgkMjdw8tfnM3q3ZU861PTWBMgfaNEZHgUbh7ZNP+en708lYumlDAdTOGeV2OiEQYhbsHWtp9fPnJ98hNTdQFrkWkT2jO3QP3/2ULu2ua+eMtM8lOTfC6HBGJQBq597OVpYd58p19LDh3FGeP1nVQRaRvKNz7UWVjK3c9XcKEQRl87aKxXpcjIhFM4d5PnHPcuayE5jYfP7/mTG3jKyJ9SuHeTx5fvZfXt1fxjUvHa9mjiPQ5hXs/2FXVxPdXbOG8sfnMmz3c63JEJAoo3PuYr9PPV5/aQGJcLD++WsseRaR/aClkH3vo1Z1sKD/CL689i4KMJK/LEZEooZF7HyqpOMIvXtnBFWcO5tNTBntdjohEEYV7H2nt6ORrT20gPy2R/7x8ktfliEiU0bRMH/nJX7dRVtnEYzfPIDMl3utyRCTKaOTeB97ZXcujb+7mupnD+NjYfK/LEZEopHAPsuY2H3cs20BhdjLfvHS81+WISJTStEyQ/fDFLZTXtbB0/ixSE/XHKyLe0Mg9iN4qq+bx1fu4+ZyRzByV63U5IhLFFO5B0tTm484/lzAqL5U7Ljrd63JEJMpp3iBIfrBiCwfqj/LnW2eTnKBNwUTEWxq5B8GbO6p5Ys0+5p87imnDc7wuR0RE4X6qmtp83PV0CaPyU/nahdqjXURCg6ZlTtEDL/5zOiYpXtMxIhIaNHI/BW/t/OfqGE3HiEgoUbifpJZ2H3c/vZERuSlaHSMiIUfTMifpxy9vY19tC39aMEurY0Qk5GjkfhLW7a3ld2/tYd7s4TpZSURCksL9BLV2dPL1P5cwODOZr18yzutyRESOS9MyJ+iXr+xgZ1Uzj908gzTtHSMiIUoj9xOw+UA9C1/fxVVTC7WVr4iENIV7gHydfu56uoTslAS+9Slt5SsioU3zCgFa/OZuNu1v4OHrp5KVkuB1OSIiHymgkbuZXWJm28yszMzuPs7r15tZSfftLTObEvxSvbOnupmfrtzORRMKmDNpoNfliIj0qtdwN7NY4CFgDjABuNbMJhzTbDfwMefcZOB+YFGwC/WKc457ntlIQmwM9185CTPzuiQRkV4FMnKfAZQ553Y559qBpcAVPRs4595yztV1P1wNFAa3TO88VVzO27tquOfS8RRkJHldjohIQAIJ9yFAeY/HFd3PfZgvAC8e7wUzW2BmxWZWXFVVFXiVHqlsbOX7y7cwc2QO10wf6nU5IiIBCyTcjzcP4Y7b0OwTdIX7Xcd73Tm3yDlX5Jwrys8P/aWE332+lFafnx/+2xnExGg6RkTCRyDhXgH0HLYWAgeObWRmk4FHgSucczXBKc87fys9zPKNB/ny+acxKj/N63JERE5IIOG+FhhjZiPNLAG4Bni+ZwMzGwY8A8x1zm0Pfpn9q7G1g289t4nTC9JZcN5or8sRETlhva5zd875zOw24GUgFljinNtsZrd2v74Q+DaQCzzcvZrE55wr6ruy+9ZP/rqdQw2tPHT9VBLidJ6XiISfgE5ics6tAFYc89zCHvdvAW4JbmneeG9fHY+9vYe5s4YzdVi21+WIiJwUDUt76Oj0c88zGylIT+LOi3UBDhEJX9p+oIfFb+5m66FGfjN3GulJ8V6XIyJy0jRy77avpoWf/a1ri4GLJ2qLAREJbwp3urYYuPe5TcTFxPDdKyZ6XY6IyClTuAMvlBzkje1V3HHRWAZlJntdjojIKYv6cK9v6eA/XyhlcmEmc2eP8LocEZGgiPoDqg++vJXa5jZ+d9N0YrXFgIhEiKgeua/bW8sTa/Zx8zkjmTQk0+tyRESCJmrDvaPTzzee2cTgzCS+euFYr8sREQmqqJ2WWfzmbrYdbmTR3GmkJkbtH4OIRKioHLmX13atab9wQgEXaU27iESgqAt35xz3Pb+ZGDO+c7nWtItIZIq6cH958yFe2VrJVz85liFZWtMuIpEpqsK9qc3Hd54vZfygDG46Z4TX5YiI9JmoCvefrdzO4cZWvnflJOJio6rrIhJloibhSg808Nu39nDN9GFMG6592kUkskVFuPv9jm8+u5Gs5HjuukT7tItI5IuKcF+6tpz39h3hm5eNJyslwetyRET6XMSHe3VTGw++tJVZo3L417OGeF2OiEi/iPhw/8GKLbS0+/jelZPovni3iEjEi+hwX72rhmfe3c/8c0dx2oB0r8sREek3ERvu7T4/9z67icLsZG4/f4zX5YiI9KuI3THr0Td3UVbZxOIbikhOiPW6HBGRfhWRI/eKuhZ+8T87uGhCAReML/C6HBGRfheR4f7dF0oxjPu0MZiIRKmIC/e/lR5mZelhvvLJMdoYTESiVkSF+9H2Tr7zwmbGDEjj5nNGel2OiIhnIuqA6q9e3UFF3VH+tGAWCXER9XtLROSEREwCllU2seiNXfzb1CHMHJXrdTkiIp6KiHB3zvHt5zaRHB/LPXPGe12OiIjnIiLcXyg5yFs7a7jzknHkpyd6XY6IiOfCPtwbWju4/y+lTC7M5LoZw7wuR0QkJIT9AdX/Xrmd6qY2Ft9QRGyMNgYTEYEwH7mXHmjgsbf2cP3MYUwuzPK6HBGRkBG24e73O+59diPZKQncedE4r8sREQkpAYW7mV1iZtvMrMzM7j7O62Zmv+h+vcTMpga/1Pdbtq6cd/cd4Z5Lx5OZEt/XHyciElZ6DXcziwUeAuYAE4BrzWzCMc3mAGO6bwuAXwe5zvepa27ngRe3MmNEDldN1dWVRESOFcjIfQZQ5pzb5ZxrB5YCVxzT5grg967LaiDLzAYFudZ/+NHLW2lo9XG/rq4kInJcgYT7EKC8x+OK7udOtE1QvLuvjiffKefmc0Zw+kBdXUlE5HgCCffjDY3dSbTBzBaYWbGZFVdVVQVS3wfEmHHumDy+8smxJ/X9IiLRIJBwrwCG9nhcCBw4iTY45xY554qcc0X5+fknWisAZw7N4g9fmElaYtgv0RcR6TOBhPtaYIyZjTSzBOAa4Plj2jwPzOteNTMLqHfOHQxyrSIiEqBeh7/OOZ+Z3Qa8DMQCS5xzm83s1u7XFwIrgEuBMqAFuKnvShYRkd4ENLfhnFtBV4D3fG5hj/sO+FJwSxMRkZMVtmeoiojIh1O4i4hEIIW7iEgEUriLiEQghbuISASyroUuHnywWRWw9yS/PQ+oDmI54UB9jg7qc3Q4lT4Pd871ehaoZ+F+Ksys2DlX5HUd/Ul9jg7qc3Tojz5rWkZEJAIp3EVEIlC4hvsirwvwgPocHdTn6NDnfQ7LOXcREflo4TpyFxGRjxDS4R6KF+buawH0+fruvpaY2VtmNsWLOoOptz73aDfdzDrN7Or+rK8vBNJnM/u4ma03s81m9np/1xhsAfzdzjSzF8xsQ3efw3p3WTNbYmaVZrbpQ17v2/xyzoXkja7thXcCo4AEYAMw4Zg2lwIv0nUlqFnAGq/r7oc+nw1kd9+fEw197tHuFbp2J73a67r74eecBZQCw7ofD/C67n7o8zeAB7vv5wO1QILXtZ9Cn88DpgKbPuT1Ps2vUB65h9yFuftBr312zr3lnKvrfriarqtehbNAfs4AtwNPA5X9WVwfCaTP1wHPOOf2ATjnwr3fgfTZAenWddX7NLrC3de/ZQaPc+4NuvrwYfo0v0I53EPqwtz95ET78wW6fvOHs177bGZDgH8FFhIZAvk5jwWyzew1M1tnZvP6rbq+EUiffwWMp+sSnRuBrzjn/P1Tnif6NL9C+UKkQbswdxgJuD9m9gm6wv1f+rSivhdIn38G3OWc6+wa1IW9QPocB0wDLgCSgbfNbLVzbntfF9dHAunzxcB64HxgNLDSzFY55xr6ujiP9Gl+hXK4B+3C3GEkoP6Y2WTgUWCOc66mn2rrK4H0uQhY2h3secClZuZzzj3bPyUGXaB/t6udc81As5m9AUwBwjXcA+nzTcADrmtCuszMdgPjgHf6p8R+16f5FcrTMtF4Ye5e+2xmw4BngLlhPIrrqdc+O+dGOudGOOdGAH8G/j2Mgx0C+7v9HHCumcWZWQowE9jSz3UGUyB93kfX/1QwswLgdGBXv1bZv/o0v0J25O6i8MLcAfb520Au8HD3SNbnwnjTpQD7HFEC6bNzbouZvQSUAH7gUefccZfUhYMAf873A78zs410TVnc5ZwL290izexJ4ONAnplVAPcB8dA/+aUzVEVEIlAoT8uIiMhJUriLiEQghbuISARSuIuIRCCFu4hIBFK4i4hEIIW7iEgEUriLiESg/wNhrl3YwrAXzwAAAABJRU5ErkJggg==\n", | |
"text/plain": [ | |
"<Figure size 432x288 with 1 Axes>" | |
] | |
}, | |
"metadata": { | |
"needs_background": "light" | |
}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"import matplotlib.pyplot as plt\n", | |
"\n", | |
"def optical_limiting(x):\n", | |
" return ((x * 5).sigmoid() - 0.5) * 1.8\n", | |
"\n", | |
"plt.plot(torch.linspace(0, 1).numpy(), optical_limiting(torch.linspace(0, 1)).numpy())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 214, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Act(nn.Module):\n", | |
" def __init__(self, activation):\n", | |
" super().__init__()\n", | |
" self._activation = activation\n", | |
" \n", | |
" def forward(self, x):\n", | |
" angle, intensity = x\n", | |
" return (angle, self._activation(intensity))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 225, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"tensor(0.6629, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)\n", | |
"tensor([0, 1, 1, 0])\n", | |
"tensor([[0.0000],\n", | |
" [1.1859],\n", | |
" [1.2430],\n", | |
" [1.1676]], grad_fn=<NormBackward3>)\n", | |
"tensor(0.4733, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)\n", | |
"tensor([0, 1, 1, 0])\n", | |
"tensor([[0.0000],\n", | |
" [1.2728],\n", | |
" [1.2726],\n", | |
" [0.0262]], grad_fn=<NormBackward3>)\n", | |
"tensor(0.4276, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)\n", | |
"tensor([0, 1, 1, 0])\n", | |
"tensor([[0.0000e+00],\n", | |
" [1.7377e+00],\n", | |
" [1.7377e+00],\n", | |
" [3.2093e-05]], grad_fn=<NormBackward3>)\n", | |
"tensor(0.4261, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)\n", | |
"tensor([0, 1, 1, 0])\n", | |
"tensor([[0.0000e+00],\n", | |
" [1.7579e+00],\n", | |
" [1.7579e+00],\n", | |
" [6.3294e-06]], grad_fn=<NormBackward3>)\n", | |
"tensor(0.4261, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)\n", | |
"tensor([0, 1, 1, 0])\n", | |
"tensor([[0.0000e+00],\n", | |
" [1.7580e+00],\n", | |
" [1.7580e+00],\n", | |
" [1.0729e-07]], grad_fn=<NormBackward3>)\n", | |
"tensor(0.4261, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)\n", | |
"tensor([0, 1, 1, 0])\n", | |
"tensor([[0.0000],\n", | |
" [1.7580],\n", | |
" [1.7580],\n", | |
" [0.0000]], grad_fn=<NormBackward3>)\n", | |
"tensor(0.4261, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)\n", | |
"tensor([0, 1, 1, 0])\n", | |
"tensor([[0.0000],\n", | |
" [1.7580],\n", | |
" [1.7580],\n", | |
" [0.0000]], grad_fn=<NormBackward3>)\n", | |
"tensor(0.4261, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)\n", | |
"tensor([0, 1, 1, 0])\n", | |
"tensor([[0.0000],\n", | |
" [1.7580],\n", | |
" [1.7580],\n", | |
" [0.0000]], grad_fn=<NormBackward3>)\n", | |
"tensor(0.4261, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)\n", | |
"tensor([0, 1, 1, 0])\n", | |
"tensor([[0.0000],\n", | |
" [1.7580],\n", | |
" [1.7580],\n", | |
" [0.0000]], grad_fn=<NormBackward3>)\n", | |
"tensor(0.4261, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)\n", | |
"tensor([0, 1, 1, 0])\n", | |
"tensor([[0.0000],\n", | |
" [1.7580],\n", | |
" [1.7580],\n", | |
" [0.0000]], grad_fn=<NormBackward3>)\n" | |
] | |
} | |
], | |
"source": [ | |
"xor = {\n", | |
" (0, 0): 0,\n", | |
" (0, 1): 1,\n", | |
" (1, 0): 1,\n", | |
" (1, 1): 0\n", | |
"}\n", | |
"\n", | |
"batch_x = torch.tensor([[0, 0],\n", | |
" [0, 1],\n", | |
" [1, 0],\n", | |
" [1, 1]]).float()\n", | |
"\n", | |
"batch_y = torch.tensor([0, 1, 1, 0])\n", | |
"\n", | |
"net = nn.Sequential(\n", | |
" SLMRouter(2),\n", | |
" Act(optical_limiting),\n", | |
" SLMRouter(2),\n", | |
" Act(optical_limiting),\n", | |
" SLMRouter(1)\n", | |
")\n", | |
"\n", | |
"start_angles = nn.Parameter(torch.zeros(1, 2).uniform_(0, 3.14159).expand(4, 2))\n", | |
"\n", | |
"optim = torch.optim.Adam([start_angles, *net.parameters()], lr=0.01)\n", | |
"\n", | |
"for i in range(1000):\n", | |
" optim.zero_grad()\n", | |
" \n", | |
" x = (start_angles, batch_x)\n", | |
" out_angle, out_intensity = net(x)\n", | |
" loss = F.binary_cross_entropy_with_logits(out_intensity.squeeze(), batch_y.float())\n", | |
" loss.backward()\n", | |
" \n", | |
" optim.step()\n", | |
" if i % 100 == 0:\n", | |
" print(loss)\n", | |
" print(batch_y)\n", | |
" print(out_intensity)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.10" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment