Created
November 6, 2017 11:06
-
-
Save ayushidalmia/6127bdd649f2527a9e533aa7b523035b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def load_initial_weights(self, session): | |
""" | |
As the weights from http://www.cs.toronto.edu/~guerzhoy/tf_alexnet/ come | |
as a dict of lists (e.g. weights['conv1'] is a list) and not as dict of | |
dicts (e.g. weights['conv1'] is a dict with keys 'weights' & 'biases') we | |
need a special load function | |
""" | |
# Load the weights into memory | |
weights_dict = np.load(self.WEIGHTS_PATH, encoding = 'bytes').item() | |
# Loop over all layer names stored in the weights dict | |
for op_name in weights_dict: | |
# Check if the layer is one of the layers that should be reinitialized | |
train_bool = True | |
if op_name not in self.SKIP_LAYER: | |
train_bool = False | |
with tf.variable_scope(op_name, reuse = True): | |
# Loop over list of weights/biases and assign them to their corresponding tf variable | |
for data in weights_dict[op_name]: | |
# Biases | |
if len(data.shape) == 1: | |
var = tf.get_variable('biases', trainable = train_bool) | |
session.run(var.assign(data)) | |
# Weights | |
else: | |
var = tf.get_variable('weights', trainable = train_bool) | |
session.run(var.assign(data)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I think there is an indentation error starting from
this whole block needs to be indented.
For the rest, I think it looks good to me.