Created
September 1, 2020 09:33
-
-
Save hanwinbi/c94d05014a79648ffa7cbdba6a53976b to your computer and use it in GitHub Desktop.
KITS19测试代码
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
depth = 48 # 输入的深度 | |
data_root = '/datasets/users/bihanwen/data20/kits/' # 带数据增强部分 '/datasets/KITS2020/TEST/' # '/datasets/users/bihanwen/temp/' # 原始数据图片的路径 | |
model_path = '/datasets/users/bihanwen/model/pth_48slice/' | |
result_path = '/datasets/users/bihanwen/result/' | |
json_path = os.path.abspath('./data/no_aug_48slice') + '/' # json存放的路径 | |
log_path = './log/log_48slice/' | |
def gene_dir(_dir): | |
if not os.path.isdir(_dir): | |
os.makedirs(_dir) | |
gene_dir(model_path) | |
gene_dir(json_path) | |
gene_dir(log_path) | |
print(json_path, data_root) | |
# 用于配置测试的路径 | |
# rela_path = '/datasets/users/bihanwen/temp' | |
# abs_path = os.path.abspath(rela_path) + '/' | |
# abs_data_root = os.path.abspath(data_root) + '/' | |
# print(rela_path, abs_path) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import config | |
import json | |
import cv2 as cv | |
from collections import OrderedDict | |
rela_path = config.json_path | |
data_path = config.data_root # 数据集的绝对路径 | |
output_json_folder = config.json_path # 输出所有数据信息的json文件夹目录 | |
json_dict = OrderedDict() | |
# 获取数据集信息的方法 | |
def dataset_info(path): | |
cases = sorted(os.listdir(path)) # 将文件目录进行读取并排序 | |
print('cases:', cases) | |
json_dict['case num'] = len(cases) # 创建字典,数据集中案例的数目 | |
json_dict['case'] = list() # 案例列表 | |
ave_size = 0 # 所有案例的平均图片大小(这里其实是算的总大小) | |
ave_slice_num = 0 # 所有案例的平均切片数目 | |
for case in cases: # 遍历案例 | |
GT = str(path+case+'/GT/') # gt案例路径 | |
Images = str(path+case+'/Images/') # 原始图片路径 | |
slice_path = sorted(os.listdir(GT)) # | |
total_image_num = len(slice_path) | |
print('slice name', slice_path) | |
count = 0 # 同一个分类的切片计数,0表示没有进行增强的 | |
for slice in slice_path: | |
if slice[0] == '0': | |
count += 1 | |
dirfile = str(path + case + '/GT/' + slice_path[0]) # 读取一个案例中的一张图片获得属性 | |
print('dirfile:', dirfile) | |
img = cv.imread(dirfile) | |
size = img.shape | |
print(size) | |
ave_size += size[0] | |
ave_slice_num += count | |
print("sum_size:{0},sum_slice:{1}".format(ave_size, ave_slice_num)) | |
# 把信息添加到字典中 | |
dict = {'GT': GT, "Images": Images, "Total Image Num": total_image_num, "Slice num": count, "Img Size": size} | |
json_dict['case'].append(dict) | |
print(case) | |
json_dict['Average pic size'] = ave_size/len(cases) # 图片的平均大小 | |
json_dict['Average slice num'] = ave_slice_num/len(cases) # 平均的切片数量 | |
with open(os.path.join(output_json_folder, "dataset.json"), 'w') as f: | |
json.dump(json_dict, f, indent=4, sort_keys=True) | |
dataset_info(data_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import cv2 | |
import json | |
import torch | |
import random | |
import config | |
from collections import OrderedDict | |
json_dir = config.json_path | |
json_dict = OrderedDict() | |
# 从seeds.json中按6:2:2比例得到训练集、验证集、测试集 | |
def randomDiv(seeds_path, size): | |
with open(seeds_path, 'r') as load_f: | |
load_dict = json.load(load_f) | |
seeds = load_dict['case'] | |
print(seeds) | |
lenofsets = len(seeds) | |
trainsize = int(size[0] * lenofsets) | |
validationsize = int(size[1] * lenofsets) | |
# 生成随机数作为seeds字典的idx | |
idx = list(range(0, lenofsets)) | |
trainDataset = random.sample(idx, trainsize) | |
restDataset = set(idx) - set(trainDataset) | |
validationDataset = random.sample(restDataset, validationsize) | |
testDataset = set(restDataset) - set(validationDataset) | |
json_dict['train case'] = get_slice_include_aug(trainDataset, seeds) | |
json_dict['test case'] = get_origin_slice(testDataset, seeds) | |
json_dict['validation case'] = get_origin_slice(validationDataset, seeds) | |
trainData_path = os.path.join(json_dir, 'trainData.json') | |
testData_path = os.path.join(json_dir, 'testData.json') | |
validationData_path = os.path.join(json_dir, 'validationData.json') | |
with open(trainData_path, 'w') as f: | |
traincase = json_dict["train case"] | |
json.dump(traincase, f, indent=4) | |
with open(testData_path, 'w') as f: | |
testcase = json_dict['test case'] | |
json.dump(testcase, f, indent=4) | |
with open(validationData_path, 'w') as f: | |
validationcase = json_dict['validation case'] | |
json.dump(validationcase, f, indent=4) | |
print('train data path', trainData_path) | |
return trainData_path, testData_path, testData_path | |
# 训练集中包括数据增强部分 | |
def get_slice_include_aug(random_seed, seeds): | |
case_list = list() | |
loop_time = int(seeds[0]['Total Image Num']/seeds[0]['Slice num']) # 一个案例中遍历的次数,数据增强为四次 | |
# 遍历得到的随机种子,生成对应的list | |
for idx in random_seed: | |
file_list = sorted(os.listdir(seeds[idx]['GT'])) | |
slice_num = seeds[idx]['Slice num'] # 每个案例的切片数目不一样,获取案例的切片数 | |
start_pos = int((seeds[idx]['Slice num'] - config.depth) / 2) # 得到此案例的中间切片位置 | |
for i in range(loop_time): | |
# 第一张切片 | |
start = file_list[1] | |
slice_path = seeds[idx]['GT'] + start | |
case_list.append(slice_path) | |
# 倒数切片 | |
start = file_list[-config.depth] | |
slice_path = seeds[idx]['GT'] + start | |
case_list.append(slice_path) | |
# 中间切片 | |
start = file_list[start_pos + slice_num * i] | |
slice_path = seeds[idx]['GT'] + start | |
case_list.append(slice_path) | |
random.shuffle(case_list) # 将选中的样例打乱 | |
return case_list | |
# 测试集和验证集中不包括数据增强 | |
def get_origin_slice(random_seed, seeds): | |
case_list = list() | |
for idx in random_seed: | |
file_list = sorted(os.listdir(seeds[idx]['GT'])) | |
start_pos = int((seeds[idx]['Slice num'] - config.depth) / 2) # 得到此案例的中间切片位置 | |
start = file_list[start_pos] | |
slice_path = seeds[idx]['GT'] + start | |
case_list.append(slice_path) | |
# 增加前48张和后48张 | |
start = file_list[0] | |
slice_path = seeds[idx]['GT'] + start | |
case_list.append(slice_path) | |
start = file_list[-config.depth] | |
slice_path = seeds[idx]['GT'] + start | |
case_list.append(slice_path) | |
# print(case_list) | |
random.shuffle(case_list) | |
return case_list | |
trainData, validationData, testData = randomDiv(json_dir+'dataset.json', (0.6, 0.2, 0.2)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from nnunet.utilities.nd_softmax import softmax_helper | |
from nnunet.utilities.tensor_utilities import sum_tensor | |
from torch import nn | |
from utils import make_one_hot_3d | |
class SoftDiceLoss(nn.Module): | |
def __init__(self, smooth=1., apply_nonlin=None, batch_dice=True, do_bg=False, smooth_in_nom=True, | |
background_weight=1, rebalance_weights=None, square_nominator=False, square_denom=False): | |
""" | |
hahaa no documentation for you today | |
:param smooth: | |
:param apply_nonlin: | |
:param batch_dice: | |
:param do_bg: | |
:param smooth_in_nom: | |
:param background_weight: | |
:param rebalance_weights: | |
""" | |
super(SoftDiceLoss, self).__init__() | |
self.square_denom = square_denom | |
self.square_nominator = square_nominator | |
if not do_bg: | |
assert background_weight == 1, "if there is no bg, then set background weight to 1 you dummy" | |
self.rebalance_weights = rebalance_weights | |
self.background_weight = background_weight | |
if smooth_in_nom: | |
self.smooth_in_nom = smooth | |
else: | |
self.smooth_in_nom = 0 | |
self.do_bg = do_bg | |
self.batch_dice = batch_dice | |
self.apply_nonlin = apply_nonlin | |
self.smooth = smooth | |
self.y_onehot = None | |
def forward(self, x, y): | |
with torch.no_grad(): | |
y = y.long() | |
shp_x = x.shape | |
# print('x shape is:',shp_x) | |
shp_y = y.shape | |
# print('y shape is:',shp_y) | |
#y shape is: torch.Size([8, 1, 192, 192, 48]) | |
# nonlin maybe mean NONLINEARITY! | |
if self.apply_nonlin is not None: | |
x = self.apply_nonlin(x) | |
if len(shp_x) != len(shp_y): # 统一维度 | |
y = y.view((shp_y[0], 1, *shp_y[1:])) | |
# print('After apply nonlin, x shape is:',x.shape) | |
#After apply nonlin, x shape is: torch.Size([8, 3, 192, 192, 48]) | |
# output shape is: [8,3,192,192,48] when batch size is 8 and labels are [0,1,2] | |
# now x and y should have shape (B, C, X, Y(, Z))) and (B, 1, X, Y(, Z))), respectively | |
y_onehot = torch.zeros(shp_x) | |
if x.device.type == "cuda": | |
y_onehot = y_onehot.cuda(x.device.index) | |
y_onehot.scatter_(1, y, 1) | |
if not self.do_bg: | |
x = x[:, 1:]# This means to reduce the first 0 dimension of the shape of output x, to remove background prediction | |
# x is the probability output, so its range is between [0,1] | |
y_onehot = y_onehot[:, 1:] | |
# print('y_onehot shape is:',y_onehot.shape) | |
# y_onehot shape is: torch.Size([8, 2, 192, 192, 48]) | |
# print('The last version of x shape is:',x.shape) | |
#The last version of x shape is: torch.Size([8, 2, 192, 192, 48]) | |
# print('x max is:', torch.max(x)) | |
# x max is: tensor(1.0000, device='cuda:4', grad_fn=<MaxBackward1>) | |
# print('x min is:', torch.min(x)) | |
# x min is: tensor(3.1973e-07, device='cuda:4', grad_fn=<MinBackward1>) | |
# print('y_onehot max is:', torch.max(y_onehot)) | |
#y_onehot max is: tensor(1., device='cuda:4') | |
#x max is: tensor(1.0000, device='cuda:4', grad_fn=<MaxBackward1>) | |
# print('y_onehot min is:', torch.min(y_onehot)) | |
#y_onehot min is: tensor(0., device='cuda:4') | |
if not self.batch_dice: | |
if self.background_weight != 1 or (self.rebalance_weights is not None): | |
raise NotImplementedError("nah son") | |
l = soft_dice(x, y_onehot, self.smooth, self.smooth_in_nom, self.square_nominator, self.square_denom) | |
# print('Using soft_dice!') | |
else: | |
l = soft_dice_per_batch_2(x, y_onehot, self.smooth, self.smooth_in_nom, | |
background_weight=self.background_weight, | |
rebalance_weights=self.rebalance_weights) | |
# print('Using soft_dice_per_batch_2!') | |
# Here we use the soft_dice_per_batch_2 | |
# print('dc shape is:',l.size()) | |
# dc_loss is: tensor(-0.0282, device='cuda:4', grad_fn=<MeanBackward0>) | |
return l | |
def soft_dice_per_batch_2(net_output, gt, smooth=1., smooth_in_nom=1., background_weight=1, rebalance_weights=None, | |
square_nominator=False, square_denom=False): | |
if rebalance_weights is not None and len(rebalance_weights) != gt.shape[1]: | |
rebalance_weights = rebalance_weights[1:] # this is the case when use_bg=False | |
# print('\nrebalance_weights is:',rebalance_weights) | |
#rebalance_weights is: None | |
axes = tuple([0] + list(range(2, len(net_output.size())))) | |
# print('\naxes is:',axes) | |
# axes is: (0, 2, 3, 4) | |
# print('\nnet_output shape is:',net_output.shape) | |
# print('\ngt shape is:',gt.shape) | |
# net_output shape is: torch.Size([8, 2, 192, 192, 48]) | |
# gt shape is: torch.Size([8, 2, 192, 192, 48]) | |
tp = sum_tensor(net_output * gt, axes, keepdim=False) | |
# print('\ntp is:',tp) | |
# tp shape is: torch.Size([2]) | |
# tp is: tensor([62684.4570, 82510.1562], device='cuda:4', grad_fn=<SumBackward2>) | |
fn = sum_tensor((1 - net_output) * gt, axes, keepdim=False) | |
# print('\nfn is:',fn) | |
# fn shape is: torch.Size([2]) | |
# fn is: tensor([195664.5312, 103144.8438], device='cuda:4', grad_fn=<SumBackward2>) | |
fp = sum_tensor(net_output * (1 - gt), axes, keepdim=False) | |
# print('\nfp is:',fp) | |
# fp shape is: torch.Size([2]) | |
# fp is: tensor([3610596., 6475380.], device='cuda:4', grad_fn=<SumBackward2>) | |
weights = torch.ones(tp.shape) | |
# print('\nweights shape is:',weights.shape) | |
# weights shape is: torch.Size([2]) | |
weights[0] = background_weight | |
# print('\nbackground_weight is:',background_weight) | |
# background_weight is: 1 | |
if net_output.device.type == "cuda": | |
weights = weights.cuda(net_output.device.index) | |
if rebalance_weights is not None: | |
rebalance_weights = torch.from_numpy(rebalance_weights).float() | |
if net_output.device.type == "cuda": | |
rebalance_weights = rebalance_weights.cuda(net_output.device.index) | |
tp = tp * rebalance_weights | |
fn = fn * rebalance_weights | |
nominator = tp | |
if square_nominator: | |
nominator = nominator ** 2 | |
if square_denom: | |
denom = 2 * tp ** 2 + fp ** 2 + fn ** 2 | |
else: | |
denom = 2 * tp + fp + fn | |
# result_1=(- ((2 * nominator + smooth_in_nom) / (denom + smooth)) * weights) | |
# print('\nresult_1 is:',result_1) | |
# result_1 is: tensor([-0.0616, -0.0038], device='cuda:4', grad_fn=<MulBackward0>) | |
dice_1 = (((2 * nominator + smooth_in_nom) / (denom + smooth)) * weights) | |
# print('\ndice_1 is:',dice_1) | |
result_1 = torch.pow((-torch.log(dice_1[0])), 0.3)*0.4+torch.pow((-torch.log(dice_1[1])), 0.3)*0.6 | |
# print('\nresult_1 is:',result_1) | |
# result = (- ((2 * nominator + smooth_in_nom) / (denom + smooth)) * weights).mean() | |
# print('\nresult is:',result) | |
# result is: tensor(-0.0327, device='cuda:4', grad_fn= < MeanBackward0 >) | |
# Here we should notice that the soft dice is set as negative. | |
return result_1 | |
def soft_dice(net_output, gt, smooth=1., smooth_in_nom=1., square_nominator=False, square_denom=False): | |
axes = tuple(range(2, len(net_output.size()))) | |
if square_nominator: | |
intersect = sum_tensor(net_output * gt, axes, keepdim=False) | |
else: | |
intersect = sum_tensor((net_output * gt) ** 2, axes, keepdim=False) | |
if square_denom: | |
denom = sum_tensor(net_output ** 2 + gt ** 2, axes, keepdim=False) | |
else: | |
denom = sum_tensor(net_output + gt, axes, keepdim=False) | |
result = (- ((2 * intersect + smooth_in_nom) / (denom + smooth))).mean() | |
return result | |
class DC_and_CE_loss(nn.Module): | |
def __init__(self, aggregate="sum", mssu=False): # aggregate表示ce+dc的和 | |
super(DC_and_CE_loss, self).__init__() | |
self.aggregate = aggregate | |
self.ce = CrossentropyND() | |
self.dc = SoftDiceLoss(apply_nonlin=softmax_helper) | |
self.mssu = mssu | |
def forward(self, net_output, target): | |
# print('target shape is:{0}, output{1}'.format(target.shape, net_output.shape)) | |
#target shape is: torch.Size([8, 1, 192, 192, 48]) | |
ce_weights = torch.tensor([0.28, 0.28, 0.44]).to(torch.cuda.current_device()) | |
ce_1 = CrossentropyND(weight=ce_weights) | |
# dc_loss = self.dc(net_output, target) | |
# # ce_loss = self.ce(net_output, target) | |
# ce1_loss = ce_1(net_output, target) | |
# target_layers=list() | |
dc_loss_layers = list() | |
ce_loss_layers = list() | |
# if isinstance(target, list): | |
if self.mssu: | |
# print('The target is list!') | |
# for i in range(len(target)): | |
for i in range(len(net_output)): | |
# print('net_output[%d] is cuda?'%(2*i),net_output[2*i].is_cuda) | |
# print('target is cuda?', target.is_cuda) | |
# print('target %d shape is:'%i,target.shape) | |
# print('net_output %d shape is:'%(2*i),net_output[2*i].shape) | |
# print('net_output[%d] shape is:'%i,net_output[i].shape) | |
# print('target[%d] shape is:' % i, target.shape) | |
dc_loss_layers.append(self.dc(net_output[i], target)) | |
ce_loss_layers.append(ce_1(net_output[i], target)) | |
dc_loss = dc_loss_layers[0]*0.4+dc_loss_layers[1]*0.2+dc_loss_layers[2]*0.1+dc_loss_layers[3]*0.1+dc_loss_layers[4]*0.1 | |
ce_loss = ce_loss_layers[0]*0.4+ce_loss_layers[1]*0.2+ce_loss_layers[2]*0.1+ce_loss_layers[3]*0.1+ce_loss_layers[4]*0.1 | |
# print('Final dc_loss is:',dc_loss) | |
# print('Final ce_loss is:',ce_loss) | |
if self.aggregate == "sum": | |
# print('ce_loss:{0},dc_loss:{1}'.format(ce_loss, dc_loss)) | |
result = ce_loss + dc_loss | |
else: | |
raise NotImplementedError("nah son") # reserved for other stuff (later) | |
return result | |
else: | |
# print('Target is not list!') | |
dc_loss = self.dc(net_output, target) | |
ce_loss = ce_1(net_output, target) | |
if self.aggregate == "sum": | |
# print('ce_loss:{0},dc_loss:{1}'.format(ce_loss, dc_loss)) | |
result = ce_loss + dc_loss | |
else: | |
raise NotImplementedError("nah son") # reserved for other stuff (later) | |
return result | |
class CrossentropyND(torch.nn.CrossEntropyLoss): | |
""" | |
Network has to have NO NONLINEARITY! | |
""" | |
def forward(self, inp, target): | |
target = target.long() | |
num_classes = inp.size()[1] | |
i0 = 1 | |
i1 = 2 | |
while i1 < len(inp.shape): # this is ugly but torch only allows to transpose two axes at once | |
inp = inp.transpose(i0, i1) | |
i0 += 1 | |
i1 += 1 | |
inp = inp.contiguous() | |
inp = inp.view(-1, num_classes) | |
target = target.view(-1,) | |
return super(CrossentropyND, self).forward(inp, target) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
from torchsummary import summary | |
# [conv3d+IN+Leaky Relu+conv3d+IN], | |
def Conv_IN_LeRU_2s(in_dim, out_dim, kernel_size, stride, padding, activation): | |
return nn.Sequential( | |
nn.Conv3d(in_dim, out_dim, kernel_size, stride, padding), | |
nn.InstanceNorm3d(out_dim), | |
activation, | |
nn.Conv3d(out_dim, out_dim, kernel_size, stride, padding), | |
nn.InstanceNorm3d(out_dim) | |
) | |
# 跨步卷积 | |
def stride_conv(in_dim, out_dim, kernel_size, stride, padding): | |
return nn.Sequential(nn.Conv3d(in_dim, out_dim, kernel_size, stride, padding)) | |
# 残差网络 | |
def ResNet(raw, processed): | |
temp = torch.add(raw, processed) | |
return temp | |
# 反卷积 | |
def conv_trans(in_dim, out_dim, kernel_size, stride, padding): | |
return nn.ConvTranspose3d(in_dim, out_dim, kernel_size, stride, padding) | |
def de_conv_in_relu_2s(in_dim, out_dim, kernel_size, stride, padding, activation): | |
return nn.Sequential( | |
nn.Conv3d(in_dim, out_dim, kernel_size, stride, padding), | |
nn.InstanceNorm3d(out_dim), | |
activation, | |
nn.Conv3d(out_dim, out_dim, kernel_size=(1, 1, 1), stride=1) | |
) | |
# 三线性插值 | |
def tri_inter(input, size, mode): | |
return nn.functional.interpolate(input=input, size=size, mode=mode) | |
class UNetStage1(nn.Module): | |
def __init__(self): | |
super(UNetStage1, self).__init__() | |
# 按照网络结构,进入网络后马上进行一次卷积 | |
self.init = nn.Conv3d(1, 30, 3, 1, 1) | |
# 第一层 | |
self.encoder1 = Conv_IN_LeRU_2s(30, 30, 3, 1, 1, nn.LeakyReLU()) | |
# 加入残差网络1 | |
self.encoder1_1 = nn.LeakyReLU() | |
# 第二层 | |
# padding的计算:https://pytorch.org/docs/master/generated/torch.nn.Conv3d.html | |
self.stride_conv1 = stride_conv(30, 60, (3, 3, 3), (1, 2, 2), 1) | |
self.encoder2 = Conv_IN_LeRU_2s(60, 60, 3, 1, 1, nn.LeakyReLU()) | |
# 加入残差网络2 | |
self.encoder2_1 = nn.LeakyReLU() | |
# 第三层 | |
self.stride_conv2 = stride_conv(60, 120, 3, 2, 1) | |
self.encoder3 = Conv_IN_LeRU_2s(120, 120, 3, 1, 1, nn.LeakyReLU()) | |
# 加入残差网络3 | |
self.encoder3_1 = nn.LeakyReLU() | |
# 第四层 | |
self.stride_conv3 = stride_conv(120, 240, 3, 2, 1) | |
self.encoder4 = Conv_IN_LeRU_2s(240, 240, 3, 1, 1, nn.LeakyReLU()) | |
# 加入残差网络4 | |
self.encoder4_1 = nn.LeakyReLU() | |
# 第五层 | |
self.stride_conv4 = stride_conv(240, 480, 3, 2, 1) | |
self.encoder5 = Conv_IN_LeRU_2s(480, 480, 3, 1, 1, nn.LeakyReLU()) | |
# 加入残差网络5 | |
self.encoder5_1 = nn.LeakyReLU() | |
# 第六层 | |
self.stride_conv5 = stride_conv(480, 960, 3, 2, 1) | |
self.encoder6 = Conv_IN_LeRU_2s(960, 960, 3, 1, 1, nn.LeakyReLU()) | |
# 加入残差网络6 | |
self.encoder6_1 = nn.LeakyReLU() | |
# 第六层的ResNet结果 | |
# decode部分 | |
# Out = (in - 1) * stride - 2 * padding + kernel_size, | |
# Link: https://pytorch.org/docs/master/generated/torch.nn.ConvTranspose3d.html | |
self.decoder1 = conv_trans(960, 480, kernel_size=2, stride=2, padding=0) | |
# 进行cat操作,skip connection | |
self.decoder1_1 = nn.Conv3d(960, 480, 1, 1, 0) # 将通道数减少 | |
self.decoder1_2 = de_conv_in_relu_2s(480, 480, 3, 1, 1, nn.LeakyReLU()) | |
# ResNet | |
self.decoder1_3 = nn.LeakyReLU() | |
self.res1 = nn.Conv3d(480, 3, 3, 1, 1) | |
self.decoder2 = conv_trans(480, 240, kernel_size=2, stride=2, padding=0) | |
# 进行cat操作,skip connection | |
self.decoder2_1 = nn.Conv3d(480, 240, 1, 1, 0) # 将通道数减少 | |
self.decoder2_2 = de_conv_in_relu_2s(240, 240, 3, 1, 1, nn.LeakyReLU()) | |
# ResNet | |
self.decoder2_3 = nn.LeakyReLU() | |
self.res2 = nn.Conv3d(240, 3, 3, 1, 1) | |
self.decoder3 = conv_trans(240, 120, kernel_size=2, stride=2, padding=0) | |
# 进行cat操作,skip connection | |
self.decoder3_1 = nn.Conv3d(240, 120, 1, 1, 0) # 将通道数减少 | |
self.decoder3_2 = de_conv_in_relu_2s(120, 120, 3, 1, 1, nn.LeakyReLU()) | |
# ResNet | |
self.decoder3_3 = nn.LeakyReLU() | |
self.res3 = nn.Conv3d(120, 3, 3, 1, 1) | |
self.decoder4 = conv_trans(120, 60, kernel_size=2, stride=2, padding=0) | |
self.decoder4_1 = nn.Conv3d(120, 60, 1, 1, 0) # 将通道数减少 | |
# 进行cat操作,skip connection | |
self.decoder4_2 = de_conv_in_relu_2s(60, 60, 3, 1, 1, nn.LeakyReLU()) | |
# ResNet | |
self.decoder4_3 = nn.LeakyReLU() | |
self.res4 = nn.Conv3d(60, 3, 3, 1, 1) | |
self.decoder5 = conv_trans(60, 30, kernel_size=(1, 2, 2), stride=(1, 2, 2), padding=0) | |
# 进行cat操作,skip connection | |
self.decoder5_1 = nn.Conv3d(60, 30, 1, 1, 0) # 将通道数减少 | |
self.decoder5_2 = de_conv_in_relu_2s(30, 30, 3, 1, 1, nn.LeakyReLU()) | |
# ResNet | |
self.decoder5_3 = nn.LeakyReLU() | |
self.end = nn.Conv3d(30, 3, 3, 1, 1) | |
# 下采样 | |
''' | |
方法名:内部处理 | |
encoderNum:conv->ins_norm->relu conv->ins_norm | |
encoderNum_1: relu | |
stride_convNum: 跨步卷积 kernal:3*3*3 第一层stride:1*2*2 其他层:2*2*2 padding:1 | |
''' | |
def get_features(self, x): | |
# start_time = printbar() | |
enc0 = self.init(x) # [N C 32 160 160] -> [N 30 32 160 160] 放入网络之前进行一次3d卷积,通道数变为30 | |
enc1 = self.encoder1(enc0) # [N 30 32 160 160] | |
res1 = ResNet(enc0, enc1) # 残差块 | |
sync1 = self.encoder1_1(res1) | |
# print('sync1', sync1.shape) | |
enc2 = self.stride_conv1(sync1) # [N 30 32 160 160] -> [N 60 32 80 80] | |
enc2_1 = self.encoder2(enc2) | |
res2 = ResNet(enc2, enc2_1) | |
sync2 = self.encoder2_1(res2) | |
# print('sync2', sync2.shape) | |
enc3 = self.stride_conv2(sync2) # [N 60 32 80 80] -> [N 120 16 40 40] | |
enc3_1 = self.encoder3(enc3) | |
res3 = ResNet(enc3, enc3_1) | |
sync3 = self.encoder3_1(res3) | |
# print('sync3', sync3.shape) | |
enc4 = self.stride_conv3(sync3) # [N 120 16 40 40] -> [N 240 8 20 20] | |
enc4_1 = self.encoder4(enc4) | |
res4 = ResNet(enc4, enc4_1) | |
sync4 = self.encoder4_1(res4) | |
# print('sync4', sync4.shape) | |
enc5 = self.stride_conv4(sync4) # [N 240 8 20 20] -> [N 480 4 10 10] | |
enc5_1 = self.encoder5(enc5) | |
res5 = ResNet(enc5, enc5_1) | |
sync5 = self.encoder5_1(res5) | |
# print('sync5', sync5.shape) | |
enc6 = self.stride_conv5(sync5) # [N 480 4 10 10] -> [N 960 2 5 5] | |
enc6_1 = self.encoder6(enc6) | |
res6 = ResNet(enc6, enc6_1) | |
sync6 = self.encoder6_1(res6) | |
# print('sync6', sync6.shape) | |
# end_time = printbar() | |
# print('Encode time:', end_time - start_time) | |
return sync6, sync5, sync4, sync3, sync2, sync1 | |
# 上采样 | |
''' | |
方法名:解释 | |
decoderNum: 反卷积 | |
skip_conNum: 跨层连接 | |
decoderNum_1: 将拼接后结果通道数减半 | |
decoderNum_2: conv->ins_norm->relu conv->ins_norm | |
decoderNum_3: relu | |
tri_inter: 三线性插值 | |
''' | |
def upSample(self, enc): | |
# start_time = printbar() | |
dec1 = self.decoder1(enc[0]) # [N 960 2 5 5] -> [N 480 4 10 10]最后一层的结果直接进行上采样 | |
# print("enc[0]:{0},dec1:{1}".format(enc[1].shape, dec1.shape)) | |
skip_con1 = torch.cat((enc[1], dec1), dim=1) # [N 480 4 10 10] -> [N 960 4 10 10] | |
dec1_1 = self.decoder1_1(skip_con1) # [N 960 4 10 10] -> [N 480 4 10 10] | |
dec1_2 = self.decoder1_2(dec1_1) # [N 480 4 10 10] | |
resnet1 = ResNet(dec1_1, dec1_2) | |
resnet1 = self.decoder1_3(resnet1) | |
# result1 = tri_inter(resnet1, (32, 160, 160), 'trilinear') | |
dec2 = self.decoder2(resnet1) # [N 480 4 10 10] -> [N 240 8 20 20] | |
# print("enc[1]:{0},dec2:{1}".format(enc[2].shape, dec2.shape)) | |
skip_con2 = torch.cat((enc[2], dec2), dim=1) # [N 240 8 20 20] -> [N 480 8 20 20] | |
dec2_1 = self.decoder2_1(skip_con2) | |
dec2_2 = self.decoder2_2(dec2_1) | |
resnet2 = ResNet(dec2_1, dec2_2) | |
resnet2 = self.decoder2_3(resnet2) | |
# result2 = tri_inter(resnet2, (32, 160, 160), 'trilinear') | |
dec3 = self.decoder3(resnet2) # [N 480 8 20 20] -> [N 240 16 40 40] | |
# print("enc3:{0},dec3:{1}".format(enc[3].shape, dec3.shape)) | |
skip_con3 = torch.cat((enc[3], dec3), dim=1) # [N 480 16 40 40] -> [N 240 16 40 40] | |
dec3_1 = self.decoder3_1(skip_con3) | |
dec3_2 = self.decoder3_2(dec3_1) | |
resnet3 = ResNet(dec3_2, dec3_2) | |
resnet3 = self.decoder3_3(resnet3) | |
# result3 = tri_inter(resnet3, (32, 160, 160), 'trilinear') | |
dec4 = self.decoder4(resnet3) # [N 240 16 40 40] -> [N 120 32 80 80] | |
# print("enc[4]:{0},dec4:{1}".format(enc[4].shape, dec4.shape)) | |
skip_con4 = torch.cat((enc[4], dec4), dim=1) # [N 120 32 80 80] -> [N 60 32 80 80] | |
dec4_1 = self.decoder4_1(skip_con4) | |
dec4_2 = self.decoder4_2(dec4_1) | |
resnet4 = ResNet(dec4_2, dec4_2) | |
resnet4 = self.decoder4_3(resnet4) | |
# result4 = tri_inter(resnet4, (32, 160, 160), 'trilinear') | |
dec5 = self.decoder5(resnet4) # [N 60 32 80 80] -> [N 30 32 160 160] | |
# print("enc[5]:{0},dec5:{1}".format(enc[5].shape, dec5.shape)) | |
skip_con5 = torch.cat((enc[5], dec5), dim=1) # [N 60 32 160 160] -> [N 30 32 160 160] | |
dec5_1 = self.decoder5_1(skip_con5) | |
dec5_2 = self.decoder5_2(dec5_1) | |
resnet5 = ResNet(dec5_2, dec5_2) | |
resnet5 = self.decoder5_3(resnet5) | |
result5 = self.end(resnet5) # 最后一层的输出 | |
# result4 = self.res4(result4) | |
# result3 = self.res3(result3) | |
# result2 = self.res2(result2) | |
# result1 = self.res1(result1) | |
return result5 # , result4, result3, result2, result1 # 多尺度 | |
def forward(self, x): | |
enc = self.get_features(x) | |
res = self.upSample(enc) | |
return res | |
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # PyTorch | |
# model = UNetStage1().to(device) | |
# summary(model, input_size=(1, 32, 160, 160), batch_size=1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import cv2 as cv | |
import torch | |
import numpy as np | |
import config | |
import json | |
from torch.utils.data import Dataset, DataLoader | |
from torchvision import transforms | |
dir = config.json_path | |
# 载入测试集[验证集、测试集]案例 | |
def load_data(path): | |
with open(path, 'r') as load_f: | |
load_dict = json.load(load_f) | |
data = [] | |
for i in range(len(load_dict)): | |
data.append(load_dict[i]) | |
return data | |
trainData = load_data(dir+'trainData.json') | |
testData = load_data(dir+'testData.json') | |
validationData = load_data(dir+'validationData.json') | |
# caseData = load_case_data(dir+'trainData.json', 0) | |
class DataSets(Dataset): | |
def __init__(self, casedata): | |
self.transform = transforms.Compose( | |
[transforms.Normalize(mean=(0.485,), std=(0.229,))] | |
) | |
# print('case_path', casedata) | |
self.GT = casedata # load_case_data(casedata) | |
self.slice = config.depth # 一个病人取32张切片 | |
def __len__(self): | |
return len(self.GT) | |
def __getitem__(self, idx): | |
# print('case:[%d] name:' % idx, self.GT[idx]) | |
case_data = self.load_case_data(self.GT[idx]) | |
imgs = self.getOriginImage(case_data) | |
gts = self.getGroundTruth(case_data) | |
gts = np.array(gts, dtype='int64') | |
gts = torch.from_numpy(gts) | |
imgs = np.array(imgs, dtype='float32') | |
imgs = torch.from_numpy(imgs) | |
imgs = self.transform(imgs) | |
return imgs, gts | |
def getOriginImage(self, casedata): | |
imgs = [] | |
for i in range(config.depth): | |
origin_image = casedata[i].replace('GT', 'Images') | |
pic = cv.imread(origin_image, 0) | |
pic = cv.resize(pic, (160, 160)) | |
imgs.append(pic) | |
return imgs | |
def getGroundTruth(self, casedata): | |
gts = [] | |
for i in range(config.depth): | |
gt = cv.imread(casedata[i], 0) | |
gt = cv.resize(gt, (160, 160)) | |
gt = gt / 127 | |
gts.append(gt) | |
return gts | |
# 载入一个样例中的全部切片 | |
def load_case_data(self, case): | |
# print('current case:', case) | |
caseData = [] | |
for i in range(config.depth): | |
start = int(case[-8:-4]) # 起始图片的序号 | |
slice_path = case[0:-8] + str("%04d" % (start + i)) + '.bmp' # 切片数是32,连续的32张 | |
caseData.append(slice_path) | |
return caseData |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import NetModel | |
from utils import * | |
import torch.optim as optim | |
from torch.autograd import Variable | |
from torch.optim import lr_scheduler | |
from torch.utils.data import DataLoader | |
from Loss import DC_and_CE_loss | |
from torch.cuda.amp import GradScaler, autocast | |
from PrepareData import DataSets | |
from PrepareData import trainData, validationData, testData | |
from torch.utils.tensorboard import SummaryWriter | |
writer = SummaryWriter(config.log_path) | |
import warnings | |
warnings.filterwarnings("ignore") | |
lr_scheduler_eps = 1e-3 | |
lr_scheduler_patience = 20 | |
initial_lr = 3e-4 | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model = NetModel.UNetStage1() | |
model = model.to(device) | |
criterion_dc_ce = DC_and_CE_loss().to(device) | |
optimizer = optim.Adam(model.parameters(), initial_lr) | |
def train(Epoches,mpth): | |
start_epoch = 0 | |
pthList = sorted(os.listdir(mpth)) | |
print(pthList) | |
if not pthList: | |
print('starting train:') | |
else: | |
print('Continue training:') | |
pth = pthList[-1] | |
checkPoint = torch.load(mpth + pth) | |
model.load_state_dict(checkPoint['model']) | |
optimizer.load_state_dict(checkPoint['optimizer']) | |
start_epoch = checkPoint['epoch'] + 1 | |
numEpoches = Epoches | |
flag = 1 | |
scaler = GradScaler() | |
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=lr_scheduler_patience,verbose=True, threshold=lr_scheduler_eps, threshold_mode="abs") # 降低学习率 https://pytorch.org/docs/stable/optim.html#torch.optim.lr_scheduler.ReduceLROnPlateau | |
for epoch in range(start_epoch, numEpoches): | |
t = time.time() | |
print('------------------------') | |
print('this is {} epoch.'.format(epoch)) | |
t = time.localtime(t) | |
t = time.strftime("%Y-%m-%d %H:%M:%S", t) | |
print(t) | |
model.train() | |
oneCaseLoss = 0 | |
data = DataSets(trainData) | |
dataLoader = DataLoader(data, batch_size=1, shuffle=False, num_workers=0) | |
start_time = time.time() | |
for i, (x, yy) in enumerate(dataLoader): | |
optimizer.zero_grad() | |
x = Variable(x).to(device) | |
y = Variable(yy).to(device) | |
x = x.unsqueeze(1) | |
# with autocast(enabled=True): | |
output = model(x) | |
loss = criterion_dc_ce(output, y) # loss为SoftDice+CrossEntropy | |
iterLoss = loss.item() | |
print('loss[%d]:' % i, iterLoss) | |
oneCaseLoss += iterLoss | |
# loss.backward() | |
# optimizer.step() | |
scaler.scale(loss).backward() | |
scaler.step(optimizer) | |
scaler.update() | |
end_time = time.time() | |
print('run time:{0},one case loss:{1}'.format(end_time-start_time, oneCaseLoss/len(trainData))) | |
lr = optimizer.param_groups[0]['lr'] | |
writer.add_scalar('lr', lr, epoch) | |
writer.add_scalar('loss/tr_loss', oneCaseLoss/len(trainData), epoch) | |
if epoch % 3 == 0: | |
###保存模型### | |
savepPth = mpth + 'v10_NotAll_' + str('%.2d' % epoch) + '.pth' | |
state = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch} | |
torch.save(state, savepPth) | |
if epoch % 3 == 0: | |
## 验证集 给出验证效果 | |
model.eval() | |
oneCaseValLoss = 0 | |
kidneyDice = 0 | |
tumorDice = 0 | |
data = DataSets(validationData) | |
dataLoader = DataLoader(data, batch_size=1, shuffle=False, num_workers=4) | |
for x, yy in dataLoader: | |
with torch.no_grad(): | |
x = Variable(x).to(device) | |
y = Variable(yy).to(device) | |
x = x.unsqueeze(1) | |
output = model(x) | |
loss = criterion_dc_ce(output, y) | |
iterLoss = loss.item() | |
oneCaseValLoss += iterLoss | |
# dice指数 | |
# output = torch.softmax(output[0], dim=1) | |
output = torch.softmax(output, dim=1) | |
dice_kidney, dice_tumor = Dice(y, output) | |
kidneyDice += dice_kidney | |
tumorDice += dice_tumor | |
lr_scheduler.step(oneCaseValLoss) # 用于更新学习率 | |
print('********************************************************************') | |
print('**** lr: {:.8f} ****'.format(lr)) | |
print('**** val loss: {:.8f} ****'.format(oneCaseValLoss/len(validationData))) | |
print('**** kidneyDice: {:.8f} ****'.format(kidneyDice/len(validationData))) | |
print('**** tumorDice: {:.8f} ****'.format(tumorDice/len(validationData))) | |
writer.add_scalar('loss/val_loss', oneCaseValLoss/len(validationData), epoch) | |
writer.add_scalar('Dice/kidney_dice', kidneyDice/len(validationData), epoch) | |
writer.add_scalar('Dice/tumor_dice', tumorDice/len(validationData), epoch) | |
if __name__ == "__main__": | |
modelPath = config.model_path | |
t = time.time() | |
t = time.localtime(t) | |
t = time.strftime("%Y--%m--%d %H:%M:%S", t) | |
print(t) | |
train(800, modelPath) | |
writer.flush() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import os | |
import json | |
import torch | |
import datetime | |
import config | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
def make_one_hot_3d(x, n): # 对输入的volume数据x,对每个像素值进行one-hot编码 | |
x = x.unsqueeze(1) | |
result = torch.zeros(x.shape[0], n, x.shape[2], x.shape[3], x.shape[4]) | |
result = result.to(x.device).scatter_(1, x, 1) | |
return result | |
def expand_as_one_hot(input, C, ignore_index=None): | |
""" | |
Converts NxDxHxW label image to NxCxDxHxW, where each label gets converted to its corresponding one-hot vector | |
:param input: 4D input image (NxDxHxW) | |
:param C: number of channels/labels | |
:param ignore_index: ignore index to be kept during the expansion | |
:return: 5D output image (NxCxDxHxW) | |
""" | |
assert input.dim() == 4 | |
# expand the input tensor to Nx1xDxHxW before scattering | |
input = input.unsqueeze(1) | |
# create result tensor shape (NxCxDxHxW) | |
shape = list(input.size()) | |
shape[1] = C | |
if ignore_index is not None: | |
# create ignore_index mask for the result | |
mask = input.expand(shape) == ignore_index | |
# clone the src tensor and zero out ignore_index in the input | |
input = input.clone() | |
input[input == ignore_index] = 0 | |
# scatter to get the one-hot tensor | |
result = torch.zeros(shape).to(input.device).scatter_(1, input, 1) | |
# bring back the ignore_index in the result | |
result[mask] = ignore_index | |
return result | |
else: | |
# scatter to get the one-hot tensor | |
result = torch.zeros(shape).scatter_(1, input, 1) | |
return result | |
def dice_coeff(pred, target): | |
pred = torch.from_numpy(pred) | |
target = torch.from_numpy(target) | |
smooth = 1. | |
num = pred.size(0) | |
m1 = pred.view(num, -1) # Flatten | |
m2 = target.view(num, -1) # Flatten | |
intersection = (m1 * m2).sum() | |
return (2. * intersection + smooth) / (m1.sum() + m2.sum() + smooth) | |
# 求Dice | |
def Dice(y, output): | |
y_pred = output.detach().cpu().numpy().copy() # (1, 3, 32, 160, 160) | |
y_pred = y_pred.squeeze(0) | |
pred_bg = y_pred[0] # (32, 160, 160) | |
pred_kidney = y_pred[1] | |
pred_tumor = y_pred[2] | |
y = make_one_hot_3d(y, 3) # one-hot处理 | |
y_t = y.detach().cpu().numpy().copy() # (1, 3, 32, 160, 160) | |
y_t = y_t.squeeze(0) | |
y_bg = y_t[0] | |
y_kidney = y_t[1] | |
y_tumor = y_t[2] | |
dice_bg = dice_coeff(pred_bg, y_bg) | |
dice_kidney = dice_coeff(pred_kidney, y_kidney) | |
dice_tumor = dice_coeff(pred_tumor, y_tumor) | |
return dice_kidney, dice_tumor |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment