diff options
Diffstat (limited to 'app/relay/modules/pix2pixhd.js')
| -rw-r--r-- | app/relay/modules/pix2pixhd.js | 137 |
1 files changed, 135 insertions, 2 deletions
diff --git a/app/relay/modules/pix2pixhd.js b/app/relay/modules/pix2pixhd.js index 9aa30d0..c5f9be5 100644 --- a/app/relay/modules/pix2pixhd.js +++ b/app/relay/modules/pix2pixhd.js @@ -1,5 +1,138 @@ import path from 'path' +import fs from 'fs' + +const name = 'pix2pixhd' +const cwd = process.env.PIX2PIXHD_CWD || path.join(process.env.HOME, 'code/' + name + '/') + +const fetch = { + type: 'perl', + script: 'get.pl', + params: (task) => { + console.log(task) + return [ task.module, task.opt.url ] + }, + listen: (task, res, i) => { + // relay the new dataset name from youtube-dl or w/e + const lines = res.split('\n') + for (let line of lines) { + console.log(line) + if ( line.match(/^created dataset: /) ) { + let tag = line.split(': ')[1].trim() + task.dataset = tag + // task.opt.filename = filename + console.log(">>>>>> created dataset", tag) + return { type: 'progress', action: 'resolve_dataset', task } + } + } + return null + }, + after: 'build', +} +const build = { + type: 'perl', + script: 'build_dataset.pl', + params: (task) => { + return [ + task.dataset, + ] + } +} +const train = { + type: 'pytorch', + script: 'train.py', + params: (task) => { + let epoch = 0 + + const datasets_path = path.join(cwd, 'datasets', task.dataset) + const checkpoints_path = path.join(cwd, 'checkpoints', task.dataset) + if (fs.existsSync(checkpoints_path)) { + try { + const checkpoints = fs.readdirSync(checkpoints_path) + checkpoints.forEach(c => { + epoch = Math.max(parseInt(c.name) || 0, epoch) + }) + console.log(task.module, task.dataset, epoch, task.epochs) + } catch (e) { } + } + let args = [ + '--dataroot', datasets_path, + '--module_name', task.module, + '--name', task.dataset, + '--model', 'pix2pixhd', + '--label_nc', 0, '--no_instance', + '--niter', task.epochs, + '--niter_decay', 0, + ] + if (epoch) { + args = args.concat([ + '--epoch_count', task.epoch + task.epochs + 1, + '--which_epoch', 'latest', + '--continue_train', + ]) + } + return args + }, +} +const generate = { + type: 'pytorch', + script: 'test.py', + params: (task) => { + return [ + '--dataroot', '/sequences/' + task.module + '/' + task.dataset, + '--module_name', task.module, + '--name', task.dataset, + '--start_img', '/sequences/' + task.module + '/' + task.dataset + '/frame_00001.png', + '--how_many', 1000, + '--model', 'test', + '--aspect_ratio', 1.777777, + '--which_model_netG', 'unet_256', + '--which_direction', 'AtoB', + '--dataset_mode', 'test', + '--loadSize', 256, + '--fineSize', 256, + '--norm', 'batch' + ] + }, +} +const live = { + type: 'pytorch', + script: 'live-mogrify.py', + params: (task) => { + console.log(task) + const opt = task.opt || {} + return [ + '--dataroot', path.join(cwd, 'sequences', task.module, task.dataset), + '--start_img', path.join(cwd, 'sequences', task.module, task.dataset, 'frame_00001.png'), + '--checkpoint-name', task.checkpoint, + '--experiment', task.checkpoint, + '--name', task.checkpoint, + '--module_name', task.module, + '--sequence-name', task.dataset, + '--recursive', '--recursive-frac', 0.1, + '--sequence', '--sequence-frac', 0.3, + '--process-frac', 0.5, + '--nThreads', 0, + '--transition-min', 0.05, + '--how_many', 1000000, '--transition-period', 1000, + '--loadSize', 256, '--fineSize', 256, + '--just-copy', '--poll_delay', opt.poll_delay || 0.09, + '--model', 'test', + '--which_model_netG', 'unet_256', + '--which_direction', 'AtoB', + '--dataset_mode', 'recursive', + '--which_epoch', 'latest', + '--norm', 'batch', + ] + }, +} export default { - enabled: false, -}
\ No newline at end of file + name, cwd, + activities: { + fetch, + build, + train, + generate, + live, + } +} |
