Created
January 15, 2022 16:11
-
-
Save AdityaKane2001/c9fb4058a5ebaa8aee2c8765ec41f297 to your computer and use it in GitHub Desktop.
Training RegNets gists
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def random_flip(self, image: tf.Tensor, target: tf.Tensor) -> tuple: | |
""" | |
Returns randomly flipped batch of images. Only horizontal flip | |
is available | |
Args: | |
image: Batch of images to perform random rotation on. | |
target: Target tensor. | |
Returns: | |
Augmented example with batch of images and targets with same dimensions. | |
""" | |
aug_images = tf.image.random_flip_left_right(image) | |
return aug_images, target |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def _get_random_dims(self, area, height, width, max_iter=10): | |
""" | |
Working logic: | |
1. Initialize values, start for loop and generate required random values. | |
2. If h_crop and w_crop (i.e. the generated values) are lesser than image dimensions, | |
then generate random x and y values. | |
3. Cache these x and y values if they are useful (BOTH MUST BE USEFUL) | |
4. If cached values exist, return them, else return bad values (atleast one of x and y will contain -1). | |
Run cases: | |
1. Cached values can be filled multiple times. Any time, they will be useful. | |
2. If there were no (or partial) cached values after 10 iterations, we can safely apply validation crop | |
Pros: | |
1. The graph is constant, since we are not using break statement. | |
2. From augmentation POV, the function remains constant. | |
Cons: | |
1. We run the function 10 times. However note that even if we encounter multiple valid values, | |
all of them are valid. Thus this maintains corretness of the function. | |
""" | |
w_crop = tf.cast(-1, tf.int32) | |
h_crop = tf.cast(-1, tf.int32) | |
x_cache = -1 | |
y_cache = -1 | |
for _ in tf.range(max_iter): | |
target_area = tf.random.uniform( | |
(), minval=self.area_factor, maxval=1) * area | |
aspect_ratio = tf.random.uniform((), minval=3./4., maxval=4./3.) | |
w_crop = tf.cast( | |
tf.math.round( | |
tf.math.sqrt( | |
target_area * aspect_ratio | |
)), tf.int32) | |
h_crop = tf.cast( | |
tf.math.round( | |
tf.math.sqrt( | |
target_area / aspect_ratio | |
)), tf.int32) | |
prob = tf.random.uniform((), minval=0.0, maxval=1.0) | |
w_crop, h_crop = tf.cond( | |
tf.math.greater(prob, tf.constant(0.5)), | |
lambda: (h_crop, w_crop), | |
lambda: (w_crop, h_crop) | |
) | |
x = -1 | |
y = -1 | |
if h_crop < height: | |
y = tf.random.uniform( | |
(), minval=0, maxval=height - h_crop, dtype=tf.int32) | |
if w_crop < width: | |
x = tf.random.uniform( | |
(), minval=0, maxval=width - w_crop, dtype=tf.int32) | |
x_cache = x | |
y_cache = y | |
if x_cache>-1: | |
if y_cache>-1: | |
return x_cache, y_cache, w_crop, h_crop | |
else: | |
return x, y, w_crop, h_crop | |
else: | |
return x, y, w_crop, h_crop | |
def _inception_style_crop_single(self, example, max_iter=10): | |
""" | |
Working logic: | |
1. Get random values from generate random dims function (see its docstring). | |
2. If the values are good (both > -1) then do inception style cropping | |
2. In all other cases do valiation cropping | |
""" | |
height = tf.cast(example["height"], tf.int32) | |
width = tf.cast(example["width"], tf.int32) | |
area = tf.cast(height * width, tf.float32) | |
x, y, w_crop, h_crop = self._get_random_dims(area, height, width) | |
img = tf.cast(example["image"], tf.uint8) | |
w = width | |
h = height | |
if x > -1: | |
if y > -1: | |
# Inception | |
w_resize = tf.cast(0, tf.int32) | |
h_resize = tf.cast(0, tf.int32) | |
img = img[y: y + h_crop, x: x + w_crop, :] | |
img = tf.cast(tf.math.round(tf.image.resize( | |
img, (self.crop_size, self.crop_size))), tf.uint8) | |
else: | |
# Validation | |
if w < h: | |
w_resize, h_resize = tf.cast(self.resize_pre_crop, tf.int32), tf.cast( | |
((h / w) * self.resize_pre_crop), tf.int32) | |
img = tf.image.resize(img, (h_resize, w_resize)) | |
elif h <= w: | |
w_resize, h_resize = tf.cast( | |
((w / h) * self.resize_pre_crop), tf.int32), tf.cast(self.resize_pre_crop, tf.int32) | |
img = tf.image.resize(img, (h_resize, w_resize)) | |
else: | |
w_resize = tf.cast(w, tf.int32) | |
h_resize = tf.cast(h, tf.int32) | |
img = tf.image.resize(img, (h_resize, w_resize)) | |
x = tf.cast(tf.math.ceil((w_resize - self.crop_size)/2), tf.int32) | |
y = tf.cast(tf.math.ceil((h_resize - self.crop_size)/2), tf.int32) | |
img = img[y: (y + self.crop_size), x: (x + self.crop_size), :] | |
img = tf.cast(tf.math.round(tf.image.resize( | |
img, (self.crop_size, self.crop_size))), tf.uint8) | |
else: | |
#Valiadation | |
if w < h: | |
w_resize, h_resize = tf.cast(self.resize_pre_crop, tf.int32), tf.cast( | |
((h / w) * self.resize_pre_crop), tf.int32) | |
img = tf.image.resize(img, (h_resize, w_resize)) | |
elif h <= w: | |
w_resize, h_resize = tf.cast( | |
((w / h) * self.resize_pre_crop), tf.int32), tf.cast(self.resize_pre_crop, tf.int32) | |
img = tf.image.resize(img, (h_resize, w_resize)) | |
else: | |
w_resize = tf.cast(w, tf.int32) | |
h_resize = tf.cast(h, tf.int32) | |
img = tf.image.resize(img, (h_resize, w_resize)) | |
x = tf.cast(tf.math.ceil((w_resize - self.crop_size)/2), tf.int32) | |
y = tf.cast(tf.math.ceil((h_resize - self.crop_size)/2), tf.int32) | |
img = img[y: (y + self.crop_size), x: (x + self.crop_size), :] | |
img = tf.cast(tf.math.round(tf.image.resize( | |
img, (self.crop_size, self.crop_size))), tf.uint8) | |
return { | |
"image": img, | |
"height": self.crop_size, | |
"width": self.crop_size, | |
"filename": example["filename"], | |
"label": example["label"], | |
"synset": example["synset"], | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def _mixup(self, image, label, alpha=0.2) -> Tuple: | |
""" | |
Function to apply mixup augmentation. To be applied after | |
one hot encoding and before batching. | |
Args: | |
entry1: Entry from first dataset. Should be one hot encoded and batched. | |
entry2: Entry from second dataset. Must be one hot encoded and batched. | |
Returns: | |
Tuple with same structure as the entries. | |
""" | |
image1, label1 = image, label | |
image2, label2 = tf.reverse( | |
image, axis=[0]), tf.reverse(label, axis=[0]) | |
image1 = tf.cast(image1, tf.float32) | |
image2 = tf.cast(image2, tf.float32) | |
alpha = [alpha] | |
dist = tfd.Beta(alpha, alpha) | |
l = dist.sample(1)[0][0] | |
img = l * image1 + (1 - l) * image2 | |
lab = l * label1 + (1 - l) * label2 | |
img = tf.cast(tf.math.round(tf.image.resize( | |
img, (self.crop_size, self.crop_size))), tf.uint8) | |
return img, lab |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def _pca_jitter(self, image, target): | |
""" | |
Applies PCA jitter to images. | |
Args: | |
image: Batch of images to perform random rotation on. | |
target: Target tensor. | |
Returns: | |
Augmented example with batch of images and targets with same dimensions. | |
""" | |
aug_images = tf.cast(image, tf.float32) / 255. | |
alpha = tf.random.normal((self.batch_size, 3), stddev=0.1) | |
alpha = tf.stack([alpha, alpha, alpha], axis=1) | |
rgb = tf.math.reduce_sum( | |
alpha * self.eigen_vals * self.eigen_vecs, axis=2) | |
rgb = tf.expand_dims(rgb, axis=1) | |
rgb = tf.expand_dims(rgb, axis=1) | |
aug_images = aug_images + rgb | |
aug_images = aug_images * 255. | |
aug_images = tf.cast(tf.clip_by_value(aug_images, 0, 255), tf.uint8) | |
return aug_images, target |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment