summaryrefslogtreecommitdiff
path: root/app
diff options
context:
space:
mode:
Diffstat (limited to 'app')
-rw-r--r--app/client/modules/pix2pixhd/views/pix2pixhd.train.js212
1 files changed, 127 insertions, 85 deletions
diff --git a/app/client/modules/pix2pixhd/views/pix2pixhd.train.js b/app/client/modules/pix2pixhd/views/pix2pixhd.train.js
index 4399a60..8ccc7f8 100644
--- a/app/client/modules/pix2pixhd/views/pix2pixhd.train.js
+++ b/app/client/modules/pix2pixhd/views/pix2pixhd.train.js
@@ -9,7 +9,7 @@ import * as pix2pixhdTasks from '../pix2pixhd.tasks'
import {
Loading,
FileList, FileRow,
- Select, SelectGroup, Group, Button,
+ Select, SelectGroup, Group, Param, Button,
TextInput, NumberInput,
CurrentTask, TaskList
} from '../../../common'
@@ -26,12 +26,36 @@ class Pix2PixHDTrain extends Component {
super(props)
this.handleChange = this.handleChange.bind(this)
this.state = {
- checkpoint_name: '',
+ checkpoint_name: 'PLACEHOLDER',
epoch: 'latest',
augment_name: '',
augment_take: 100,
augment_make: 20,
}
+ this.short_presets = [
+ { augment_take: 100, augment_make: 5 },
+ { augment_take: 200, augment_make: 5 },
+ { augment_take: 200, augment_make: 3 },
+ { augment_take: 50, augment_make: 10 },
+ { augment_take: 100, augment_make: 10 },
+ { augment_take: 1000, augment_make: 1 },
+ ]
+ this.medium_presets = [
+ { augment_take: 30, augment_make: 20 },
+ { augment_take: 20, augment_make: 30 },
+ { augment_take: 30, augment_make: 30 },
+ { augment_take: 50, augment_make: 20 },
+ { augment_take: 20, augment_make: 50 },
+ { augment_take: 15, augment_make: 70 },
+ ]
+ this.long_presets = [
+ { augment_take: 2, augment_make: 100 },
+ { augment_take: 2, augment_make: 200 },
+ { augment_take: 5, augment_make: 100 },
+ { augment_take: 5, augment_make: 200 },
+ { augment_take: 10, augment_make: 100 },
+ { augment_take: 10, augment_make: 200 },
+ ]
}
componentWillMount(){
const id = this.props.match.params.id || localStorage.getItem('pix2pixhd.last_id')
@@ -48,14 +72,7 @@ class Pix2PixHDTrain extends Component {
}
}
componentDidUpdate(prevProps, prevState){
- if (! prevProps.pix2pixhd.data && this.props.pix2pixhd.data) {
- console.log('set checkpoint_name')
- this.setState({
- checkpoint_name: this.props.pix2pixhd.data.sequences[0].name,
- epoch: 'latest'
- })
- }
- else if (prevState.checkpoint_name !== this.state.checkpoint_name) {
+ if (prevState.checkpoint_name !== this.state.checkpoint_name) {
this.setState({ epoch: 'latest' })
this.props.actions.list_epochs(this.state.checkpoint_name)
}
@@ -95,91 +112,116 @@ class Pix2PixHDTrain extends Component {
// console.log('state', this.props.pix2pixhd.data.epochs)
// console.log(this.state.checkpoint_name, this.state.epoch)
- console.log(queue)
+ // console.log(queue)
return (
<div className='app pix2pixhd'>
- <div class='heading'>
+ <div className='heading'>
<h1>pix2pixhd training</h1>
</div>
- <div class='heading'>
- <SelectGroup
- name='checkpoint_name'
- title='Checkpoint'
- options={checkpointGroups}
- onChange={this.handleChange}
- value={this.state.checkpoint_name}
- />
- <Select
- title="Epoch"
- name="epoch"
- options={this.props.pix2pixhd.data.epochs}
- onChange={this.handleChange}
- value={this.state.epoch}
- />
- <br/>
- <Group title='Augment'>
- <NumberInput
- name="augment_take"
- title="Pick N random frames"
- value={this.state.augment_take}
- onChange={this.handleChange}
- type="int"
- min="1"
- max="1000"
- />
- <NumberInput
- name="augment_make"
- title="Generate N recursively"
- value={this.state.augment_make}
- onChange={this.handleChange}
- type="int"
- min="1"
- max="1000"
- />
- <TextInput
- name="augment_name"
- title="Tag this epoch"
- value={this.state.augment_name}
- onChange={this.handleChange}
- className='small'
- />
- <Button
- title="Augment dataset"
- value="Augment"
- onClick={() => this.props.remote.augment_task(this.state.checkpoint_name, this.state)}
- />
- </Group>
-
- <Group title='Train'>
- <Button
- title="Train one epoch"
- value="Train"
- onClick={() => this.props.remote.train_task(this.state.checkpoint_name, pix2pixhd.folder_id, 1)}
- />
- </Group>
+ <div className='columns'>
+ <div className='column'>
+ <Group title='Dataset'>
+ <SelectGroup
+ name='checkpoint_name'
+ title='Dataset'
+ options={checkpointGroups}
+ onChange={this.handleChange}
+ placeholder='Pick a dataset'
+ value={this.state.checkpoint_name}
+ />
+ <Select
+ title="Epoch"
+ name="epoch"
+ options={this.props.pix2pixhd.data.epochs}
+ onChange={this.handleChange}
+ value={this.state.epoch}
+ />
+ </Group>
+ <Group title='Augment'>
+ <NumberInput
+ name="augment_take"
+ title="Pick N random frames"
+ value={this.state.augment_take}
+ onChange={this.handleChange}
+ type="int"
+ min="1"
+ max="1000"
+ />
+ <NumberInput
+ name="augment_make"
+ title="Generate N recursively"
+ value={this.state.augment_make}
+ onChange={this.handleChange}
+ type="int"
+ min="1"
+ max="1000"
+ />
+ <Button
+ title="Augment dataset"
+ value="Augment"
+ onClick={() => this.props.remote.augment_task(this.state.checkpoint_name, this.state)}
+ />
+ </Group>
+ <Group title='Augmentation Presets'>
+ <Param title="Short Recursion">
+ <div>
+ {this.short_presets.map(p => (
+ <button onClick={() => this.props.remote.augment_task(this.state.checkpoint_name, p)}>
+ {p.augment_take}{'x'}{p.augment_make}
+ </button>
+ ))}
+ </div>
+ </Param>
+ <Param title="Medium Recursion">
+ <div>
+ {this.medium_presets.map(p => (
+ <button onClick={() => this.props.remote.augment_task(this.state.checkpoint_name, p)}>
+ {p.augment_take}{'x'}{p.augment_make}
+ </button>
+ ))}
+ </div>
+ </Param>
+ <Param title="Long Recursion">
+ <div>
+ {this.long_presets.map(p => (
+ <button onClick={() => this.props.remote.augment_task(this.state.checkpoint_name, p)}>
+ {p.augment_take}{'x'}{p.augment_make}
+ </button>
+ ))}
+ </div>
+ </Param>
+ </Group>
- <Group title='Clear'>
- <Button
- title="Delete recursive frames"
- value="Clear"
- onClick={() => this.props.remote.clear_recursive_task(this.state.checkpoint_name)}
- />
- </Group>
+ <Group title='Train'>
+ <Button
+ title="Train one epoch"
+ value="Train"
+ onClick={() => this.props.remote.train_task(this.state.checkpoint_name, pix2pixhd.folder_id, 1)}
+ />
+ </Group>
- <Group title='Status'>
- <Button
- title="GPU"
- value={this.props.runner.gpu.status === 'IDLE' ? "Idle" : "Interrupt"}
- onClick={() => this.interrupt()}
- />
- <CurrentTask />
- </Group>
+ <Group title='Clear'>
+ <Button
+ title="Delete recursive frames"
+ value="Clear"
+ onClick={() => this.props.remote.clear_recursive_task(this.state.checkpoint_name)}
+ />
+ </Group>
- {!!queue.queue.length &&
+ <Group title='Status'>
+ <Button
+ title="GPU"
+ value={this.props.runner.gpu.status === 'IDLE' ? "Idle" : "Interrupt"}
+ onClick={() => this.interrupt()}
+ />
+ <CurrentTask />
+ </Group>
+ </div>
+ <div className='column'>
<Group title='Upcoming Tasks'>
<TaskList tasks={queue.queue.map(id => queue.tasks[id])} sort="date asc" />
</Group>
- }
+ </div>
</div>
</div>
)