Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save hartikainen/67a4cb20b0f7afc145ef475943c105a0 to your computer and use it in GitHub Desktop.
Save hartikainen/67a4cb20b0f7afc145ef475943c105a0 to your computer and use it in GitHub Desktop.
import tensorflow as tf
class TestModule(tf.Module):
def __init__(self, value):
self.variable = tf.Variable(value)
module_1 = TestModule(value=9000)
tf.saved_model.save(module_1, "./foo")
module_2 = tf.saved_model.load("./foo")
assert module_1.variable.numpy() == module_2.variable.numpy()
assert module_1.trainable_variables == (module_1.variable, )
assert module_2.trainable_variables == (module_2.variable, )
assert module_1.trainable_variables == module_2.trainable_variables
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment