Created
February 14, 2021 23:25
-
-
Save sorrge/ef0b4cbf53c3496a09596d53663655f9 to your computer and use it in GitHub Desktop.
DALL-E training on a toy dataset
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "%load_ext autoreload\n", | |
| "%autoreload 2" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "<ipython-input-2-6efe993dfcaf>:13: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n", | |
| " from tqdm.autonotebook import tqdm, trange\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "import math\n", | |
| "import itertools\n", | |
| "import os\n", | |
| "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n", | |
| "import glob\n", | |
| "\n", | |
| "import numpy as np\n", | |
| "import cairo\n", | |
| "import matplotlib.pyplot as plt\n", | |
| "import matplotlib.colors as mcolors\n", | |
| "%config InlineBackend.figure_format = 'retina'\n", | |
| "from PIL import Image\n", | |
| "from tqdm.autonotebook import tqdm, trange\n", | |
| "import torch\n", | |
| "from dalle_pytorch import DiscreteVAE, DALLE" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": {}, | |
| "source": [ | |
| "# Dataset generation" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Total shapes: 9216\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "<matplotlib.image.AxesImage at 0x7fc55e6c0d30>" | |
| ] | |
| }, | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| }, | |
| { | |
| "data": { |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment