From 25815ab4916dc8c9e3256cbfe53bea0535930f30 Mon Sep 17 00:00:00 2001 From: Jules Laplace Date: Tue, 4 Sep 2018 13:53:30 +0200 Subject: fix loading bars --- app/relay/modules/pix2pixhd.js | 67 ++++++++++++++++++++++++++++++++++-------- 1 file changed, 54 insertions(+), 13 deletions(-) (limited to 'app/relay/modules/pix2pixhd.js') diff --git a/app/relay/modules/pix2pixhd.js b/app/relay/modules/pix2pixhd.js index 821be71..73b49ca 100644 --- a/app/relay/modules/pix2pixhd.js +++ b/app/relay/modules/pix2pixhd.js @@ -81,20 +81,60 @@ const generate = { type: 'pytorch', script: 'test.py', params: (task) => { + let epoch = 0 + const dataset = task.dataset.toLowerCase() + const datasets_path = path.join(cwd, 'datasets', dataset) + const checkpoints_path = path.join(cwd, 'checkpoints', dataset) + const iter_txt = path.join(checkpoints_path, 'iter.txt') + console.log(dataset, iter_txt) + if (fs.existsSync(iter_txt)) { + const iter = fs.readFileSync(iter_txt).toString().split('\n'); + console.log(iter) + epoch = iter[0] || 0 + console.log(task.module, dataset, '=>', epoch, task.epochs) + } else { + console.log(task.module, dataset, '=>', 'starting new training') + } return [ - '--dataroot', '/sequences/' + task.dataset, + '--dataroot', datasets_path, '--module_name', task.module, - '--name', task.dataset, - '--start_img', '/sequences/' + task.dataset + '/frame_00001.png', - '--how_many', 1000, - '--model', 'test', - '--aspect_ratio', 1.777777, - '--which_model_netG', 'unet_256', - '--which_direction', 'AtoB', - '--dataset_mode', 'test', - '--loadSize', 256, - '--fineSize', 256, - '--norm', 'batch' + '--name', dataset, + '--model', 'pix2pixHD', + '--label_nc', 0, '--no_instance', + '--niter', task.epochs, + '--niter_decay', 0, + '--save_epoch_freq', 1, + ] + }, + after: 'render', +} +const augment = { + type: 'pytorch', + script: 'test.py', + params: (task) => { + let epoch = 0 + const dataset = task.dataset.toLowerCase() + const datasets_path = path.join(cwd, 'datasets', dataset) + const checkpoints_path = path.join(cwd, 'checkpoints', dataset) + const iter_txt = path.join(checkpoints_path, 'iter.txt') + console.log(dataset, iter_txt) + if (fs.existsSync(iter_txt)) { + const iter = fs.readFileSync(iter_txt).toString().split('\n'); + console.log(iter) + epoch = iter[0] || 0 + console.log(task.module, dataset, '=>', epoch, task.epochs) + } else { + console.log(task.module, dataset, '=>', 'starting new training') + } + return [ + '--dataroot', datasets_path, + '--module_name', task.module, + '--name', dataset, + '--model', 'pix2pixHD', + '--label_nc', 0, '--no_instance', + '--niter', task.epochs, + '--niter_decay', 0, + '--save_epoch_freq', 1, ] }, } @@ -121,7 +161,7 @@ const live = { '--just-copy', '--poll_delay', opt.poll_delay || 0.09, '--which_epoch', 'latest', '--norm', 'batch', - '--store_b', // comment this line to store all live output + '--store_b', // uncomment this line to store all live output ] }, listen: (task, res, i) => { @@ -159,6 +199,7 @@ export default { build, train, generate, + augment, live, render, } -- cgit v1.2.3-70-g09d2