summaryrefslogtreecommitdiff
path: root/python/nsatf.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/nsatf.py')
-rw-r--r--python/nsatf.py14
1 files changed, 10 insertions, 4 deletions
diff --git a/python/nsatf.py b/python/nsatf.py
index 796f448..6d54765 100644
--- a/python/nsatf.py
+++ b/python/nsatf.py
@@ -1,8 +1,6 @@
-
# coding: utf-8
-# In[4]:
-
+from tensorflow.python.client import device_lib
import tensorflow as tf
import librosa
import os
@@ -29,6 +27,14 @@ if len(sys.argv) == 5:
else:
ALPHA = 1e-3
+device_ids = [device.name for device in device_lib.list_local_devices()]
+
+if '/gpu:0' in device_ids:
+ DEVICE = '/gpu:0'
+else:
+ DEVICE = '/cpu:0'
+
+print DEVICE
# In[6]:
@@ -110,7 +116,7 @@ std = np.sqrt(2) * np.sqrt(2.0 / ((N_CHANNELS + N_FILTERS) * 10))
kernel = np.random.randn(1, 10, N_CHANNELS, N_FILTERS)*std
g = tf.Graph()
-with g.as_default(), g.device('/cpu:0'), tf.Session() as sess:
+with g.as_default(), g.device(DEVICE), tf.Session() as sess:
# data shape is "[batch, in_height, in_width, in_channels]",
x = tf.placeholder('float32', [1,1,N_SAMPLES,N_CHANNELS], name="x")