Skip to content

Instantly share code, notes, and snippets.

@hartikainen
Last active February 1, 2020 19:41
Show Gist options
  • Save hartikainen/7e466bd6ad4bb9cecc8bcbd481189e53 to your computer and use it in GitHub Desktop.
Save hartikainen/7e466bd6ad4bb9cecc8bcbd481189e53 to your computer and use it in GitHub Desktop.
import tensorflow as tf
import tensorflow_probability as tfp
from softlearning.utils.tensorflow import nest
@tf.function(experimental_relax_shapes=True)
def another_test_function(inputs):
outputs = {'old': inputs, 'new': inputs + 100}
return outputs
@tf.function(experimental_relax_shapes=True)
def test_function(inputs):
outputs = tf.nest.map_structure(another_test_function, inputs)
return outputs
def main():
for i in range(20):
inputs = {
'x': tf.random.uniform((3, 2)),
'y': tf.random.uniform((3, 6)),
}
outputs = test_function(inputs)
nest.map_structure_up_to(
inputs,
lambda input_, output: tf.debugging.assert_equal(
input_ + 100, output['new']),
inputs,
outputs)
nest.map_structure_up_to(
inputs,
lambda input_, output: tf.debugging.assert_equal(
input_, output['old']),
inputs,
outputs)
print(inputs, outputs)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment