Created
December 1, 2024 11:46
-
-
Save devops-school/55f7a4c36a4ee57428f613ca6dce11f4 to your computer and use it in GitHub Desktop.
PyTorch Lab - 9 - Pytorch CUDA Semantics
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import numpy as np \n", | |
| "import torch\n", | |
| "import torch.nn as nn" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 18, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "W = torch.randn(6)\n", | |
| "\n", | |
| "x = torch.tensor([10.0, 10.0, 10.0, 10.0, 10.0, 10.0])\n", | |
| "\n", | |
| "b = torch.tensor(3)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 19, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([-1.5724, -0.6086, 1.2982, -0.3485, 0.0616, 0.7579])" | |
| ] | |
| }, | |
| "execution_count": 19, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "W" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 20, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "y = W*x + b" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 21, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([-12.7245, -3.0857, 15.9816, -0.4845, 3.6161, 10.5788])" | |
| ] | |
| }, | |
| "execution_count": 21, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "y" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "#### Working Dynamically" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 22, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([-15.7245, -6.0857, 12.9816, -3.4845, 0.6161, 7.5788])" | |
| ] | |
| }, | |
| "execution_count": 22, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "W*x" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 23, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor(3)" | |
| ] | |
| }, | |
| "execution_count": 23, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "b" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 32, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "W1 = torch.tensor(6)\n", | |
| "W2 = torch.tensor(6)\n", | |
| "W3 = torch.tensor(6)\n", | |
| "\n", | |
| "x1 = torch.tensor([2, 2, 2])\n", | |
| "x2 = torch.tensor([3, 3, 3])\n", | |
| "x3 = torch.tensor([4, 4, 4])\n", | |
| "\n", | |
| "b = torch.tensor(10)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 33, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(tensor(6), tensor(6), tensor(6))" | |
| ] | |
| }, | |
| "execution_count": 33, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "W1, W2, W3" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 34, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([30, 30, 30])" | |
| ] | |
| }, | |
| "execution_count": 34, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "intermediate_value = W1 * x1 + W2 * x2\n", | |
| "\n", | |
| "intermediate_value" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 35, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "tensor([64, 64, 64])" | |
| ] | |
| }, | |
| "execution_count": 35, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "final_value = W1 * x1 + W2 * x2 + W3 * x3 + b\n", | |
| "\n", | |
| "final_value" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "### Viewing PyTorch computation graphs\n", | |
| "\n", | |
| "https://github.com/waleedka/hiddenlayer" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "#### Computation Graph" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 38, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import hiddenlayer as hl" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 39, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "x_train = np.array ([[1.7], [2.5], [5.5], [7.9], [8.8],\n", | |
| " [2.4],[2.4], [8.89], [5], [4.4]],\n", | |
| " dtype = np.float32)\n", | |
| "\n", | |
| "y_train = np.array ([[1.9], [2.68], [4.22], [8.19], [9.69],\n", | |
| " [3.4],[2.6], [8.8], [5.6], [4.7]],\n", | |
| " dtype = np.float32)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 40, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "torch.Size([10, 1])" | |
| ] | |
| }, | |
| "execution_count": 40, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "X_train = torch.tensor(x_train)\n", | |
| "Y_train = torch.tensor(y_train)\n", | |
| "\n", | |
| "X_train.shape" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 77, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "inp = 1\n", | |
| "out = 1\n", | |
| "\n", | |
| "hid = 100" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 78, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "model1 = torch.nn.Sequential(torch.nn.Linear(inp, hid),\n", | |
| " torch.nn.Linear(hid, out))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 76, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "image/svg+xml": [ | |
| "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n", | |
| "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n", | |
| " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n", | |
| "<!-- Generated by graphviz version 2.36.0 (20140111.2315)\n", | |
| " -->\n", | |
| "<!-- Title: %3 Pages: 1 -->\n", | |
| "<svg width=\"198pt\" height=\"118pt\"\n", | |
| " viewBox=\"0.00 0.00 198.00 118.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n", | |
| "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(72 82)\">\n", | |
| "<title>%3</title>\n", | |
| "<polygon fill=\"#ffffff\" stroke=\"none\" points=\"-72,36 -72,-82 126,-82 126,36 -72,36\"/>\n", | |
| "<!-- 15774382701664537878 -->\n", | |
| "<g id=\"node1\" class=\"node\"><title>15774382701664537878</title>\n", | |
| "<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"54,-45.5 0,-45.5 0,-0.5 54,-0.5 54,-45.5\"/>\n", | |
| "<text text-anchor=\"start\" x=\"13.5\" y=\"-28.5\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Linear</text>\n", | |
| "<text text-anchor=\"start\" x=\"34\" y=\"-7\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">x2</text>\n", | |
| "</g>\n", | |
| "</g>\n", | |
| "</svg>\n" | |
| ], | |
| "text/plain": [ | |
| "<hiddenlayer.graph.Graph at 0x114e7e588>" | |
| ] | |
| }, | |
| "execution_count": 76, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "hl.build_graph(model1, torch.zeros([10, 1]))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 79, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "image/svg+xml": [ | |
| "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n", | |
| "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n", | |
| " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n", | |
| "<!-- Generated by graphviz version 2.36.0 (20140111.2315)\n", | |
| " -->\n", | |
| "<!-- Title: %3 Pages: 1 -->\n", | |
| "<svg width=\"277pt\" height=\"434pt\"\n", | |
| " viewBox=\"0.00 0.00 277.00 434.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n", | |
| "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(72 398)\">\n", | |
| "<title>%3</title>\n", | |
| "<polygon fill=\"#ffffff\" stroke=\"none\" points=\"-72,36 -72,-398 205,-398 205,36 -72,36\"/>\n", | |
| "<!-- Sequential/Linear[0]/outputs/5 -->\n", | |
| "<g id=\"node1\" class=\"node\"><title>Sequential/Linear[0]/outputs/5</title>\n", | |
| "<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"58.25,-362 -0.25,-362 -0.25,-326 58.25,-326 58.25,-362\"/>\n", | |
| "<text text-anchor=\"start\" x=\"8.5\" y=\"-341.5\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Transpose</text>\n", | |
| "</g>\n", | |
| "<!-- Sequential/Linear[0]/outputs/6 -->\n", | |
| "<g id=\"node2\" class=\"node\"><title>Sequential/Linear[0]/outputs/6</title>\n", | |
| "<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"56,-288 2,-288 2,-252 56,-252 56,-288\"/>\n", | |
| "<text text-anchor=\"start\" x=\"12.5\" y=\"-267.5\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">MatMul</text>\n", | |
| "</g>\n", | |
| "<!-- Sequential/Linear[0]/outputs/5->Sequential/Linear[0]/outputs/6 -->\n", | |
| "<g id=\"edge1\" class=\"edge\"><title>Sequential/Linear[0]/outputs/5->Sequential/Linear[0]/outputs/6</title>\n", | |
| "<path fill=\"none\" stroke=\"#000000\" d=\"M29,-325.937C29,-317.807 29,-307.876 29,-298.705\"/>\n", | |
| "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"32.5001,-298.441 29,-288.441 25.5001,-298.441 32.5001,-298.441\"/>\n", | |
| "</g>\n", | |
| "<!-- /outputs/7 -->\n", | |
| "<g id=\"node3\" class=\"node\"><title>/outputs/7</title>\n", | |
| "<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"56,-204 2,-204 2,-168 56,-168 56,-204\"/>\n", | |
| "<text text-anchor=\"start\" x=\"20.5\" y=\"-183.5\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Add</text>\n", | |
| "</g>\n", | |
| "<!-- Sequential/Linear[0]/outputs/6->/outputs/7 -->\n", | |
| "<g id=\"edge2\" class=\"edge\"><title>Sequential/Linear[0]/outputs/6->/outputs/7</title>\n", | |
| "<path fill=\"none\" stroke=\"#000000\" d=\"M29,-251.61C29,-240.774 29,-226.601 29,-214.291\"/>\n", | |
| "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"32.5001,-214.084 29,-204.084 25.5001,-214.084 32.5001,-214.084\"/>\n", | |
| "<text text-anchor=\"middle\" x=\"37\" y=\"-225.5\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">100</text>\n", | |
| "</g>\n", | |
| "<!-- Sequential/Linear[1]/outputs/9 -->\n", | |
| "<g id=\"node5\" class=\"node\"><title>Sequential/Linear[1]/outputs/9</title>\n", | |
| "<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"104,-120 50,-120 50,-84 104,-84 104,-120\"/>\n", | |
| "<text text-anchor=\"start\" x=\"60.5\" y=\"-99.5\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">MatMul</text>\n", | |
| "</g>\n", | |
| "<!-- /outputs/7->Sequential/Linear[1]/outputs/9 -->\n", | |
| "<g id=\"edge3\" class=\"edge\"><title>/outputs/7->Sequential/Linear[1]/outputs/9</title>\n", | |
| "<path fill=\"none\" stroke=\"#000000\" d=\"M39.1796,-167.61C45.7786,-156.336 54.4924,-141.45 61.8925,-128.809\"/>\n", | |
| "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"64.9684,-130.482 66.9997,-120.084 58.9273,-126.946 64.9684,-130.482\"/>\n", | |
| "<text text-anchor=\"middle\" x=\"64\" y=\"-141.5\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">100</text>\n", | |
| "</g>\n", | |
| "<!-- Sequential/Linear[1]/outputs/8 -->\n", | |
| "<g id=\"node4\" class=\"node\"><title>Sequential/Linear[1]/outputs/8</title>\n", | |
| "<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"133.25,-204 74.75,-204 74.75,-168 133.25,-168 133.25,-204\"/>\n", | |
| "<text text-anchor=\"start\" x=\"83.5\" y=\"-183.5\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Transpose</text>\n", | |
| "</g>\n", | |
| "<!-- Sequential/Linear[1]/outputs/8->Sequential/Linear[1]/outputs/9 -->\n", | |
| "<g id=\"edge4\" class=\"edge\"><title>Sequential/Linear[1]/outputs/8->Sequential/Linear[1]/outputs/9</title>\n", | |
| "<path fill=\"none\" stroke=\"#000000\" d=\"M98.274,-167.61C94.6701,-156.665 89.945,-142.315 85.8635,-129.919\"/>\n", | |
| "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"89.0772,-128.488 82.6252,-120.084 82.4283,-130.677 89.0772,-128.488\"/>\n", | |
| "</g>\n", | |
| "<!-- /outputs/10 -->\n", | |
| "<g id=\"node6\" class=\"node\"><title>/outputs/10</title>\n", | |
| "<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"104,-36 50,-36 50,-0 104,-0 104,-36\"/>\n", | |
| "<text text-anchor=\"start\" x=\"68.5\" y=\"-15.5\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Add</text>\n", | |
| "</g>\n", | |
| "<!-- Sequential/Linear[1]/outputs/9->/outputs/10 -->\n", | |
| "<g id=\"edge5\" class=\"edge\"><title>Sequential/Linear[1]/outputs/9->/outputs/10</title>\n", | |
| "<path fill=\"none\" stroke=\"#000000\" d=\"M77,-83.6099C77,-72.7743 77,-58.6012 77,-46.2913\"/>\n", | |
| "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"80.5001,-46.0838 77,-36.0838 73.5001,-46.0839 80.5001,-46.0838\"/>\n", | |
| "<text text-anchor=\"middle\" x=\"80\" y=\"-57.5\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">1</text>\n", | |
| "</g>\n", | |
| "</g>\n", | |
| "</svg>\n" | |
| ], | |
| "text/plain": [ | |
| "<hiddenlayer.graph.Graph at 0x114e81828>" | |
| ] | |
| }, | |
| "execution_count": 79, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "hl.build_graph(model1, torch.zeros([1]))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 80, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "model2 = torch.nn.Sequential(torch.nn.Linear(inp, hid),\n", | |
| " torch.nn.Linear(hid, hid),\n", | |
| " torch.nn.Sigmoid(),\n", | |
| " torch.nn.Linear(hid, out))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 81, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "image/svg+xml": [ | |
| "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n", | |
| "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n", | |
| " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n", | |
| "<!-- Generated by graphviz version 2.36.0 (20140111.2315)\n", | |
| " -->\n", | |
| "<!-- Title: %3 Pages: 1 -->\n", | |
| "<svg width=\"202pt\" height=\"286pt\"\n", | |
| " viewBox=\"0.00 0.00 202.00 286.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n", | |
| "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(72 250)\">\n", | |
| "<title>%3</title>\n", | |
| "<polygon fill=\"#ffffff\" stroke=\"none\" points=\"-72,36 -72,-250 130,-250 130,36 -72,36\"/>\n", | |
| "<!-- Sequential/Sigmoid[2]/outputs/9 -->\n", | |
| "<g id=\"node1\" class=\"node\"><title>Sequential/Sigmoid[2]/outputs/9</title>\n", | |
| "<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"54,-120 0,-120 0,-84 54,-84 54,-120\"/>\n", | |
| "<text text-anchor=\"start\" x=\"10.5\" y=\"-99.5\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Sigmoid</text>\n", | |
| "</g>\n", | |
| "<!-- Sequential/Sigmoid[2]/outputs/10 -->\n", | |
| "<g id=\"node2\" class=\"node\"><title>Sequential/Sigmoid[2]/outputs/10</title>\n", | |
| "<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"54,-36 0,-36 0,-0 54,-0 54,-36\"/>\n", | |
| "<text text-anchor=\"start\" x=\"13.5\" y=\"-15.5\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Linear</text>\n", | |
| "</g>\n", | |
| "<!-- Sequential/Sigmoid[2]/outputs/9->Sequential/Sigmoid[2]/outputs/10 -->\n", | |
| "<g id=\"edge1\" class=\"edge\"><title>Sequential/Sigmoid[2]/outputs/9->Sequential/Sigmoid[2]/outputs/10</title>\n", | |
| "<path fill=\"none\" stroke=\"#000000\" d=\"M27,-83.6099C27,-72.7743 27,-58.6012 27,-46.2913\"/>\n", | |
| "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"30.5001,-46.0838 27,-36.0838 23.5001,-46.0839 30.5001,-46.0838\"/>\n", | |
| "<text text-anchor=\"middle\" x=\"42.5\" y=\"-57.5\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">10x100</text>\n", | |
| "</g>\n", | |
| "<!-- 11839368744124124526 -->\n", | |
| "<g id=\"node3\" class=\"node\"><title>11839368744124124526</title>\n", | |
| "<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"54,-213.5 0,-213.5 0,-168.5 54,-168.5 54,-213.5\"/>\n", | |
| "<text text-anchor=\"start\" x=\"13.5\" y=\"-196.5\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Linear</text>\n", | |
| "<text text-anchor=\"start\" x=\"34\" y=\"-175\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">x2</text>\n", | |
| "</g>\n", | |
| "<!-- 11839368744124124526->Sequential/Sigmoid[2]/outputs/9 -->\n", | |
| "<g id=\"edge2\" class=\"edge\"><title>11839368744124124526->Sequential/Sigmoid[2]/outputs/9</title>\n", | |
| "<path fill=\"none\" stroke=\"#000000\" d=\"M27,-168.494C27,-156.975 27,-142.662 27,-130.357\"/>\n", | |
| "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"30.5001,-130.182 27,-120.182 23.5001,-130.182 30.5001,-130.182\"/>\n", | |
| "<text text-anchor=\"middle\" x=\"42.5\" y=\"-141.5\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">10x100</text>\n", | |
| "</g>\n", | |
| "</g>\n", | |
| "</svg>\n" | |
| ], | |
| "text/plain": [ | |
| "<hiddenlayer.graph.Graph at 0x114e87ac8>" | |
| ] | |
| }, | |
| "execution_count": 81, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "hl.build_graph(model2, torch.zeros([10, 1]))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 82, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "image/svg+xml": [ | |
| "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n", | |
| "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n", | |
| " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n", | |
| "<!-- Generated by graphviz version 2.36.0 (20140111.2315)\n", | |
| " -->\n", | |
| "<!-- Title: %3 Pages: 1 -->\n", | |
| "<svg width=\"325pt\" height=\"686pt\"\n", | |
| " viewBox=\"0.00 0.00 325.00 686.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n", | |
| "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(72 650)\">\n", | |
| "<title>%3</title>\n", | |
| "<polygon fill=\"#ffffff\" stroke=\"none\" points=\"-72,36 -72,-650 253,-650 253,36 -72,36\"/>\n", | |
| "<!-- Sequential/Linear[0]/outputs/7 -->\n", | |
| "<g id=\"node1\" class=\"node\"><title>Sequential/Linear[0]/outputs/7</title>\n", | |
| "<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"58.25,-614 -0.25,-614 -0.25,-578 58.25,-578 58.25,-614\"/>\n", | |
| "<text text-anchor=\"start\" x=\"8.5\" y=\"-593.5\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Transpose</text>\n", | |
| "</g>\n", | |
| "<!-- Sequential/Linear[0]/outputs/8 -->\n", | |
| "<g id=\"node2\" class=\"node\"><title>Sequential/Linear[0]/outputs/8</title>\n", | |
| "<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"56,-540 2,-540 2,-504 56,-504 56,-540\"/>\n", | |
| "<text text-anchor=\"start\" x=\"12.5\" y=\"-519.5\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">MatMul</text>\n", | |
| "</g>\n", | |
| "<!-- Sequential/Linear[0]/outputs/7->Sequential/Linear[0]/outputs/8 -->\n", | |
| "<g id=\"edge1\" class=\"edge\"><title>Sequential/Linear[0]/outputs/7->Sequential/Linear[0]/outputs/8</title>\n", | |
| "<path fill=\"none\" stroke=\"#000000\" d=\"M29,-577.937C29,-569.807 29,-559.876 29,-550.705\"/>\n", | |
| "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"32.5001,-550.441 29,-540.441 25.5001,-550.441 32.5001,-550.441\"/>\n", | |
| "</g>\n", | |
| "<!-- /outputs/9 -->\n", | |
| "<g id=\"node3\" class=\"node\"><title>/outputs/9</title>\n", | |
| "<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"56,-456 2,-456 2,-420 56,-420 56,-456\"/>\n", | |
| "<text text-anchor=\"start\" x=\"20.5\" y=\"-435.5\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Add</text>\n", | |
| "</g>\n", | |
| "<!-- Sequential/Linear[0]/outputs/8->/outputs/9 -->\n", | |
| "<g id=\"edge2\" class=\"edge\"><title>Sequential/Linear[0]/outputs/8->/outputs/9</title>\n", | |
| "<path fill=\"none\" stroke=\"#000000\" d=\"M29,-503.61C29,-492.774 29,-478.601 29,-466.291\"/>\n", | |
| "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"32.5001,-466.084 29,-456.084 25.5001,-466.084 32.5001,-466.084\"/>\n", | |
| "<text text-anchor=\"middle\" x=\"37\" y=\"-477.5\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">100</text>\n", | |
| "</g>\n", | |
| "<!-- Sequential/Linear[1]/outputs/11 -->\n", | |
| "<g id=\"node5\" class=\"node\"><title>Sequential/Linear[1]/outputs/11</title>\n", | |
| "<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"104,-372 50,-372 50,-336 104,-336 104,-372\"/>\n", | |
| "<text text-anchor=\"start\" x=\"60.5\" y=\"-351.5\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">MatMul</text>\n", | |
| "</g>\n", | |
| "<!-- /outputs/9->Sequential/Linear[1]/outputs/11 -->\n", | |
| "<g id=\"edge3\" class=\"edge\"><title>/outputs/9->Sequential/Linear[1]/outputs/11</title>\n", | |
| "<path fill=\"none\" stroke=\"#000000\" d=\"M39.1796,-419.61C45.7786,-408.336 54.4924,-393.45 61.8925,-380.809\"/>\n", | |
| "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"64.9684,-382.482 66.9997,-372.084 58.9273,-378.946 64.9684,-382.482\"/>\n", | |
| "<text text-anchor=\"middle\" x=\"64\" y=\"-393.5\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">100</text>\n", | |
| "</g>\n", | |
| "<!-- Sequential/Linear[1]/outputs/10 -->\n", | |
| "<g id=\"node4\" class=\"node\"><title>Sequential/Linear[1]/outputs/10</title>\n", | |
| "<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"133.25,-456 74.75,-456 74.75,-420 133.25,-420 133.25,-456\"/>\n", | |
| "<text text-anchor=\"start\" x=\"83.5\" y=\"-435.5\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Transpose</text>\n", | |
| "</g>\n", | |
| "<!-- Sequential/Linear[1]/outputs/10->Sequential/Linear[1]/outputs/11 -->\n", | |
| "<g id=\"edge4\" class=\"edge\"><title>Sequential/Linear[1]/outputs/10->Sequential/Linear[1]/outputs/11</title>\n", | |
| "<path fill=\"none\" stroke=\"#000000\" d=\"M98.274,-419.61C94.6701,-408.665 89.945,-394.315 85.8635,-381.919\"/>\n", | |
| "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"89.0772,-380.488 82.6252,-372.084 82.4283,-382.677 89.0772,-380.488\"/>\n", | |
| "</g>\n", | |
| "<!-- /outputs/12 -->\n", | |
| "<g id=\"node6\" class=\"node\"><title>/outputs/12</title>\n", | |
| "<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"104,-288 50,-288 50,-252 104,-252 104,-288\"/>\n", | |
| "<text text-anchor=\"start\" x=\"68.5\" y=\"-267.5\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Add</text>\n", | |
| "</g>\n", | |
| "<!-- Sequential/Linear[1]/outputs/11->/outputs/12 -->\n", | |
| "<g id=\"edge5\" class=\"edge\"><title>Sequential/Linear[1]/outputs/11->/outputs/12</title>\n", | |
| "<path fill=\"none\" stroke=\"#000000\" d=\"M77,-335.61C77,-324.774 77,-310.601 77,-298.291\"/>\n", | |
| "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"80.5001,-298.084 77,-288.084 73.5001,-298.084 80.5001,-298.084\"/>\n", | |
| "<text text-anchor=\"middle\" x=\"85\" y=\"-309.5\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">100</text>\n", | |
| "</g>\n", | |
| "<!-- Sequential/Sigmoid[2]/outputs/13 -->\n", | |
| "<g id=\"node7\" class=\"node\"><title>Sequential/Sigmoid[2]/outputs/13</title>\n", | |
| "<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"104,-204 50,-204 50,-168 104,-168 104,-204\"/>\n", | |
| "<text text-anchor=\"start\" x=\"60.5\" y=\"-183.5\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Sigmoid</text>\n", | |
| "</g>\n", | |
| "<!-- /outputs/12->Sequential/Sigmoid[2]/outputs/13 -->\n", | |
| "<g id=\"edge6\" class=\"edge\"><title>/outputs/12->Sequential/Sigmoid[2]/outputs/13</title>\n", | |
| "<path fill=\"none\" stroke=\"#000000\" d=\"M77,-251.61C77,-240.774 77,-226.601 77,-214.291\"/>\n", | |
| "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"80.5001,-214.084 77,-204.084 73.5001,-214.084 80.5001,-214.084\"/>\n", | |
| "<text text-anchor=\"middle\" x=\"85\" y=\"-225.5\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">100</text>\n", | |
| "</g>\n", | |
| "<!-- Sequential/Linear[3]/outputs/15 -->\n", | |
| "<g id=\"node9\" class=\"node\"><title>Sequential/Linear[3]/outputs/15</title>\n", | |
| "<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"152,-120 98,-120 98,-84 152,-84 152,-120\"/>\n", | |
| "<text text-anchor=\"start\" x=\"108.5\" y=\"-99.5\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">MatMul</text>\n", | |
| "</g>\n", | |
| "<!-- Sequential/Sigmoid[2]/outputs/13->Sequential/Linear[3]/outputs/15 -->\n", | |
| "<g id=\"edge7\" class=\"edge\"><title>Sequential/Sigmoid[2]/outputs/13->Sequential/Linear[3]/outputs/15</title>\n", | |
| "<path fill=\"none\" stroke=\"#000000\" d=\"M87.1796,-167.61C93.7786,-156.336 102.492,-141.45 109.892,-128.809\"/>\n", | |
| "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"112.968,-130.482 115,-120.084 106.927,-126.946 112.968,-130.482\"/>\n", | |
| "<text text-anchor=\"middle\" x=\"111\" y=\"-141.5\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">100</text>\n", | |
| "</g>\n", | |
| "<!-- Sequential/Linear[3]/outputs/14 -->\n", | |
| "<g id=\"node8\" class=\"node\"><title>Sequential/Linear[3]/outputs/14</title>\n", | |
| "<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"181.25,-204 122.75,-204 122.75,-168 181.25,-168 181.25,-204\"/>\n", | |
| "<text text-anchor=\"start\" x=\"131.5\" y=\"-183.5\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Transpose</text>\n", | |
| "</g>\n", | |
| "<!-- Sequential/Linear[3]/outputs/14->Sequential/Linear[3]/outputs/15 -->\n", | |
| "<g id=\"edge8\" class=\"edge\"><title>Sequential/Linear[3]/outputs/14->Sequential/Linear[3]/outputs/15</title>\n", | |
| "<path fill=\"none\" stroke=\"#000000\" d=\"M146.274,-167.61C142.67,-156.665 137.945,-142.315 133.864,-129.919\"/>\n", | |
| "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"137.077,-128.488 130.625,-120.084 130.428,-130.677 137.077,-128.488\"/>\n", | |
| "</g>\n", | |
| "<!-- /outputs/16 -->\n", | |
| "<g id=\"node10\" class=\"node\"><title>/outputs/16</title>\n", | |
| "<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"152,-36 98,-36 98,-0 152,-0 152,-36\"/>\n", | |
| "<text text-anchor=\"start\" x=\"116.5\" y=\"-15.5\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Add</text>\n", | |
| "</g>\n", | |
| "<!-- Sequential/Linear[3]/outputs/15->/outputs/16 -->\n", | |
| "<g id=\"edge9\" class=\"edge\"><title>Sequential/Linear[3]/outputs/15->/outputs/16</title>\n", | |
| "<path fill=\"none\" stroke=\"#000000\" d=\"M125,-83.6099C125,-72.7743 125,-58.6012 125,-46.2913\"/>\n", | |
| "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"128.5,-46.0838 125,-36.0838 121.5,-46.0839 128.5,-46.0838\"/>\n", | |
| "<text text-anchor=\"middle\" x=\"128\" y=\"-57.5\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">1</text>\n", | |
| "</g>\n", | |
| "</g>\n", | |
| "</svg>\n" | |
| ], | |
| "text/plain": [ | |
| "<hiddenlayer.graph.Graph at 0x114e89358>" | |
| ] | |
| }, | |
| "execution_count": 82, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "hl.build_graph(model2, torch.zeros([1]))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 83, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "inp = 2\n", | |
| "out = 1\n", | |
| "\n", | |
| "hid = 100" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 84, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "model3 = torch.nn.Sequential(torch.nn.Linear(inp, hid),\n", | |
| " torch.nn.Linear(hid, hid),\n", | |
| " torch.nn.Sigmoid(),\n", | |
| " torch.nn.Linear(hid, out))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 85, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "image/svg+xml": [ | |
| "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n", | |
| "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n", | |
| " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n", | |
| "<!-- Generated by graphviz version 2.36.0 (20140111.2315)\n", | |
| " -->\n", | |
| "<!-- Title: %3 Pages: 1 -->\n", | |
| "<svg width=\"202pt\" height=\"286pt\"\n", | |
| " viewBox=\"0.00 0.00 202.00 286.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n", | |
| "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(72 250)\">\n", | |
| "<title>%3</title>\n", | |
| "<polygon fill=\"#ffffff\" stroke=\"none\" points=\"-72,36 -72,-250 130,-250 130,36 -72,36\"/>\n", | |
| "<!-- Sequential/Sigmoid[2]/outputs/9 -->\n", | |
| "<g id=\"node1\" class=\"node\"><title>Sequential/Sigmoid[2]/outputs/9</title>\n", | |
| "<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"54,-120 0,-120 0,-84 54,-84 54,-120\"/>\n", | |
| "<text text-anchor=\"start\" x=\"10.5\" y=\"-99.5\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Sigmoid</text>\n", | |
| "</g>\n", | |
| "<!-- Sequential/Sigmoid[2]/outputs/10 -->\n", | |
| "<g id=\"node2\" class=\"node\"><title>Sequential/Sigmoid[2]/outputs/10</title>\n", | |
| "<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"54,-36 0,-36 0,-0 54,-0 54,-36\"/>\n", | |
| "<text text-anchor=\"start\" x=\"13.5\" y=\"-15.5\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Linear</text>\n", | |
| "</g>\n", | |
| "<!-- Sequential/Sigmoid[2]/outputs/9->Sequential/Sigmoid[2]/outputs/10 -->\n", | |
| "<g id=\"edge1\" class=\"edge\"><title>Sequential/Sigmoid[2]/outputs/9->Sequential/Sigmoid[2]/outputs/10</title>\n", | |
| "<path fill=\"none\" stroke=\"#000000\" d=\"M27,-83.6099C27,-72.7743 27,-58.6012 27,-46.2913\"/>\n", | |
| "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"30.5001,-46.0838 27,-36.0838 23.5001,-46.0839 30.5001,-46.0838\"/>\n", | |
| "<text text-anchor=\"middle\" x=\"42.5\" y=\"-57.5\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">10x100</text>\n", | |
| "</g>\n", | |
| "<!-- 12500975005242187305 -->\n", | |
| "<g id=\"node3\" class=\"node\"><title>12500975005242187305</title>\n", | |
| "<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"54,-213.5 0,-213.5 0,-168.5 54,-168.5 54,-213.5\"/>\n", | |
| "<text text-anchor=\"start\" x=\"13.5\" y=\"-196.5\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Linear</text>\n", | |
| "<text text-anchor=\"start\" x=\"34\" y=\"-175\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">x2</text>\n", | |
| "</g>\n", | |
| "<!-- 12500975005242187305->Sequential/Sigmoid[2]/outputs/9 -->\n", | |
| "<g id=\"edge2\" class=\"edge\"><title>12500975005242187305->Sequential/Sigmoid[2]/outputs/9</title>\n", | |
| "<path fill=\"none\" stroke=\"#000000\" d=\"M27,-168.494C27,-156.975 27,-142.662 27,-130.357\"/>\n", | |
| "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"30.5001,-130.182 27,-120.182 23.5001,-130.182 30.5001,-130.182\"/>\n", | |
| "<text text-anchor=\"middle\" x=\"42.5\" y=\"-141.5\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">10x100</text>\n", | |
| "</g>\n", | |
| "</g>\n", | |
| "</svg>\n" | |
| ], | |
| "text/plain": [ | |
| "<hiddenlayer.graph.Graph at 0x114e93c18>" | |
| ] | |
| }, | |
| "execution_count": 85, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "hl.build_graph(model3, torch.zeros([10, 2]))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 86, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "image/svg+xml": [ | |
| "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n", | |
| "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n", | |
| " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n", | |
| "<!-- Generated by graphviz version 2.36.0 (20140111.2315)\n", | |
| " -->\n", | |
| "<!-- Title: %3 Pages: 1 -->\n", | |
| "<svg width=\"325pt\" height=\"686pt\"\n", | |
| " viewBox=\"0.00 0.00 325.00 686.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n", | |
| "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(72 650)\">\n", | |
| "<title>%3</title>\n", | |
| "<polygon fill=\"#ffffff\" stroke=\"none\" points=\"-72,36 -72,-650 253,-650 253,36 -72,36\"/>\n", | |
| "<!-- Sequential/Linear[0]/outputs/7 -->\n", | |
| "<g id=\"node1\" class=\"node\"><title>Sequential/Linear[0]/outputs/7</title>\n", | |
| "<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"58.25,-614 -0.25,-614 -0.25,-578 58.25,-578 58.25,-614\"/>\n", | |
| "<text text-anchor=\"start\" x=\"8.5\" y=\"-593.5\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Transpose</text>\n", | |
| "</g>\n", | |
| "<!-- Sequential/Linear[0]/outputs/8 -->\n", | |
| "<g id=\"node2\" class=\"node\"><title>Sequential/Linear[0]/outputs/8</title>\n", | |
| "<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"56,-540 2,-540 2,-504 56,-504 56,-540\"/>\n", | |
| "<text text-anchor=\"start\" x=\"12.5\" y=\"-519.5\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">MatMul</text>\n", | |
| "</g>\n", | |
| "<!-- Sequential/Linear[0]/outputs/7->Sequential/Linear[0]/outputs/8 -->\n", | |
| "<g id=\"edge1\" class=\"edge\"><title>Sequential/Linear[0]/outputs/7->Sequential/Linear[0]/outputs/8</title>\n", | |
| "<path fill=\"none\" stroke=\"#000000\" d=\"M29,-577.937C29,-569.807 29,-559.876 29,-550.705\"/>\n", | |
| "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"32.5001,-550.441 29,-540.441 25.5001,-550.441 32.5001,-550.441\"/>\n", | |
| "</g>\n", | |
| "<!-- /outputs/9 -->\n", | |
| "<g id=\"node3\" class=\"node\"><title>/outputs/9</title>\n", | |
| "<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"56,-456 2,-456 2,-420 56,-420 56,-456\"/>\n", | |
| "<text text-anchor=\"start\" x=\"20.5\" y=\"-435.5\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Add</text>\n", | |
| "</g>\n", | |
| "<!-- Sequential/Linear[0]/outputs/8->/outputs/9 -->\n", | |
| "<g id=\"edge2\" class=\"edge\"><title>Sequential/Linear[0]/outputs/8->/outputs/9</title>\n", | |
| "<path fill=\"none\" stroke=\"#000000\" d=\"M29,-503.61C29,-492.774 29,-478.601 29,-466.291\"/>\n", | |
| "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"32.5001,-466.084 29,-456.084 25.5001,-466.084 32.5001,-466.084\"/>\n", | |
| "<text text-anchor=\"middle\" x=\"37\" y=\"-477.5\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">100</text>\n", | |
| "</g>\n", | |
| "<!-- Sequential/Linear[1]/outputs/11 -->\n", | |
| "<g id=\"node5\" class=\"node\"><title>Sequential/Linear[1]/outputs/11</title>\n", | |
| "<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"104,-372 50,-372 50,-336 104,-336 104,-372\"/>\n", | |
| "<text text-anchor=\"start\" x=\"60.5\" y=\"-351.5\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">MatMul</text>\n", | |
| "</g>\n", | |
| "<!-- /outputs/9->Sequential/Linear[1]/outputs/11 -->\n", | |
| "<g id=\"edge3\" class=\"edge\"><title>/outputs/9->Sequential/Linear[1]/outputs/11</title>\n", | |
| "<path fill=\"none\" stroke=\"#000000\" d=\"M39.1796,-419.61C45.7786,-408.336 54.4924,-393.45 61.8925,-380.809\"/>\n", | |
| "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"64.9684,-382.482 66.9997,-372.084 58.9273,-378.946 64.9684,-382.482\"/>\n", | |
| "<text text-anchor=\"middle\" x=\"64\" y=\"-393.5\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">100</text>\n", | |
| "</g>\n", | |
| "<!-- Sequential/Linear[1]/outputs/10 -->\n", | |
| "<g id=\"node4\" class=\"node\"><title>Sequential/Linear[1]/outputs/10</title>\n", | |
| "<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"133.25,-456 74.75,-456 74.75,-420 133.25,-420 133.25,-456\"/>\n", | |
| "<text text-anchor=\"start\" x=\"83.5\" y=\"-435.5\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Transpose</text>\n", | |
| "</g>\n", | |
| "<!-- Sequential/Linear[1]/outputs/10->Sequential/Linear[1]/outputs/11 -->\n", | |
| "<g id=\"edge4\" class=\"edge\"><title>Sequential/Linear[1]/outputs/10->Sequential/Linear[1]/outputs/11</title>\n", | |
| "<path fill=\"none\" stroke=\"#000000\" d=\"M98.274,-419.61C94.6701,-408.665 89.945,-394.315 85.8635,-381.919\"/>\n", | |
| "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"89.0772,-380.488 82.6252,-372.084 82.4283,-382.677 89.0772,-380.488\"/>\n", | |
| "</g>\n", | |
| "<!-- /outputs/12 -->\n", | |
| "<g id=\"node6\" class=\"node\"><title>/outputs/12</title>\n", | |
| "<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"104,-288 50,-288 50,-252 104,-252 104,-288\"/>\n", | |
| "<text text-anchor=\"start\" x=\"68.5\" y=\"-267.5\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Add</text>\n", | |
| "</g>\n", | |
| "<!-- Sequential/Linear[1]/outputs/11->/outputs/12 -->\n", | |
| "<g id=\"edge5\" class=\"edge\"><title>Sequential/Linear[1]/outputs/11->/outputs/12</title>\n", | |
| "<path fill=\"none\" stroke=\"#000000\" d=\"M77,-335.61C77,-324.774 77,-310.601 77,-298.291\"/>\n", | |
| "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"80.5001,-298.084 77,-288.084 73.5001,-298.084 80.5001,-298.084\"/>\n", | |
| "<text text-anchor=\"middle\" x=\"85\" y=\"-309.5\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">100</text>\n", | |
| "</g>\n", | |
| "<!-- Sequential/Sigmoid[2]/outputs/13 -->\n", | |
| "<g id=\"node7\" class=\"node\"><title>Sequential/Sigmoid[2]/outputs/13</title>\n", | |
| "<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"104,-204 50,-204 50,-168 104,-168 104,-204\"/>\n", | |
| "<text text-anchor=\"start\" x=\"60.5\" y=\"-183.5\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Sigmoid</text>\n", | |
| "</g>\n", | |
| "<!-- /outputs/12->Sequential/Sigmoid[2]/outputs/13 -->\n", | |
| "<g id=\"edge6\" class=\"edge\"><title>/outputs/12->Sequential/Sigmoid[2]/outputs/13</title>\n", | |
| "<path fill=\"none\" stroke=\"#000000\" d=\"M77,-251.61C77,-240.774 77,-226.601 77,-214.291\"/>\n", | |
| "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"80.5001,-214.084 77,-204.084 73.5001,-214.084 80.5001,-214.084\"/>\n", | |
| "<text text-anchor=\"middle\" x=\"85\" y=\"-225.5\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">100</text>\n", | |
| "</g>\n", | |
| "<!-- Sequential/Linear[3]/outputs/15 -->\n", | |
| "<g id=\"node9\" class=\"node\"><title>Sequential/Linear[3]/outputs/15</title>\n", | |
| "<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"152,-120 98,-120 98,-84 152,-84 152,-120\"/>\n", | |
| "<text text-anchor=\"start\" x=\"108.5\" y=\"-99.5\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">MatMul</text>\n", | |
| "</g>\n", | |
| "<!-- Sequential/Sigmoid[2]/outputs/13->Sequential/Linear[3]/outputs/15 -->\n", | |
| "<g id=\"edge7\" class=\"edge\"><title>Sequential/Sigmoid[2]/outputs/13->Sequential/Linear[3]/outputs/15</title>\n", | |
| "<path fill=\"none\" stroke=\"#000000\" d=\"M87.1796,-167.61C93.7786,-156.336 102.492,-141.45 109.892,-128.809\"/>\n", | |
| "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"112.968,-130.482 115,-120.084 106.927,-126.946 112.968,-130.482\"/>\n", | |
| "<text text-anchor=\"middle\" x=\"111\" y=\"-141.5\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">100</text>\n", | |
| "</g>\n", | |
| "<!-- Sequential/Linear[3]/outputs/14 -->\n", | |
| "<g id=\"node8\" class=\"node\"><title>Sequential/Linear[3]/outputs/14</title>\n", | |
| "<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"181.25,-204 122.75,-204 122.75,-168 181.25,-168 181.25,-204\"/>\n", | |
| "<text text-anchor=\"start\" x=\"131.5\" y=\"-183.5\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Transpose</text>\n", | |
| "</g>\n", | |
| "<!-- Sequential/Linear[3]/outputs/14->Sequential/Linear[3]/outputs/15 -->\n", | |
| "<g id=\"edge8\" class=\"edge\"><title>Sequential/Linear[3]/outputs/14->Sequential/Linear[3]/outputs/15</title>\n", | |
| "<path fill=\"none\" stroke=\"#000000\" d=\"M146.274,-167.61C142.67,-156.665 137.945,-142.315 133.864,-129.919\"/>\n", | |
| "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"137.077,-128.488 130.625,-120.084 130.428,-130.677 137.077,-128.488\"/>\n", | |
| "</g>\n", | |
| "<!-- /outputs/16 -->\n", | |
| "<g id=\"node10\" class=\"node\"><title>/outputs/16</title>\n", | |
| "<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"152,-36 98,-36 98,-0 152,-0 152,-36\"/>\n", | |
| "<text text-anchor=\"start\" x=\"116.5\" y=\"-15.5\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Add</text>\n", | |
| "</g>\n", | |
| "<!-- Sequential/Linear[3]/outputs/15->/outputs/16 -->\n", | |
| "<g id=\"edge9\" class=\"edge\"><title>Sequential/Linear[3]/outputs/15->/outputs/16</title>\n", | |
| "<path fill=\"none\" stroke=\"#000000\" d=\"M125,-83.6099C125,-72.7743 125,-58.6012 125,-46.2913\"/>\n", | |
| "<polygon fill=\"#000000\" stroke=\"#000000\" points=\"128.5,-46.0838 125,-36.0838 121.5,-46.0839 128.5,-46.0838\"/>\n", | |
| "<text text-anchor=\"middle\" x=\"128\" y=\"-57.5\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">1</text>\n", | |
| "</g>\n", | |
| "</g>\n", | |
| "</svg>\n" | |
| ], | |
| "text/plain": [ | |
| "<hiddenlayer.graph.Graph at 0x114e96978>" | |
| ] | |
| }, | |
| "execution_count": 86, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "hl.build_graph(model3, torch.zeros([2]))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.7.1" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 2 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment