summaryrefslogtreecommitdiff
path: root/app/client/modules/pix2pixhd
diff options
context:
space:
mode:
authorJules Laplace <julescarbon@gmail.com>2018-09-16 16:03:41 +0200
committerJules Laplace <julescarbon@gmail.com>2018-09-16 16:03:41 +0200
commit14652eecb0fb4ebcb14e830504bfb02017bd010e (patch)
treebe42c8cd41fdad6387554af0ba05a546fc3121a9 /app/client/modules/pix2pixhd
parentc7f0268ad3d02a72e3639e289ab706fef1bb2645 (diff)
augmentation grid
Diffstat (limited to 'app/client/modules/pix2pixhd')
-rw-r--r--app/client/modules/pix2pixhd/views/pix2pixhd.train.js104
1 files changed, 34 insertions, 70 deletions
diff --git a/app/client/modules/pix2pixhd/views/pix2pixhd.train.js b/app/client/modules/pix2pixhd/views/pix2pixhd.train.js
index 9c8aacc..05ad638 100644
--- a/app/client/modules/pix2pixhd/views/pix2pixhd.train.js
+++ b/app/client/modules/pix2pixhd/views/pix2pixhd.train.js
@@ -11,7 +11,8 @@ import {
FileList, FileRow,
Select, SelectGroup, Group, Param, Button,
TextInput, NumberInput,
- CurrentTask, TaskList
+ CurrentTask, TaskList,
+ AugmentationGrid,
} from '../../../common'
import DatasetForm from '../../../dataset/dataset.form'
import NewDatasetForm from '../../../dataset/dataset.new'
@@ -32,30 +33,6 @@ class Pix2PixHDTrain extends Component {
augment_take: 100,
augment_make: 20,
}
- this.short_presets = [
- { augment_take: 100, augment_make: 5 },
- { augment_take: 200, augment_make: 5 },
- { augment_take: 200, augment_make: 3 },
- { augment_take: 50, augment_make: 10 },
- { augment_take: 100, augment_make: 10 },
- { augment_take: 1000, augment_make: 1 },
- ]
- this.medium_presets = [
- { augment_take: 30, augment_make: 20 },
- { augment_take: 20, augment_make: 30 },
- { augment_take: 30, augment_make: 30 },
- { augment_take: 50, augment_make: 20 },
- { augment_take: 20, augment_make: 50 },
- { augment_take: 15, augment_make: 70 },
- ]
- this.long_presets = [
- { augment_take: 2, augment_make: 100 },
- { augment_take: 2, augment_make: 200 },
- { augment_take: 5, augment_make: 100 },
- { augment_take: 5, augment_make: 200 },
- { augment_take: 10, augment_make: 100 },
- { augment_take: 10, augment_make: 200 },
- ]
}
componentWillMount(){
const id = this.props.match.params.id || localStorage.getItem('pix2pixhd.last_id')
@@ -84,7 +61,6 @@ class Pix2PixHDTrain extends Component {
interrupt(){
this.props.actions.queue.stop_task('gpu')
}
-
render(){
if (this.props.pix2pixhd.loading) {
return <Loading progress={this.props.pix2pixhd.progress} />
@@ -164,52 +140,40 @@ class Pix2PixHDTrain extends Component {
<Button
title="Make a movie without augmenting"
value="Generate"
- onClick={() => this.props.remote.augment_task(this.state.checkpoint_name, { ...this.state, no_symlinks: true, mov: true, folder_id: this.props.pix2pixhd.data.resultsFolder.id })}
- />
- </Group>
- <Group title='Augmentation Presets'>
- <Param title="Short Recursion">
- <div>
- {this.short_presets.map(p => (
- <button onClick={() => this.props.remote.augment_task(this.state.checkpoint_name, p)}>
- {p.augment_take}{'x'}{p.augment_make}
- </button>
- ))}
- </div>
- </Param>
- <Param title="Medium Recursion">
- <div>
- {this.medium_presets.map(p => (
- <button onClick={() => this.props.remote.augment_task(this.state.checkpoint_name, p)}>
- {p.augment_take}{'x'}{p.augment_make}
- </button>
- ))}
- </div>
- </Param>
- <Param title="Long Recursion">
- <div>
- {this.long_presets.map(p => (
- <button onClick={() => this.props.remote.augment_task(this.state.checkpoint_name, p)}>
- {p.augment_take}{'x'}{p.augment_make}
- </button>
- ))}
- </div>
- </Param>
- </Group>
-
- <Group title='Train'>
- <Button
- title="Train one epoch"
- value="Train"
- onClick={() => this.props.remote.train_task(this.state.checkpoint_name, pix2pixhd.folder_id, 1)}
+ onClick={() => {
+ this.props.remote.augment_task(this.state.checkpoint_name, {
+ ...this.state,
+ no_symlinks: true,
+ mov: true,
+ folder_id: this.props.pix2pixhd.data.resultsFolder.id
+ })
+ }}
/>
</Group>
-
- <Group title='Clear'>
- <Button
- title="Delete recursive frames"
- value="Clear"
- onClick={() => this.props.remote.clear_recursive_task(this.state.checkpoint_name)}
+ <Group title='Augmentation Grid'>
+ <AugmentationGrid
+ take={[1,2,3,4,5,10,15,20,25,50,75,100,200,300,400,500,1000]}
+ make={[1,2,3,4,5,10,15,20,25,50,75,100,200,]}
+ onAugment={(augment_take, augment_make) => {
+ this.props.remote.augment_task(this.state.checkpoint_name, {
+ ...this.state,
+ augment_take,
+ augment_make,
+ })
+ }}
+ onTrain={() => {
+ this.props.remote.train_task(this.state.checkpoint_name, pix2pixhd.folder_id, 1)
+ setTimeout(() => { // auto-generate epoch demo
+ this.props.remote.augment_task(this.state.checkpoint_name, {
+ ...this.state,
+ augment_take: 10,
+ augment_make: 150,
+ no_symlinks: true,
+ mov: true,
+ folder_id: this.props.pix2pixhd.data.resultsFolder.id
+ })
+ }, 250)
+ }}
/>
</Group>