Created
June 2, 2020 13:08
-
-
Save faustomilletari/f2d33259fbf459d2e4a8127787c856e8 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Eisen EMIDEC Challenge starter kit | |
NOTE: you need to register to the challenge, download and unpack the data in | |
order to be able to run the following example. | |
Find more info here: http://emidec.com/ | |
This is released under MIT license. Do what you want with this code. | |
""" | |
import os | |
from eisen.datasets import EMIDEC | |
from eisen.models.segmentation import VNet | |
from eisen.io import LoadNiftyFromFilename | |
from eisen.transforms import ( | |
ResampleNiftiVolumes, | |
NiftiToNumpy, | |
CropCenteredSubVolumes, | |
AddChannelDimension, | |
MapValues, | |
FixedMeanStdNormalization, | |
LabelMapToOneHot, | |
StackImagesChannelwise, | |
FilterFields | |
) | |
from eisen.ops.losses import DiceLoss | |
from eisen.ops.metrics import DiceMetric | |
from eisen.utils import EisenModuleWrapper | |
from eisen.utils.workflows import Training | |
from eisen.utils.logging import LoggingHook | |
from eisen.utils.logging import TensorboardSummaryHook | |
from eisen.utils.artifacts import SaveTorchModelHook | |
from torchvision.transforms import Compose | |
from torch.utils.data import DataLoader | |
from torch.optim import Adam | |
""" | |
Constants defining important parameters of the algorithm. | |
CHANGE HERE WHAT SHOULD BE CHANGED TO FIT YOUR CONFIG. | |
""" | |
# Defining some constants | |
PATH_DATA = './emidec_data' # path of data as unpacked from the challenge files | |
PATH_ARTIFACTS = './results' # path for model results | |
if not os.path.exists(PATH_ARTIFACTS): | |
os.mkdir(PATH_ARTIFACTS) | |
NUM_EPOCHS = 100 | |
BATCH_SIZE = 4 | |
VOLUMES_RESOLUTION = [4, 4, 1] # original emidec data has 1 cubic mm voxel spacing | |
VOLUMES_PIXEL_SIZE = [64, 64, 16] | |
CLASSES = [1, 2] | |
""" | |
Define Readers and Transforms | |
In order to load data and prepare it for being used by the network, we need to operate | |
I/O operations and define transforms to standardize data. | |
You can add transforms or change the existing ones by editing this | |
""" | |
# readers: for images and labels | |
read_tform = LoadNiftyFromFilename(['image', 'label'], PATH_DATA) # load content of image and data field from dataset | |
# image manipulation transforms | |
resample_tform_img = ResampleNiftiVolumes( | |
['image'], | |
VOLUMES_RESOLUTION, | |
'linear' | |
) # resamples volume | |
resample_tform_lbl = ResampleNiftiVolumes( | |
['label'], | |
VOLUMES_RESOLUTION, | |
'nearest' | |
) # resamples labels with nearest interpolations | |
to_numpy_tform = NiftiToNumpy(['image', 'label']) # converts to numpy | |
crop = CropCenteredSubVolumes(['image', 'label'], size=VOLUMES_PIXEL_SIZE) # crops images to size | |
map_intensities = MapValues(['image'], min_value=0.0, max_value=1.0) # maps intensities between 0 and 1 | |
add_channel_dim = AddChannelDimension(['image']) # adds a singleton channel dimension to images | |
map_labels = LabelMapToOneHot(['label'], CLASSES) # maps labels to 1 class per channel according to CLASSES | |
preserve_only_fields = FilterFields(['image', 'label']) # filters out all fields in dictionary apart image and label | |
# create a transform to manipulate and load data | |
tform = Compose([ | |
read_tform, | |
resample_tform_img, | |
resample_tform_lbl, | |
to_numpy_tform, | |
crop, | |
map_intensities, | |
add_channel_dim, | |
map_labels, | |
preserve_only_fields | |
]) | |
# create a dataset from the training set of the ABC dataset | |
dataset = EMIDEC( | |
PATH_DATA, | |
training=True, | |
transform=tform # transform is passed here | |
) | |
# Data loader: a pytorch DataLoader is used here to loop through the data as provided by the dataset. | |
data_loader = DataLoader( | |
dataset, | |
batch_size=BATCH_SIZE, | |
shuffle=True, | |
num_workers=4 | |
) | |
""" | |
Building blocks: we define here: | |
* model | |
* loss | |
* metric | |
* optimizer | |
These components are used during training. | |
These blocks will be joined together in a workflow (Eg. training workflow). | |
""" | |
# specify model and loss (building blocks) | |
model = EisenModuleWrapper( | |
module=VNet(input_channels=1, output_channels=len(CLASSES)), | |
input_names=['image'], # define that the inputs of the network from the batch are called "image" | |
output_names=['predictions'] # define that the output of network should be called "prediction" | |
) | |
# CHANGE TASK HERE if needed!! | |
loss = EisenModuleWrapper( | |
module=DiceLoss(dim=[2, 3, 4]), | |
input_names=['predictions', 'label'], # define that it is a comparison between prdictions and label | |
output_names=['dice_loss'] # the output shall be called dice_loss | |
) | |
# CHANGE TASK HERE if needed!! | |
metric = EisenModuleWrapper( | |
module=DiceMetric(dim=[2, 3, 4]), | |
input_names=['predictions', 'label'], | |
output_names=['dice_metric'] | |
) | |
optimizer = Adam(model.parameters(), 0.001) | |
# join all blocks into a workflow (training workflow) | |
training_workflow = Training( | |
model=model, | |
losses=[loss], | |
data_loader=data_loader, | |
optimizer=optimizer, | |
metrics=[metric], | |
gpu=False | |
) | |
# create Hook to monitor training and save models | |
training_loggin_hook = LoggingHook(training_workflow.id, 'Training', PATH_ARTIFACTS) | |
training_summary_hook = TensorboardSummaryHook(training_workflow.id, 'Training', PATH_ARTIFACTS) | |
save_model_hook = SaveTorchModelHook(training_workflow.id, 'Training', PATH_ARTIFACTS) | |
# run optimization for NUM_EPOCHS | |
for i in range(NUM_EPOCHS): | |
training_workflow.run() | |
# todo: VALIDATION and INFERENCE code |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment