summaryrefslogtreecommitdiff
path: root/app/relay/modules/pix2pixhd.js
diff options
context:
space:
mode:
Diffstat (limited to 'app/relay/modules/pix2pixhd.js')
-rw-r--r--app/relay/modules/pix2pixhd.js67
1 files changed, 54 insertions, 13 deletions
diff --git a/app/relay/modules/pix2pixhd.js b/app/relay/modules/pix2pixhd.js
index 821be71..73b49ca 100644
--- a/app/relay/modules/pix2pixhd.js
+++ b/app/relay/modules/pix2pixhd.js
@@ -81,20 +81,60 @@ const generate = {
type: 'pytorch',
script: 'test.py',
params: (task) => {
+ let epoch = 0
+ const dataset = task.dataset.toLowerCase()
+ const datasets_path = path.join(cwd, 'datasets', dataset)
+ const checkpoints_path = path.join(cwd, 'checkpoints', dataset)
+ const iter_txt = path.join(checkpoints_path, 'iter.txt')
+ console.log(dataset, iter_txt)
+ if (fs.existsSync(iter_txt)) {
+ const iter = fs.readFileSync(iter_txt).toString().split('\n');
+ console.log(iter)
+ epoch = iter[0] || 0
+ console.log(task.module, dataset, '=>', epoch, task.epochs)
+ } else {
+ console.log(task.module, dataset, '=>', 'starting new training')
+ }
return [
- '--dataroot', '/sequences/' + task.dataset,
+ '--dataroot', datasets_path,
'--module_name', task.module,
- '--name', task.dataset,
- '--start_img', '/sequences/' + task.dataset + '/frame_00001.png',
- '--how_many', 1000,
- '--model', 'test',
- '--aspect_ratio', 1.777777,
- '--which_model_netG', 'unet_256',
- '--which_direction', 'AtoB',
- '--dataset_mode', 'test',
- '--loadSize', 256,
- '--fineSize', 256,
- '--norm', 'batch'
+ '--name', dataset,
+ '--model', 'pix2pixHD',
+ '--label_nc', 0, '--no_instance',
+ '--niter', task.epochs,
+ '--niter_decay', 0,
+ '--save_epoch_freq', 1,
+ ]
+ },
+ after: 'render',
+}
+const augment = {
+ type: 'pytorch',
+ script: 'test.py',
+ params: (task) => {
+ let epoch = 0
+ const dataset = task.dataset.toLowerCase()
+ const datasets_path = path.join(cwd, 'datasets', dataset)
+ const checkpoints_path = path.join(cwd, 'checkpoints', dataset)
+ const iter_txt = path.join(checkpoints_path, 'iter.txt')
+ console.log(dataset, iter_txt)
+ if (fs.existsSync(iter_txt)) {
+ const iter = fs.readFileSync(iter_txt).toString().split('\n');
+ console.log(iter)
+ epoch = iter[0] || 0
+ console.log(task.module, dataset, '=>', epoch, task.epochs)
+ } else {
+ console.log(task.module, dataset, '=>', 'starting new training')
+ }
+ return [
+ '--dataroot', datasets_path,
+ '--module_name', task.module,
+ '--name', dataset,
+ '--model', 'pix2pixHD',
+ '--label_nc', 0, '--no_instance',
+ '--niter', task.epochs,
+ '--niter_decay', 0,
+ '--save_epoch_freq', 1,
]
},
}
@@ -121,7 +161,7 @@ const live = {
'--just-copy', '--poll_delay', opt.poll_delay || 0.09,
'--which_epoch', 'latest',
'--norm', 'batch',
- '--store_b', // comment this line to store all live output
+ '--store_b', // uncomment this line to store all live output
]
},
listen: (task, res, i) => {
@@ -159,6 +199,7 @@ export default {
build,
train,
generate,
+ augment,
live,
render,
}