summaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
Diffstat (limited to 'python')
-rw-r--r--python/nsatf.py14
-rw-r--r--python/sleep.py4
2 files changed, 14 insertions, 4 deletions
diff --git a/python/nsatf.py b/python/nsatf.py
index 796f448..6d54765 100644
--- a/python/nsatf.py
+++ b/python/nsatf.py
@@ -1,8 +1,6 @@
-
# coding: utf-8
-# In[4]:
-
+from tensorflow.python.client import device_lib
import tensorflow as tf
import librosa
import os
@@ -29,6 +27,14 @@ if len(sys.argv) == 5:
else:
ALPHA = 1e-3
+device_ids = [device.name for device in device_lib.list_local_devices()]
+
+if '/gpu:0' in device_ids:
+ DEVICE = '/gpu:0'
+else:
+ DEVICE = '/cpu:0'
+
+print DEVICE
# In[6]:
@@ -110,7 +116,7 @@ std = np.sqrt(2) * np.sqrt(2.0 / ((N_CHANNELS + N_FILTERS) * 10))
kernel = np.random.randn(1, 10, N_CHANNELS, N_FILTERS)*std
g = tf.Graph()
-with g.as_default(), g.device('/cpu:0'), tf.Session() as sess:
+with g.as_default(), g.device(DEVICE), tf.Session() as sess:
# data shape is "[batch, in_height, in_width, in_channels]",
x = tf.placeholder('float32', [1,1,N_SAMPLES,N_CHANNELS], name="x")
diff --git a/python/sleep.py b/python/sleep.py
new file mode 100644
index 0000000..9c981a5
--- /dev/null
+++ b/python/sleep.py
@@ -0,0 +1,4 @@
+import time
+
+time.sleep(2)
+print "slept 2 seconds"