summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJules Laplace <julescarbon@gmail.com>2018-09-05 11:58:34 +0200
committerJules Laplace <julescarbon@gmail.com>2018-09-05 11:58:34 +0200
commit15eb6806b6e216255f33abcb885f6cdbc38a7663 (patch)
tree9629d267611a48f9492f53c7ec74cf68c930deb4
parent7910767b7283e62f03dec5f86e08a796c792080f (diff)
relay stuff
-rw-r--r--app/client/modules/pix2pixhd/pix2pixhd.tasks.js18
-rw-r--r--app/client/modules/pix2pixhd/views/pix2pixhd.train.js13
-rw-r--r--app/relay/modules/pix2pixhd.js34
3 files changed, 35 insertions, 30 deletions
diff --git a/app/client/modules/pix2pixhd/pix2pixhd.tasks.js b/app/client/modules/pix2pixhd/pix2pixhd.tasks.js
index 92c0ff4..bd51f2b 100644
--- a/app/client/modules/pix2pixhd/pix2pixhd.tasks.js
+++ b/app/client/modules/pix2pixhd/pix2pixhd.tasks.js
@@ -53,34 +53,28 @@ export const live_task = (sequence, checkpoint, opt) => dispatch => {
return actions.queue.add_task(task)
}
-export const augment_task = (opt) => dispatch => {
+export const augment_task = (dataset, opt) => dispatch => {
const task = {
module: module.name,
activity: 'augment',
- dataset: sequence,
- checkpoint,
+ dataset,
opt: {
...opt,
- poll_delay: 0.01,
}
}
console.log(task)
- console.log('add live task')
+ console.log('add augment task')
return actions.queue.add_task(task)
}
-export const clear_recursive_task = (opt) => dispatch => {
+export const clear_recursive_task = (dataset) => dispatch => {
const task = {
module: module.name,
activity: 'clear_recursive',
- dataset: sequence,
- checkpoint,
- opt: {
- ...opt,
- }
+ dataset,
}
console.log(task)
- console.log('add live task')
+ console.log('add clear recursive task')
return actions.queue.add_task(task)
}
diff --git a/app/client/modules/pix2pixhd/views/pix2pixhd.train.js b/app/client/modules/pix2pixhd/views/pix2pixhd.train.js
index cd6507f..df3a1f2 100644
--- a/app/client/modules/pix2pixhd/views/pix2pixhd.train.js
+++ b/app/client/modules/pix2pixhd/views/pix2pixhd.train.js
@@ -28,6 +28,7 @@ class Pix2PixHDTrain extends Component {
this.state = {
checkpoint_name: '',
epoch: 'latest',
+ augment_name: '',
augment_take: 100,
augment_make: 20,
}
@@ -131,10 +132,16 @@ class Pix2PixHDTrain extends Component {
min="1"
max="1000"
/>
+ <TextInput
+ name="augment_name"
+ title="Tag this epoch"
+ value={this.state.augment_name}
+ onChange={this.handleChange}
+ />
<Button
title="Augment dataset"
value="Augment"
- onClick={() => remote.augment_task(dataset, pix2pixhd.folder_id, 1)}
+ onClick={() => remote.augment_task(this.state.checkpoint_name, this.state)}
/>
</Group>
@@ -142,7 +149,7 @@ class Pix2PixHDTrain extends Component {
<Button
title="Train one epoch"
value="Train"
- onClick={() => remote.train_task(dataset, pix2pixhd.folder_id, 1)}
+ onClick={() => remote.train_task(this.state.checkpoint_name, pix2pixhd.folder_id, 1)}
/>
</Group>
@@ -150,7 +157,7 @@ class Pix2PixHDTrain extends Component {
<Button
title="Delete recursive frames"
value="Clear"
- onClick={() => remote.clear_recursive_task(dataset, pix2pixhd.folder_id, 1)}
+ onClick={() => remote.clear_recursive_task(this.state.checkpoint_name)}
/>
</Group>
</div>
diff --git a/app/relay/modules/pix2pixhd.js b/app/relay/modules/pix2pixhd.js
index 73b49ca..a90fc15 100644
--- a/app/relay/modules/pix2pixhd.js
+++ b/app/relay/modules/pix2pixhd.js
@@ -110,31 +110,35 @@ const generate = {
}
const augment = {
type: 'pytorch',
- script: 'test.py',
+ script: 'augment.py',
params: (task) => {
let epoch = 0
const dataset = task.dataset.toLowerCase()
const datasets_path = path.join(cwd, 'datasets', dataset)
const checkpoints_path = path.join(cwd, 'checkpoints', dataset)
- const iter_txt = path.join(checkpoints_path, 'iter.txt')
- console.log(dataset, iter_txt)
- if (fs.existsSync(iter_txt)) {
- const iter = fs.readFileSync(iter_txt).toString().split('\n');
- console.log(iter)
- epoch = iter[0] || 0
- console.log(task.module, dataset, '=>', epoch, task.epochs)
- } else {
- console.log(task.module, dataset, '=>', 'starting new training')
- }
+ // supply render_dir
return [
'--dataroot', datasets_path,
+ '--results_dir', './recursive',
'--module_name', task.module,
'--name', dataset,
'--model', 'pix2pixHD',
'--label_nc', 0, '--no_instance',
- '--niter', task.epochs,
- '--niter_decay', 0,
- '--save_epoch_freq', 1,
+ '--augment-take', task.opt.augment_take,
+ '--augment-make', task.opt.augment_make,
+ '--augment-name', task.opt.augment_name,
+ '--which_epoch', task.opt.epoch,
+ ]
+ },
+}
+const clear_recursive = {
+ type: 'pytorch',
+ script: 'clear_recursive.py',
+ params: (task) => {
+ const dataset = task.dataset.toLowerCase()
+ return [
+ '--name', dataset,
+ '--epoch', epoch,
]
},
}
@@ -153,7 +157,6 @@ const live = {
'--name', task.checkpoint,
'--module_name', 'pix2pixHD',
'--sequence-name', task.dataset,
- '--recursive', '--recursive-frac', 0.1,
'--sequence', '--sequence-frac', 0.3,
'--process-frac', 0.5,
'--label_nc', '0', '--no_instance',
@@ -200,6 +203,7 @@ export default {
train,
generate,
augment,
+ clear_recursive,
live,
render,
}