-
-
Save cakiki/80fc11d75a9c1f03dab831ea0cf96ddb to your computer and use it in GitHub Desktop.
You can use these environment variables to run a Python process on a subset of the TPU cores on a Cloud TPU VM. This allows running multiple TPU processes at the same time, since only one process can access a given TPU core at a time. Note that in JAX, 1 TPU core = 1 TpuDevice as reported by `jax.devices()`.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# 4x 1 chip (2 cores) per process: | |
os.environ["TPU_CHIPS_PER_HOST_BOUNDS"] = "1,1,1" | |
os.environ["TPU_HOST_BOUNDS"] = "1,1,1" | |
# Different per process: | |
os.environ["TPU_VISIBLE_DEVICES"] = "0" # "1", "2", "3" | |
# Pick a unique port per process | |
os.environ["TPU_MESH_CONTROLLER_ADDRESS"] = "localhost:8476" | |
os.environ["TPU_MESH_CONTROLLER_PORT"] = "8476" | |
# 1-liner for bash: TPU_CHIPS_PER_HOST_BOUNDS=1,1,1 TPU_HOST_BOUNDS=1,1,1 TPU_VISIBLE_DEVICES=0 TPU_MESH_CONTROLLER_ADDRESS=localhost:8476 TPU_MESH_CONTROLLER_PORT=8476 | |
# 2x 2 chips (4 cores) per process: | |
os.environ["TPU_CHIPS_PER_HOST_BOUNDS"] = "1,2,1" | |
os.environ["TPU_HOST_BOUNDS"] = "1,1,1" | |
# Different per process: | |
os.environ["TPU_VISIBLE_DEVICES"] = "0,1" # "2,3" | |
# Pick a unique port per process | |
os.environ["TPU_MESH_CONTROLLER_ADDRESS"] = "localhost:8476" | |
os.environ["TPU_MESH_CONTROLLER_PORT"] = "8476" | |
# 1-liner for bash: TPU_CHIPS_PER_HOST_BOUNDS=1,2,1 TPU_HOST_BOUNDS=1,1,1 TPU_VISIBLE_DEVICES=0,1 TPU_MESH_CONTROLLER_ADDRESS=localhost:8476 TPU_MESH_CONTROLLER_PORT=8476 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment