Last active
October 27, 2020 04:10
-
-
Save bhpfelix/8001f2e2c4770655e23ad0c1900f1f15 to your computer and use it in GitHub Desktop.
Code snippet for porting TensorFlow trained model to PyTorch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from PIL import Image | |
np.random.seed(2) | |
import torchvision | |
import torch | |
# torch.manual_seed(0) | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torchvision import transforms | |
import tensorflow as tf | |
from common_layers import Stage | |
import matplotlib.pyplot as plt | |
slim = tf.contrib.slim | |
def read_images_from_disk(filename): | |
img_mean = np.array((104.00698793,116.66876762,122.67891434), dtype=np.float32) | |
img_contents = tf.read_file(filename) | |
img = tf.image.decode_image(img_contents, channels=3) | |
img.set_shape((None, None, 3)) | |
# bgr | |
img_r, img_g, img_b = tf.split(axis=2, num_or_size_splits=3, value=img) | |
img = tf.cast(tf.concat(axis=2, values=[img_b, img_g, img_r]), dtype=tf.float32) | |
# Subtract mean. | |
img -= img_mean | |
img = tf.expand_dims(img, 0) | |
return img | |
def pt_read_im(filename): | |
image = Image.open(filename) | |
image = np.array(image) | |
# image = image[:, :, ::-1] | |
print(image.shape) | |
to_tensor = transforms.ToTensor() | |
# normalize = transforms.Normalize((104.00698793, 116.66876762, 122.67891434), (1., 1., 1.)) | |
normalize = transforms.Normalize((122.67891434, 116.66876762, 104.00698793), (1., 1., 1.)) | |
return normalize(to_tensor(image - 255.).float() + 255.).unsqueeze(0).float() | |
def vgg_16_deeplab_st(inputs, | |
num_classes=21, | |
is_training=True, | |
dropout_keep_prob=0.5, | |
scope='vgg_16'): | |
"""VGG-16 Deeplab lfov model for single task. | |
Args: | |
inputs: a tensor of size [batch_size, height, width, channels]. | |
num_classes: number of predicted classes. | |
is_training: whether or not the model is being trained. | |
dropout_keep_prob: the probability that activations are kept in the dropout | |
layers during training. | |
scope: Optional scope for the variables. | |
Returns: | |
the last op containing the log predictions and end_points dict. | |
""" | |
with tf.variable_scope(scope, 'vgg_16', [inputs]) as sc: | |
end_points_collection = sc.name + '_end_points' | |
# Collect outputs for conv2d, fully_connected and max_pool2d. | |
with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d], | |
outputs_collections=end_points_collection): | |
net = slim.repeat(inputs, 2, slim.conv2d, 64, [3, 3], scope='conv1') | |
net = slim.max_pool2d(net, [3, 3], stride=2, padding='SAME', scope='pool1') | |
net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], scope='conv2') | |
net = slim.max_pool2d(net, [3, 3], stride=2, padding='SAME', scope='pool2') | |
net = slim.repeat(net, 3, slim.conv2d, 256, [3, 3], scope='conv3') | |
net = slim.max_pool2d(net, [3, 3], stride=2, padding='SAME', scope='pool3') | |
net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv4') | |
net = slim.max_pool2d(net, [3, 3], stride=1, padding='SAME', scope='pool4') | |
# net = slim.repeat(net, 3, conv2d_same, 512, [3, 3], stride=1, rate=2, scope='conv5') | |
net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], rate=2, scope='conv5') | |
net = slim.max_pool2d(net, [3, 3], stride=1, padding='SAME', scope='pool5_max') | |
net = slim.avg_pool2d(net, [3, 3], stride=1, padding='SAME', scope='pool5_avg') | |
# Use conv2d instead of fully_connected layers. | |
rate = 12 | |
net = slim.conv2d(net, 1024, [3, 3], rate=rate, padding='SAME', scope='fc6') | |
net = slim.dropout(net, dropout_keep_prob, is_training=is_training, | |
scope='dropout6') | |
net = slim.conv2d(net, 1024, [1, 1], padding='SAME', scope='fc7') | |
net = slim.dropout(net, dropout_keep_prob, is_training=is_training, | |
scope='dropout7') | |
net = slim.conv2d(net, num_classes, [1, 1], | |
activation_fn=None, | |
normalizer_fn=None, | |
scope='fc8_voc12') | |
# Convert end_points_collection into a end_point dict. | |
end_points = slim.utils.convert_collection_to_dict(end_points_collection) | |
return net, end_points | |
class DeepLabLargeFOV(nn.Module): | |
def __init__(self, in_dim, out_dim, weights='ImageNet', *args, **kwargs): | |
super(DeepLabLargeFOV, self).__init__(*args, **kwargs) | |
self.stages = [] | |
layers = [] | |
stage = [ | |
nn.Conv2d(in_dim, 64, kernel_size=3, stride=1, padding=1), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), | |
nn.ReLU(inplace=True), | |
nn.ConstantPad2d((0, 1, 0, 1), 0), # TensorFlow 'SAME' behavior | |
nn.MaxPool2d(3, stride=2) | |
] | |
layers += stage | |
self.stages.append(Stage(64, stage)) | |
stage = [ | |
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1), | |
nn.ReLU(inplace=True), | |
nn.ConstantPad2d((0, 1, 0, 1), 0), # TensorFlow 'SAME' behavior | |
nn.MaxPool2d(3, stride=2) | |
] | |
layers += stage | |
self.stages.append(Stage(128, stage)) | |
stage = [ | |
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1), | |
nn.ReLU(inplace=True), | |
nn.ConstantPad2d((0, 1, 0, 1), 0), # TensorFlow 'SAME' behavior | |
nn.MaxPool2d(3, stride=2) | |
] | |
layers += stage | |
self.stages.append(Stage(256, stage)) | |
stage = [ | |
nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1), | |
nn.ReLU(inplace=True), | |
nn.MaxPool2d(3, stride=1, padding=1) | |
] | |
layers += stage | |
self.stages.append(Stage(512, stage)) | |
stage = [ | |
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=2, dilation=2), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=2, dilation=2), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=2, dilation=2), | |
nn.ReLU(inplace=True), | |
nn.MaxPool2d(3, stride=1, padding=1) | |
] | |
layers += stage | |
self.stages.append(Stage(512, stage)) | |
self.stages = nn.ModuleList(self.stages) | |
self.features = nn.Sequential(*layers) | |
head = [ | |
# must use count_include_pad=False to make sure result is same as TF | |
nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), | |
nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=12, dilation=12), | |
nn.ReLU(inplace=True), | |
nn.Dropout(p=0.5), | |
nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0), | |
nn.ReLU(inplace=True), | |
nn.Dropout(p=0.5), | |
nn.Conv2d(1024, out_dim, kernel_size=1) | |
] | |
self.head = nn.Sequential(*head) | |
self.weights = weights | |
self.init_weights() | |
def forward(self, x): | |
N, C, H, W = x.size() | |
for stage in self.stages: | |
x = stage(x) | |
x = self.head(x) | |
# x = F.interpolate(x, (H, W), mode='bilinear', align_corners=True) | |
return x | |
def _forward(self, x): | |
x = self.stages[0](x) | |
x = self.stages[1](x) | |
x = self.stages[2](x) | |
return x | |
def init_weights(self): | |
for layer in self.head.children(): | |
if isinstance(layer, nn.Conv2d): | |
nn.init.kaiming_normal_(layer.weight, a=1) | |
nn.init.constant_(layer.bias, 0) | |
if self.weights == 'ImageNet': | |
vgg = torchvision.models.vgg16(pretrained=True) | |
state_vgg = vgg.features.state_dict() | |
self.features.load_state_dict(state_vgg) | |
elif self.weights == 'DeepLab': | |
pretrained_dict = torch.load('weights/vgg_deeplab_lfov/model_final.pkl') | |
model_dict = self.state_dict() | |
# 1. filter out unnecessary keys | |
pretrained_dict = {k.replace('classifier', 'head'): v for k, v in pretrained_dict.items()} | |
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and 'head.7' not in k} | |
# 2. overwrite entries in the existing state dict | |
model_dict.update(pretrained_dict) | |
# 3. load the new state dict | |
self.load_state_dict(model_dict) | |
elif self.weights == 'TFDeepLab': | |
# TODO: Check with a sample input, TF feedforward vs PyTorch feedforward | |
checkpoint_path = tf.train.latest_checkpoint('weights/nyu_v2_tf/slim_finetune_seg') | |
# checkpoint_path = 'weights/vgg_deeplab_lfov_tf/model.ckpt-init-slim' | |
tf_input = read_images_from_disk('0002.png') | |
# tf_input = tf.convert_to_tensor(input_image, dtype=tf.float32) | |
net, end_points = vgg_16_deeplab_st(tf_input, num_classes=40, is_training=False, dropout_keep_prob=1.0) | |
# Which variables to load. | |
restore_var = tf.global_variables() | |
sess = tf.Session() | |
init = tf.global_variables_initializer() | |
sess.run(init) | |
sess.run(tf.local_variables_initializer()) | |
tf.train.Saver(var_list=restore_var).restore(sess, checkpoint_path) | |
print("Restored model parameters from {}".format(checkpoint_path)) | |
tf_vars = tf.trainable_variables() | |
pt_vars = list(self.named_parameters()) | |
for tf_var, (pt_var_k, pt_var_v) in zip(tf_vars, pt_vars): | |
if 'weight' in tf_var.name: | |
weight = tf_var.eval(session=sess) | |
weight = weight.transpose((3, 2, 0, 1)) | |
if 'conv1_1' in tf_var.name: | |
# Flip weight of first conv layer because TF model is trained on BGR data, but we want RGB in PyTorch | |
print("flipping weights") | |
weight = np.flip(weight, axis=1).copy() | |
print(weight.shape) | |
weight = torch.from_numpy(weight).float() | |
pt_var_v.data = weight | |
else: | |
assert 'bias' in tf_var.name | |
bias = tf_var.eval(session=sess) | |
bias = torch.from_numpy(bias).float() | |
pt_var_v.data = bias | |
# print(tf_var.name, pt_var_k, pt_var_v.data.size()) | |
# Check if weights are correct | |
self.eval() | |
# pt_input = tf_input.eval(session=sess) | |
# pt_input = torch.tensor(pt_input.transpose((0, 3, 1, 2))).float() | |
pt_input = pt_read_im('0002.png') | |
# print(torch.abs(pt_input2 - pt_input).mean()) | |
tf_result = end_points['vgg_16/fc8_voc12'].eval(session=sess) | |
# tf_result = end_points['vgg_16/pool3'].eval(session=sess) | |
pt_result = self.forward(pt_input).detach().numpy().transpose((0, 2, 3, 1)) | |
# pt_result = self._forward(pt_input).detach().numpy().transpose((0, 2, 3, 1)) | |
# # Compare final results | |
# print(tf_result.squeeze().shape) | |
# tf_pred = tf_result.squeeze().argmax(axis=2) | |
# pt_pred = pt_result.squeeze().argmax(axis=2) | |
# diff = np.dstack([tf_pred, pt_pred, np.zeros_like(tf_pred)]).astype('float') | |
# diff /= diff.max() | |
# mask = (tf_pred - pt_pred) != 0 | |
# tf_pred[mask] = -10 | |
# pt_pred[mask] = -10 | |
# plt.matshow(tf_pred) | |
# plt.colorbar() | |
# plt.title('tf_pred') | |
# plt.matshow(pt_pred) | |
# plt.colorbar() | |
# plt.title('pt_pred') | |
# plt.matshow(mask.astype('int')) | |
# plt.colorbar() | |
# # plt.imshow(diff) | |
# plt.title('diff') | |
print(tf_result.shape) | |
print(tf_result.max(), pt_result.max()) | |
print(tf_result.mean(), pt_result.mean()) | |
diff = np.abs(tf_result - pt_result).squeeze() | |
# plt.matshow(tf_result.squeeze().mean(axis=2)) | |
# plt.colorbar() | |
# plt.title('tf') | |
# plt.matshow(pt_result.squeeze().mean(axis=2)) | |
# plt.colorbar() | |
# plt.title('pt') | |
# plt.matshow(diff.mean(axis=2)) | |
# plt.colorbar() | |
# plt.title('diff') | |
# plt.show() | |
print(diff.max()) | |
elif self.weights == '': | |
pass | |
else: | |
raise NotImplementedError | |
if __name__ == "__main__": | |
net = DeepLabLargeFOV(3, 40, weights='TFDeepLab') | |
# print(net) | |
in_ten = torch.randn(1, 3, 321, 321) | |
torch.save(net.state_dict(), 'weights/nyu_v2/tf_finetune_seg.pth') | |
# torch.save(net.state_dict(), 'weights/vgg_deeplab_lfov/tf_deeplab.pth') | |
# print(out.size()) | |
# print(net.stages[1]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Messy example for converting the TensorFlow trained models from NDDR-CNN to corresponding PyTorch models.