From d41070c7b00fafc974a1a6e7b6d1b42391fa57ed Mon Sep 17 00:00:00 2001 From: Jules Laplace Date: Fri, 21 Jul 2017 04:48:52 +0200 Subject: all async paths working --- python/nsatf.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) (limited to 'python/nsatf.py') diff --git a/python/nsatf.py b/python/nsatf.py index 796f448..6d54765 100644 --- a/python/nsatf.py +++ b/python/nsatf.py @@ -1,8 +1,6 @@ - # coding: utf-8 -# In[4]: - +from tensorflow.python.client import device_lib import tensorflow as tf import librosa import os @@ -29,6 +27,14 @@ if len(sys.argv) == 5: else: ALPHA = 1e-3 +device_ids = [device.name for device in device_lib.list_local_devices()] + +if '/gpu:0' in device_ids: + DEVICE = '/gpu:0' +else: + DEVICE = '/cpu:0' + +print DEVICE # In[6]: @@ -110,7 +116,7 @@ std = np.sqrt(2) * np.sqrt(2.0 / ((N_CHANNELS + N_FILTERS) * 10)) kernel = np.random.randn(1, 10, N_CHANNELS, N_FILTERS)*std g = tf.Graph() -with g.as_default(), g.device('/cpu:0'), tf.Session() as sess: +with g.as_default(), g.device(DEVICE), tf.Session() as sess: # data shape is "[batch, in_height, in_width, in_channels]", x = tf.placeholder('float32', [1,1,N_SAMPLES,N_CHANNELS], name="x") -- cgit v1.2.3-70-g09d2