diff options
Diffstat (limited to 'python/nsatf.py')
| -rw-r--r-- | python/nsatf.py | 235 |
1 files changed, 235 insertions, 0 deletions
diff --git a/python/nsatf.py b/python/nsatf.py new file mode 100644 index 0000000..a35cf51 --- /dev/null +++ b/python/nsatf.py @@ -0,0 +1,235 @@ + +# coding: utf-8 + +# In[4]: + +import tensorflow as tf +import librosa +import os +# from IPython.display import Audio, display +import numpy as np +# import matplotlib.pyplot as plt +import sys +# get_ipython().magic(u'matplotlib inline') + + +# ### Load style and content + +# In[5]: + +if len(sys.argv) < 4: + print "python nsatf.py content.wav style.wav output.wav alpha" + sys.exit() + +CONTENT_FILENAME = sys.argv[1] +STYLE_FILENAME = sys.argv[2] +OUTPUT_FILENAME = sys.argv[3] +if len(sys.argv) == 5: + ALPHA = float(sys.argv[4] or "1e-3") +else: + ALPHA = 1e-3 + + +# In[6]: + +# display(Audio(CONTENT_FILENAME)) +# display(Audio(STYLE_FILENAME)) + + +# In[7]: + +# Reads wav file and produces spectrum +# Fourier phases are ignored +N_FFT = 2048 +def read_audio_spectum(filename): + print 'load ' + filename + x, fs = librosa.load(filename, 44100) + S = librosa.stft(x, N_FFT) + p = np.angle(S) + + S = np.log1p(np.abs(S[:,:1020])) + return S, fs + + +# In[8]: + +a_content, fs = read_audio_spectum(CONTENT_FILENAME) +a_style, fs = read_audio_spectum(STYLE_FILENAME) + +hs = a_content.shape[1] +ms = a_style.shape[1] + +if hs > ms: + a_style = np.lib.pad(a_style, ((0,0), (0, hs - ms)), 'constant', constant_values=(0, 0)) +else: + a_content = np.lib.pad(a_content, ((0,0), (0, ms - hs)), 'constant', constant_values=(0, 0)) + +print a_content.shape +print a_style.shape + +hs = a_content.shape[0] +ms = a_style.shape[0] + +if hs > ms: + a_style = np.lib.pad(a_style, ((0, hs - ms), (0,0)), 'constant', constant_values=(0, 0)) +else: + a_content = np.lib.pad(a_content, ((0, ms - hs), (0,0)), 'constant', constant_values=(0, 0)) + +print a_content.shape +print a_style.shape + +N_SAMPLES = a_style.shape[1] +N_CHANNELS = a_style.shape[0] + +# ### Visualize spectrograms for content and style tracks + +# In[9]: + +""" +plt.figure(figsize=(10, 5)) +plt.subplot(1, 2, 1) +plt.title('Content') +plt.imshow(a_content[:400,:]) +plt.subplot(1, 2, 2) +plt.title('Style') +plt.imshow(a_style[:400,:]) +plt.show() +""" + + +# ### Compute content and style feats + +# In[10]: + +N_FILTERS = 4096 + +a_content_tf = np.ascontiguousarray(a_content.T[None,None,:,:]) +a_style_tf = np.ascontiguousarray(a_style.T[None,None,:,:]) + +# filter shape is "[filter_height, filter_width, in_channels, out_channels]" +std = np.sqrt(2) * np.sqrt(2.0 / ((N_CHANNELS + N_FILTERS) * 10)) +kernel = np.random.randn(1, 10, N_CHANNELS, N_FILTERS)*std + +g = tf.Graph() +with g.as_default(), g.device('/cpu:0'), tf.Session() as sess: + # data shape is "[batch, in_height, in_width, in_channels]", + x = tf.placeholder('float32', [1,1,N_SAMPLES,N_CHANNELS], name="x") + + kernel_tf = tf.constant(kernel, name="kernel", dtype='float32') + conv = tf.nn.conv2d( + x, + kernel_tf, + strides=[1, 1, 1, 1], + padding="VALID", + name="conv") + + net = tf.nn.relu(conv) + + content_features = net.eval(feed_dict={x: a_content_tf}) + style_features = net.eval(feed_dict={x: a_style_tf}) + + features = np.reshape(style_features, (-1, N_FILTERS)) + style_gram = np.matmul(features.T, features) / N_SAMPLES + + +# ### Optimize + +# In[14]: + +from sys import stderr + +learning_rate= 1e-3 +iterations = 100 + +result = None +with tf.Graph().as_default(): + + # Build graph with variable input +# x = tf.Variable(np.zeros([1,1,N_SAMPLES,N_CHANNELS], dtype=np.float32), name="x") + x = tf.Variable(np.random.randn(1,1,N_SAMPLES,N_CHANNELS).astype(np.float32)*1e-3, name="x") + + kernel_tf = tf.constant(kernel, name="kernel", dtype='float32') + conv = tf.nn.conv2d( + x, + kernel_tf, + strides=[1, 1, 1, 1], + padding="VALID", + name="conv") + + + net = tf.nn.relu(conv) + + content_loss = ALPHA * 2 * tf.nn.l2_loss( net - content_features) + + style_loss = 0 + + _, height, width, number = map(lambda i: i.value, net.get_shape()) + + size = height * width * number + feats = tf.reshape(net, (-1, number)) + gram = tf.matmul(tf.transpose(feats), feats) / N_SAMPLES + style_loss = 2 * tf.nn.l2_loss(gram - style_gram) + + # Overall loss + loss = content_loss + style_loss + + opt = tf.contrib.opt.ScipyOptimizerInterface( + loss, method='L-BFGS-B', options={'maxiter': iterations}) + + # Optimization + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + + print('Started optimization.') + opt.minimize(sess) + + print 'Final loss:', loss.eval() + result = x.eval() + + +# ### Invert spectrogram and save the result + +# In[15]: + +a = np.zeros_like(a_content) +a[:N_CHANNELS,:] = np.exp(result[0,0].T) - 1 + +# This code is supposed to do phase reconstruction +p = 2 * np.pi * np.random.random_sample(a.shape) - np.pi +for i in range(500): + S = a * np.exp(1j*p) + x = librosa.istft(S) + p = np.angle(librosa.stft(x, N_FFT)) + +librosa.output.write_wav(OUTPUT_FILENAME, x, fs) + + +# In[16]: + +#print OUTPUT_FILENAME +#display(Audio(OUTPUT_FILENAME)) + + +# ### Visualize spectrograms + +# In[17]: + +""" +plt.figure(figsize=(15,5)) +plt.subplot(1,3,1) +plt.title('Content') +plt.imshow(a_content[:400,:]) +plt.subplot(1,3,2) +plt.title('Style') +plt.imshow(a_style[:400,:]) +plt.subplot(1,3,3) +plt.title('Result') +plt.imshow(a[:400,:]) +plt.show() +""" + + +# In[ ]: + + + |
