Last active
August 6, 2020 15:41
-
-
Save alberduris/efa7b42d1d96691d589f32e77eef76f9 to your computer and use it in GitHub Desktop.
Python class which implements a base Node for creating Trees with Pytorch Tensors #Others #JupyterNotebook #CodeSnippet
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Imports" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"from torch import nn" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Node class for Trees\n", | |
"\n", | |
"Python class which implements a base Node for creating Trees with Pytorch Tensors" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Node:\n", | |
" \"\"\"\n", | |
" Class that implements a Node of a Tree\n", | |
" \"\"\"\n", | |
" def __init__(self, name=None, data=None, children=[]):\n", | |
" \n", | |
" self.name = name # Name of the node as ID (optional)\n", | |
" self.data = data # Each node carries a differentiable zero-dimensional (scalar) tensor initialized as random\n", | |
" if self.data is None:\n", | |
" self.data = torch.rand((1,), requires_grad=True)\n", | |
" \n", | |
" self.children = children \n", | |
" \n", | |
" def get_paths(self, node=None, path=None):\n", | |
" \"\"\"\n", | |
" Get all the paths of the Tree\n", | |
" \"\"\"\n", | |
" if node is None:\n", | |
" node = self\n", | |
" \n", | |
" paths = []\n", | |
" if path is None:\n", | |
" path = []\n", | |
" path.append(node)\n", | |
" \n", | |
" if node.children:\n", | |
" for child in node.children:\n", | |
" paths.extend(self.get_paths(child, path[:]))\n", | |
" else:\n", | |
" paths.append(path)\n", | |
" \n", | |
" return paths\n", | |
" \n", | |
" def traverse(self, f, node=None):\n", | |
" \"\"\"\n", | |
" Traverse the Tree recusively applying the function f to each Node data w/ updating inplace\n", | |
" \"\"\"\n", | |
" if node is None:\n", | |
" node = self\n", | |
" \n", | |
" if f is None: # Sanity check\n", | |
" raise NotImplementedError(\"Please, make sure the function {} passed to traverse function is implemented and not None.\".format(f))\n", | |
" node.data = f(node)\n", | |
" \n", | |
" if node.children:\n", | |
" for child in node.children:\n", | |
" self.traverse(f, child)\n", | |
" else:\n", | |
" return\n", | |
"\n", | |
" \n", | |
" def __str__(self):\n", | |
" return '{}: {}'.format(self.name, self.data)\n", | |
" \n", | |
" __repr__ = __str__\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Create your own Tree\n", | |
"\n", | |
"He creado el árbol genérico que me has pasado por WhatsApp.\n", | |
"\n", | |
"Por supuesto se puede hacer programáticamente..." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"tree = Node(name='x_1^1', children=[\n", | |
" Node(name='x_1^2', children=[\n", | |
" Node(name='x_1^3', children=[\n", | |
" Node(name='x_1^4', children=[]), Node(name='x_2^4', children=[])\n", | |
" ]), \n", | |
" Node(name='x_2^3', children=[\n", | |
" Node(name='x_3^4', children=[]), Node(name='x_4^4', children=[])\n", | |
" ])\n", | |
" ]),\n", | |
" Node(name='x_2^2', children=[\n", | |
" Node(name='x_3^3', children=[\n", | |
" Node(name='x_5^4', children=[]), Node(name='x_6^4', children=[])\n", | |
" ]), \n", | |
" Node(name='x_4^3', children=[\n", | |
" Node(name='x_7^4', children=[]), Node(name='x_8^4', children=[])\n", | |
" ])\n", | |
" ])\n", | |
"])" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Functions\n", | |
"\n", | |
"Algunas funciones por probar... \n", | |
"\n", | |
"`mult_sigmoid` corresponde a la que me has comentado antes a.k.a `\"Peso por numerito y función de activación es oyro numerito\"`" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def sigmoid(node):\n", | |
" \"\"\"\n", | |
" Apply the Sigmoid function to a given Node data\n", | |
" \"\"\"\n", | |
" sigm = nn.Sigmoid()\n", | |
" return sigm(node.data)\n", | |
"\n", | |
"def mult_sigmoid(node):\n", | |
" \"\"\"\n", | |
" Apply a random product plus Sigmoid to a given Node data\n", | |
" \"\"\"\n", | |
" sigm = nn.Sigmoid()\n", | |
" return sigm(torch.rand((1,)) * node.data)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Explore & Traverse the Tree\n", | |
"\n", | |
"1. Explore the initial random Tree\n", | |
"2. Traverse the Tree applying the `\"Peso por numerito y función de activación es oyro numerito\"` function\n", | |
"3. Explore the resulting Tree" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[[x_1^1: tensor([0.2864], requires_grad=True),\n", | |
" x_1^2: tensor([0.1287], requires_grad=True),\n", | |
" x_1^3: tensor([0.0580], requires_grad=True),\n", | |
" x_1^4: tensor([0.8677], requires_grad=True)],\n", | |
" [x_1^1: tensor([0.2864], requires_grad=True),\n", | |
" x_1^2: tensor([0.1287], requires_grad=True),\n", | |
" x_1^3: tensor([0.0580], requires_grad=True),\n", | |
" x_2^4: tensor([0.5528], requires_grad=True)],\n", | |
" [x_1^1: tensor([0.2864], requires_grad=True),\n", | |
" x_1^2: tensor([0.1287], requires_grad=True),\n", | |
" x_2^3: tensor([0.6109], requires_grad=True),\n", | |
" x_3^4: tensor([0.2566], requires_grad=True)],\n", | |
" [x_1^1: tensor([0.2864], requires_grad=True),\n", | |
" x_1^2: tensor([0.1287], requires_grad=True),\n", | |
" x_2^3: tensor([0.6109], requires_grad=True),\n", | |
" x_4^4: tensor([0.1304], requires_grad=True)],\n", | |
" [x_1^1: tensor([0.2864], requires_grad=True),\n", | |
" x_2^2: tensor([0.5555], requires_grad=True),\n", | |
" x_3^3: tensor([0.9111], requires_grad=True),\n", | |
" x_5^4: tensor([0.9655], requires_grad=True)],\n", | |
" [x_1^1: tensor([0.2864], requires_grad=True),\n", | |
" x_2^2: tensor([0.5555], requires_grad=True),\n", | |
" x_3^3: tensor([0.9111], requires_grad=True),\n", | |
" x_6^4: tensor([0.9909], requires_grad=True)],\n", | |
" [x_1^1: tensor([0.2864], requires_grad=True),\n", | |
" x_2^2: tensor([0.5555], requires_grad=True),\n", | |
" x_4^3: tensor([0.8007], requires_grad=True),\n", | |
" x_7^4: tensor([0.7274], requires_grad=True)],\n", | |
" [x_1^1: tensor([0.2864], requires_grad=True),\n", | |
" x_2^2: tensor([0.5555], requires_grad=True),\n", | |
" x_4^3: tensor([0.8007], requires_grad=True),\n", | |
" x_8^4: tensor([0.3543], requires_grad=True)]]" | |
] | |
}, | |
"execution_count": 23, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# 1. Explore the initial random Tree\n", | |
"tree.get_paths()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 24, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# 2. Traverse the Tree\n", | |
"tree.traverse(f=mult_sigmoid)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[[x_1^1: tensor([0.5006], grad_fn=<SigmoidBackward>),\n", | |
" x_1^2: tensor([0.5165], grad_fn=<SigmoidBackward>),\n", | |
" x_1^3: tensor([0.5013], grad_fn=<SigmoidBackward>),\n", | |
" x_1^4: tensor([0.5470], grad_fn=<SigmoidBackward>)],\n", | |
" [x_1^1: tensor([0.5006], grad_fn=<SigmoidBackward>),\n", | |
" x_1^2: tensor([0.5165], grad_fn=<SigmoidBackward>),\n", | |
" x_1^3: tensor([0.5013], grad_fn=<SigmoidBackward>),\n", | |
" x_2^4: tensor([0.6257], grad_fn=<SigmoidBackward>)],\n", | |
" [x_1^1: tensor([0.5006], grad_fn=<SigmoidBackward>),\n", | |
" x_1^2: tensor([0.5165], grad_fn=<SigmoidBackward>),\n", | |
" x_2^3: tensor([0.5691], grad_fn=<SigmoidBackward>),\n", | |
" x_3^4: tensor([0.5629], grad_fn=<SigmoidBackward>)],\n", | |
" [x_1^1: tensor([0.5006], grad_fn=<SigmoidBackward>),\n", | |
" x_1^2: tensor([0.5165], grad_fn=<SigmoidBackward>),\n", | |
" x_2^3: tensor([0.5691], grad_fn=<SigmoidBackward>),\n", | |
" x_4^4: tensor([0.5016], grad_fn=<SigmoidBackward>)],\n", | |
" [x_1^1: tensor([0.5006], grad_fn=<SigmoidBackward>),\n", | |
" x_2^2: tensor([0.5077], grad_fn=<SigmoidBackward>),\n", | |
" x_3^3: tensor([0.6667], grad_fn=<SigmoidBackward>),\n", | |
" x_5^4: tensor([0.5391], grad_fn=<SigmoidBackward>)],\n", | |
" [x_1^1: tensor([0.5006], grad_fn=<SigmoidBackward>),\n", | |
" x_2^2: tensor([0.5077], grad_fn=<SigmoidBackward>),\n", | |
" x_3^3: tensor([0.6667], grad_fn=<SigmoidBackward>),\n", | |
" x_6^4: tensor([0.6168], grad_fn=<SigmoidBackward>)],\n", | |
" [x_1^1: tensor([0.5006], grad_fn=<SigmoidBackward>),\n", | |
" x_2^2: tensor([0.5077], grad_fn=<SigmoidBackward>),\n", | |
" x_4^3: tensor([0.5693], grad_fn=<SigmoidBackward>),\n", | |
" x_7^4: tensor([0.5114], grad_fn=<SigmoidBackward>)],\n", | |
" [x_1^1: tensor([0.5006], grad_fn=<SigmoidBackward>),\n", | |
" x_2^2: tensor([0.5077], grad_fn=<SigmoidBackward>),\n", | |
" x_4^3: tensor([0.5693], grad_fn=<SigmoidBackward>),\n", | |
" x_8^4: tensor([0.5736], grad_fn=<SigmoidBackward>)]]" | |
] | |
}, | |
"execution_count": 25, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# 3. Explore the resulting Tree\n", | |
"tree.get_paths()" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.1" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment