Skip to content

Instantly share code, notes, and snippets.

@simonespa
Created June 11, 2024 08:14
Show Gist options
  • Save simonespa/0ef1bfbbc124ad54a6f228c4a99ccb0f to your computer and use it in GitHub Desktop.
Save simonespa/0ef1bfbbc124ad54a6f228c4a99ccb0f to your computer and use it in GitHub Desktop.
TensorfFlow Distribute Strategy
# https://www.tensorflow.org/tutorials/distribute/keras
logical_device_names = [logical_device.name for logical_device in tf.config.list_logical_devices()]
if 'GPU' in ''.join(logical_device_names):
distribution_strategy = tf.distribute.MirroredStrategy()
elif 'TPU' in ''.join(logical_device_names):
tf.tpu.experimental.initialize_tpu_system()
tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='/device:TPU_SYSTEM:0')
distribution_strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
print('Warning: this will be really slow.')
distribution_strategy = tf.distribute.OneDeviceStrategy(logical_device_names[0])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment