Last active
January 7, 2020 20:24
Revisions
-
ProGamerGov revised this gist
Jan 7, 2020 . 1 changed file with 1 addition and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,6 +1,6 @@ # neural-style-pt with histogram loss The code here is based on [genekogan](https://github.com/genekogan)'s [neural-style-pt histogram loss](https://github.com/genekogan/neural-style-pt/tree/histogram-loss) code. The CUDA code comes from [pierre-wilmot](https://github.com/pierre-wilmot)'s code here: https://github.com/pierre-wilmot/NeuralTextureSynthesis ### Histogram Loss Layers -
ProGamerGov revised this gist
Jan 7, 2020 . 1 changed file with 3 additions and 3 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -16,11 +16,11 @@ This histogram loss layers will only work on a GPU device. You can download all 3 required files to your neural-style-pt directory with: ``` wget https://gist.githubusercontent.com/ProGamerGov/30e95ac9ff42f3e09288ae07dc012a76/raw/2683c4861dd47ba5f2066a35f9191a842dc2a6ea/neural_style_hist_loss.py wget https://gist.githubusercontent.com/ProGamerGov/30e95ac9ff42f3e09288ae07dc012a76/raw/2683c4861dd47ba5f2066a35f9191a842dc2a6ea/histogram.cpp wget https://gist.githubusercontent.com/ProGamerGov/30e95ac9ff42f3e09288ae07dc012a76/raw/2683c4861dd47ba5f2066a35f9191a842dc2a6ea/histogram.cu ``` ## Usage -
ProGamerGov revised this gist
Jan 7, 2020 . 1 changed file with 2 additions and 2 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -103,7 +103,7 @@ void matchHistogram(at::Tensor &featureMaps, at::Tensor &targetHistogram) { static std::map<unsigned int, at::Tensor> randomIndices; if (randomIndices[featureMaps.numel()].numel() != featureMaps.numel()) randomIndices[featureMaps.numel()] = torch::randperm(featureMaps.numel(), torch::TensorOptions().dtype(at::kLong)).cuda(); at::Tensor unsqueezed(featureMaps); if (unsqueezed.ndimension() == 1) @@ -139,4 +139,4 @@ void matchHistogram(at::Tensor &featureMaps, at::Tensor &targetHistogram) cudaFree(linkMap); cudaFree(localIndexes); } -
ProGamerGov revised this gist
Jan 6, 2020 . 1 changed file with 2 additions and 2 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -102,8 +102,8 @@ at::Tensor computeHistogram(at::Tensor const &t, unsigned int numBins) void matchHistogram(at::Tensor &featureMaps, at::Tensor &targetHistogram) { static std::map<unsigned int, at::Tensor> randomIndices; if (randomIndices[featureMaps.numel()].numel() != featureMaps.numel()) randomIndices[featureMaps.numel()] = torch::randperm(featureMaps.numel()).to(at::kLong).cuda(); at::Tensor unsqueezed(featureMaps); if (unsqueezed.ndimension() == 1) -
ProGamerGov revised this gist
Jan 6, 2020 . 1 changed file with 2 additions and 2 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -102,8 +102,8 @@ at::Tensor computeHistogram(at::Tensor const &t, unsigned int numBins) void matchHistogram(at::Tensor &featureMaps, at::Tensor &targetHistogram) { static std::map<unsigned int, at::Tensor> randomIndices; if (randomIndices[n*c].numel() != n*c) randomIndices[n*c] = torch::randperm(n*c, torch::TensorOptions().dtype(at::kLong)).cuda(); at::Tensor unsqueezed(featureMaps); if (unsqueezed.ndimension() == 1) -
ProGamerGov revised this gist
Jan 6, 2020 . 1 changed file with 1 addition and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,6 +1,6 @@ # neural-style-pt with histogram loss The code here is based on [genekogan](https://github.com/genekogan)'s [neural-style-pt histogram loss](https://github.com/genekogan/neural-style-pt/tree/histogram-loss) code. The CUDA code come from here: https://github.com/pierre-wilmot/NeuralTextureSynthesis ### Histogram Loss Layers -
ProGamerGov revised this gist
Dec 8, 2019 . 1 changed file with 11 additions and 0 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -12,6 +12,17 @@ You may have to install `ninja` via `pip3 install ninja`. This histogram loss layers will only work on a GPU device. You can download all 3 required files to your neural-style-pt directory with: ``` wget https://gist.githubusercontent.com/ProGamerGov/30e95ac9ff42f3e09288ae07dc012a76/raw/f439a0137148a950d6c28b3b69543850ac55cab9/neural_style_hist_loss.py wget https://gist.githubusercontent.com/ProGamerGov/30e95ac9ff42f3e09288ae07dc012a76/raw/f439a0137148a950d6c28b3b69543850ac55cab9/histogram.cpp wget https://gist.githubusercontent.com/ProGamerGov/30e95ac9ff42f3e09288ae07dc012a76/raw/f439a0137148a950d6c28b3b69543850ac55cab9/histogram.cu ``` ## Usage Basic usage: ``` -
ProGamerGov revised this gist
Dec 8, 2019 . 1 changed file with 2 additions and 0 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,3 +1,5 @@ # neural-style-pt with histogram loss The code here is based on [genekogan](https://github.com/genekogan)'s [neural-style-pt histogram loss](https://github.com/genekogan/neural-style-pt/tree/histogram-loss) code. ### Histogram Loss Layers -
ProGamerGov revised this gist
Dec 8, 2019 . 1 changed file with 8 additions and 5 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,17 +1,20 @@ The code here is based on [genekogan](https://github.com/genekogan)'s [neural-style-pt histogram loss](https://github.com/genekogan/neural-style-pt/tree/histogram-loss) code. ### Histogram Loss Layers Each histogram loss layer stores the style image's histogram as a target, and then uses that compute the difference to the image being stylized. ### Setup You may have to install `ninja` via `pip3 install ninja`. This histogram loss layers will only work on a GPU device. ## Usage Basic usage: ``` python neural_style_hist_loss.py -style_image <image.jpg> -content_image <image.jpg> ``` **New Options**: -
ProGamerGov revised this gist
Dec 8, 2019 . 1 changed file with 7 additions and 0 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,6 +1,13 @@ The code here is based on [genekogan](https://github.com/genekogan)'s [neural-style-pt histogram loss](https://github.com/genekogan/neural-style-pt/tree/histogram-loss) code. ### Setup You may have to install `ninja` via `pip3 install ninja`. This histogram loss layers will only work on a GPU device. ### Histogram Loss Layers -
ProGamerGov revised this gist
Dec 8, 2019 . 1 changed file with 9 additions and 2 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,9 +1,16 @@ The code here is based on [genekogan](https://github.com/genekogan)'s [neural-style-pt histogram loss](https://github.com/genekogan/neural-style-pt/tree/histogram-loss) code. ### Histogram Loss Layers Each histogram loss layer stores the style image's histogram as a target, and then uses that compute the difference to the image being stylized. **New Options**: * `-hist_weight`: How much to weight the histogram reconstruction term. Default is 1e2. * `-hist_layers`: Comma-separated list of layer names to use for histogram reconstruction. -
ProGamerGov revised this gist
Dec 8, 2019 . 1 changed file with 5 additions and 0 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,3 +1,8 @@ ### Histogram Loss Each histogram loss layer stores the style image's histogram as a target, and then uses that compute the difference to the image being stylized. **New Options**: * `-hist_weight`: How much to weight the histogram reconstruction term. Default is 1e2. -
ProGamerGov revised this gist
Dec 8, 2019 . 4 changed files with 691 additions and 0 deletions.There are no files selected for viewing
File renamed without changes.This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,12 @@ #include <torch/extension.h> #include <iostream> at::Tensor computeHistogram(at::Tensor const &t, unsigned int numBins = 256); void matchHistogram(at::Tensor &featureMaps, at::Tensor &targetHistogram); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("computeHistogram", &computeHistogram, "ComputeHistogram"); m.def("matchHistogram", &matchHistogram, "MatchHistogram"); } This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,142 @@ #include <torch/extension.h> #include <cuda.h> #include <cuda_runtime.h> #include <string.h> #include <stdio.h> #include <stdlib.h> #include <cuda_runtime.h> #include <math.h> #define THREAD_COUNT 1024 __global__ void computeHistogram(float *tensor, float *histogram, float *minv, float *maxv, unsigned int channels, unsigned int tensorSize, unsigned int nBins) { unsigned int index = threadIdx.x + blockIdx.x * blockDim.x; if (index < channels * tensorSize) { // Compute which channel we're in unsigned int channel = index / tensorSize; // Normalize the value in range [0, numBins] float value = (tensor[index] - minv[channel]) / (maxv[channel] - minv[channel]) * float(nBins); // Compute bin index int bin = min((unsigned int)(value), nBins - 1); // Increment relevant bin atomicAdd(histogram + (channel * nBins) + bin, 1); } } // return cummulative histogram shifed to the right by 1 // ==> histogram[c][0] alweays == 0 __global__ void accumulateHistogram(float *histogram, unsigned int nBins) { float t = 0; for (unsigned int i=0 ; i < nBins ; ++i) { float swap = histogram[i + blockIdx.x * nBins]; histogram[i + blockIdx.x * nBins ] = t; t += swap; } } __global__ void buildSortedLinkmap(float *tensor, unsigned int *linkMap, float *cumulativeHistogram, unsigned int *localIndexes, long *indirection, float *minv, float *maxv, unsigned int channels, unsigned int tensorSize, unsigned int nBins) { unsigned int index = threadIdx.x + blockIdx.x* blockDim.x; if (index < channels * tensorSize) { // Shuffle image -- Avoid the blurry top bug index = indirection[index]; // Compute which channel we're in unsigned int channel = index / tensorSize; // Normalize the value in range [0, numBins] float value = (tensor[index] - minv[channel]) / (maxv[channel] - minv[channel]) * float(nBins); // Compute bin index int binIndex = min((unsigned int)(value), nBins - 1); // Increment and retrieve the number of pixel in said bin int localIndex = atomicAdd(&localIndexes[(channel * 256) + binIndex], 1); // Retrieve the number of pixel in all bin lower (in cummulative histogram) unsigned int lowerPixelCount = cumulativeHistogram[(channel * 256) + binIndex]; // Set the linkmap for indes to it's position as "pseudo-sorted" linkMap[index] = lowerPixelCount + localIndex; } } __global__ void rebuild(float *tensor, unsigned int *linkMap, float *targetHistogram, float scale, unsigned int channels, unsigned int tensorSize) { unsigned int index = threadIdx.x + blockIdx.x* blockDim.x; if (index < channels * tensorSize) { unsigned int channel = index / tensorSize; unsigned int value = 0; for (int i=0 ; i < 256 ; ++i) if (linkMap[index] >= targetHistogram[(channel * 256) + i] * scale) value = i; tensor[index] = (float)value; } } at::Tensor computeHistogram(at::Tensor const &t, unsigned int numBins) { at::Tensor unsqueezed(t); unsqueezed = unsqueezed.cuda(); if (unsqueezed.ndimension() == 1) unsqueezed.unsqueeze_(0); if (unsqueezed.ndimension() > 2) unsqueezed = unsqueezed.view({unsqueezed.size(0), -1}); unsigned int c = unsqueezed.size(0); // Number od channels unsigned int n = unsqueezed.numel() / c; // Number of element per channel at::Tensor min = torch::min_values(unsqueezed, 1, true).cuda(); at::Tensor max = torch::max_values(unsqueezed, 1, true).cuda(); at::Tensor h = at::zeros({int(c), int(numBins)}, unsqueezed.type()).cuda(); computeHistogram<<<(c*n) / THREAD_COUNT + 1, THREAD_COUNT>>>(unsqueezed.data<float>(), h.data<float>(), min.data<float>(), max.data<float>(), c, n, numBins); return h; } void matchHistogram(at::Tensor &featureMaps, at::Tensor &targetHistogram) { static std::map<unsigned int, at::Tensor> randomIndices; if (randomIndices[featureMaps.numel()].numel() != featureMaps.numel()) randomIndices[featureMaps.numel()] = torch::randperm(featureMaps.numel()).to(at::kLong).cuda(); at::Tensor unsqueezed(featureMaps); if (unsqueezed.ndimension() == 1) unsqueezed.unsqueeze_(0); if (unsqueezed.ndimension() > 2) unsqueezed = unsqueezed.view({unsqueezed.size(0), -1}); unsigned int nBins = targetHistogram.size(1); unsigned int c = unsqueezed.size(0); // Number of channels unsigned int n = unsqueezed.numel() / c; // Number of element per channel // Scale = numberOf Element in features / number of element in target float scale = float(featureMaps.numel()) / targetHistogram.sum().item<float>(); at::Tensor featuresHistogram = computeHistogram(unsqueezed, nBins); accumulateHistogram<<<c, 1>>>(featuresHistogram.data<float>(), nBins); accumulateHistogram<<<c, 1>>>(targetHistogram.data<float>(), nBins); unsigned int *linkMap = NULL; cudaMalloc(&linkMap, c * n * sizeof(unsigned int)); unsigned int *localIndexes = NULL; cudaMalloc(&localIndexes, c * nBins * sizeof(unsigned int)); cudaMemset(localIndexes, 0, c * nBins * sizeof(unsigned int)); at::Tensor min = torch::min_values(unsqueezed, 1, true).cuda(); at::Tensor max = torch::max_values(unsqueezed, 1, true).cuda(); buildSortedLinkmap<<<(c*n) / THREAD_COUNT + 1, THREAD_COUNT>>>(featureMaps.data<float>(), linkMap, featuresHistogram.data<float>(), localIndexes, randomIndices[featureMaps.numel()].data<long>(), min.data<float>(), max.data<float>(), c, n, nBins); rebuild<<<(c*n) / THREAD_COUNT + 1, THREAD_COUNT>>>(featureMaps.data<float>(), linkMap, targetHistogram.data<float>(), scale, c, n); featureMaps.div_(float(nBins)); cudaFree(linkMap); cudaFree(localIndexes); } This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,537 @@ import os import copy import torch import torch.nn as nn import torch.optim as optim import torchvision.transforms as transforms from torch.utils.cpp_extension import load cpp = torch.utils.cpp_extension.load(name="histogram_cpp", sources=["histogram.cpp", "histogram.cu"]) from PIL import Image from CaffeLoader import loadCaffemodel, ModelParallel import argparse parser = argparse.ArgumentParser() # Basic options parser.add_argument("-style_image", help="Style target image", default='examples/inputs/seated-nude.jpg') parser.add_argument("-style_blend_weights", default=None) parser.add_argument("-content_image", help="Content target image", default='examples/inputs/tubingen.jpg') parser.add_argument("-image_size", help="Maximum height / width of generated image", type=int, default=512) parser.add_argument("-gpu", help="Zero-indexed ID of the GPU to use; for CPU mode set -gpu = c", default=0) # Optimization options parser.add_argument("-content_weight", type=float, default=5e0) parser.add_argument("-style_weight", type=float, default=1e2) parser.add_argument("-normalize_weights", action='store_true') parser.add_argument("-hist_weight", type=float, default=1e2) parser.add_argument("-tv_weight", type=float, default=1e-3) parser.add_argument("-num_iterations", type=int, default=1000) parser.add_argument("-init", choices=['random', 'image'], default='random') parser.add_argument("-init_image", default=None) parser.add_argument("-optimizer", choices=['lbfgs', 'adam'], default='lbfgs') parser.add_argument("-learning_rate", type=float, default=1e0) parser.add_argument("-lbfgs_num_correction", type=int, default=100) # Output options parser.add_argument("-print_iter", type=int, default=50) parser.add_argument("-save_iter", type=int, default=100) parser.add_argument("-output_image", default='out.png') # Other options parser.add_argument("-style_scale", type=float, default=1.0) parser.add_argument("-original_colors", type=int, choices=[0, 1], default=0) parser.add_argument("-pooling", choices=['avg', 'max'], default='max') parser.add_argument("-model_file", type=str, default='models/vgg19-d01eb7cb.pth') parser.add_argument("-disable_check", action='store_true') parser.add_argument("-backend", choices=['nn', 'cudnn', 'mkl', 'mkldnn', 'openmp', 'mkl,cudnn', 'cudnn,mkl'], default='nn') parser.add_argument("-cudnn_autotune", action='store_true') parser.add_argument("-seed", type=int, default=-1) parser.add_argument("-content_layers", help="layers for content", default='relu4_2') parser.add_argument("-style_layers", help="layers for style", default='relu1_1,relu2_1,relu3_1,relu4_1,relu5_1') parser.add_argument("-multidevice_strategy", default='4,7,29') parser.add_argument("-hist_layers", help="layers for histogram", default='') params = parser.parse_args() Image.MAX_IMAGE_PIXELS = 1000000000 # Support gigapixel images def main(): dtype, multidevice, backward_device = setup_gpu() cnn, layerList = loadCaffemodel(params.model_file, params.pooling, params.gpu, params.disable_check) content_image = preprocess(params.content_image, params.image_size).type(dtype) style_image_input = params.style_image.split(',') style_image_list, ext = [], [".jpg", ".jpeg", ".png", ".tiff"] for image in style_image_input: if os.path.isdir(image): images = (image + "/" + file for file in os.listdir(image) if os.path.splitext(file)[1].lower() in ext) style_image_list.extend(images) else: style_image_list.append(image) style_images_caffe = [] for image in style_image_list: style_size = int(params.image_size * params.style_scale) img_caffe = preprocess(image, style_size).type(dtype) style_images_caffe.append(img_caffe) if params.init_image != None: image_size = (content_image.size(2), content_image.size(3)) init_image = preprocess(params.init_image, image_size).type(dtype) # Handle style blending weights for multiple style inputs style_blend_weights = [] if params.style_blend_weights == None: # Style blending not specified, so use equal weighting for i in style_image_list: style_blend_weights.append(1.0) for i, blend_weights in enumerate(style_blend_weights): style_blend_weights[i] = int(style_blend_weights[i]) else: style_blend_weights = params.style_blend_weights.split(',') assert len(style_blend_weights) == len(style_image_list), \ "-style_blend_weights and -style_images must have the same number of elements!" # Normalize the style blending weights so they sum to 1 style_blend_sum = 0 for i, blend_weights in enumerate(style_blend_weights): style_blend_weights[i] = float(style_blend_weights[i]) style_blend_sum = float(style_blend_sum) + style_blend_weights[i] for i, blend_weights in enumerate(style_blend_weights): style_blend_weights[i] = float(style_blend_weights[i]) / float(style_blend_sum) content_layers = params.content_layers.split(',') style_layers = params.style_layers.split(',') hist_layers = params.hist_layers.split(',') # Set up the network, inserting style and content loss modules cnn = copy.deepcopy(cnn) content_losses, style_losses, tv_losses, hist_losses = [], [], [], [] next_content_idx, next_style_idx, next_hist_idx = 1, 1, 1 net = nn.Sequential() c, r = 0, 0 if params.tv_weight > 0: tv_mod = TVLoss(params.tv_weight).type(dtype) net.add_module(str(len(net)), tv_mod) tv_losses.append(tv_mod) for i, layer in enumerate(list(cnn), 1): if next_content_idx <= len(content_layers) or next_style_idx <= len(style_layers) or next_hist_idx <= len(hist_layers): if isinstance(layer, nn.Conv2d): net.add_module(str(len(net)), layer) if layerList['C'][c] in content_layers: print("Setting up content layer " + str(i) + ": " + str(layerList['C'][c])) loss_module = ContentLoss(params.content_weight) net.add_module(str(len(net)), loss_module) content_losses.append(loss_module) if layerList['C'][c] in style_layers: print("Setting up style layer " + str(i) + ": " + str(layerList['C'][c])) loss_module = StyleLoss(params.style_weight) net.add_module(str(len(net)), loss_module) style_losses.append(loss_module) c+=1 if isinstance(layer, nn.ReLU): net.add_module(str(len(net)), layer) if layerList['R'][r] in content_layers: print("Setting up content layer " + str(i) + ": " + str(layerList['R'][r])) loss_module = ContentLoss(params.content_weight) net.add_module(str(len(net)), loss_module) content_losses.append(loss_module) next_content_idx += 1 if layerList['R'][r] in style_layers: print("Setting up style layer " + str(i) + ": " + str(layerList['R'][r])) loss_module = StyleLoss(params.style_weight) net.add_module(str(len(net)), loss_module) style_losses.append(loss_module) next_style_idx += 1 if layerList['R'][r] in hist_layers: print("Setting up histogram layer " + str(i) + ": " + str(layerList['R'][r])) loss_module = HistLoss(params.hist_weight) net.add_module(str(len(net)), loss_module) hist_losses.append(loss_module) next_hist_idx +=1 r+=1 if isinstance(layer, nn.MaxPool2d) or isinstance(layer, nn.AvgPool2d): net.add_module(str(len(net)), layer) if multidevice: net = setup_multi_device(net) print("Capturing histogram targets") for i in hist_losses: i.mode = 'captureS' net(style_images_caffe[0]) for i in hist_losses: i.mode = 'None' # Capture content targets for i in content_losses: i.mode = 'capture' print("Capturing content targets") print_torch(net, multidevice) net(content_image) # Capture style targets for i in content_losses: i.mode = 'None' for i, image in enumerate(style_images_caffe): print("Capturing style target " + str(i+1)) for j in style_losses: j.mode = 'capture' j.blend_weight = style_blend_weights[i] net(style_images_caffe[i]) # Set all loss modules to loss mode for i in content_losses: i.mode = 'loss' for i in style_losses: i.mode = 'loss' for i in hist_losses: i.mode = 'loss' # Maybe normalize content and style weights if params.normalize_weights: normalize_weights(content_losses, style_losses, hist_losses) # Freeze the network in order to prevent # unnecessary gradient calculations for param in net.parameters(): param.requires_grad = False # Initialize the image if params.seed >= 0: torch.manual_seed(params.seed) torch.cuda.manual_seed_all(params.seed) torch.backends.cudnn.deterministic=True if params.init == 'random': B, C, H, W = content_image.size() img = torch.randn(C, H, W).mul(0.001).unsqueeze(0).type(dtype) elif params.init == 'image': if params.init_image != None: img = init_image.clone() else: img = content_image.clone() img = nn.Parameter(img) def maybe_print(t, loss): if params.print_iter > 0 and t % params.print_iter == 0: print("Iteration " + str(t) + " / "+ str(params.num_iterations)) for i, loss_module in enumerate(content_losses): print(" Content " + str(i+1) + " loss: " + str(loss_module.loss.item())) for i, loss_module in enumerate(style_losses): print(" Style " + str(i+1) + " loss: " + str(loss_module.loss.item())) for i, loss_module in enumerate(hist_losses): print(" Histogram " + str(i+1) + " loss: " + str(loss_module.loss.item())) print(" Total loss: " + str(loss.item())) def maybe_save(t): should_save = params.save_iter > 0 and t % params.save_iter == 0 should_save = should_save or t == params.num_iterations if should_save: output_filename, file_extension = os.path.splitext(params.output_image) if t == params.num_iterations: filename = output_filename + str(file_extension) else: filename = str(output_filename) + "_" + str(t) + str(file_extension) disp = deprocess(img.clone()) # Maybe perform postprocessing for color-independent style transfer if params.original_colors == 1: disp = original_colors(deprocess(content_image.clone()), disp) disp.save(str(filename)) # Function to evaluate loss and gradient. We run the net forward and # backward to get the gradient, and sum up losses from the loss modules. # optim.lbfgs internally handles iteration and calls this function many # times, so we manually count the number of iterations to handle printing # and saving intermediate results. num_calls = [0] def feval(): num_calls[0] += 1 optimizer.zero_grad() net(img) loss = 0 for mod in content_losses: loss += mod.loss.to(backward_device) for mod in style_losses: loss += mod.loss.to(backward_device) if params.tv_weight > 0: for mod in tv_losses: loss += mod.loss.to(backward_device) for mod in hist_losses: loss += mod.loss.to(backward_device) loss.backward() maybe_save(num_calls[0]) maybe_print(num_calls[0], loss) return loss optimizer, loopVal = setup_optimizer(img) while num_calls[0] <= loopVal: optimizer.step(feval) # Configure the optimizer def setup_optimizer(img): if params.optimizer == 'lbfgs': print("Running optimization with L-BFGS") optim_state = { 'max_iter': params.num_iterations, 'tolerance_change': -1, 'tolerance_grad': -1, } if params.lbfgs_num_correction != 100: optim_state['history_size'] = params.lbfgs_num_correction optimizer = optim.LBFGS([img], **optim_state) loopVal = 1 elif params.optimizer == 'adam': print("Running optimization with ADAM") optimizer = optim.Adam([img], lr = params.learning_rate) loopVal = params.num_iterations - 1 return optimizer, loopVal def setup_gpu(): def setup_cuda(): if 'cudnn' in params.backend: torch.backends.cudnn.enabled = True if params.cudnn_autotune: torch.backends.cudnn.benchmark = True else: torch.backends.cudnn.enabled = False def setup_cpu(): if 'mkl' in params.backend and 'mkldnn' not in params.backend: torch.backends.mkl.enabled = True elif 'mkldnn' in params.backend: raise ValueError("MKL-DNN is not supported yet.") elif 'openmp' in params.backend: torch.backends.openmp.enabled = True multidevice = False if "," in str(params.gpu): devices = params.gpu.split(',') multidevice = True if 'c' in str(devices[0]).lower(): backward_device = "cpu" setup_cuda(), setup_cpu() else: backward_device = "cuda:" + devices[0] setup_cuda() dtype = torch.FloatTensor elif "c" not in str(params.gpu).lower(): setup_cuda() dtype, backward_device = torch.cuda.FloatTensor, "cuda:" + str(params.gpu) else: setup_cpu() dtype, backward_device = torch.FloatTensor, "cpu" return dtype, multidevice, backward_device def setup_multi_device(net): assert len(params.gpu.split(',')) - 1 == len(params.multidevice_strategy.split(',')), \ "The number of -multidevice_strategy layer indices minus 1, must be equal to the number of -gpu devices." new_net = ModelParallel(net, params.gpu, params.multidevice_strategy) return new_net # Preprocess an image before passing it to a model. # We need to rescale from [0, 1] to [0, 255], convert from RGB to BGR, # and subtract the mean pixel. def preprocess(image_name, image_size): image = Image.open(image_name).convert('RGB') if type(image_size) is not tuple: image_size = tuple([int((float(image_size) / max(image.size))*x) for x in (image.height, image.width)]) Loader = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor()]) rgb2bgr = transforms.Compose([transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])])]) Normalize = transforms.Compose([transforms.Normalize(mean=[103.939, 116.779, 123.68], std=[1,1,1])]) tensor = Normalize(rgb2bgr(Loader(image) * 256)).unsqueeze(0) return tensor # Undo the above preprocessing. def deprocess(output_tensor): Normalize = transforms.Compose([transforms.Normalize(mean=[-103.939, -116.779, -123.68], std=[1,1,1])]) bgr2rgb = transforms.Compose([transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])])]) output_tensor = bgr2rgb(Normalize(output_tensor.squeeze(0).cpu())) / 256 output_tensor.clamp_(0, 1) Image2PIL = transforms.ToPILImage() image = Image2PIL(output_tensor.cpu()) return image # Combine the Y channel of the generated image and the UV/CbCr channels of the # content image to perform color-independent style transfer. def original_colors(content, generated): content_channels = list(content.convert('YCbCr').split()) generated_channels = list(generated.convert('YCbCr').split()) content_channels[0] = generated_channels[0] return Image.merge('YCbCr', content_channels).convert('RGB') # Print like Lua/Torch7 def print_torch(net, multidevice): if multidevice: return simplelist = "" for i, layer in enumerate(net, 1): simplelist = simplelist + "(" + str(i) + ") -> " print("nn.Sequential ( \n [input -> " + simplelist + "output]") def strip(x): return str(x).replace(", ",',').replace("(",'').replace(")",'') + ", " def n(): return " (" + str(i) + "): " + "nn." + str(l).split("(", 1)[0] for i, l in enumerate(net, 1): if "2d" in str(l): ks, st, pd = strip(l.kernel_size), strip(l.stride), strip(l.padding) if "Conv2d" in str(l): ch = str(l.in_channels) + " -> " + str(l.out_channels) print(n() + "(" + ch + ", " + (ks).replace(",",'x', 1) + st + pd.replace(", ",')')) elif "Pool2d" in str(l): st = st.replace(" ",' ') + st.replace(", ",')') print(n() + "(" + ((ks).replace(",",'x' + ks, 1) + st).replace(", ",',')) else: print(n()) print(")") # Divide weights by channel size def normalize_weights(content_losses, style_losses, hist_losses=None): for n, i in enumerate(content_losses): i.strength = i.strength / max(i.target.size()) for n, i in enumerate(style_losses): i.strength = i.strength / max(i.target.size()) if hist_losses != None: for n, i in enumerate(hist_losses): i.strength = i.strength / max(i.target_size) # Define an nn Module to compute content loss class ContentLoss(nn.Module): def __init__(self, strength): super(ContentLoss, self).__init__() self.strength = strength self.crit = nn.MSELoss() self.mode = 'None' def forward(self, input): if self.mode == 'loss': self.loss = self.crit(input, self.target) * self.strength elif self.mode == 'capture': self.target = input.detach() return input class GramMatrix(nn.Module): def forward(self, input): B, C, H, W = input.size() x_flat = input.view(C, H * W) return torch.mm(x_flat, x_flat.t()) # Define an nn Module to compute style loss class StyleLoss(nn.Module): def __init__(self, strength): super(StyleLoss, self).__init__() self.target = torch.Tensor() self.strength = strength self.gram = GramMatrix() self.crit = nn.MSELoss() self.mode = 'None' self.blend_weight = None def forward(self, input): self.G = self.gram(input) self.G = self.G.div(input.nelement()) if self.mode == 'capture': if self.blend_weight == None: self.target = self.G.detach() elif self.target.nelement() == 0: self.target = self.G.detach().mul(self.blend_weight) else: self.target = self.target.add(self.blend_weight, self.G.detach()) elif self.mode == 'loss': self.loss = self.strength * self.crit(self.G, self.target) return input class TVLoss(nn.Module): def __init__(self, strength): super(TVLoss, self).__init__() self.strength = strength def forward(self, input): self.x_diff = input[:,:,1:,:] - input[:,:,:-1,:] self.y_diff = input[:,:,:,1:] - input[:,:,:,:-1] self.loss = self.strength * (torch.sum(torch.abs(self.x_diff)) + torch.sum(torch.abs(self.y_diff))) return input # Define an nn Module to compute histogram loss class HistLoss(nn.Module): def __init__(self, strength): super(HistLoss, self).__init__() self.crit = nn.MSELoss() self.mode = 'None' self.target_max = None self.target_min = None self.strength = strength def minmax(self, input): return torch.min(input[0].view(input.shape[1], -1), 1)[0].data.clone(), \ torch.max(input[0].view(input.shape[1], -1), 1)[0].data.clone() def calcHist(self, input, target, min_val, max_val): res = input.data.clone() cpp.matchHistogram(res, target.clone()) for c in range(res.size(0)): res[c].mul_(max_val[c] - min_val[c]) res[c].add_(min_val[c]) return res.data.unsqueeze(0) def forward(self, input): if self.mode == 'captureS': self.target_min, self.target_max = self.minmax(input) self.target_hist = cpp.computeHistogram(input[0], 256) self.target_size = list(input.detach().size()) elif self.mode == 'loss': target = self.calcHist(input[0], self.target_hist, self.target_min, self.target_max) self.loss = 0.01 * self.strength * self.crit(input, target) return input if __name__ == "__main__": main() -
ProGamerGov created this gist
Dec 8, 2019 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,4 @@ **New Options**: * `-hist_weight`: How much to weight the histogram reconstruction term. Default is 1e2. * `-hist_layers`: Comma-separated list of layer names to use for histogram reconstruction.