This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def augment(images, labels, | |
resize=None, # (width, height) tuple or None | |
horizontal_flip=False, | |
vertical_flip=False, | |
rotate=0, # Maximum rotation angle in degrees | |
crop_probability=0, # How often we do crops | |
crop_min_percent=0.6, # Minimum linear dimension of a crop | |
crop_max_percent=1., # Maximum linear dimension of a crop | |
mixup=0): # Mixup coeffecient, see https://arxiv.org/abs/1710.09412.pdf | |
if resize is not None: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import numpy as np | |
import tensorflow as tf | |
class GRU(tf.contrib.rnn.RNNCell): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
def orthogonal_initializer(scale=1.0, seed=None, dtype=tf.float32): | |
def _initializer(shape, dtype=dtype, partition_info=None): | |
flat = (shape[0], np.prod(shape[1:])) | |
a = np.random.normal(0.0, 1.0, flat) | |
u, _, v = np.linalg.svd(a, full_matrices=False) | |
q = (u if u.shape == flat else v).reshape(shape) | |
return tf.constant(scale * q[:shape[0], :shape[1]], dtype=dtype) | |
return _initializer |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
# Copyright (C) 2017 by Akira TAMAMORI | |
# Copyright (C) 2016 by hardmaru | |
# Permission is hereby granted, free of charge, to any person obtaining a copy | |
# of this software and associated documentation files (the "Software"), to deal | |
# in the Software without restriction, including without limitation the rights | |
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""adapted from https://github.com/OlavHN/bnlstm to store separate population statistics per state""" | |
import tensorflow as tf, numpy as np | |
RNNCell = tf.nn.rnn_cell.RNNCell | |
class BNLSTMCell(RNNCell): | |
'''Batch normalized LSTM as described in arxiv.org/abs/1603.09025''' | |
def __init__(self, num_units, is_training_tensor, max_bn_steps, initial_scale=0.1, activation=tf.tanh, decay=0.95): | |
""" | |
* max bn steps is the maximum number of steps for which to store separate population stats | |
""" |