Skip to content

Instantly share code, notes, and snippets.

@Anwar05108
Created September 5, 2023 18:38
Show Gist options
  • Save Anwar05108/456394731c5e3f7d20383b8f077855b5 to your computer and use it in GitHub Desktop.
Save Anwar05108/456394731c5e3f7d20383b8f077855b5 to your computer and use it in GitHub Desktop.
#from monai.networks.nets import UNETR as UNETR_monai
from self_attention_cv import UNETR
device = torch.device("cuda:0")
num_heads = 10 # 12 normally
embed_dim= 512 # 768 normally
# model = UNETR(img_shape=tuple(roi_size), input_dim=4, output_dim=3,
# embed_dim=embed_dim, patch_size=16, num_heads=num_heads,
# ext_layers=[3, 6, 9, 12], norm='instance',
# base_filters=16,
# dim_linear_block=2048).to(device)
model = UNet(
dimensions=3,
in_channels=4,
out_channels=3,
channels=(16, 32, 64, 128, 256),
strides=(2, 2, 2, 2),
num_res_units=2,
).to(device)
# model = UNETR_monai(
# in_channels=4,
# out_channels=3,
# img_size=tuple(roi_size),
# feature_size=16,
# hidden_size=embed_dim,
# mlp_dim=3072,
# num_heads=12,
# pos_embed="perceptron",
# norm_name="instance",
# res_block=True,
# dropout_rate=0.0,
# ).to(device)
pytorch_total_params = sum(p.numel() for p in model.parameters())/1000000
print('Parameters in millions:',pytorch_total_params)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment