Created
September 15, 2023 09:24
-
-
Save antferdom/5f68a8618e9c41ae51bbb0c5b2c49219 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# see https://github.com/pytorch/torchsnapshot/blob/main/benchmarks/fsdp/main.py | |
import torch | |
from torch import distributed as dist, nn | |
def create_model() -> nn.Module: | |
# 7.8GB model, 1.9B parameters | |
model = nn.Transformer( | |
d_model=864, | |
num_encoder_layers=1, | |
num_decoder_layers=20, | |
nhead=12, | |
dim_feedforward=50257, | |
) | |
# 80GB 21B parameters | |
# model = nn.Transformer( | |
# d_model=4000, | |
# num_encoder_layers=1, | |
# num_decoder_layers=40, | |
# nhead=40, | |
# dim_feedforward=50257, | |
# ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment