Created
December 5, 2019 10:00
-
-
Save danmou/bafa5c80356fdb2c843eaf38c8597f84 to your computer and use it in GitHub Desktop.
Mixin for `tf.keras.layers.Layer`s and subclasses to automatically define input and output specs the first time the model is called.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Any, Mapping, Optional, Sequence, TypeVar, Union | |
import tensorflow as tf | |
from tensorflow.keras import layers | |
T = TypeVar('T') | |
Nested = Union[T, Sequence[T], Mapping[Any, T]] | |
class AutoShapeMixin: | |
""" | |
Mixin for `tf.keras.layers.Layer`s and subclasses to automatically define input and output specs the first time the model is called. Must be listed before `tf.keras.layers.Layer` when subclassing. Only works for | |
models and layers with static input and output shapes. First `batch_dims` dimensions (default 1) are assumed to be batch dimensions. | |
""" | |
def __init__(self, *args: Any, **kwargs: Any) -> None: | |
self.batch_dims: int = kwargs.pop('batch_dims', 1) | |
super().__init__(*args, **kwargs) | |
assert not getattr(self, 'dynamic'), 'AutoShapeMixin should not be used with dynamic layers!' | |
self._input_spec: Optional[Nested[layers.InputSpec]] = None | |
self._output_spec: Optional[Nested[layers.InputSpec]] = None | |
self.built_with_input = False | |
def build_with_input(self, input: Nested[tf.Tensor], *args: Any, **kwargs: Any) -> None: | |
bd = self.batch_dims | |
self._input_spec = tf.nest.map_structure( | |
lambda x: layers.InputSpec(shape=[None]*bd + x.shape[bd:], dtype=x.dtype), input) | |
dummy_input = tf.nest.map_structure(lambda t: tf.zeros([2]*bd + t.shape[bd:], t.dtype), input) | |
dummy_output = super().__call__(dummy_input, *args, **kwargs) | |
self._output_spec = tf.nest.map_structure(lambda x: layers.InputSpec(shape=[None]*bd + x.shape[bd:], | |
dtype=x.dtype), dummy_output) | |
self.built_with_input = True | |
def __call__(self, inputs: Nested[tf.Tensor], *args: Any, **kwargs: Any) -> Any: | |
if not self.built_with_input: | |
self.build_with_input(inputs, *args, **kwargs) | |
return super().__call__(inputs, *args, **kwargs) | |
@property | |
def input_spec(self) -> Optional[Nested[layers.InputSpec]]: | |
return self._input_spec | |
@input_spec.setter | |
def input_spec(self, value: Optional[layers.InputSpec]) -> None: | |
self._input_spec = value | |
@property | |
def output_spec(self) -> Optional[Nested[layers.InputSpec]]: | |
return self._output_spec | |
@output_spec.setter | |
def output_spec(self, value: Optional[layers.InputSpec]) -> None: | |
self._output_spec = value | |
@property | |
def input_shape(self) -> Nested[tf.TensorShape]: | |
assert self.input_spec is not None, 'build_with_input has not been called; input shape is not defined' | |
return tf.nest.map_structure(lambda x: x.shape, self.input_spec) | |
@property | |
def output_shape(self) -> Nested[tf.TensorShape]: | |
assert self.output_spec is not None, 'build_with_input has not been called; output shape is not defined' | |
return tf.nest.map_structure(lambda x: x.shape, self.output_spec) | |
@property | |
def input_dtype(self) -> Nested[tf.TensorShape]: | |
assert self.input_spec is not None, 'build_with_input has not been called; input dtype is not defined' | |
return tf.nest.map_structure(lambda x: x.dtype, self.input_spec) | |
@property | |
def output_dtype(self) -> Nested[tf.TensorShape]: | |
assert self.output_spec is not None, 'build_with_input has not been called; output dtype is not defined' | |
return tf.nest.map_structure(lambda x: x.dtype, self.output_spec) | |
def compute_output_shape(self, input_shape: Nested[tf.TensorShape]) -> Nested[tf.TensorShape]: | |
if self.output_spec is None: | |
return super().compute_output_shape(input_shape) | |
batch_shape = tf.nest.flatten(input_shape)[0][:self.batch_dims] | |
return tf.nest.map_structure(lambda x: batch_shape + x[self.batch_dims:], self.output_shape) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import List | |
import tensorflow as tf | |
from tensorflow.keras import layers | |
from auto_shape_mixin import AutoShapeMixin | |
### To use the standard Keras layers with auto shape, redefine them like this: | |
class Layer(AutoShapeMixin, layers.Layer): | |
pass | |
class Dense(AutoShapeMixin, layers.Dense): | |
pass | |
class Conv2D(AutoShapeMixin, layers.Conv2D): | |
pass | |
class Flatten(AutoShapeMixin, layers.Flatten): | |
pass | |
class Concatenate(AutoShapeMixin, layers.Concatenate): | |
pass | |
class Model(AutoShapeMixin, tf.keras.Model): | |
pass | |
class Sequential(AutoShapeMixin, tf.keras.Sequential): | |
pass | |
# etc | |
### For your own layers simply inherit from one of the above classes and also use them for all sub-layers, e.g.: | |
class ExampleNetwork(Model): | |
def __init__(self) -> None: | |
super().__init__() | |
self.encoder = Sequential([ | |
Conv2D(filters=32, kernel_size=3, strides=2, activation='relu'), | |
Flatten(), | |
]) | |
self.concat = Concatenate(axis=-1) | |
self.dense = Dense(units=100) | |
def call(self, inputs: List[tf.Tensor]) -> tf.Tensor: | |
encoded = self.encoder(inputs[0]) | |
joined = self.concat([encoded] + inputs[1:]) | |
return self.dense(joined) | |
# After the first time you call your model with an input, its input and output shapes and dtypes will be defined and `summary` will work as expected. | |
model = ExampleNetwork() | |
first_batch = [tf.zeros((1, 64, 64, 3)), tf.zeros((1, 10))] | |
model(first_batch) | |
model.summary() | |
# Model: "example_network_1" | |
# _________________________________________________________________ | |
# Layer (type) Output Shape Param # | |
# ================================================================= | |
# sequential_1 (Sequential) (None, 30752) 896 | |
# _________________________________________________________________ | |
# concatenate (Concatenate) (None, 30762) 0 | |
# _________________________________________________________________ | |
# dense (Dense) (None, 100) 3076300 | |
# ================================================================= | |
# Total params: 3,077,196 | |
# Trainable params: 3,077,196 | |
# Non-trainable params: 0 | |
# _________________________________________________________________ | |
print(model.input_shape) | |
# ListWrapper([TensorShape([None, 64, 64, 3]), TensorShape([None, 10])]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment