summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore8
-rw-r--r--best-epoch.sh1
-rwxr-xr-xbin/crossfade_cat.sh129
-rwxr-xr-xbin/loop.sh20
-rwxr-xr-xbin/xfade.sh23
-rwxr-xr-xdatasets/count_subdirs.sh1
-rwxr-xr-xdatasets/generate.sh34
-rwxr-xr-xdatasets/split44k.sh24
-rwxr-xr-xdatasets/spread.sh25
-rw-r--r--generate.py365
-rw-r--r--glass_test2.sh85
-rw-r--r--glass_test3.sh112
-rw-r--r--glass_tests.sh202
-rw-r--r--kick_test.sh51
-rwxr-xr-xmix.sh20
-rw-r--r--train.py2
-rw-r--r--train_basic_22k.sh15
-rw-r--r--train_basic_32k.sh15
-rw-r--r--train_basic_44k.sh15
-rw-r--r--train_drums.sh44
-rw-r--r--train_test_generate.sh93
-rw-r--r--trainer/__init__.py9
-rw-r--r--trainer/plugins.py8
23 files changed, 1296 insertions, 5 deletions
diff --git a/.gitignore b/.gitignore
index 9c17115..3bb05ff 100644
--- a/.gitignore
+++ b/.gitignore
@@ -92,3 +92,11 @@ ENV/
*~
*.swp
*.swo
+
+*.wav
+*.mp3
+*.aif
+*.aiff
+
+results/
+
diff --git a/best-epoch.sh b/best-epoch.sh
new file mode 100644
index 0000000..230daa0
--- /dev/null
+++ b/best-epoch.sh
@@ -0,0 +1 @@
+ls -lat results/*/checkpoints/best* | head -n 20
diff --git a/bin/crossfade_cat.sh b/bin/crossfade_cat.sh
new file mode 100755
index 0000000..13a8e90
--- /dev/null
+++ b/bin/crossfade_cat.sh
@@ -0,0 +1,129 @@
+#!/bin/bash
+#
+# crossfade_cat.sh
+#
+# Concatenates two files together with a crossfade of $1 seconds.
+# Filenames are specified as $2 and $3.
+#
+# $4 is optional and specifies if a fadeout should be performed on
+# first file.
+# $5 is optional and specifies if a fadein should be performed on
+# second file.
+#
+# Example: crossfade_cat.sh 10 infile1.wav infile2.wav auto auto
+#
+# By default, the script attempts to guess if the audio files
+# already have a fadein/out on them or if they just have really
+# low volumes that won't cause clipping when mixxing. If this
+# is not detected then the script will perform a fade in/out to
+# prevent clipping.
+#
+# The user may specify "yes" or "no" to force the fade in/out
+# to occur. They can also specify "auto" which is the default.
+#
+# Crossfaded file is created as "mix.wav".
+#
+# Original script from Kester Clegg. Mods by Chris Bagwell to show
+# more examples of sox features.
+#
+
+SOX=sox
+SOXI=soxi
+
+if [ "$3" == "" ]; then
+ echo "Usage: $0 crossfade_seconds first_file second_file [ fadeout ] [ fadein ]"
+ echo
+ echo "If a fadeout or fadein is not desired then specify \"no\" for that option. \"yes\" will force a fade and \"auto\" will try to detect if a fade should occur."
+ echo
+ echo "Example: $0 10 infile1.wav infile2.wav auto auto"
+ exit 1
+fi
+
+fade_length=$1
+first_file=$2
+second_file=$3
+
+fade_first="auto"
+if [ "$4" != "" ]; then
+ fade_first=$4
+fi
+
+fade_second="auto"
+if [ "$5" != "" ]; then
+ fade_second=$5
+fi
+
+fade_first_opts=
+if [ "$fade_first" != "no" ]; then
+ fade_first_opts="fade t 0 0:0:$fade_length 0:0:$fade_length"
+fi
+
+fade_second_opts=
+if [ "$fade_second" != "no" ]; then
+ fade_second_opts="fade t 0:0:$fade_length"
+fi
+
+echo "crossfade and concatenate files"
+echo
+echo "Finding length of $first_file..."
+first_length=`$SOX "$first_file" 2>&1 -n stat | grep Length | cut -d : -f 2 | cut -f 1`
+echo "Length is $first_length seconds"
+
+trim_length=`echo "$first_length - $fade_length" | bc`
+
+# Get crossfade section from first file and optionally do the fade out
+echo "Obtaining $fade_length seconds of fade out portion from $first_file..."
+$SOX "$first_file" -e signed-integer -b 16 fadeout1.wav trim $trim_length
+
+# When user specifies "auto" try to guess if a fadeout is needed.
+# "RMS amplitude" from the stat effect is effectively an average
+# value of samples for the whole fade length file. If it seems
+# quite then assume a fadeout has already been done. An RMS value
+# of 0.1 was just obtained from trail and error.
+if [ "$fade_first" == "auto" ]; then
+ RMS=`$SOX fadeout1.wav 2>&1 -n stat | grep RMS | grep amplitude | cut -d : -f 2 | cut -f 1`
+ should_fade=`echo "$RMS > 0.1" | bc`
+ if [ $should_fade == 0 ]; then
+ echo "Auto mode decided not to fadeout with RMS of $RMS"
+ fade_first_opts=""
+ else
+ echo "Auto mode will fadeout"
+ fi
+fi
+
+$SOX fadeout1.wav fadeout2.wav $fade_first_opts
+
+# Get the crossfade section from the second file and optionally do the fade in
+echo "Obtaining $fade_length seconds of fade in portion from $second_file..."
+$SOX "$second_file" -e signed-integer -b 16 fadein1.wav trim 0 $fade_length
+
+# For auto, do similar thing as for fadeout.
+if [ "$fade_second" == "auto" ]; then
+ RMS=`$SOX fadein1.wav 2>&1 -n stat | grep RMS | grep amplitude | cut -d : -f 2 | cut -f 1`
+ should_fade=`echo "$RMS > 0.1" | bc`
+ if [ $should_fade == 0 ]; then
+ echo "Auto mode decided not to fadein with RMS of $RMS"
+ fade_second_opts=""
+ else
+ echo "Auto mode will fadein"
+ fi
+fi
+
+$SOX fadein1.wav fadein2.wav $fade_second_opts
+
+# Mix the crossfade files together at full volume
+echo "Crossfading..."
+$SOX -m -v 1.0 fadeout2.wav -v 1.0 fadein2.wav crossfade.wav
+
+echo "Trimming off crossfade sections from original files..."
+
+$SOX "$first_file" -e signed-integer -b 16 song1.wav trim 0 $trim_length
+$SOX "$second_file" -e signed-integer -b 16 song2.wav trim $fade_length
+$SOX song1.wav crossfade.wav song2.wav mix.wav
+
+echo -e "Removing temporary files...\n"
+rm fadeout1.wav fadeout2.wav fadein1.wav fadein2.wav crossfade.wav song1.wav song2.wav
+mins=`echo "$trim_length / 60" | bc`
+secs=`echo "$trim_length % 60" | bc`
+echo "The crossfade in mix.wav occurs at around $mins mins $secs secs"
+
diff --git a/bin/loop.sh b/bin/loop.sh
new file mode 100755
index 0000000..beb1286
--- /dev/null
+++ b/bin/loop.sh
@@ -0,0 +1,20 @@
+rm fades.txt
+rm files.txt
+rm mix.sh
+
+ITER=0
+NEXT=1
+for i in `ls -v *.wav`
+do
+echo "[0$ITER][$NEXT]acrossfade=ns=5:o=1:c1=tri:c2=tri[0$NEXT];" >> fades.txt
+printf '\-i %s ' $i >> files.txt
+ITER=$(expr $ITER + 1)
+NEXT=$(expr $NEXT + 1)
+done
+
+printf "ffmpeg " >> mix.sh
+cat files.txt >> mix.sh
+printf " -filter_complex \"" >> mix.sh
+cat fades.txt >> mix.sh
+echo '\" out.wav' >> mix.sh
+
diff --git a/bin/xfade.sh b/bin/xfade.sh
new file mode 100755
index 0000000..5e861e0
--- /dev/null
+++ b/bin/xfade.sh
@@ -0,0 +1,23 @@
+crossfade_dur=1
+i=0
+limit=10000
+
+for file in `ls *.wav | sort -V`
+do
+ i=$((i+1))
+ if [ $i -eq $limit ]
+ then
+ break
+ fi
+
+ if [ $i -eq 1 ]
+ then
+ cp $file mix.wav
+ else
+ # ../../crossfade_cat.sh $crossfade_dur mix.wav $file yes yes
+ echo $file
+ sox mix.wav "$file" out.wav splice $(soxi -D mix.wav),0.01
+ mv out.wav mix.wav
+ fi
+done
+
diff --git a/datasets/count_subdirs.sh b/datasets/count_subdirs.sh
new file mode 100755
index 0000000..3999b3c
--- /dev/null
+++ b/datasets/count_subdirs.sh
@@ -0,0 +1 @@
+find -maxdepth 1 -type d | sort | while read -r dir; do printf "%s:\t" "$dir"; find "$dir" -type f | wc -l; done
diff --git a/datasets/generate.sh b/datasets/generate.sh
new file mode 100755
index 0000000..335928c
--- /dev/null
+++ b/datasets/generate.sh
@@ -0,0 +1,34 @@
+function process () {
+ echo "____________________________________________________"
+ echo "process $1"
+ name=$1
+ in="${name}.wav"
+ out="s_${in}"
+ ./spread.sh $in $out 0.99 0.01 1.01
+ ./split44k.sh $out 8 "44k_$name"
+ rm $out
+}
+function ease_process () {
+ echo "____________________________________________________"
+ echo "ease_process $1"
+ name=$1
+ step=$2
+ in="${name}.wav"
+ sout="o_${in}"
+ out="s_${in}"
+ sox -v 0.95 $in $sout
+ ./spread.sh $sout $out 0.999 $step 1.001
+ ./split44k.sh $out 8 "44k_$name"
+ rm $sout
+ rm $out
+}
+#ease_process '' 0.0000
+ease_process 'blblbl' 0.00001515
+ease_process 'faty-scrub1' 0.0000285
+ease_process 'faty-medieval' 0.00003
+ease_process 'faty-crystals' 0.0000111
+ease_process 'faty-vocal1' 0.000013
+ease_process 'faty-vocal2' 0.000028145
+ease_process 'faty-scrub2' 0.00000466
+ease_process 'siren' 0.0000275
+
diff --git a/datasets/split44k.sh b/datasets/split44k.sh
new file mode 100755
index 0000000..4884af1
--- /dev/null
+++ b/datasets/split44k.sh
@@ -0,0 +1,24 @@
+#/bin/sh
+
+if [ "$#" -ne 3 ]; then
+ echo "Usage: $0 <filename.wav> <chunk size in seconds> <dataset path>"
+ exit
+fi
+
+fn=$1
+chunk_size=$2
+dataset_path=$3
+
+converted=".temp2.wav"
+rm -f $converted
+ffmpeg -i $fn -ac 1 -ar 44100 $converted
+
+mkdir $dataset_path
+length=$(ffprobe -i $converted -show_entries format=duration -v quiet -of csv="p=0")
+end=$(echo "$length / $chunk_size - 1" | bc)
+echo "splitting..."
+for i in $(seq 0 $end); do
+ ffmpeg -hide_banner -loglevel error -ss $(($i * $chunk_size)) -t $chunk_size -i $converted "$dataset_path/$i.wav"
+done
+echo "done"
+rm -f $converted
diff --git a/datasets/spread.sh b/datasets/spread.sh
new file mode 100755
index 0000000..bec1da3
--- /dev/null
+++ b/datasets/spread.sh
@@ -0,0 +1,25 @@
+
+if [ "$#" -ne 5 ]; then
+ echo "Usage: $0 <in.wav> <out.wav> <rate_min> <rate_step> <rate_max>"
+ exit
+fi
+
+FN_IN=$1
+FN_OUT=$2
+RATE=$3
+STEP=$4
+MAX=$5
+
+ITER=0
+while true; do
+ if (( $(echo "$RATE > $MAX" | bc -l) )); then
+ break
+ fi
+ let ITER+=1
+ RATE=`echo "$RATE+$STEP" | bc`
+ sox $FN_IN "tmp_$ITER.wav" speed $RATE
+done
+
+sox tmp_* $FN_OUT
+rm tmp_*
+
diff --git a/generate.py b/generate.py
new file mode 100644
index 0000000..92a930f
--- /dev/null
+++ b/generate.py
@@ -0,0 +1,365 @@
+# CometML needs to be imported first.
+try:
+ import comet_ml
+except ImportError:
+ pass
+
+from model import SampleRNN, Predictor
+from optim import gradient_clipping
+from nn import sequence_nll_loss_bits
+from trainer import Trainer
+from trainer.plugins import (
+ TrainingLossMonitor, ValidationPlugin, AbsoluteTimeMonitor, SaverPlugin,
+ GeneratorPlugin, StatsPlugin
+)
+from dataset import FolderDataset, DataLoader
+
+import torch
+from torch.utils.trainer.plugins import Logger
+
+from natsort import natsorted
+
+from functools import reduce
+import os
+import shutil
+import sys
+from glob import glob
+import re
+import argparse
+
+
+default_params = {
+ # model parameters
+ 'n_rnn': 1,
+ 'dim': 1024,
+ 'learn_h0': True,
+ 'q_levels': 256,
+ 'seq_len': 1024,
+ 'weight_norm': True,
+ 'batch_size': 128,
+ 'val_frac': 0.1,
+ 'test_frac': 0.1,
+
+ # training parameters
+ 'keep_old_checkpoints': False,
+ 'datasets_path': 'datasets',
+ 'results_path': 'results',
+ 'epoch_limit': 1000,
+ 'resume': True,
+ 'sample_rate': 16000,
+ 'n_samples': 1,
+ 'sample_length': 80000,
+ 'loss_smoothing': 0.99,
+ 'cuda': True,
+ 'comet_key': None
+}
+
+tag_params = [
+ 'exp', 'frame_sizes', 'n_rnn', 'dim', 'learn_h0', 'q_levels', 'seq_len',
+ 'batch_size', 'dataset', 'val_frac', 'test_frac'
+]
+
+def param_to_string(value):
+ if isinstance(value, bool):
+ return 'T' if value else 'F'
+ elif isinstance(value, list):
+ return ','.join(map(param_to_string, value))
+ else:
+ return str(value)
+
+def make_tag(params):
+ return '-'.join(
+ key + ':' + param_to_string(params[key])
+ for key in tag_params
+ if key not in default_params or params[key] != default_params[key]
+ )
+
+def setup_results_dir(params):
+ def ensure_dir_exists(path):
+ if not os.path.exists(path):
+ os.makedirs(path)
+
+ tag = make_tag(params)
+ results_path = os.path.abspath(params['results_path'])
+ ensure_dir_exists(results_path)
+ results_path = os.path.join(results_path, tag)
+ if not os.path.exists(results_path):
+ os.makedirs(results_path)
+ elif not params['resume']:
+ shutil.rmtree(results_path)
+ os.makedirs(results_path)
+
+ for subdir in ['checkpoints', 'samples']:
+ ensure_dir_exists(os.path.join(results_path, subdir))
+
+ return results_path
+
+def load_last_checkpoint(checkpoints_path):
+ checkpoints_pattern = os.path.join(
+ checkpoints_path, SaverPlugin.last_pattern.format('*', '*')
+ )
+ checkpoint_paths = natsorted(glob(checkpoints_pattern))
+ if len(checkpoint_paths) > 0:
+ checkpoint_path = checkpoint_paths[-1]
+ checkpoint_name = os.path.basename(checkpoint_path)
+ match = re.match(
+ SaverPlugin.last_pattern.format(r'(\d+)', r'(\d+)'),
+ checkpoint_name
+ )
+ epoch = int(match.group(1))
+ iteration = int(match.group(2))
+ return (torch.load(checkpoint_path), epoch, iteration)
+ else:
+ return None
+
+def tee_stdout(log_path):
+ log_file = open(log_path, 'a', 1)
+ stdout = sys.stdout
+
+ class Tee:
+
+ def write(self, string):
+ log_file.write(string)
+ stdout.write(string)
+
+ def flush(self):
+ log_file.flush()
+ stdout.flush()
+
+ sys.stdout = Tee()
+
+def make_data_loader(overlap_len, params):
+ path = os.path.join(params['datasets_path'], params['dataset'])
+ def data_loader(split_from, split_to, eval):
+ dataset = FolderDataset(
+ path, overlap_len, params['q_levels'], split_from, split_to
+ )
+ return DataLoader(
+ dataset,
+ batch_size=params['batch_size'],
+ seq_len=params['seq_len'],
+ overlap_len=overlap_len,
+ shuffle=(not eval),
+ drop_last=(not eval)
+ )
+ return data_loader
+
+def init_comet(params, trainer):
+ if params['comet_key'] is not None:
+ from comet_ml import Experiment
+ from trainer.plugins import CometPlugin
+ experiment = Experiment(api_key=params['comet_key'], log_code=False)
+ hyperparams = {
+ name: param_to_string(params[name]) for name in tag_params
+ }
+ experiment.log_multiple_params(hyperparams)
+ trainer.register_plugin(CometPlugin(
+ experiment, [
+ ('training_loss', 'epoch_mean'),
+ 'validation_loss',
+ 'test_loss'
+ ]
+ ))
+
+def main(exp, frame_sizes, dataset, **params):
+ params = dict(
+ default_params,
+ exp=exp, frame_sizes=frame_sizes, dataset=dataset,
+ **params
+ )
+
+ results_path = setup_results_dir(params)
+ tee_stdout(os.path.join(results_path, 'log'))
+
+ model = SampleRNN(
+ frame_sizes=params['frame_sizes'],
+ n_rnn=params['n_rnn'],
+ dim=params['dim'],
+ learn_h0=params['learn_h0'],
+ q_levels=params['q_levels'],
+ weight_norm=params['weight_norm']
+ )
+ predictor = Predictor(model)
+ if params['cuda']:
+ model = model.cuda()
+ predictor = predictor.cuda()
+
+ optimizer = gradient_clipping(torch.optim.Adam(predictor.parameters()))
+
+ data_loader = make_data_loader(model.lookback, params)
+ test_split = 1 - params['test_frac']
+ val_split = test_split - params['val_frac']
+
+ trainer = Trainer(
+ predictor, sequence_nll_loss_bits, optimizer,
+ data_loader(0, val_split, eval=False),
+ cuda=params['cuda']
+ )
+
+ checkpoints_path = os.path.join(results_path, 'checkpoints')
+ checkpoint_data = load_last_checkpoint(checkpoints_path)
+ if checkpoint_data is not None:
+ (state_dict, epoch, iteration) = checkpoint_data
+ trainer.epochs = epoch
+ trainer.iterations = iteration
+ predictor.load_state_dict(state_dict)
+ print("epochs: {} iterations: {}".format(epoch, iteration))
+
+ """
+ trainer.register_plugin(TrainingLossMonitor(
+ smoothing=params['loss_smoothing']
+ ))
+ """
+ trainer.register_plugin(ValidationPlugin(
+ data_loader(val_split, test_split, eval=True),
+ data_loader(test_split, 1, eval=True)
+ ))
+ trainer.register_plugin(AbsoluteTimeMonitor())
+ """
+ trainer.register_plugin(SaverPlugin(
+ checkpoints_path, params['keep_old_checkpoints']
+ ))
+ """
+ trainer.register_plugin(GeneratorPlugin(
+ os.path.join(results_path, 'samples'), params['n_samples'],
+ params['sample_length'], params['sample_rate']
+ ))
+ """
+ trainer.register_plugin(
+ Logger([
+ 'training_loss',
+ 'validation_loss',
+ 'test_loss',
+ 'time'
+ ])
+ )
+ trainer.register_plugin(StatsPlugin(
+ results_path,
+ iteration_fields=[
+ 'training_loss',
+ ('training_loss', 'running_avg'),
+ 'time'
+ ],
+ epoch_fields=[
+ 'validation_loss',
+ 'test_loss',
+ 'time'
+ ],
+ plots={
+ 'loss': {
+ 'x': 'iteration',
+ 'ys': [
+ 'training_loss',
+ ('training_loss', 'running_avg'),
+ 'validation_loss',
+ 'test_loss',
+ ],
+ 'log_y': True
+ }
+ }
+ ))
+ init_comet(params, trainer)
+ """
+ trainer.generate(int(params['epoch_limit']))
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ argument_default=argparse.SUPPRESS
+ )
+
+ def parse_bool(arg):
+ arg = arg.lower()
+ if 'true'.startswith(arg):
+ return True
+ elif 'false'.startswith(arg):
+ return False
+ else:
+ raise ValueError()
+
+ parser.add_argument('--exp', required=True, help='experiment name')
+ parser.add_argument(
+ '--frame_sizes', nargs='+', type=int, required=True,
+ help='frame sizes in terms of the number of lower tier frames, \
+ starting from the lowest RNN tier'
+ )
+ parser.add_argument(
+ '--dataset', required=True,
+ help='dataset name - name of a directory in the datasets path \
+ (settable by --datasets_path)'
+ )
+ parser.add_argument(
+ '--n_rnn', type=int, help='number of RNN layers in each tier'
+ )
+ parser.add_argument(
+ '--dim', type=int, help='number of neurons in every RNN and MLP layer'
+ )
+ parser.add_argument(
+ '--learn_h0', type=parse_bool,
+ help='whether to learn the initial states of RNNs'
+ )
+ parser.add_argument(
+ '--q_levels', type=int,
+ help='number of bins in quantization of audio samples'
+ )
+ parser.add_argument(
+ '--seq_len', type=int,
+ help='how many samples to include in each truncated BPTT pass'
+ )
+ parser.add_argument(
+ '--weight_norm', type=parse_bool,
+ help='whether to use weight normalization'
+ )
+ parser.add_argument('--batch_size', type=int, help='batch size')
+ parser.add_argument(
+ '--val_frac', type=float,
+ help='fraction of data to go into the validation set'
+ )
+ parser.add_argument(
+ '--test_frac', type=float,
+ help='fraction of data to go into the test set'
+ )
+ parser.add_argument(
+ '--keep_old_checkpoints', type=parse_bool,
+ help='whether to keep checkpoints from past epochs'
+ )
+ parser.add_argument(
+ '--datasets_path', help='path to the directory containing datasets'
+ )
+ parser.add_argument(
+ '--results_path', help='path to the directory to save the results to'
+ )
+ parser.add_argument('--epoch_limit', help='how many epochs to run')
+ parser.add_argument(
+ '--resume', type=parse_bool, default=True,
+ help='whether to resume training from the last checkpoint'
+ )
+ parser.add_argument(
+ '--sample_rate', type=int,
+ help='sample rate of the training data and generated sound'
+ )
+ parser.add_argument(
+ '--n_samples', type=int,
+ help='number of samples to generate in each epoch'
+ )
+ parser.add_argument(
+ '--sample_length', type=int,
+ help='length of each generated sample (in samples)'
+ )
+ parser.add_argument(
+ '--loss_smoothing', type=float,
+ help='smoothing parameter of the exponential moving average over \
+ training loss, used in the log and in the loss plot'
+ )
+ parser.add_argument(
+ '--cuda', type=parse_bool,
+ help='whether to use CUDA'
+ )
+ parser.add_argument(
+ '--comet_key', help='comet.ml API key'
+ )
+
+ parser.set_defaults(**default_params)
+
+ main(**vars(parser.parse_args()))
diff --git a/glass_test2.sh b/glass_test2.sh
new file mode 100644
index 0000000..072c451
--- /dev/null
+++ b/glass_test2.sh
@@ -0,0 +1,85 @@
+function runq () {
+ dataset=$1
+ sample_rate=$2
+ duration=$3
+ let sample_length=$2*$3
+ fs1=$4
+ fs2=$5
+ qs=$6
+
+ exp_name="space_q$qs"
+
+ echo "__________________________________________"
+ echo "__________________________________________"
+ echo "__________________________________________"
+ echo ""
+ echo ">> running $exp_name"
+ echo ""
+ echo "__________________________________________"
+ echo "__________________________________________"
+ echo "__________________________________________"
+
+ python train.py \
+ --exp $exp_name --dataset $dataset \
+ --frame_sizes $fs1 $fs2 \
+ --n_rnn 2 --dim 1024 --q_levels 512 \
+ --seq_len 1024 --batch_size 128 \
+ --val_frac 0.1 --test_frac 0.1 \
+ --sample_rate $sample_rate \
+ --sample_length $sample_length \
+ --keep_old_checkpoints False \
+ --epoch_limit 6 \
+ --resume True
+}
+function generateq () {
+ dataset=$1
+ sample_rate=$2
+ duration=$3
+ let sample_length=$2*$3
+ fs1=$4
+ fs2=$5
+ qs=$6
+
+ exp_name="space_q$qs"
+
+ echo ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>"
+ echo ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>"
+ echo ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>"
+ echo ""
+ echo ">> generating $exp_name"
+ echo ""
+ echo ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>"
+ echo ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>"
+ echo ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>"
+
+ python generate.py \
+ --exp $exp_name --dataset $dataset \
+ --frame_sizes $fs1 $fs2 \
+ --n_rnn 2 --dim 1024 --q_levels 512 \
+ --seq_len 1024 --batch_size 128 \
+ --val_frac 0.1 --test_frac 0.1 \
+ --sample_rate $sample_rate \
+ --sample_length $sample_length \
+ --keep_old_checkpoints False \
+ --epoch_limit 6 \
+ --resume True
+}
+
+#runq 44k_glass_space1 44100 5 8 2 32
+runq 44k_glass_space1 44100 5 8 2 1
+runq 44k_glass_space1 44100 5 8 2 2
+runq 44k_glass_space1 44100 5 8 2 4
+runq 44k_glass_space1 44100 5 8 2 8
+runq 44k_glass_space1 44100 5 8 2 16
+runq 44k_glass_space1 44100 5 8 2 32
+#runq 44k_glass_space1 44100 5 8 2 2048
+
+#generateq 44k_glass_space1 44100 5 8 2 32
+#generateq 44k_glass_space1 44100 5 8 2 1
+#generateq 44k_glass_space1 44100 5 8 2 2
+#generateq 44k_glass_space1 44100 5 8 2 4
+#generateq 44k_glass_space1 44100 5 8 2 8
+#generateq 44k_glass_space1 44100 5 8 2 16
+#generateq 44k_glass_space1 44100 5 8 2 32
+#generateq 44k_glass_space1 44100 5 8 2 2048
+
diff --git a/glass_test3.sh b/glass_test3.sh
new file mode 100644
index 0000000..029cf58
--- /dev/null
+++ b/glass_test3.sh
@@ -0,0 +1,112 @@
+function runq () {
+ dataset=$1
+ sample_rate=$2
+ duration=$3
+ let sample_length=$2*$3
+ qs=$4
+
+ exp_name="space_qs$dim"
+
+ echo "__________________________________________"
+ echo "__________________________________________"
+ echo "__________________________________________"
+ echo ""
+ echo ">> running $exp_name"
+ echo ""
+ echo "__________________________________________"
+ echo "__________________________________________"
+ echo "__________________________________________"
+
+ python train.py \
+ --exp $exp_name --dataset $dataset \
+ --frame_sizes 8 2 \
+ --n_rnn 2 --dim 1024 --q_levels $qs \
+ --seq_len 1024 --batch_size 128 \
+ --val_frac 0.1 --test_frac 0.1 \
+ --sample_rate $sample_rate \
+ --sample_length $sample_length \
+ --keep_old_checkpoints False \
+ --epoch_limit 4 \
+ --resume True
+}
+
+function rundim () {
+ dataset=$1
+ sample_rate=$2
+ duration=$3
+ let sample_length=$2*$3
+ dim=$4
+
+ exp_name="space_dim$dim"
+
+ echo "__________________________________________"
+ echo "__________________________________________"
+ echo "__________________________________________"
+ echo ""
+ echo ">> running $exp_name"
+ echo ""
+ echo "__________________________________________"
+ echo "__________________________________________"
+ echo "__________________________________________"
+
+ python train.py \
+ --exp $exp_name --dataset $dataset \
+ --frame_sizes 8 2 \
+ --n_rnn 2 --dim $dim --q_levels 256 \
+ --seq_len 1024 --batch_size 128 \
+ --val_frac 0.1 --test_frac 0.1 \
+ --sample_rate $sample_rate \
+ --sample_length $sample_length \
+ --keep_old_checkpoints False \
+ --epoch_limit 8 \
+ --resume True
+}
+
+function runseq () {
+ dataset=$1
+ sample_rate=$2
+ duration=$3
+ let sample_length=$2*$3
+ seq=$4
+
+ exp_name="space_seq$seq"
+
+ echo "__________________________________________"
+ echo "__________________________________________"
+ echo "__________________________________________"
+ echo ""
+ echo ">> running $exp_name"
+ echo ""
+ echo "__________________________________________"
+ echo "__________________________________________"
+ echo "__________________________________________"
+
+ python train.py \
+ --exp $exp_name --dataset $dataset \
+ --frame_sizes 8 2 \
+ --n_rnn 2 --dim 1024 --q_levels 256 \
+ --seq_len $seq --batch_size 128 \
+ --val_frac 0.1 --test_frac 0.1 \
+ --sample_rate $sample_rate \
+ --sample_length $sample_length \
+ --keep_old_checkpoints False \
+ --epoch_limit 4 \
+ --resume True
+}
+
+#rundim 44k_glass_space1 44100 5 4096
+runseq 44k_glass_space1 44100 5 512
+rundim 44k_glass_space1 44100 5 512
+#rundim 44k_glass_space1 44100 5 256
+#rundim 44k_glass_space1 44100 5 128
+#rundim 44k_glass_space1 44100 5 64
+#rundim 44k_glass_space1 44100 5 32
+#rundim 44k_glass_space1 44100 5 16
+#rundim 44k_glass_space1 44100 5 8
+#rundim 44k_glass_space1 44100 5 4
+#rundim 44k_glass_space1 44100 5 2
+#rundim 44k_glass_space1 44100 5 1
+
+runq 44k_glass_space1 44100 5 1024
+runq 44k_glass_space1 44100 5 128
+
diff --git a/glass_tests.sh b/glass_tests.sh
new file mode 100644
index 0000000..62d8932
--- /dev/null
+++ b/glass_tests.sh
@@ -0,0 +1,202 @@
+function run2 () {
+ dataset=$1
+ sample_rate=$2
+ duration=$3
+ let sample_length=$2*$3
+ fs1=$4
+ fs2=$5
+
+ exp_name=space_2
+
+ echo "@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@"
+ echo "@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@"
+ echo "@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@"
+ echo ""
+ echo ">> running $exp_name $4 $5"
+ echo ""
+ echo "@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@"
+ echo "@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@"
+ echo "@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@"
+ python train.py \
+ --exp $exp_name --dataset $dataset \
+ --frame_sizes $fs1 $fs2 \
+ --n_rnn 2 --dim 1024 --q_levels 256 \
+ --seq_len 1024 --batch_size 128 \
+ --val_frac 0.1 --test_frac 0.1 \
+ --sample_rate $sample_rate \
+ --sample_length $sample_length \
+ --keep_old_checkpoints False \
+ --epoch_limit 2 \
+ --resume True
+}
+function generate2 () {
+ dataset=$1
+ sample_rate=$2
+ duration=$3
+ let sample_length=$2*$3
+ fs1=$4
+ fs2=$5
+
+ exp_name=space_2
+
+ echo ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>"
+ echo ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>"
+ echo ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>"
+ echo ""
+ echo ">> generating $exp_name $4 $5"
+ echo ""
+ echo ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>"
+ echo ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>"
+ echo ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>"
+ python generate.py \
+ --exp $exp_name --dataset $dataset \
+ --frame_sizes $fs1 $fs2 \
+ --n_rnn 2 --dim 1024 --q_levels 256 \
+ --seq_len 1024 --batch_size 128 \
+ --val_frac 0.1 --test_frac 0.1 \
+ --sample_rate $sample_rate \
+ --sample_length $sample_length \
+ --keep_old_checkpoints False \
+ --epoch_limit 1 \
+ --n_samples 6 \
+ --resume True
+}
+function run3 () {
+ dataset=$1
+ sample_rate=$2
+ duration=$3
+ let sample_length=$2*$3
+ fs1=$4
+ fs2=$5
+ fs3=$6
+
+ exp_name=space_3
+
+ echo ">> running $exp_name $4 $5 $6"
+ python train.py \
+ --exp $exp_name --dataset $dataset \
+ --frame_sizes $fs1 $fs2 $fs3 \
+ --n_rnn 3 --dim 1024 --q_levels 256 \
+ --seq_len 1024 --batch_size 128 \
+ --val_frac 0.1 --test_frac 0.1 \
+ --sample_rate $sample_rate \
+ --sample_length $sample_length \
+ --keep_old_checkpoints False \
+ --epoch_limit 1 \
+ --resume True
+}
+
+function run4 () {
+ dataset=$1
+ sample_rate=$2
+ duration=$3
+ let sample_length=$2*$3
+ fs1=$4
+ fs2=$5
+ fs3=$6
+ fs4=$7
+
+ exp_name=space_4
+
+ echo ">> running $exp_name $4 $5 $6 $7"
+ python train.py \
+ --exp $exp_name --dataset $dataset \
+ --frame_sizes $fs1 $fs2 $fs3 $fs4 \
+ --n_rnn 4 --dim 1024 --q_levels 256 \
+ --seq_len 1024 --batch_size 128 \
+ --val_frac 0.1 --test_frac 0.1 \
+ --sample_rate $sample_rate \
+ --sample_length $sample_length \
+ --keep_old_checkpoints False \
+ --epoch_limit 4 \
+ --resume True
+}
+function run5 () {
+ dataset=$1
+ sample_rate=$2
+ duration=$3
+ let sample_length=$2*$3
+ fs1=$4
+ fs2=$5
+ fs3=$6
+ fs4=$7
+ fs5=$8
+
+ exp_name=space_5
+
+ echo ">> running $exp_name $4 $5 $6 $7 $8"
+ python train.py \
+ --exp $exp_name --dataset $dataset \
+ --frame_sizes $fs1 $fs2 $fs3 $fs4 $fs5 \
+ --n_rnn 5 --dim 1024 --q_levels 256 \
+ --seq_len 1024 --batch_size 128 \
+ --val_frac 0.1 --test_frac 0.1 \
+ --sample_rate $sample_rate \
+ --sample_length $sample_length \
+ --keep_old_checkpoints False \
+ --epoch_limit 4 \
+ --resume True
+}
+function run6 () {
+ dataset=$1
+ sample_rate=$2
+ duration=$3
+ let sample_length=$2*$3
+ fs1=$4
+ fs2=$5
+ fs3=$6
+ fs4=$7
+ fs5=$8
+ fs6=$9
+
+ exp_name=space_6
+
+ echo ">> running $exp_name $4 $5 $6 $7 $8 $9"
+ python train.py \
+ --exp $exp_name --dataset $dataset \
+ --frame_sizes $fs1 $fs2 $fs3 $fs4 $fs5 $fs6 \
+ --n_rnn 5 --dim 1024 --q_levels 256 \
+ --seq_len 1024 --batch_size 128 \
+ --val_frac 0.1 --test_frac 0.1 \
+ --sample_rate $sample_rate \
+ --sample_length $sample_length \
+ --keep_old_checkpoints False \
+ --epoch_limit 4 \
+ --resume True
+}
+
+run4 44k_glass_space1 44100 5 2 2 2 2
+run5 44k_glass_space1 44100 5 2 2 2 2 2
+run6 44k_glass_space1 44100 5 2 2 2 2 2 2
+
+run4 44k_glass_space1 44100 5 8 2 2 2
+run5 44k_glass_space1 44100 5 8 2 2 2 2
+run6 44k_glass_space1 44100 5 8 2 2 2 2 2
+#generate2 44k_glass_space1 44100 5 8 2
+#run2 44k_glass_space1 44100 5 8 4
+#generate2 44k_glass_space1 44100 5 8 4
+#run2 44k_glass_space1 44100 5 8 1
+#generate2 44k_glass_space1 44100 5 8 1
+#run2 44k_glass_space1 44100 5 16 8
+#run2 44k_glass_space1 44100 5 16 4
+#run2 44k_glass_space1 44100 5 16 2
+#generate2 44k_glass_space1 44100 5 16 2
+#run2 44k_glass_space1 44100 5 16 1
+#generate2 44k_glass_space1 44100 5 16 1
+
+#run2 44k_glass_space1 44100 5 4 2
+#run2 44k_glass_space1 44100 5 4 1
+#run2 44k_glass_space1 44100 5 2 1
+
+#run3 44k_glass_space1 44100 5 8 2 1
+#run3 44k_glass_space1 44100 5 8 4 2
+#run3 44k_glass_space1 44100 5 16 8 2
+#run3 44k_glass_space1 44100 5 16 4 2
+#run3 44k_glass_space1 44100 5 16 2 1
+
+#run2 44k_glass_space1 44100 5 6 2
+#run2 44k_glass_space1 44100 5 6 3
+#run3 44k_glass_space1 44100 5 5 3 2
+#run3 44k_glass_space1 44100 5 10 5 2
+#run3 44k_glass_space1 44100 5 8 5 2
+
diff --git a/kick_test.sh b/kick_test.sh
new file mode 100644
index 0000000..c3754a4
--- /dev/null
+++ b/kick_test.sh
@@ -0,0 +1,51 @@
+function generaterrr () {
+ exp_name=$1
+ dataset=$2
+ epoch_limit=$3
+ sample_rate=$4
+ duration=$5
+ let sample_length=$4*$5
+ qlev=$6
+
+ echo ">> generating $exp_name"
+ python generate.py \
+ --exp $exp_name --dataset $dataset \
+ --frame_sizes 8 2 \
+ --n_rnn 2 --dim 1024 --q_levels $qlev \
+ --seq_len 1024 --batch_size 128 \
+ --val_frac 0.1 --test_frac 0.1 \
+ --sample_rate $sample_rate \
+ --sample_length $sample_length \
+ --keep_old_checkpoints False \
+ --epoch_limit $epoch_limit \
+ --resume True
+}
+function runrrr () {
+ exp_name=$1
+ dataset=$2
+ epoch_limit=$3
+ sample_rate=$4
+ duration=$5
+ let sample_length=$4*$5
+ qlev=$6
+
+ echo ">> running $exp_name"
+ python train.py \
+ --exp $exp_name --dataset $dataset \
+ --frame_sizes 8 2 \
+ --n_rnn 2 --dim 1024 --q_levels $qlev \
+ --seq_len 1024 --batch_size 128 \
+ --val_frac 0.1 --test_frac 0.1 \
+ --sample_rate $sample_rate \
+ --sample_length $sample_length \
+ --keep_old_checkpoints False \
+ --epoch_limit $epoch_limit \
+ --resume True
+}
+
+runrrr kiq256 kiq 1 44100 1 256
+generaterrr kiq256 kiq 6 44100 1 256
+
+runrrr kiq512 kiq 1 44100 1 512
+generaterrr kiq512 kiq 10 44100 1 512
+
diff --git a/mix.sh b/mix.sh
new file mode 100755
index 0000000..d95c4ef
--- /dev/null
+++ b/mix.sh
@@ -0,0 +1,20 @@
+#!/bin/bash
+
+if [ "$#" -ne 2 ]; then
+ echo "Usage: $0 <results/exp\:experiment> render_name"
+ exit
+fi
+
+dir=$1
+name=$2
+now=`date +'%Y%m%d'`
+
+cd "$dir/samples"
+mkdir "$name"
+../../xfade.sh
+cp mix.wav "../../../output/$name.wav"
+mv *.wav "$name"
+cd ..
+vim log
+cd ../..
+
diff --git a/train.py b/train.py
index e3061c9..b808cee 100644
--- a/train.py
+++ b/train.py
@@ -255,7 +255,7 @@ def main(exp, frame_sizes, dataset, **params):
init_comet(params, trainer)
- trainer.run(params['epoch_limit'])
+ trainer.run(int(params['epoch_limit']))
if __name__ == '__main__':
diff --git a/train_basic_22k.sh b/train_basic_22k.sh
new file mode 100644
index 0000000..5104034
--- /dev/null
+++ b/train_basic_22k.sh
@@ -0,0 +1,15 @@
+exp_name=$1
+dataset=$2
+epoch_limit=$3
+python train.py \
+ --exp $exp_name --dataset $exp_name \
+ --frame_sizes 8 2 \
+ --n_rnn 2 --dim 1024 --q_levels 256 \
+ --seq_len 1024 --batch_size 64 \
+ --val_frac 0.1 --test_frac 0.1 \
+ --sample_rate 22050 \
+ --sample_length 110250 \
+ --keep_old_checkpoints False \
+ --epoch_limit $epoch_limit \
+ --resume True
+
diff --git a/train_basic_32k.sh b/train_basic_32k.sh
new file mode 100644
index 0000000..c4d1319
--- /dev/null
+++ b/train_basic_32k.sh
@@ -0,0 +1,15 @@
+exp_name=$1
+dataset=$2
+epoch_limit=$3
+python train.py \
+ --exp $exp_name --dataset $exp_name \
+ --frame_sizes 8 2 \
+ --n_rnn 2 --dim 1024 --q_levels 256 \
+ --seq_len 1024 --batch_size 64 \
+ --val_frac 0.1 --test_frac 0.1 \
+ --sample_rate 32000 \
+ --sample_length 160000 \
+ --keep_old_checkpoints False \
+ --epoch_limit $epoch_limit \
+ --resume True
+
diff --git a/train_basic_44k.sh b/train_basic_44k.sh
new file mode 100644
index 0000000..1908f5d
--- /dev/null
+++ b/train_basic_44k.sh
@@ -0,0 +1,15 @@
+exp_name=$1
+dataset=$2
+epoch_limit=$3
+python train.py \
+ --exp $exp_name --dataset $exp_name \
+ --frame_sizes 8 2 \
+ --n_rnn 2 --dim 1024 --q_levels 256 \
+ --seq_len 1024 --batch_size 64 \
+ --val_frac 0.1 --test_frac 0.1 \
+ --sample_rate 44100 \
+ --sample_length 220500 \
+ --keep_old_checkpoints False \
+ --epoch_limit $epoch_limit \
+ --resume True
+
diff --git a/train_drums.sh b/train_drums.sh
new file mode 100644
index 0000000..2157fc7
--- /dev/null
+++ b/train_drums.sh
@@ -0,0 +1,44 @@
+function generate () {
+ exp_name=$1
+ epoch_limit=$2
+ sample_rate=$3
+ sample_length=$4
+
+ echo ">> generating $exp_name"
+ python generate.py \
+ --exp $exp_name --dataset $exp_name \
+ --frame_sizes 8 2 \
+ --n_rnn 2 --dim 1024 --q_levels 256 \
+ --seq_len 1024 --batch_size 128 \
+ --val_frac 0.1 --test_frac 0.1 \
+ --sample_rate $sample_rate \
+ --sample_length $sample_length \
+ --keep_old_checkpoints False \
+ --epoch_limit $epoch_limit \
+ --resume True
+}
+function run () {
+ exp_name=$1
+ epoch_limit=$2
+ sample_rate=$3
+ sample_length=$4
+
+ echo ">> running $exp_name"
+ python train.py \
+ --exp $exp_name --dataset $exp_name \
+ --frame_sizes 8 2 \
+ --n_rnn 2 --dim 1024 --q_levels 256 \
+ --seq_len 1024 --batch_size 128 \
+ --val_frac 0.1 --test_frac 0.1 \
+ --sample_rate $sample_rate \
+ --sample_length $sample_length \
+ --keep_old_checkpoints False \
+ --epoch_limit $epoch_limit \
+ --resume True
+}
+
+#generate 44k_jlin3 64 44100 4134
+generate 44k_jlin4 64 44100 4134
+generate 44k_clouds2 64 44100 4134
+generate 44k_glassmix 64 44100 4134
+
diff --git a/train_test_generate.sh b/train_test_generate.sh
new file mode 100644
index 0000000..af4f1c7
--- /dev/null
+++ b/train_test_generate.sh
@@ -0,0 +1,93 @@
+function generate () {
+ exp_name=$1
+ n_samples=$2
+ sample_rate=$3
+ duration=$4
+ let sample_length=$3*$4
+
+ echo ""
+ echo "###################################################"
+ echo "###################################################"
+ echo "###################################################"
+ echo ""
+ echo ">> generating $exp_name"
+ echo ""
+ echo "###################################################"
+ echo "###################################################"
+ echo "###################################################"
+ echo ""
+ python generate.py \
+ --exp $exp_name --dataset $exp_name \
+ --frame_sizes 8 2 \
+ --n_rnn 2 --dim 1024 --q_levels 256 \
+ --seq_len 1024 --batch_size 128 \
+ --val_frac 0.1 --test_frac 0.1 \
+ --sample_rate $sample_rate \
+ --sample_length $sample_length \
+ --keep_old_checkpoints False \
+ --n_samples $n_samples \
+ --epoch_limit 1 \
+ --resume True
+}
+function run () {
+ exp_name=$1
+ epoch_limit=$2
+ n_samples=$3
+ sample_rate=$4
+ duration=$5
+ let sample_length=$4*$5
+
+ echo ""
+ echo ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>"
+ echo ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>"
+ echo ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>"
+ echo ""
+ echo ">> running $exp_name"
+ echo ""
+ echo ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>"
+ echo ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>"
+ echo ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>"
+ echo ""
+ python train.py \
+ --exp $exp_name --dataset $exp_name \
+ --frame_sizes 8 2 \
+ --n_rnn 2 --dim 1024 --q_levels 256 \
+ --seq_len 1024 --batch_size 128 \
+ --val_frac 0.1 --test_frac 0.1 \
+ --sample_rate $sample_rate \
+ --sample_length $sample_length \
+ --n_samples $n_samples \
+ --keep_old_checkpoints False \
+ --epoch_limit $epoch_limit \
+ --resume True
+}
+function standard () {
+ dataset=$1
+ run $1 6 6 44100 5
+}
+function quick () {
+ dataset=$1
+ run $1 4 6 44100 5
+}
+function fast () {
+ dataset=$1
+ run $1 1 6 44100 10
+}
+
+standard 44k_blblbl2
+standard 44k_faty-scrub2
+standard 44k_faty-vocal2
+
+quick 44k_siren
+quick 44k_whatifvocode
+quick 44k_jlin-faty
+
+fast 44k_lipnoise
+quick 44k_jlin-faty
+quick 44k_jlin-faty
+
+run 44k_dances 8 6 44100 5
+run 44k_jlin4 4 3 44100 10
+
+run 44k_sundae 4 3 44100 10
+
diff --git a/trainer/__init__.py b/trainer/__init__.py
index 7e2ea18..1f39506 100644
--- a/trainer/__init__.py
+++ b/trainer/__init__.py
@@ -56,6 +56,15 @@ class Trainer(object):
self.train()
self.call_plugins('epoch', self.epochs)
+ def generate(self, epochs=1):
+ for q in self.plugin_queues.values():
+ heapq.heapify(q)
+
+ for self.epochs in range(self.epochs + 1, self.epochs + epochs + 1):
+ # self.train()
+ self.call_plugins('update', self.iterations, self.model)
+ self.call_plugins('epoch', self.epochs)
+
def train(self):
for (self.iterations, data) in \
enumerate(self.dataset, self.iterations + 1):
diff --git a/trainer/plugins.py b/trainer/plugins.py
index f8c299b..0126870 100644
--- a/trainer/plugins.py
+++ b/trainer/plugins.py
@@ -141,7 +141,7 @@ class SaverPlugin(Plugin):
class GeneratorPlugin(Plugin):
- pattern = 'ep{}-s{}.wav'
+ pattern = 'd-{}-ep{}-s{}.wav'
def __init__(self, samples_path, n_samples, sample_length, sample_rate):
super().__init__([(1, 'epoch')])
@@ -159,7 +159,7 @@ class GeneratorPlugin(Plugin):
for i in range(self.n_samples):
write_wav(
os.path.join(
- self.samples_path, self.pattern.format(epoch_index, i + 1)
+ self.samples_path, self.pattern.format(int(time.time()), epoch_index, i + 1)
),
samples[i, :], sr=self.sample_rate, norm=True
)
@@ -168,7 +168,7 @@ class GeneratorPlugin(Plugin):
class StatsPlugin(Plugin):
data_file_name = 'stats.pkl'
- plot_pattern = '{}.svg'
+ plot_pattern = 'd-{}-{}.svg'
def __init__(self, results_path, iteration_fields, epoch_fields, plots):
super().__init__([(1, 'iteration'), (1, 'epoch')])
@@ -252,7 +252,7 @@ class StatsPlugin(Plugin):
pyplot.legend()
pyplot.savefig(
- os.path.join(self.results_path, self.plot_pattern.format(name))
+ os.path.join(self.results_path, self.plot_pattern.format(int(time.time()), name))
)
@staticmethod