summaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
Diffstat (limited to 'python')
-rw-r--r--python/nsatf.py235
1 files changed, 235 insertions, 0 deletions
diff --git a/python/nsatf.py b/python/nsatf.py
new file mode 100644
index 0000000..a35cf51
--- /dev/null
+++ b/python/nsatf.py
@@ -0,0 +1,235 @@
+
+# coding: utf-8
+
+# In[4]:
+
+import tensorflow as tf
+import librosa
+import os
+# from IPython.display import Audio, display
+import numpy as np
+# import matplotlib.pyplot as plt
+import sys
+# get_ipython().magic(u'matplotlib inline')
+
+
+# ### Load style and content
+
+# In[5]:
+
+if len(sys.argv) < 4:
+ print "python nsatf.py content.wav style.wav output.wav alpha"
+ sys.exit()
+
+CONTENT_FILENAME = sys.argv[1]
+STYLE_FILENAME = sys.argv[2]
+OUTPUT_FILENAME = sys.argv[3]
+if len(sys.argv) == 5:
+ ALPHA = float(sys.argv[4] or "1e-3")
+else:
+ ALPHA = 1e-3
+
+
+# In[6]:
+
+# display(Audio(CONTENT_FILENAME))
+# display(Audio(STYLE_FILENAME))
+
+
+# In[7]:
+
+# Reads wav file and produces spectrum
+# Fourier phases are ignored
+N_FFT = 2048
+def read_audio_spectum(filename):
+ print 'load ' + filename
+ x, fs = librosa.load(filename, 44100)
+ S = librosa.stft(x, N_FFT)
+ p = np.angle(S)
+
+ S = np.log1p(np.abs(S[:,:1020]))
+ return S, fs
+
+
+# In[8]:
+
+a_content, fs = read_audio_spectum(CONTENT_FILENAME)
+a_style, fs = read_audio_spectum(STYLE_FILENAME)
+
+hs = a_content.shape[1]
+ms = a_style.shape[1]
+
+if hs > ms:
+ a_style = np.lib.pad(a_style, ((0,0), (0, hs - ms)), 'constant', constant_values=(0, 0))
+else:
+ a_content = np.lib.pad(a_content, ((0,0), (0, ms - hs)), 'constant', constant_values=(0, 0))
+
+print a_content.shape
+print a_style.shape
+
+hs = a_content.shape[0]
+ms = a_style.shape[0]
+
+if hs > ms:
+ a_style = np.lib.pad(a_style, ((0, hs - ms), (0,0)), 'constant', constant_values=(0, 0))
+else:
+ a_content = np.lib.pad(a_content, ((0, ms - hs), (0,0)), 'constant', constant_values=(0, 0))
+
+print a_content.shape
+print a_style.shape
+
+N_SAMPLES = a_style.shape[1]
+N_CHANNELS = a_style.shape[0]
+
+# ### Visualize spectrograms for content and style tracks
+
+# In[9]:
+
+"""
+plt.figure(figsize=(10, 5))
+plt.subplot(1, 2, 1)
+plt.title('Content')
+plt.imshow(a_content[:400,:])
+plt.subplot(1, 2, 2)
+plt.title('Style')
+plt.imshow(a_style[:400,:])
+plt.show()
+"""
+
+
+# ### Compute content and style feats
+
+# In[10]:
+
+N_FILTERS = 4096
+
+a_content_tf = np.ascontiguousarray(a_content.T[None,None,:,:])
+a_style_tf = np.ascontiguousarray(a_style.T[None,None,:,:])
+
+# filter shape is "[filter_height, filter_width, in_channels, out_channels]"
+std = np.sqrt(2) * np.sqrt(2.0 / ((N_CHANNELS + N_FILTERS) * 10))
+kernel = np.random.randn(1, 10, N_CHANNELS, N_FILTERS)*std
+
+g = tf.Graph()
+with g.as_default(), g.device('/cpu:0'), tf.Session() as sess:
+ # data shape is "[batch, in_height, in_width, in_channels]",
+ x = tf.placeholder('float32', [1,1,N_SAMPLES,N_CHANNELS], name="x")
+
+ kernel_tf = tf.constant(kernel, name="kernel", dtype='float32')
+ conv = tf.nn.conv2d(
+ x,
+ kernel_tf,
+ strides=[1, 1, 1, 1],
+ padding="VALID",
+ name="conv")
+
+ net = tf.nn.relu(conv)
+
+ content_features = net.eval(feed_dict={x: a_content_tf})
+ style_features = net.eval(feed_dict={x: a_style_tf})
+
+ features = np.reshape(style_features, (-1, N_FILTERS))
+ style_gram = np.matmul(features.T, features) / N_SAMPLES
+
+
+# ### Optimize
+
+# In[14]:
+
+from sys import stderr
+
+learning_rate= 1e-3
+iterations = 100
+
+result = None
+with tf.Graph().as_default():
+
+ # Build graph with variable input
+# x = tf.Variable(np.zeros([1,1,N_SAMPLES,N_CHANNELS], dtype=np.float32), name="x")
+ x = tf.Variable(np.random.randn(1,1,N_SAMPLES,N_CHANNELS).astype(np.float32)*1e-3, name="x")
+
+ kernel_tf = tf.constant(kernel, name="kernel", dtype='float32')
+ conv = tf.nn.conv2d(
+ x,
+ kernel_tf,
+ strides=[1, 1, 1, 1],
+ padding="VALID",
+ name="conv")
+
+
+ net = tf.nn.relu(conv)
+
+ content_loss = ALPHA * 2 * tf.nn.l2_loss( net - content_features)
+
+ style_loss = 0
+
+ _, height, width, number = map(lambda i: i.value, net.get_shape())
+
+ size = height * width * number
+ feats = tf.reshape(net, (-1, number))
+ gram = tf.matmul(tf.transpose(feats), feats) / N_SAMPLES
+ style_loss = 2 * tf.nn.l2_loss(gram - style_gram)
+
+ # Overall loss
+ loss = content_loss + style_loss
+
+ opt = tf.contrib.opt.ScipyOptimizerInterface(
+ loss, method='L-BFGS-B', options={'maxiter': iterations})
+
+ # Optimization
+ with tf.Session() as sess:
+ sess.run(tf.global_variables_initializer())
+
+ print('Started optimization.')
+ opt.minimize(sess)
+
+ print 'Final loss:', loss.eval()
+ result = x.eval()
+
+
+# ### Invert spectrogram and save the result
+
+# In[15]:
+
+a = np.zeros_like(a_content)
+a[:N_CHANNELS,:] = np.exp(result[0,0].T) - 1
+
+# This code is supposed to do phase reconstruction
+p = 2 * np.pi * np.random.random_sample(a.shape) - np.pi
+for i in range(500):
+ S = a * np.exp(1j*p)
+ x = librosa.istft(S)
+ p = np.angle(librosa.stft(x, N_FFT))
+
+librosa.output.write_wav(OUTPUT_FILENAME, x, fs)
+
+
+# In[16]:
+
+#print OUTPUT_FILENAME
+#display(Audio(OUTPUT_FILENAME))
+
+
+# ### Visualize spectrograms
+
+# In[17]:
+
+"""
+plt.figure(figsize=(15,5))
+plt.subplot(1,3,1)
+plt.title('Content')
+plt.imshow(a_content[:400,:])
+plt.subplot(1,3,2)
+plt.title('Style')
+plt.imshow(a_style[:400,:])
+plt.subplot(1,3,3)
+plt.title('Result')
+plt.imshow(a[:400,:])
+plt.show()
+"""
+
+
+# In[ ]:
+
+
+