Created
January 29, 2021 10:17
-
-
Save L-Ramos/13b06493b70395776a456c37e36c8467 to your computer and use it in GitHub Desktop.
data_loader.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "data_loader.ipynb", | |
"provenance": [], | |
"collapsed_sections": [], | |
"authorship_tag": "ABX9TyMCzrac7mM9kfO4ItoOgnAh", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/L-Ramos/13b06493b70395776a456c37e36c8467/data_loader.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "y9Rtrn_72TGw", | |
"outputId": "4f1756c0-5ecf-4530-eab5-28181a59559d" | |
}, | |
"source": [ | |
"from google.colab import drive\r\n", | |
"drive.mount('/content/drive')\r\n" | |
], | |
"execution_count": 1, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Mounted at /content/drive\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "dysIc_bC3Que", | |
"outputId": "a1344517-5a6c-44db-efac-0687f4640aa0" | |
}, | |
"source": [ | |
"#Install sitk library to read dicom files\r\n", | |
"!pip install SimpleITK \r\n", | |
"import pandas as pd\r\n", | |
"import os\r\n", | |
"import SimpleITK as sitk\r\n", | |
"import torch\r\n", | |
"from torchvision import transforms, datasets\r\n", | |
"from torch.utils.data import Dataset, DataLoader\r\n", | |
"import numpy as np\r\n", | |
"df = pd.read_csv(r\"/content/drive/MyDrive/data_loader/test.csv\")" | |
], | |
"execution_count": 2, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Collecting SimpleITK\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/cc/85/6a7ce61f07cdaca722dd64f028b5678fb0a9e1bf66f534c2f8dd2eb78490/SimpleITK-2.0.2-cp36-cp36m-manylinux2010_x86_64.whl (47.4MB)\n", | |
"\u001b[K |████████████████████████████████| 47.4MB 89kB/s \n", | |
"\u001b[?25hInstalling collected packages: SimpleITK\n", | |
"Successfully installed SimpleITK-2.0.2\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "8ABwdrvQ3k6q" | |
}, | |
"source": [ | |
"df = df.drop(['Sex','SmokingStatus'],axis=1)" | |
], | |
"execution_count": 18, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "JBHfiDJ82_ZD" | |
}, | |
"source": [ | |
"def read_dicom_image(path):\r\n", | |
" reader = sitk.ImageSeriesReader()\r\n", | |
"\r\n", | |
" dicom_names = reader.GetGDCMSeriesFileNames(path)\r\n", | |
" reader.SetFileNames(dicom_names)\r\n", | |
"\r\n", | |
" image = reader.Execute()\r\n", | |
"\r\n", | |
" image = sitk.GetArrayFromImage(image)\r\n", | |
" #some images in the dataset had a different size, this would have to be \r\n", | |
" #pre-processed using registration or simply cropping as done below\r\n", | |
" image = image[0:20,0:512,0:512]\r\n", | |
" \r\n", | |
" return(image)\r\n", | |
"\r\n", | |
"class CombineDataset(Dataset):\r\n", | |
"\r\n", | |
" def __init__(self, frame, id_col, label_name, path_imgs, transform=None):\r\n", | |
" \"\"\"\r\n", | |
" Args:\r\n", | |
" frame (pd.DataFrame): Frame with the tabular data.\r\n", | |
" id_col (string): Name of the column that connects image to tabular data\r\n", | |
" label_name (string): Name of the column with the label to be predicted\r\n", | |
" path_imgs (string): path to the folder where the images are.\r\n", | |
" transform (callable, optional): Optional transform to be applied\r\n", | |
" on a sample, you need to implement a transform to use this.\r\n", | |
" \"\"\"\r\n", | |
" self.frame = frame\r\n", | |
" self.id_col = id_col\r\n", | |
" self.label_name = label_name\r\n", | |
" self.path_imgs = path_imgs\r\n", | |
" #self.transform = transform\r\n", | |
"\r\n", | |
" def __len__(self):\r\n", | |
" return (self.frame.shape[0])\r\n", | |
"\r\n", | |
" def __getitem__(self, idx):\r\n", | |
" if torch.is_tensor(idx):\r\n", | |
" idx = idx.tolist()\r\n", | |
" #complete image path and read\r\n", | |
" img_name = self.frame[self.id_col].iloc[idx]\r\n", | |
" path = os.path.join(self.path_imgs,img_name)\r\n", | |
" image = read_dicom_image(path)\r\n", | |
" image = torch.from_numpy(image.astype(np.float32))\r\n", | |
"\r\n", | |
" #get the other features to be used as training data\r\n", | |
" feats = [feat for feat in self.frame.columns if feat not in [self.label_name,self.id_col]]\r\n", | |
" feats = np.array(self.frame[feats].iloc[idx])\r\n", | |
" feats = torch.from_numpy(feats.astype(np.float32))\r\n", | |
" \r\n", | |
" \r\n", | |
" #get label\r\n", | |
" label = np.array(self.frame[self.label_name].iloc[idx])\r\n", | |
" label = torch.from_numpy(label.astype(np.float32))\r\n", | |
" #label = torch.as_tensor(np.array(label).astype('float'))\r\n", | |
"\r\n", | |
" return image, feats, label" | |
], | |
"execution_count": 44, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "hHcM0Roc6fBZ" | |
}, | |
"source": [ | |
"path_imgs = r\"/content/drive/MyDrive/data_loader/test\"\r\n", | |
"train_set = CombineDataset(df,'Patient','FVC',path_imgs)\r\n", | |
"\r\n", | |
"loader_trainer = DataLoader(\r\n", | |
" train_set,\r\n", | |
" batch_size = 2,\r\n", | |
" shuffle = True,\r\n", | |
" num_workers = 0,\r\n", | |
" drop_last=True\r\n", | |
")\r\n" | |
], | |
"execution_count": 45, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "jxZg2ZSB9aWM", | |
"outputId": "b2c09904-0553-4061-d433-bfde3136d3c0" | |
}, | |
"source": [ | |
"for data in loader_trainer:\r\n", | |
" x,z, y = data\r\n", | |
" \r\n", | |
" print('Image Loaded Size = ', x.shape)\r\n", | |
" print('Tabular data = ', z)\r\n", | |
" print('Label = ', y)\r\n" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Image Loaded Size = torch.Size([2, 20, 512, 512])\n", | |
"Tabular data = tensor([[ 0.0000, 71.8250, 73.0000],\n", | |
" [17.0000, 79.2589, 72.0000]])\n", | |
"Label = tensor([2925., 3294.])\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "8RkanpeepfXj" | |
}, | |
"source": [ | |
"" | |
], | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment