Created
May 13, 2019 03:23
-
-
Save Z-Zheng/3ad65a26f5bac95b83f0f261d89af6db to your computer and use it in GitHub Desktop.
dataset class and transform class for lxy
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch.utils.data import Dataset | |
import glob | |
import os | |
from skimage.io import imread | |
from PIL import Image | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from simplecv.util import tensor_util | |
from simplecv.interface import CVModule | |
from simplecv.data import preprocess | |
import torch | |
import torch.nn.functional as F | |
pallete = [ | |
220, 20, 60, | |
128, 64, 128, | |
70, 70, 70, | |
102, 102, 156, | |
190, 153, 153, | |
153, 153, 153, | |
250, 170, 30, | |
220, 220, 0, | |
107, 142, 35, | |
0, 130, 180, | |
255, 0, 0, | |
0, 0, 142, | |
0, 0, 70, | |
0, 60, 100, | |
244, 35, 232, | |
0, 80, 100, | |
119, 11, 32, | |
0, 0, 230, | |
152, 251, 152, | |
] | |
def get_color_pallete(npimg): | |
"""Visualize image. | |
Parameters | |
---------- | |
npimg : numpy.ndarray | |
Single channel image with shape `H, W, 1`. | |
Returns | |
------- | |
out_img : PIL.Image | |
Image with color pallete | |
""" | |
# recovery boundary | |
# if dataset in ('pascal_voc', 'pascal_aug'): | |
# npimg[npimg == -1] = 255 | |
# put colormap | |
out_img = Image.fromarray(npimg.astype('uint8')) | |
out_img.putpalette(pallete) | |
return out_img | |
def plot_mask(img, masks, alpha=0.5): | |
"""Visualize segmentation mask. | |
Parameters | |
---------- | |
img : numpy.ndarray | |
Image with shape `H, W, 3`. | |
masks : numpy.ndarray | |
Binary images with shape `N, H, W`. | |
alpha : float, optional, default 0.5 | |
Transparency of plotted mask | |
Returns | |
------- | |
numpy.ndarray | |
The image plotted with segmentation masks | |
""" | |
rs = np.random.RandomState(567) | |
for mask in masks: | |
color = rs.random_sample(3) * 255 | |
mask = np.repeat((mask > 0)[:, :, np.newaxis], repeats=3, axis=2) | |
img = np.where(mask, img * (1 - alpha) + color * alpha, img) | |
return img.astype('uint8') | |
class DeepglobeRoad(Dataset): | |
def __init__(self, root, training=True, transforms=None): | |
self.root = root | |
self.transforms = transforms | |
self.training = training | |
self.im_path_list = glob.glob(os.path.join(root, '*_sat.jpg')) | |
self.mask_path_list = [im_path.replace('sat.jpg', 'mask.png') for im_path in self.im_path_list] | |
def __getitem__(self, idx): | |
image_np = imread(self.im_path_list[idx]) | |
mask_np = imread(self.mask_path_list[idx]) | |
if self.transforms is not None: | |
image_tensor, mask_tensor = self.transforms(image_np, mask_np) | |
else: | |
image_tensor, mask_tensor = tensor_util.to_tensor([image_np, mask_np]) | |
return dict(rgb=image_tensor, | |
image_filename=os.path.basename(self.im_path_list[idx])), dict(cls=mask_tensor) | |
def __len__(self): | |
return len(self.im_path_list) | |
def show_image_mask(self, index, with_mask=True): | |
x, y = self[index] | |
# denormalize | |
_mean = torch.tensor(self.transforms.config.mean_std_normalize.mean).reshape(3, 1, 1) | |
_std = torch.tensor(self.transforms.config.mean_std_normalize.std).reshape(3, 1, 1) | |
# to uint8 | |
x['rgb'] = x['rgb'].mul_(_std).add_(_mean).byte() | |
# to np | |
image_np = x['rgb'].permute((1, 2, 0)).numpy() | |
if with_mask: | |
mask_np = y['cls'].numpy() | |
color_mask = np.asarray(get_color_pallete(mask_np)) | |
vis_image = plot_mask(image_np, color_mask.reshape([1] + list(color_mask.shape))) | |
else: | |
vis_image = image_np | |
plt.imshow(vis_image) | |
class DeepglobeRoadTransform(CVModule): | |
def __init__(self, config): | |
super(DeepglobeRoadTransform, self).__init__(config) | |
pass | |
def forward(self, images, masks): | |
""" | |
Args: | |
images: 3-D array of shape [height, width, channel] | |
masks: 2-D array of shape [height, width] | |
Returns: | |
images_tensor: 3-D float32 tensor of shape [channel, height, width] | |
masks_tensor: 2-D int64 tensor of shape [height, width] | |
""" | |
assert images.ndim == 3 | |
assert masks.ndim == 2 | |
images = images.astype(np.float32) | |
masks = masks.astype(np.float32) | |
images_tensor, masks_tensor = tensor_util.to_tensor([images, masks]) | |
images_tensor = preprocess.mean_std_normalize(images_tensor, | |
self.config.mean_std_normalize.mean, | |
self.config.mean_std_normalize.std) | |
if self.config.training: | |
for trans_op in self.config.transforms: | |
images_tensor, masks_tensor = trans_op(images_tensor, masks_tensor) | |
images_tensor = images_tensor.permute((2, 0, 1)) | |
masks_tensor = masks_tensor.long() | |
return images_tensor, masks_tensor | |
def set_defalut_config(self): | |
self.config.update(dict( | |
training=True, | |
mean_std_normalize=dict( | |
mean=(123.675, 116.28, 103.53), | |
std=(58.395, 57.12, 57.375) | |
), | |
transforms=list([ | |
THRandomRotate90k(p=0.5), | |
THRandomHorizontalFlip(p=0.5), | |
THRandomVerticalFlip(p=0.5), | |
THRandomScale(), | |
THRandomCrop((512, 512)) | |
]) | |
)) | |
################################################## works for lxy ############################################# | |
# example for lxy | |
class THRandomRotate90k(object): | |
def __init__(self, p=0.5, k=None): | |
self.p = p | |
self.k = k | |
def __call__(self, images, masks=None): | |
""" Rotate 90 * k degree for image and mask | |
Args: | |
images: 3-D tensor of shape [height, width, channel] | |
masks: 2-D tensor of shape [height, width] | |
Returns: | |
images_tensor | |
masks_tensor | |
""" | |
k = int(np.random.choice([1, 2, 3], 1)[0]) if self.k is None else self.k | |
ret = list() | |
images_tensor = torch.rot90(images, k, [0, 1]) | |
ret.append(images_tensor) | |
if masks is not None: | |
masks_tensor = torch.rot90(masks, k, [0, 1]) | |
ret.append(masks_tensor) | |
return ret if len(ret) > 1 else ret[0] | |
class THRandomHorizontalFlip(object): | |
def __init__(self, p=0.5): | |
self.p = p | |
def __call__(self, images, masks=None): | |
""" | |
Args: | |
images: 3-D tensor of shape [height, width, channel] | |
masks: 2-D tensor of shape [height, width] | |
Returns: | |
images_tensor | |
masks_tensor | |
""" | |
ret = list() | |
if self.p < np.random.uniform(): | |
ret.append(images) | |
if masks is not None: | |
ret.append(masks) | |
return ret if len(ret) > 1 else ret[0] | |
images_tensor = torch.flip(images, [1]) | |
ret.append(images_tensor) | |
if masks is not None: | |
masks_tensor = torch.flip(masks, [1]) | |
ret.append(masks_tensor) | |
return ret if len(ret) > 1 else ret[0] | |
class THRandomVerticalFlip(object): | |
def __init__(self, p=0.5): | |
self.p = p | |
def __call__(self, images, masks=None): | |
""" | |
Args: | |
images: 3-D tensor of shape [height, width, channel] | |
masks: 2-D tensor of shape [height, width] | |
Returns: | |
images_tensor | |
masks_tensor | |
""" | |
ret = list() | |
if self.p < np.random.uniform(): | |
ret.append(images) | |
if masks is not None: | |
ret.append(masks) | |
return ret if len(ret) > 1 else ret[0] | |
images_tensor = torch.flip(images, [0]) | |
ret.append(images_tensor) | |
if masks is not None: | |
masks_tensor = torch.flip(masks, [0]) | |
ret.append(masks_tensor) | |
return ret if len(ret) > 1 else ret[0] | |
class THRandomCrop(object): | |
def __init__(self, crop_size=(512, 512)): | |
self.crop_size = crop_size | |
def __call__(self, images, masks=None): | |
""" | |
Args: | |
images: 3-D tensor of shape [height, width, channel] | |
masks: 2-D tensor of shape [height, width] | |
Returns: | |
images_tensor | |
masks_tensor | |
""" | |
im_h, im_w, _ = images.shape | |
c_h, c_w = self.crop_size | |
pad_h = c_h - im_h | |
pad_w = c_w - im_w | |
if pad_h > 0 or pad_w > 0: | |
images = F.pad(images, [0, 0, 0, max(pad_w, 0), 0, max(pad_h, 0)], mode='constant', value=0) | |
masks = F.pad(masks, [0, max(pad_w, 0), 0, max(pad_h, 0)], mode='constant', value=0) | |
im_h, im_w, _ = images.shape | |
y_lim = im_h - c_h + 1 | |
x_lim = im_w - c_w + 1 | |
ymin = int(np.random.randint(0, y_lim, 1)) | |
xmin = int(np.random.randint(0, x_lim, 1)) | |
xmax = xmin + c_w | |
ymax = ymin + c_h | |
ret = list() | |
images_tensor = images[ymin:ymax, xmin:xmax, :] | |
ret.append(images_tensor) | |
if masks is not None: | |
masks_tensor = masks[ymin:ymax, xmin:xmax] | |
ret.append(masks_tensor) | |
return ret | |
class THRandomScale(object): | |
def __init__(self, scale_range=(0.5, 2.0), scale_step=0.25): | |
scale_factors = np.linspace(scale_range[0], scale_range[1], | |
int((scale_range[1] - scale_range[0]) / scale_step) + 1) | |
self.scale_factor = np.random.choice(scale_factors, size=1)[0] | |
def __call__(self, images, masks=None): | |
""" | |
Args: | |
images: 3-D tensor of shape [height, width, channel] | |
masks: 2-D tensor of shape [height, width] | |
Returns: | |
images_tensor | |
masks_tensor | |
""" | |
ret = list() | |
_images = images.permute(2, 0, 1)[None, :, :, :] | |
images_tensor = F.interpolate(_images, scale_factor=self.scale_factor, mode='bilinear', align_corners=True) | |
images_tensor = images_tensor[0].permute(1, 2, 0) | |
ret.append(images_tensor) | |
if masks is not None: | |
masks_tensor = F.interpolate(masks[None, None, :, :], scale_factor=self.scale_factor, mode='nearest')[0][0] | |
ret.append(masks_tensor) | |
return ret | |
if __name__ == '__main__': | |
root = r'D:\deepglobe\road_data\train_0_1' | |
dr = DeepglobeRoad(root, transforms=DeepglobeRoadTransform({})) | |
dr.show_image_mask(0) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment