diff options
Diffstat (limited to 'python/nsatf.py')
| -rw-r--r-- | python/nsatf.py | 14 |
1 files changed, 10 insertions, 4 deletions
diff --git a/python/nsatf.py b/python/nsatf.py index 796f448..6d54765 100644 --- a/python/nsatf.py +++ b/python/nsatf.py @@ -1,8 +1,6 @@ - # coding: utf-8 -# In[4]: - +from tensorflow.python.client import device_lib import tensorflow as tf import librosa import os @@ -29,6 +27,14 @@ if len(sys.argv) == 5: else: ALPHA = 1e-3 +device_ids = [device.name for device in device_lib.list_local_devices()] + +if '/gpu:0' in device_ids: + DEVICE = '/gpu:0' +else: + DEVICE = '/cpu:0' + +print DEVICE # In[6]: @@ -110,7 +116,7 @@ std = np.sqrt(2) * np.sqrt(2.0 / ((N_CHANNELS + N_FILTERS) * 10)) kernel = np.random.randn(1, 10, N_CHANNELS, N_FILTERS)*std g = tf.Graph() -with g.as_default(), g.device('/cpu:0'), tf.Session() as sess: +with g.as_default(), g.device(DEVICE), tf.Session() as sess: # data shape is "[batch, in_height, in_width, in_channels]", x = tf.placeholder('float32', [1,1,N_SAMPLES,N_CHANNELS], name="x") |
