Created
February 22, 2017 10:30
-
-
Save ianlini/61214534562388383ab64ea8b11da723 to your computer and use it in GitHub Desktop.
Using scipy sparse matrix as the input of tensorflow
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import scipy.sparse as ss | |
import numpy as np | |
def scipy_run(sess, result, X, tf_indices, tf_id_values, tf_weight_values, | |
tf_dense_shape): | |
row_nnz = np.diff(X.indptr) | |
indices = np.asarray([[row_i, col_i] | |
for row_i, nnz in enumerate(row_nnz) | |
for col_i in range(nnz)], dtype=np.int64) | |
ids = X.indices.astype(np.int64) | |
weights = X.data | |
tf_result = sess.run(result, {tf_indices: indices, | |
tf_id_values: ids, | |
tf_weight_values: weights, | |
tf_dense_shape: X.shape}) | |
return tf_result | |
def main(): | |
X = ss.csr_matrix([[1, 0, 0], | |
[1, 0, 1], | |
[0, 0, 1], | |
[0, 2, 0]], dtype=np.float32) | |
W = np.asarray([[.1, .2, .3], | |
[.2, .3, .1], | |
[.3, .2, .2]], dtype=np.float32) | |
b = np.asarray([[.1, .2, .3]], dtype=np.float32) | |
# scipy version | |
direct_result = X @ W + b | |
print(direct_result) | |
# tensorflow version | |
tf_indices = tf.placeholder(tf.int64, [None, 2]) | |
tf_id_values = tf.placeholder(tf.int64, [None]) | |
tf_weight_values = tf.placeholder(tf.float32, [None]) | |
tf_dense_shape = tf.placeholder(tf.int64, [2]) | |
sp_ids = tf.SparseTensor(tf_indices, tf_id_values, tf_dense_shape) | |
sp_weights = tf.SparseTensor(tf_indices, tf_weight_values, tf_dense_shape) | |
tf_W = tf.Variable(W, tf.float32) | |
tf_b = tf.Variable(b, tf.float32) | |
result = tf.nn.embedding_lookup_sparse(tf_W, sp_ids, sp_weights, | |
combiner='sum') + tf_b | |
init = tf.global_variables_initializer() | |
sess = tf.Session() | |
sess.run(init) | |
tf_result = scipy_run(sess, result, X, tf_indices, tf_id_values, | |
tf_weight_values, tf_dense_shape) | |
print(tf_result) | |
if (direct_result == tf_result).all(): | |
print("They are the same!") | |
else: | |
print("They are different!") | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment