Last active
June 22, 2018 07:40
-
-
Save ronekko/5ff82504d2ffb770a1b36936cd732d9f to your computer and use it in GitHub Desktop.
Rotation of images in chainer
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
""" | |
Created on Fri Jun 22 15:40:40 2018 | |
@author: sakurai | |
""" | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import chainer | |
import chainer.functions as F | |
from chainer import cuda, Variable | |
def deg2rad(deg): | |
xp = cuda.get_array_module(deg) | |
return (xp.pi / 180.0) * deg | |
def rotate_image(x, angle_radian): | |
""" | |
Args: | |
x (Variable): | |
A batch of images of shape (B, C, H, W). | |
angle_radian (Variable): | |
A batch of rotation angles of shape (B,). | |
Returns: | |
A (B, C, H, W) shaped variable of the rotated images. | |
""" | |
xp = cuda.get_array_module(x) | |
batch_size, _, height, width = x.shape | |
# Create rotated points of image coordinates | |
cos = F.cos(angle_radian) | |
sin = F.sin(angle_radian) | |
zero = xp.zeros(batch_size, dtype=np.float32) | |
theta0 = F.stack((cos, -sin, zero), 1) | |
theta1 = F.stack((sin, cos, zero), 1) | |
theta = F.stack((theta0, theta1), 1) # (B, 2, 3), batch of (2, 3) matrices | |
grid = F.spatial_transformer_grid(theta, (height, width)) | |
# # The above code means like below | |
# # Create grid points of image coordinates | |
# x = np.linspace(-1, 1, width) | |
# y = np.linspace(-1, 1, height) | |
# x, y = np.meshgrid(x, y) | |
# grid = np.stack((x, y), 0) | |
# grid = np.repeat(grid[None], batch_size, 0) # (2, W, H) -> (B, 2, W, H) | |
# # Rotate each `grid` by each `theta` as rotation matrix | |
# Create rotated images | |
rotated_image = F.spatial_transformer_sampler(x, grid) | |
return rotated_image | |
if __name__ == '__main__': | |
angle_degree = [0, 30, 45, 60, 120] | |
use_gpu = False | |
batch_size = len(angle_degree) | |
xp = np if not use_gpu else cuda.cupy | |
device = -1 if not use_gpu else 0 | |
ds, _ = chainer.datasets.get_mnist(ndim=3) | |
image, label = chainer.dataset.concat_examples(ds[:batch_size], device) | |
# rotate image | |
angle_degree = Variable(xp.asarray(angle_degree, dtype=np.float32)) | |
angle_radian = deg2rad(angle_degree) | |
rotated_image = rotate_image(image, angle_radian) | |
for deg, img, img2 in zip(cuda.to_cpu(angle_degree.array), | |
cuda.to_cpu(image), | |
cuda.to_cpu(rotated_image.array)): | |
plt.subplot(1, 2, 1) | |
plt.matshow(img[0], cmap=plt.cm.gray, fignum=0) | |
plt.axis('off') | |
plt.title('Original'.format(deg)) | |
plt.subplot(1, 2, 2) | |
plt.matshow(img2[0], cmap=plt.cm.gray, fignum=0) | |
plt.axis('off') | |
plt.title('Rotated ({} [deg])'.format(deg)) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment