|
import os |
|
import copy |
|
import random |
|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
import torchvision.transforms as transforms |
|
|
|
from PIL import Image |
|
from CaffeLoader import loadCaffemodel, ModelParallel |
|
|
|
import argparse |
|
parser = argparse.ArgumentParser() |
|
# Basic options |
|
parser.add_argument("-style_image", help="Style target image", default='examples/inputs/seated-nude.jpg') |
|
parser.add_argument("-style_blend_weights", default=None) |
|
parser.add_argument("-content_image", help="Content target image", default='examples/inputs/tubingen.jpg') |
|
parser.add_argument("-image_size", help="Maximum height / width of generated image", type=int, default=512) |
|
parser.add_argument("-gpu", help="Zero-indexed ID of the GPU to use; for CPU mode set -gpu = c", default=0) |
|
|
|
# Optimization options |
|
parser.add_argument("-content_weight", type=float, default=5e0) |
|
parser.add_argument("-style_weight", type=float, default=1e2) |
|
parser.add_argument("-normalize_weights", action='store_true') |
|
parser.add_argument("-tv_weight", type=float, default=1e-3) |
|
parser.add_argument("-num_iterations", type=int, default=10) |
|
parser.add_argument("-init", choices=['random', 'image'], default='random') |
|
parser.add_argument("-init_image", default=None) |
|
parser.add_argument("-optimizer", choices=['lbfgs', 'adam'], default='lbfgs') |
|
parser.add_argument("-learning_rate", type=float, default=1e0) |
|
parser.add_argument("-lbfgs_num_correction", type=int, default=100) |
|
|
|
# Output options |
|
parser.add_argument("-print_iter", type=int, default=1) |
|
parser.add_argument("-save_iter", type=int, default=1) |
|
parser.add_argument("-output_image", default='out.png') |
|
|
|
# Other options |
|
parser.add_argument("-style_scale", type=float, default=1.0) |
|
parser.add_argument("-original_colors", type=int, choices=[0, 1], default=0) |
|
parser.add_argument("-pooling", choices=['avg', 'max'], default='max') |
|
parser.add_argument("-model_file", type=str, default='models/vgg19-d01eb7cb.pth') |
|
parser.add_argument("-disable_check", action='store_true') |
|
parser.add_argument("-backend", choices=['nn', 'cudnn', 'mkl', 'mkldnn', 'openmp', 'mkl,cudnn', 'cudnn,mkl'], default='nn') |
|
parser.add_argument("-cudnn_autotune", action='store_true') |
|
parser.add_argument("-seed", type=int, default=-1) |
|
|
|
parser.add_argument("-content_layers", help="layers for content", default='relu4_2') |
|
parser.add_argument("-style_layers", help="layers for style", default='relu1_1,relu2_1,relu3_1,relu4_1,relu5_1') |
|
|
|
parser.add_argument("-multidevice_strategy", default='4,7,29') |
|
|
|
# Tile options |
|
parser.add_argument("-tile_size", type=int, default=256) |
|
parser.add_argument("-overlap_percent", type=float, default=0.5) |
|
parser.add_argument("-tile_iter", type=int, default=0) |
|
parser.add_argument("-print_tile_iter", type=int, default=0) |
|
parser.add_argument("-print_tile", type=int, default=1) |
|
parser.add_argument("-roll_image", action='store_true') |
|
parser.add_argument("-jitter", type=int, default=0) |
|
params = parser.parse_args() |
|
|
|
|
|
Image.MAX_IMAGE_PIXELS = 1000000000 # Support gigapixel images |
|
|
|
|
|
def main(): |
|
dtype, multidevice, backward_device = setup_gpu() |
|
|
|
cnn, layerList = loadCaffemodel(params.model_file, params.pooling, params.gpu, params.disable_check) |
|
|
|
content_image = preprocess(params.content_image, params.image_size).type(dtype) |
|
style_image_input = params.style_image.split(',') |
|
style_image_list, ext = [], [".jpg", ".jpeg", ".png", ".tiff"] |
|
for image in style_image_input: |
|
if os.path.isdir(image): |
|
images = (image + "/" + file for file in os.listdir(image) |
|
if os.path.splitext(file)[1].lower() in ext) |
|
style_image_list.extend(images) |
|
else: |
|
style_image_list.append(image) |
|
style_images_caffe = [] |
|
for image in style_image_list: |
|
style_size = int(params.tile_size * params.style_scale) |
|
img_caffe = preprocess(image, style_size).type(dtype) |
|
style_images_caffe.append(img_caffe) |
|
|
|
if params.init_image != None: |
|
image_size = (content_image.size(2), content_image.size(3)) |
|
init_image = preprocess(params.init_image, image_size).type(dtype) |
|
|
|
# Handle style blending weights for multiple style inputs |
|
style_blend_weights = [] |
|
if params.style_blend_weights == None: |
|
# Style blending not specified, so use equal weighting |
|
for i in style_image_list: |
|
style_blend_weights.append(1.0) |
|
for i, blend_weights in enumerate(style_blend_weights): |
|
style_blend_weights[i] = int(style_blend_weights[i]) |
|
else: |
|
style_blend_weights = params.style_blend_weights.split(',') |
|
assert len(style_blend_weights) == len(style_image_list), \ |
|
"-style_blend_weights and -style_images must have the same number of elements!" |
|
|
|
# Normalize the style blending weights so they sum to 1 |
|
style_blend_sum = 0 |
|
for i, blend_weights in enumerate(style_blend_weights): |
|
style_blend_weights[i] = float(style_blend_weights[i]) |
|
style_blend_sum = float(style_blend_sum) + style_blend_weights[i] |
|
for i, blend_weights in enumerate(style_blend_weights): |
|
style_blend_weights[i] = float(style_blend_weights[i]) / float(style_blend_sum) |
|
|
|
content_layers = params.content_layers.split(',') |
|
style_layers = params.style_layers.split(',') |
|
|
|
# Set up the network, inserting style and content loss modules |
|
cnn = copy.deepcopy(cnn) |
|
content_losses, style_losses, tv_losses = [], [], [] |
|
next_content_idx, next_style_idx = 1, 1 |
|
net_base = nn.Sequential() |
|
c, r = 0, 0 |
|
|
|
if params.jitter > 0: |
|
jitter_mod = Jitter(params.jitter).type(dtype) |
|
net_base.add_module(str(len(net_base)), jitter_mod) |
|
if params.tv_weight > 0: |
|
tv_mod = TVLoss(params.tv_weight).type(dtype) |
|
net_base.add_module(str(len(net_base)), tv_mod) |
|
tv_losses.append(tv_mod) |
|
|
|
for i, layer in enumerate(list(cnn), 1): |
|
if next_content_idx <= len(content_layers) or next_style_idx <= len(style_layers): |
|
if isinstance(layer, nn.Conv2d): |
|
net_base.add_module(str(len(net_base)), layer) |
|
|
|
if layerList['C'][c] in content_layers: |
|
print("Setting up content layer " + str(i) + ": " + str(layerList['C'][c])) |
|
loss_module = ContentLoss(params.content_weight) |
|
net_base.add_module(str(len(net_base)), loss_module) |
|
content_losses.append(loss_module) |
|
|
|
if layerList['C'][c] in style_layers: |
|
print("Setting up style layer " + str(i) + ": " + str(layerList['C'][c])) |
|
loss_module = StyleLoss(params.style_weight) |
|
net_base.add_module(str(len(net_base)), loss_module) |
|
style_losses.append(loss_module) |
|
c+=1 |
|
|
|
if isinstance(layer, nn.ReLU): |
|
net_base.add_module(str(len(net_base)), layer) |
|
|
|
if layerList['R'][r] in content_layers: |
|
print("Setting up content layer " + str(i) + ": " + str(layerList['R'][r])) |
|
loss_module = ContentLoss(params.content_weight) |
|
net_base.add_module(str(len(net_base)), loss_module) |
|
content_losses.append(loss_module) |
|
next_content_idx += 1 |
|
|
|
if layerList['R'][r] in style_layers: |
|
print("Setting up style layer " + str(i) + ": " + str(layerList['R'][r])) |
|
loss_module = StyleLoss(params.style_weight) |
|
net_base.add_module(str(len(net_base)), loss_module) |
|
style_losses.append(loss_module) |
|
next_style_idx += 1 |
|
r+=1 |
|
|
|
if isinstance(layer, nn.MaxPool2d) or isinstance(layer, nn.AvgPool2d): |
|
net_base.add_module(str(len(net_base)), layer) |
|
|
|
if multidevice: |
|
net_base = setup_multi_device(net_base) |
|
|
|
print_torch(net_base, multidevice) |
|
|
|
|
|
if params.optimizer == 'lbfgs': |
|
print("Running optimization with L-BFGS") |
|
else: |
|
print("Running optimization with ADAM") |
|
|
|
if params.seed >= 0: |
|
torch.manual_seed(params.seed) |
|
torch.cuda.manual_seed_all(params.seed) |
|
torch.backends.cudnn.deterministic=True |
|
|
|
overlap_percent = params.overlap_percent / 100 if params.overlap_percent > 1 else params.overlap_percent |
|
if params.init_image != None: |
|
init_image_tiles = tile_image(init_image.clone(), params.tile_size, overlap_percent) |
|
output_tiles = [] |
|
total_content_losses, total_style_losses, total_loss = [], [], [0] |
|
content_tiles = tile_image(content_image.clone(), params.tile_size, overlap_percent) |
|
first_run = True |
|
h_roll, w_roll = 0, 0 |
|
|
|
_, _, tile_pattern, num_tiles = tile_image(content_image.clone(), params.tile_size, overlap_percent, True) |
|
print('\nCreated ' + str(num_tiles) + ' tiles') |
|
print('Tile pattern: ' + str(tile_pattern[0]) + 'x' + str(tile_pattern[1])) |
|
|
|
if params.tile_iter <= 0: |
|
sub_iter = int(1000 / 3)#int(params.num_iterations / 3) |
|
else: |
|
sub_iter = params.tile_iter |
|
|
|
for iter in range(1, params.num_iterations+1): |
|
for tile_num, c_tile in enumerate(content_tiles): |
|
net = copy.deepcopy(net_base) |
|
content_losses, style_losses, tv_losses = [], [], [] |
|
for i, layer in enumerate(net): |
|
if isinstance(layer, TVLoss): |
|
tv_losses.append(layer) |
|
elif isinstance(layer, ContentLoss): |
|
content_losses.append(layer) |
|
elif isinstance(layer, StyleLoss): |
|
style_losses.append(layer) |
|
|
|
maybe_print_tile(tile_num, num_tiles) |
|
|
|
# Capture content targets |
|
for i in content_losses: |
|
i.mode = 'capture' |
|
net(c_tile) |
|
|
|
# Capture style targets |
|
for i in content_losses: |
|
i.mode = 'None' |
|
|
|
for i, image in enumerate(style_images_caffe): |
|
for j in style_losses: |
|
j.mode = 'capture' |
|
j.blend_weight = style_blend_weights[i] |
|
net(style_images_caffe[i]) |
|
|
|
# Set all loss modules to loss mode |
|
for i in content_losses: |
|
i.mode = 'loss' |
|
for i in style_losses: |
|
i.mode = 'loss' |
|
|
|
# Maybe normalize content and style weights |
|
if params.normalize_weights: |
|
normalize_weights(content_losses, style_losses) |
|
|
|
# Freeze the network in order to prevent |
|
# unnecessary gradient calculations |
|
for param in net.parameters(): |
|
param.requires_grad = False |
|
|
|
# Initialize the image |
|
if params.init == 'random': |
|
B, C, H, W = c_tile.size() |
|
img = torch.randn(C, H, W).mul(0.001).unsqueeze(0).type(dtype) |
|
elif params.init == 'image': |
|
if params.init_image != None: |
|
img = init_image_tiles[tile_num].clone() |
|
else: |
|
img = c_tile.clone() |
|
if first_run == False: |
|
img = output_img_tiles[tile_num].clone() |
|
img = nn.Parameter(img) |
|
|
|
|
|
|
|
# Function to evaluate loss and gradient. We run the net forward and |
|
# backward to get the gradient, and sum up losses from the loss modules. |
|
# optim.lbfgs internally handles iteration and calls this function many |
|
# times, so we manually count the number of iterations to handle printing |
|
# and saving intermediate results. |
|
num_calls = [0] |
|
def feval(): |
|
num_calls[0] += 1 |
|
optimizer.zero_grad() |
|
net(img) |
|
loss = 0 |
|
|
|
for mod in content_losses: |
|
loss += mod.loss.to(backward_device) |
|
for mod in style_losses: |
|
loss += mod.loss.to(backward_device) |
|
if params.tv_weight > 0: |
|
for mod in tv_losses: |
|
loss += mod.loss.to(backward_device) |
|
|
|
total_loss[0] += loss.item() |
|
|
|
loss.backward() |
|
|
|
maybe_print_tile_iter(num_calls[0], len(output_tiles), sub_iter) |
|
|
|
return loss |
|
|
|
optimizer, loopVal = setup_optimizer(img) |
|
while num_calls[0] <= sub_iter: |
|
optimizer.step(feval) |
|
|
|
if len(output_tiles) == 0: |
|
for mod in content_losses: |
|
total_content_losses.append(mod.loss.item()) |
|
for mod in style_losses: |
|
total_style_losses.append(mod.loss.item()) |
|
else: |
|
for c_loss, mod in enumerate(content_losses): |
|
total_content_losses[c_loss] += mod.loss.item() |
|
for s_loss, mod in enumerate(style_losses): |
|
total_style_losses[s_loss] += mod.loss.item() |
|
|
|
output_tiles.append(img.clone()) |
|
|
|
|
|
if len(output_tiles) == len(content_tiles): |
|
first_run = False |
|
output_img = rebuild_image(output_tiles, content_image.clone(), params.tile_size, overlap_percent) |
|
output_tiles = [] |
|
if params.roll_image: |
|
output_img, _, _ = roll_tensor(output_img, -h_roll, -w_roll) |
|
|
|
maybe_print(iter, total_loss[0], total_content_losses, total_style_losses) |
|
maybe_save(iter, output_img) |
|
|
|
if params.roll_image: |
|
output_img, h_roll, w_roll = roll_tensor(output_img.clone()) |
|
output_img_tiles = tile_image(output_img.clone(), params.tile_size, overlap_percent) |
|
output_tiles = [] |
|
total_content_losses, total_style_losses, total_loss = [], [], [0] |
|
|
|
|
|
|
|
def maybe_save(t, save_img): |
|
should_save = params.save_iter > 0 and t % params.save_iter == 0 |
|
should_save = should_save or t == params.num_iterations |
|
if should_save: |
|
output_filename, file_extension = os.path.splitext(params.output_image) |
|
if t == params.num_iterations: |
|
filename = output_filename + str(file_extension) |
|
else: |
|
filename = str(output_filename) + "_" + str(t) + str(file_extension) |
|
disp = deprocess(save_img.clone()) |
|
|
|
# Maybe perform postprocessing for color-independent style transfer |
|
if params.original_colors == 1: |
|
disp = original_colors(deprocess(content_image.clone()), disp) |
|
|
|
disp.save(str(filename)) |
|
|
|
|
|
def maybe_print(t, loss, content_losses, style_losses): |
|
if params.print_iter > 0 and t % params.print_iter == 0: |
|
print("Iteration " + str(t) + " / "+ str(params.num_iterations)) |
|
for i, loss_module in enumerate(content_losses): |
|
print(" Content " + str(i+1) + " loss: " + str(loss_module)) |
|
for i, loss_module in enumerate(style_losses): |
|
print(" Style " + str(i+1) + " loss: " + str(loss_module)) |
|
print(" Total loss: " + str(loss)) |
|
|
|
|
|
def maybe_print_tile_iter(t, n, total): |
|
if params.print_tile_iter > 0 and t % params.print_tile_iter == 0: |
|
print("Tile "+str(n+1) +" iteration " + str(t) + " / "+ str(total)) |
|
|
|
|
|
def maybe_print_tile(tile_num, num_tiles): |
|
if params.print_tile > 0 and (tile_num + 1) % params.print_tile == 0: |
|
print('Processing tile: ' + str(tile_num+1) + ' of ' + str(num_tiles)) |
|
|
|
|
|
# Configure the optimizer |
|
def setup_optimizer(img): |
|
if params.optimizer == 'lbfgs': |
|
optim_state = { |
|
'max_iter': 1,#params.num_iterations, |
|
'tolerance_change': -1, |
|
'tolerance_grad': -1, |
|
} |
|
if params.lbfgs_num_correction != 100: |
|
optim_state['history_size'] = params.lbfgs_num_correction |
|
optimizer = optim.LBFGS([img], **optim_state) |
|
loopVal = 1 |
|
elif params.optimizer == 'adam': |
|
optimizer = optim.Adam([img], lr = params.learning_rate) |
|
loopVal = params.num_iterations - 1 |
|
return optimizer, loopVal |
|
|
|
|
|
def setup_gpu(): |
|
def setup_cuda(): |
|
if 'cudnn' in params.backend: |
|
torch.backends.cudnn.enabled = True |
|
if params.cudnn_autotune: |
|
torch.backends.cudnn.benchmark = True |
|
else: |
|
torch.backends.cudnn.enabled = False |
|
|
|
def setup_cpu(): |
|
if 'mkl' in params.backend and 'mkldnn' not in params.backend: |
|
torch.backends.mkl.enabled = True |
|
elif 'mkldnn' in params.backend: |
|
raise ValueError("MKL-DNN is not supported yet.") |
|
elif 'openmp' in params.backend: |
|
torch.backends.openmp.enabled = True |
|
|
|
multidevice = False |
|
if "," in str(params.gpu): |
|
devices = params.gpu.split(',') |
|
multidevice = True |
|
|
|
if 'c' in str(devices[0]).lower(): |
|
backward_device = "cpu" |
|
setup_cuda(), setup_cpu() |
|
else: |
|
backward_device = "cuda:" + devices[0] |
|
setup_cuda() |
|
dtype = torch.FloatTensor |
|
|
|
elif "c" not in str(params.gpu).lower(): |
|
setup_cuda() |
|
dtype, backward_device = torch.cuda.FloatTensor, "cuda:" + str(params.gpu) |
|
else: |
|
setup_cpu() |
|
dtype, backward_device = torch.FloatTensor, "cpu" |
|
return dtype, multidevice, backward_device |
|
|
|
|
|
def setup_multi_device(net): |
|
assert len(params.gpu.split(',')) - 1 == len(params.multidevice_strategy.split(',')), \ |
|
"The number of -multidevice_strategy layer indices minus 1, must be equal to the number of -gpu devices." |
|
|
|
new_net = ModelParallel(net, params.gpu, params.multidevice_strategy) |
|
return new_net |
|
|
|
|
|
# Preprocess an image before passing it to a model. |
|
# We need to rescale from [0, 1] to [0, 255], convert from RGB to BGR, |
|
# and subtract the mean pixel. |
|
def preprocess(image_name, image_size): |
|
image = Image.open(image_name).convert('RGB') |
|
if type(image_size) is not tuple: |
|
image_size = tuple([int((float(image_size) / max(image.size))*x) for x in (image.height, image.width)]) |
|
Loader = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor()]) |
|
rgb2bgr = transforms.Compose([transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])])]) |
|
Normalize = transforms.Compose([transforms.Normalize(mean=[103.939, 116.779, 123.68], std=[1,1,1])]) |
|
tensor = Normalize(rgb2bgr(Loader(image) * 256)).unsqueeze(0) |
|
return tensor |
|
|
|
|
|
# Undo the above preprocessing. |
|
def deprocess(output_tensor): |
|
Normalize = transforms.Compose([transforms.Normalize(mean=[-103.939, -116.779, -123.68], std=[1,1,1])]) |
|
bgr2rgb = transforms.Compose([transforms.Lambda(lambda x: x[torch.LongTensor([2,1,0])])]) |
|
output_tensor = bgr2rgb(Normalize(output_tensor.squeeze(0).cpu())) / 256 |
|
output_tensor.clamp_(0, 1) |
|
Image2PIL = transforms.ToPILImage() |
|
image = Image2PIL(output_tensor.cpu()) |
|
return image |
|
|
|
|
|
# Combine the Y channel of the generated image and the UV/CbCr channels of the |
|
# content image to perform color-independent style transfer. |
|
def original_colors(content, generated): |
|
content_channels = list(content.convert('YCbCr').split()) |
|
generated_channels = list(generated.convert('YCbCr').split()) |
|
content_channels[0] = generated_channels[0] |
|
return Image.merge('YCbCr', content_channels).convert('RGB') |
|
|
|
|
|
# Print like Lua/Torch7 |
|
def print_torch(net, multidevice): |
|
if multidevice: |
|
return |
|
simplelist = "" |
|
for i, layer in enumerate(net, 1): |
|
simplelist = simplelist + "(" + str(i) + ") -> " |
|
print("nn.Sequential ( \n [input -> " + simplelist + "output]") |
|
|
|
def strip(x): |
|
return str(x).replace(", ",',').replace("(",'').replace(")",'') + ", " |
|
def n(): |
|
return " (" + str(i) + "): " + "nn." + str(l).split("(", 1)[0] |
|
|
|
for i, l in enumerate(net, 1): |
|
if "2d" in str(l): |
|
ks, st, pd = strip(l.kernel_size), strip(l.stride), strip(l.padding) |
|
if "Conv2d" in str(l): |
|
ch = str(l.in_channels) + " -> " + str(l.out_channels) |
|
print(n() + "(" + ch + ", " + (ks).replace(",",'x', 1) + st + pd.replace(", ",')')) |
|
elif "Pool2d" in str(l): |
|
st = st.replace(" ",' ') + st.replace(", ",')') |
|
print(n() + "(" + ((ks).replace(",",'x' + ks, 1) + st).replace(", ",',')) |
|
else: |
|
print(n()) |
|
print(")") |
|
|
|
|
|
# Divide weights by channel size |
|
def normalize_weights(content_losses, style_losses): |
|
for n, i in enumerate(content_losses): |
|
i.strength = i.strength / max(i.target.size()) |
|
for n, i in enumerate(style_losses): |
|
i.strength = i.strength / max(i.target.size()) |
|
|
|
|
|
# Define an nn Module to compute content loss |
|
class ContentLoss(nn.Module): |
|
|
|
def __init__(self, strength): |
|
super(ContentLoss, self).__init__() |
|
self.strength = strength |
|
self.crit = nn.MSELoss() |
|
self.mode = 'None' |
|
|
|
def forward(self, input): |
|
if self.mode == 'loss': |
|
self.loss = self.crit(input, self.target) * self.strength |
|
elif self.mode == 'capture': |
|
self.target = input.detach() |
|
return input |
|
|
|
|
|
class GramMatrix(nn.Module): |
|
|
|
def forward(self, input): |
|
B, C, H, W = input.size() |
|
x_flat = input.view(C, H * W) |
|
return torch.mm(x_flat, x_flat.t()) |
|
|
|
|
|
# Define an nn Module to compute style loss |
|
class StyleLoss(nn.Module): |
|
|
|
def __init__(self, strength): |
|
super(StyleLoss, self).__init__() |
|
self.target = torch.Tensor() |
|
self.strength = strength |
|
self.gram = GramMatrix() |
|
self.crit = nn.MSELoss() |
|
self.mode = 'None' |
|
self.blend_weight = None |
|
|
|
def forward(self, input): |
|
self.G = self.gram(input) |
|
self.G = self.G.div(input.nelement()) |
|
if self.mode == 'capture': |
|
if self.blend_weight == None: |
|
self.target = self.G.detach() |
|
elif self.target.nelement() == 0: |
|
self.target = self.G.detach().mul(self.blend_weight) |
|
else: |
|
self.target = self.target.add(self.blend_weight, self.G.detach()) |
|
elif self.mode == 'loss': |
|
self.loss = self.strength * self.crit(self.G, self.target) |
|
return input |
|
|
|
|
|
class TVLoss(nn.Module): |
|
|
|
def __init__(self, strength): |
|
super(TVLoss, self).__init__() |
|
self.strength = strength |
|
|
|
def forward(self, input): |
|
self.x_diff = input[:,:,1:,:] - input[:,:,:-1,:] |
|
self.y_diff = input[:,:,:,1:] - input[:,:,:,:-1] |
|
self.loss = self.strength * (torch.sum(torch.abs(self.x_diff)) + torch.sum(torch.abs(self.y_diff))) |
|
return input |
|
|
|
|
|
# Shift tensor, possibly randomly. |
|
def roll_tensor(tensor, h_shift=None, w_shift=None): |
|
if h_shift == None: |
|
h_shift = torch.LongTensor(10).random_(-tensor.size(1), tensor.size(1))[0].item() |
|
if w_shift == None: |
|
w_shift = torch.LongTensor(10).random_(-tensor.size(2), tensor.size(2))[0].item() |
|
tensor = torch.roll(torch.roll(tensor, shifts=h_shift, dims=2), shifts=w_shift, dims=3) |
|
return tensor, h_shift, w_shift |
|
|
|
|
|
# Apply blend masks to tiles |
|
def mask_tile(tile, overlap, side='bottom'): |
|
h, w = tile.size(2), tile.size(3) |
|
top_overlap, bottom_overlap, right_overlap, left_overlap = overlap[0], overlap[1], overlap[2], overlap[3] |
|
if tile.is_cuda: |
|
if 'left' in side and 'left-special' not in side: |
|
lin_mask_left = torch.linspace(0,1,left_overlap, device=tile.get_device()).repeat(h,1).repeat(3,1,1).unsqueeze(0) |
|
if 'right' in side and 'right-special' not in side: |
|
lin_mask_right = torch.linspace(1,0,right_overlap, device=tile.get_device()).repeat(h,1).repeat(3,1,1).unsqueeze(0) |
|
if 'top' in side and 'top-special' not in side: |
|
lin_mask_top = torch.linspace(0,1,top_overlap, device=tile.get_device()).repeat(w,1).rot90(3).repeat(3,1,1).unsqueeze(0) |
|
if 'bottom' in side and 'bottom-special' not in side: |
|
lin_mask_bottom = torch.linspace(1,0,bottom_overlap, device=tile.get_device()).repeat(w,1).rot90(3).repeat(3,1,1).unsqueeze(0) |
|
|
|
if 'left-special' in side: |
|
lin_mask_left = torch.linspace(0,1,left_overlap, device=tile.get_device()) |
|
zeros_mask = torch.zeros(w-(left_overlap*2), device=tile.get_device()) |
|
ones_mask = torch.ones(left_overlap, device=tile.get_device()) |
|
lin_mask_left = torch.cat([zeros_mask, lin_mask_left, ones_mask], 0).repeat(h,1).repeat(3,1,1).unsqueeze(0) |
|
if 'right-special' in side: |
|
lin_mask_right = torch.linspace(1,0,right_overlap, device=tile.get_device()) |
|
ones_mask = torch.ones(w-right_overlap, device=tile.get_device()) |
|
lin_mask_right = torch.cat([ones_mask, lin_mask_right], 0).repeat(h,1).repeat(3,1,1).unsqueeze(0) |
|
if 'top-special' in side: |
|
lin_mask_top = torch.linspace(0,1,top_overlap, device=tile.get_device()) |
|
zeros_mask = torch.zeros(h-(top_overlap*2), device=tile.get_device()) |
|
ones_mask = torch.ones(top_overlap, device=tile.get_device()) |
|
lin_mask_top = torch.cat([zeros_mask, lin_mask_top, ones_mask], 0).repeat(w,1).rot90(3).repeat(3,1,1).unsqueeze(0) |
|
if 'bottom-special' in side: |
|
lin_mask_bottom = torch.linspace(1,0,bottom_overlap, device=tile.get_device()) |
|
ones_mask = torch.ones(h-bottom_overlap, device=tile.get_device()) |
|
lin_mask_bottom = torch.cat([ones_mask, lin_mask_bottom], 0).repeat(w,1).rot90(3).repeat(3,1,1).unsqueeze(0) |
|
else: |
|
if 'left' in side and 'left-special' not in side: |
|
lin_mask_left = torch.linspace(0,1,left_overlap).repeat(h,1).repeat(3,1,1).unsqueeze(0) |
|
if 'right' in side and 'right-special' not in side: |
|
lin_mask_right = torch.linspace(1,0,right_overlap).repeat(h,1).repeat(3,1,1).unsqueeze(0) |
|
if 'top' in side and 'top-special' not in side: |
|
lin_mask_top = torch.linspace(0,1,top_overlap).repeat(w,1).rot90(3).repeat(3,1,1).unsqueeze(0) |
|
if 'bottom' in side and 'bottom-special' not in side: |
|
lin_mask_bottom = torch.linspace(1,0,bottom_overlap).repeat(w,1).rot90(3).repeat(3,1,1).unsqueeze(0) |
|
|
|
if 'left-special' in side: |
|
lin_mask_left = torch.linspace(0,1,left_overlap) |
|
zeros_mask = torch.zeros(w-(left_overlap*2)) |
|
ones_mask = torch.ones(left_overlap) |
|
lin_mask_left = torch.cat([zeros_mask, lin_mask_left, ones_mask], 0).repeat(h,1).repeat(3,1,1).unsqueeze(0) |
|
if 'right-special' in side: |
|
lin_mask_right = torch.linspace(1,0,right_overlap) |
|
ones_mask = torch.ones(w-right_overlap) |
|
lin_mask_right = torch.cat([ones_mask, lin_mask_right], 0).repeat(h,1).repeat(3,1,1).unsqueeze(0) |
|
if 'top-special' in side: |
|
lin_mask_top = torch.linspace(0,1,top_overlap) |
|
zeros_mask = torch.zeros(h-(top_overlap*2)) |
|
ones_mask = torch.ones(top_overlap) |
|
lin_mask_top = torch.cat([zeros_mask, lin_mask_top, ones_mask], 0).repeat(w,1).rot90(3).repeat(3,1,1).unsqueeze(0) |
|
if 'bottom-special' in side: |
|
lin_mask_bottom = torch.linspace(1,0,bottom_overlap) |
|
ones_mask = torch.ones(h-bottom_overlap) |
|
lin_mask_bottom = torch.cat([ones_mask, lin_mask_bottom], 0).repeat(w,1).rot90(3).repeat(3,1,1).unsqueeze(0) |
|
|
|
base_mask = torch.ones_like(tile) |
|
|
|
if 'right' in side and 'right-special' not in side: |
|
base_mask[:,:,:,w-right_overlap:] = base_mask[:,:,:,w-right_overlap:] * lin_mask_right |
|
if 'left' in side and 'left-special' not in side: |
|
base_mask[:,:,:,:left_overlap] = base_mask[:,:,:,:left_overlap] * lin_mask_left |
|
if 'bottom' in side and 'bottom-special' not in side: |
|
base_mask[:,:,h-bottom_overlap:,:] = base_mask[:,:,h-bottom_overlap:,:] * lin_mask_bottom |
|
if 'top' in side and 'top-special' not in side: |
|
base_mask[:,:,:top_overlap,:] = base_mask[:,:,:top_overlap,:] * lin_mask_top |
|
|
|
if 'right-special' in side: |
|
base_mask = base_mask * lin_mask_right |
|
if 'left-special' in side: |
|
base_mask = base_mask * lin_mask_left |
|
if 'bottom-special' in side: |
|
base_mask = base_mask * lin_mask_bottom |
|
if 'top-special' in side: |
|
base_mask = base_mask * lin_mask_top |
|
return tile * base_mask |
|
|
|
|
|
def get_tile_coords(d, tile_dim, overlap=0): |
|
overlap = int(tile_dim * (1-overlap)) |
|
c, tile_start, coords = 1, 0, [0] |
|
while tile_start + tile_dim < d: |
|
tile_start = overlap * c |
|
if tile_start + tile_dim >= d: |
|
coords.append(d - tile_dim) |
|
else: |
|
coords.append(tile_start) |
|
c += 1 |
|
return coords, overlap |
|
|
|
|
|
def get_tiles(img, tile_coords, tile_size, info_only=False): |
|
tile_list = [] |
|
for y in tile_coords[0]: |
|
for x in tile_coords[1]: |
|
tile = img[:, :, y:y+tile_size[0], x:x+tile_size[1]] |
|
tile_list.append(tile) |
|
if not info_only: |
|
return tile_list |
|
else: |
|
return tile_list[0].size(2), tile_list[0].size(3) |
|
|
|
|
|
def final_overlap(tile_coords): |
|
r, c = len(tile_coords[0]), len(tile_coords[1]) |
|
return (tile_coords[0][r-1] - tile_coords[0][r-2], tile_coords[1][c-1] - tile_coords[1][c-2]) |
|
|
|
|
|
def add_tiles(tiles, base_img, tile_coords, tile_size, overlap): |
|
f_ovlp = final_overlap(tile_coords) |
|
h, w = tiles[0].size(2), tiles[0].size(3) |
|
t=0 |
|
column, row, = 0, 0 |
|
for y in tile_coords[0]: |
|
for x in tile_coords[1]: |
|
mask_sides='' |
|
c_overlap = overlap.copy() |
|
if len(tile_coords[0]) > 1: |
|
if row == 0: |
|
if row == len(tile_coords[0]) - 2: |
|
mask_sides += 'bottom-special' |
|
c_overlap[1] = f_ovlp[0] # Change bottom overlap |
|
else: |
|
mask_sides += 'bottom' |
|
elif row > 0 and row < len(tile_coords[0]) -2: |
|
mask_sides += 'bottom,top' |
|
elif row == len(tile_coords[0]) - 2: |
|
if f_ovlp[0] > 0: |
|
mask_sides += 'bottom-special,top' |
|
c_overlap[1] = f_ovlp[0] # Change bottom overlap |
|
elif f_ovlp[0] <= 0: |
|
mask_sides += 'bottom,top' |
|
elif row == len(tile_coords[0]) -1: |
|
if f_ovlp[0] > 0: |
|
mask_sides += 'top-special' |
|
c_overlap[0] = f_ovlp[0] # Change top overlap |
|
elif f_ovlp[0] <= 0: |
|
mask_sides += 'top' |
|
|
|
if len(tile_coords[1]) > 1: |
|
if column == 0: |
|
if column == len(tile_coords[1]) -2: |
|
mask_sides += ',right-special' |
|
c_overlap[2] = f_ovlp[1] # Change right overlap |
|
else: |
|
mask_sides += ',right' |
|
elif column > 0 and column < len(tile_coords[1]) -2: |
|
mask_sides += ',right,left' |
|
elif column == len(tile_coords[1]) -2: |
|
if f_ovlp[1] > 0: |
|
mask_sides += ',right-special,left' |
|
c_overlap[2] = f_ovlp[1] # Change right overlap |
|
elif f_ovlp[1] <= 0: |
|
mask_sides += ',right,left' |
|
elif column == len(tile_coords[1]) -1: |
|
if f_ovlp[1] > 0: |
|
mask_sides += ',left-special' |
|
c_overlap[3] = f_ovlp[1] # Change left overlap |
|
elif f_ovlp[1] <= 0: |
|
mask_sides += ',left' |
|
|
|
if t < len(tiles): |
|
tile = mask_tile(tiles[t], c_overlap, side=mask_sides) |
|
base_img[:, :, y:y+tile_size[0], x:x+tile_size[1]] = base_img[:, :, y:y+tile_size[0], x:x+tile_size[1]] + tile |
|
t+=1 |
|
column+=1 |
|
row+=1 |
|
column=0 |
|
return base_img |
|
|
|
|
|
def tile_setup(tile_size, overlap_percent, base_size): |
|
if type(tile_size) is not tuple and type(tile_size) is not list: |
|
tile_size = (tile_size, tile_size) |
|
if type(overlap_percent) is not tuple and type(overlap_percent) is not list: |
|
overlap_percent = (overlap_percent, overlap_percent) |
|
x_coords, x_ovlp = get_tile_coords(base_size[1], tile_size[1], overlap_percent[1]) |
|
y_coords, y_ovlp = get_tile_coords(base_size[0], tile_size[0], overlap_percent[0]) |
|
return (y_coords, x_coords), tile_size, [y_ovlp, y_ovlp, x_ovlp, x_ovlp] |
|
|
|
|
|
def tile_image(img, tile_size, overlap_percent, info_only=False): |
|
tile_coords, tile_size, _ = tile_setup(tile_size, overlap_percent, (img.size(2), img.size(3))) |
|
if not info_only: |
|
return get_tiles(img, tile_coords, tile_size) |
|
else: |
|
tile_size = get_tiles(img, tile_coords, tile_size, info_only) |
|
return tile_size[0], tile_size[1], (len(tile_coords[0]), len(tile_coords[1])), (len(tile_coords[0]) * len(tile_coords[1])) |
|
|
|
|
|
def rebuild_image(tiles, base_img, tile_size, overlap_percent): |
|
base_img = torch.zeros_like(base_img) |
|
tile_coords, tile_size, overlap = tile_setup(tile_size, overlap_percent, (base_img.size(2), base_img.size(3))) |
|
return add_tiles(tiles, base_img, tile_coords, tile_size, overlap) |
|
|
|
|
|
# Define an nn Module to apply jitter |
|
class Jitter(torch.nn.Module): |
|
|
|
def __init__(self, jitter_val): |
|
super(Jitter, self).__init__() |
|
self.jitter_val = jitter_val |
|
|
|
def roll_tensor(self, input): |
|
h_shift = random.randint(-self.jitter_val, self.jitter_val) |
|
w_shift = random.randint(-self.jitter_val, self.jitter_val) |
|
return torch.roll(torch.roll(input, shifts=h_shift, dims=2), shifts=w_shift, dims=3) |
|
|
|
def forward(self, input): |
|
return self.roll_tensor(input) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
Normal style transfer up to here:

Then 2048 image size with 1024 tile size and 2500 style weight.