diff options
| author | jules@lens <julescarbon@gmail.com> | 2018-09-05 12:00:28 +0200 |
|---|---|---|
| committer | jules@lens <julescarbon@gmail.com> | 2018-09-05 12:00:28 +0200 |
| commit | 9abfa16dc059d042c21f1636ecc8797ef29a030d (patch) | |
| tree | d0583cb5dae01de1abc57ed8f7587d23242ed6f0 /app/relay/modules/pix2pixhd.js | |
| parent | 0a3c6743543dd3dfcb876f5ce735b72d050e981d (diff) | |
| parent | 15eb6806b6e216255f33abcb885f6cdbc38a7663 (diff) | |
Merge branch 'master' of asdf.us:live-cortex
Diffstat (limited to 'app/relay/modules/pix2pixhd.js')
| -rw-r--r-- | app/relay/modules/pix2pixhd.js | 73 |
1 files changed, 59 insertions, 14 deletions
diff --git a/app/relay/modules/pix2pixhd.js b/app/relay/modules/pix2pixhd.js index 821be71..a90fc15 100644 --- a/app/relay/modules/pix2pixhd.js +++ b/app/relay/modules/pix2pixhd.js @@ -81,20 +81,64 @@ const generate = { type: 'pytorch', script: 'test.py', params: (task) => { + let epoch = 0 + const dataset = task.dataset.toLowerCase() + const datasets_path = path.join(cwd, 'datasets', dataset) + const checkpoints_path = path.join(cwd, 'checkpoints', dataset) + const iter_txt = path.join(checkpoints_path, 'iter.txt') + console.log(dataset, iter_txt) + if (fs.existsSync(iter_txt)) { + const iter = fs.readFileSync(iter_txt).toString().split('\n'); + console.log(iter) + epoch = iter[0] || 0 + console.log(task.module, dataset, '=>', epoch, task.epochs) + } else { + console.log(task.module, dataset, '=>', 'starting new training') + } return [ - '--dataroot', '/sequences/' + task.dataset, + '--dataroot', datasets_path, '--module_name', task.module, - '--name', task.dataset, - '--start_img', '/sequences/' + task.dataset + '/frame_00001.png', - '--how_many', 1000, - '--model', 'test', - '--aspect_ratio', 1.777777, - '--which_model_netG', 'unet_256', - '--which_direction', 'AtoB', - '--dataset_mode', 'test', - '--loadSize', 256, - '--fineSize', 256, - '--norm', 'batch' + '--name', dataset, + '--model', 'pix2pixHD', + '--label_nc', 0, '--no_instance', + '--niter', task.epochs, + '--niter_decay', 0, + '--save_epoch_freq', 1, + ] + }, + after: 'render', +} +const augment = { + type: 'pytorch', + script: 'augment.py', + params: (task) => { + let epoch = 0 + const dataset = task.dataset.toLowerCase() + const datasets_path = path.join(cwd, 'datasets', dataset) + const checkpoints_path = path.join(cwd, 'checkpoints', dataset) + // supply render_dir + return [ + '--dataroot', datasets_path, + '--results_dir', './recursive', + '--module_name', task.module, + '--name', dataset, + '--model', 'pix2pixHD', + '--label_nc', 0, '--no_instance', + '--augment-take', task.opt.augment_take, + '--augment-make', task.opt.augment_make, + '--augment-name', task.opt.augment_name, + '--which_epoch', task.opt.epoch, + ] + }, +} +const clear_recursive = { + type: 'pytorch', + script: 'clear_recursive.py', + params: (task) => { + const dataset = task.dataset.toLowerCase() + return [ + '--name', dataset, + '--epoch', epoch, ] }, } @@ -113,7 +157,6 @@ const live = { '--name', task.checkpoint, '--module_name', 'pix2pixHD', '--sequence-name', task.dataset, - '--recursive', '--recursive-frac', 0.1, '--sequence', '--sequence-frac', 0.3, '--process-frac', 0.5, '--label_nc', '0', '--no_instance', @@ -121,7 +164,7 @@ const live = { '--just-copy', '--poll_delay', opt.poll_delay || 0.09, '--which_epoch', 'latest', '--norm', 'batch', - '--store_b', // comment this line to store all live output + '--store_b', // uncomment this line to store all live output ] }, listen: (task, res, i) => { @@ -159,6 +202,8 @@ export default { build, train, generate, + augment, + clear_recursive, live, render, } |
