Skip to content

Instantly share code, notes, and snippets.

@Ankita-Das
Created February 7, 2020 00:52
Show Gist options
  • Select an option

  • Save Ankita-Das/ef6b8c84168995392d79b3c3f2f00b8e to your computer and use it in GitHub Desktop.

Select an option

Save Ankita-Das/ef6b8c84168995392d79b3c3f2f00b8e to your computer and use it in GitHub Desktop.
Style_Transfer_Exercise.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
},
"colab": {
"name": "Style_Transfer_Exercise.ipynb",
"provenance": [],
"include_colab_link": true
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/Ankita-Das/ef6b8c84168995392d79b3c3f2f00b8e/style_transfer_exercise.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zHG3u8iqU8eA",
"colab_type": "text"
},
"source": [
"# Style Transfer with Deep Neural Networks\n",
"\n",
"\n",
"In this notebook, we’ll *recreate* a style transfer method that is outlined in the paper, [Image Style Transfer Using Convolutional Neural Networks, by Gatys](https://www.cv-foundation.org/openaccess/content_cvpr_2016/papers/Gatys_Image_Style_Transfer_CVPR_2016_paper.pdf) in PyTorch.\n",
"\n",
"In this paper, style transfer uses the features found in the 19-layer VGG Network, which is comprised of a series of convolutional and pooling layers, and a few fully-connected layers. In the image below, the convolutional layers are named by stack and their order in the stack. Conv_1_1 is the first convolutional layer that an image is passed through, in the first stack. Conv_2_1 is the first convolutional layer in the *second* stack. The deepest convolutional layer in the network is conv_5_4.\n",
"\n",
"<img src='notebook_ims/vgg19_convlayers.png' width=80% />\n",
"\n",
"### Separating Style and Content\n",
"\n",
"Style transfer relies on separating the content and style of an image. Given one content image and one style image, we aim to create a new, _target_ image which should contain our desired content and style components:\n",
"* objects and their arrangement are similar to that of the **content image**\n",
"* style, colors, and textures are similar to that of the **style image**\n",
"\n",
"An example is shown below, where the content image is of a cat, and the style image is of [Hokusai's Great Wave](https://en.wikipedia.org/wiki/The_Great_Wave_off_Kanagawa). The generated target image still contains the cat but is stylized with the waves, blue and beige colors, and block print textures of the style image!\n",
"\n",
"<img src='notebook_ims/style_tx_cat.png' width=80% />\n",
"\n",
"In this notebook, we'll use a pre-trained VGG19 Net to extract content or style features from a passed in image. We'll then formalize the idea of content and style _losses_ and use those to iteratively update our target image until we get a result that we want. You are encouraged to use a style and content image of your own and share your work on Twitter with @udacity; we'd love to see what you come up with!"
]
},
{
"cell_type": "code",
"metadata": {
"id": "sXlGMQfDU8eH",
"colab_type": "code",
"colab": {}
},
"source": [
"# import resources\n",
"%matplotlib inline\n",
"\n",
"from PIL import Image\n",
"from io import BytesIO\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"import torch\n",
"import torch.optim as optim\n",
"import requests\n",
"from torchvision import transforms, models\n",
"\n",
"models.vgg19?\n"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "K78s4SVwU8eV",
"colab_type": "text"
},
"source": [
"## Load in VGG19 (features)\n",
"\n",
"VGG19 is split into two portions:\n",
"* `vgg19.features`, which are all the convolutional and pooling layers\n",
"* `vgg19.classifier`, which are the three linear, classifier layers at the end\n",
"\n",
"We only need the `features` portion, which we're going to load in and \"freeze\" the weights of, below."
]
},
{
"cell_type": "code",
"metadata": {
"id": "t5pUWQInU8eY",
"colab_type": "code",
"colab": {}
},
"source": [
"# get the \"features\" portion of VGG19 (we will not need the \"classifier\" portion)\n",
"vgg = models.vgg19(pretrained=True).features\n",
"\n",
"# freeze all VGG parameters since we're only optimizing the target image\n",
"for param in vgg.parameters():\n",
" param.requires_grad_(False)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "L55RiNMrU8eg",
"colab_type": "code",
"outputId": "08cae49a-0bb3-4585-82e8-b99f23ac815b",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 712
}
},
"source": [
"# move the model to GPU, if available\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"print(\"device:\",device)\n",
"\n",
"vgg.to(device)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"device: cuda\n"
],
"name": "stdout"
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"Sequential(\n",
" (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (1): ReLU(inplace=True)\n",
" (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (3): ReLU(inplace=True)\n",
" (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
" (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (6): ReLU(inplace=True)\n",
" (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (8): ReLU(inplace=True)\n",
" (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
" (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (11): ReLU(inplace=True)\n",
" (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (13): ReLU(inplace=True)\n",
" (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (15): ReLU(inplace=True)\n",
" (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (17): ReLU(inplace=True)\n",
" (18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
" (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (20): ReLU(inplace=True)\n",
" (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (22): ReLU(inplace=True)\n",
" (23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (24): ReLU(inplace=True)\n",
" (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (26): ReLU(inplace=True)\n",
" (27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
" (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (29): ReLU(inplace=True)\n",
" (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (31): ReLU(inplace=True)\n",
" (32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (33): ReLU(inplace=True)\n",
" (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (35): ReLU(inplace=True)\n",
" (36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
")"
]
},
"metadata": {
"tags": []
},
"execution_count": 3
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "b-0KJ6GvU8eq",
"colab_type": "text"
},
"source": [
"### Load in Content and Style Images\n",
"\n",
"You can load in any images you want! Below, we've provided a helper function for loading in any type and size of image. The `load_image` function also converts images to normalized Tensors.\n",
"\n",
"Additionally, it will be easier to have smaller images and to squish the content and style images so that they are of the same size."
]
},
{
"cell_type": "code",
"metadata": {
"id": "dd1wFSbgU8es",
"colab_type": "code",
"colab": {}
},
"source": [
"def load_image(img_path, max_size=400, shape=None):\n",
" ''' Load in and transform an image, making sure the image\n",
" is <= 400 pixels in the x-y dims.'''\n",
" if \"http\" in img_path:\n",
" response = requests.get(img_path)\n",
" image = Image.open(BytesIO(response.content)).convert('RGB')\n",
" else:\n",
" image = Image.open(img_path).convert('RGB')\n",
" \n",
" # large images will slow down processing\n",
" if max(image.size) > max_size:\n",
" size = max_size\n",
" else:\n",
" size = max(image.size)\n",
" \n",
" if shape is not None:\n",
" size = shape\n",
" print('style size:',size)\n",
" \n",
" in_transform = transforms.Compose([\n",
" transforms.Resize(size),\n",
" transforms.ToTensor(),\n",
" transforms.Normalize((0.485, 0.456, 0.406), \n",
" (0.229, 0.224, 0.225))])\n",
"\n",
" # discard the transparent, alpha channel (that's the :3) and add the batch dimension\n",
" #print(max(image.size))\n",
" image = in_transform(image)[:3,:,:].unsqueeze(0)\n",
" \n",
" return image"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "PwLUlp80iHNr",
"colab_type": "code",
"outputId": "93667944-3bab-4211-c21c-ec4db4471462",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
}
},
"source": [
"from google.colab import drive\n",
"drive.mount('/content/drive')"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9MecQJGnU8e0",
"colab_type": "text"
},
"source": [
"Next, I'm loading in images by file name and forcing the style image to be the same size as the content image."
]
},
{
"cell_type": "code",
"metadata": {
"id": "Fv2wLXPUU8e2",
"colab_type": "code",
"outputId": "42adcde9-d790-4dbf-d80f-c9dee49aa995",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 52
}
},
"source": [
"# load in content and style image\n",
"content= load_image('/content/drive/My Drive/style-transfer/images/content5.jpg').to(device)\n",
"target= load_image('/content/drive/My Drive/style-transfer/images/content_noise.jpg',shape=content.shape[-2:]).to(device)\n",
"\n",
"#print(content.shape[-2:])\n",
"# Resize style to match content, makes code easier\n",
"style = load_image('/content/drive/My Drive/style-transfer/images/style8.jpeg', shape=content.shape[-2:]).to(device)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"style size: torch.Size([597, 400])\n",
"style size: torch.Size([597, 400])\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "RDqTWh8GU8e9",
"colab_type": "code",
"colab": {}
},
"source": [
"# helper function for un-normalizing an image \n",
"# and converting it from a Tensor image to a NumPy image for display\n",
"def im_convert(tensor):\n",
" \"\"\" Display a tensor as an image. \"\"\"\n",
" \n",
" image = tensor.to(\"cpu\").clone().detach()\n",
" image = image.numpy().squeeze()\n",
" image = image.transpose(1,2,0)\n",
" image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))\n",
" image = image.clip(0, 1)\n",
"\n",
" return image"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "DNqPDhhEU8fH",
"colab_type": "code",
"outputId": "0fdfd180-dad4-4dd1-a38c-a981f78c0fbd",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 558
}
},
"source": [
"# display the images\n",
"fig, (ax1, ax2,ax3) = plt.subplots(1, 3, figsize=(20, 10))\n",
"# content and style ims side-by-side\n",
"ax1.imshow(im_convert(content))\n",
"ax2.imshow(im_convert(target))\n",
"ax3.imshow(im_convert(style))\n",
"\n"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x7fa97a089eb8>"
]
},
"metadata": {
"tags": []
},
"execution_count": 8
},
{
"output_type": "display_data",
"data": {
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment