Created
December 12, 2018 08:12
-
-
Save serser/3aca33e44049e1c1d7014e980ce35196 to your computer and use it in GitHub Desktop.
Use gpu for tf computation when available
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import unittest | |
import tensorflow as tf | |
import sys | |
import os | |
class TestGPU(unittest.TestCase): | |
def setUp(self): | |
os.environ["CUDA_VISIBLE_DEVICES"]='0' | |
def test_gpu(self): | |
if not tf.test.is_built_with_cuda(): | |
print('cuda not supported') | |
if not tf.test.is_gpu_available(): | |
print('gpu not available') | |
else: | |
# Creates a graph. | |
with tf.device('/device:GPU:0'): | |
a = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[2, 3], name='a') | |
b = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[3, 2], name='b') | |
c = tf.matmul(a, b) | |
# Creates a session with log_device_placement set to True. | |
sess = tf.Session(config=tf.ConfigProto(log_device_placement=True)) | |
# Runs the op. | |
print(sess.run(c)) | |
if __name__ == 'main': | |
unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment