diff options
| author | Jules Laplace <julescarbon@gmail.com> | 2018-09-05 11:58:34 +0200 |
|---|---|---|
| committer | Jules Laplace <julescarbon@gmail.com> | 2018-09-05 11:58:34 +0200 |
| commit | 15eb6806b6e216255f33abcb885f6cdbc38a7663 (patch) | |
| tree | 9629d267611a48f9492f53c7ec74cf68c930deb4 /app | |
| parent | 7910767b7283e62f03dec5f86e08a796c792080f (diff) | |
relay stuff
Diffstat (limited to 'app')
| -rw-r--r-- | app/client/modules/pix2pixhd/pix2pixhd.tasks.js | 18 | ||||
| -rw-r--r-- | app/client/modules/pix2pixhd/views/pix2pixhd.train.js | 13 | ||||
| -rw-r--r-- | app/relay/modules/pix2pixhd.js | 34 |
3 files changed, 35 insertions, 30 deletions
diff --git a/app/client/modules/pix2pixhd/pix2pixhd.tasks.js b/app/client/modules/pix2pixhd/pix2pixhd.tasks.js index 92c0ff4..bd51f2b 100644 --- a/app/client/modules/pix2pixhd/pix2pixhd.tasks.js +++ b/app/client/modules/pix2pixhd/pix2pixhd.tasks.js @@ -53,34 +53,28 @@ export const live_task = (sequence, checkpoint, opt) => dispatch => { return actions.queue.add_task(task) } -export const augment_task = (opt) => dispatch => { +export const augment_task = (dataset, opt) => dispatch => { const task = { module: module.name, activity: 'augment', - dataset: sequence, - checkpoint, + dataset, opt: { ...opt, - poll_delay: 0.01, } } console.log(task) - console.log('add live task') + console.log('add augment task') return actions.queue.add_task(task) } -export const clear_recursive_task = (opt) => dispatch => { +export const clear_recursive_task = (dataset) => dispatch => { const task = { module: module.name, activity: 'clear_recursive', - dataset: sequence, - checkpoint, - opt: { - ...opt, - } + dataset, } console.log(task) - console.log('add live task') + console.log('add clear recursive task') return actions.queue.add_task(task) } diff --git a/app/client/modules/pix2pixhd/views/pix2pixhd.train.js b/app/client/modules/pix2pixhd/views/pix2pixhd.train.js index cd6507f..df3a1f2 100644 --- a/app/client/modules/pix2pixhd/views/pix2pixhd.train.js +++ b/app/client/modules/pix2pixhd/views/pix2pixhd.train.js @@ -28,6 +28,7 @@ class Pix2PixHDTrain extends Component { this.state = { checkpoint_name: '', epoch: 'latest', + augment_name: '', augment_take: 100, augment_make: 20, } @@ -131,10 +132,16 @@ class Pix2PixHDTrain extends Component { min="1" max="1000" /> + <TextInput + name="augment_name" + title="Tag this epoch" + value={this.state.augment_name} + onChange={this.handleChange} + /> <Button title="Augment dataset" value="Augment" - onClick={() => remote.augment_task(dataset, pix2pixhd.folder_id, 1)} + onClick={() => remote.augment_task(this.state.checkpoint_name, this.state)} /> </Group> @@ -142,7 +149,7 @@ class Pix2PixHDTrain extends Component { <Button title="Train one epoch" value="Train" - onClick={() => remote.train_task(dataset, pix2pixhd.folder_id, 1)} + onClick={() => remote.train_task(this.state.checkpoint_name, pix2pixhd.folder_id, 1)} /> </Group> @@ -150,7 +157,7 @@ class Pix2PixHDTrain extends Component { <Button title="Delete recursive frames" value="Clear" - onClick={() => remote.clear_recursive_task(dataset, pix2pixhd.folder_id, 1)} + onClick={() => remote.clear_recursive_task(this.state.checkpoint_name)} /> </Group> </div> diff --git a/app/relay/modules/pix2pixhd.js b/app/relay/modules/pix2pixhd.js index 73b49ca..a90fc15 100644 --- a/app/relay/modules/pix2pixhd.js +++ b/app/relay/modules/pix2pixhd.js @@ -110,31 +110,35 @@ const generate = { } const augment = { type: 'pytorch', - script: 'test.py', + script: 'augment.py', params: (task) => { let epoch = 0 const dataset = task.dataset.toLowerCase() const datasets_path = path.join(cwd, 'datasets', dataset) const checkpoints_path = path.join(cwd, 'checkpoints', dataset) - const iter_txt = path.join(checkpoints_path, 'iter.txt') - console.log(dataset, iter_txt) - if (fs.existsSync(iter_txt)) { - const iter = fs.readFileSync(iter_txt).toString().split('\n'); - console.log(iter) - epoch = iter[0] || 0 - console.log(task.module, dataset, '=>', epoch, task.epochs) - } else { - console.log(task.module, dataset, '=>', 'starting new training') - } + // supply render_dir return [ '--dataroot', datasets_path, + '--results_dir', './recursive', '--module_name', task.module, '--name', dataset, '--model', 'pix2pixHD', '--label_nc', 0, '--no_instance', - '--niter', task.epochs, - '--niter_decay', 0, - '--save_epoch_freq', 1, + '--augment-take', task.opt.augment_take, + '--augment-make', task.opt.augment_make, + '--augment-name', task.opt.augment_name, + '--which_epoch', task.opt.epoch, + ] + }, +} +const clear_recursive = { + type: 'pytorch', + script: 'clear_recursive.py', + params: (task) => { + const dataset = task.dataset.toLowerCase() + return [ + '--name', dataset, + '--epoch', epoch, ] }, } @@ -153,7 +157,6 @@ const live = { '--name', task.checkpoint, '--module_name', 'pix2pixHD', '--sequence-name', task.dataset, - '--recursive', '--recursive-frac', 0.1, '--sequence', '--sequence-frac', 0.3, '--process-frac', 0.5, '--label_nc', '0', '--no_instance', @@ -200,6 +203,7 @@ export default { train, generate, augment, + clear_recursive, live, render, } |
