Created
August 20, 2020 11:32
-
-
Save hanwinbi/644238d047d5e5881cb70e7ea7094a48 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import cv2 as cv | |
import torch | |
import numpy as np | |
from KITS19Rank4 import config | |
import json | |
from torch.utils.data import Dataset, DataLoader | |
from torchvision import transforms | |
dir = config.json_path | |
# 载入测试集[验证集、测试集]案例 | |
def load_data(path): | |
with open(path, 'r') as load_f: | |
load_dict = json.load(load_f) | |
data = [] | |
for i in range(len(load_dict)): | |
data.append(load_dict[i]) | |
return data | |
trainData = load_data(dir+'trainData.json') | |
testData = load_data(dir+'testData.json') | |
validationData = load_data(dir+'validationData.json') | |
# 载入一个样例中的全部切片 | |
def load_case_data(case): | |
print('current case:', case) | |
caseData = [] | |
for i in range(config.depth): | |
start = int(case[-8:-4]) # 起始图片的序号 | |
slice_path = case[0:-8] + str("%04d" % (start+i)) + '.bmp' # 切片数是32,连续的32张 | |
caseData.append(slice_path) | |
return caseData | |
# caseData = load_case_data(dir+'trainData.json', 0) | |
class DataSets(Dataset): | |
def __init__(self, casedata): | |
self.transform = transforms.Compose( | |
[transforms.Normalize(mean=(0.485,), std=(0.229,))] | |
) | |
print('case_path', casedata) | |
self.GT = load_case_data(casedata) | |
self.slice = config.depth # 一个病人取32张切片 | |
def __len__(self): | |
return len(self.GT) | |
def __getitem__(self, idx): | |
gts = self.getGroundTruth() | |
imgs = self.getOriginImage() | |
gts = np.array(gts, dtype='int64') | |
# print('gts', gts.shape) | |
gts = torch.from_numpy(gts) | |
imgs = np.array(imgs, dtype='float32') | |
# print('imgs', imgs.shape) | |
imgs = torch.from_numpy(imgs) | |
imgs = self.transform(imgs) | |
return imgs, gts | |
def getOriginImage(self): | |
casedata = self.GT | |
imgs = [] | |
for i in range(config.depth): | |
origin_image = casedata[i].replace('GT', 'Images') | |
pic = cv.imread(origin_image, 0) | |
pic = cv.resize(pic, (160, 160)) | |
imgs.append(pic) | |
return imgs | |
def getGroundTruth(self): | |
casedata = self.GT | |
gts = [] | |
for i in range(config.depth): | |
gt = cv.imread(casedata[i], 0) | |
gt = cv.resize(gt, (160, 160)) | |
gt = gt / 127 | |
gts.append(gt) | |
return gts |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment