diff options
Diffstat (limited to 'app')
| -rw-r--r-- | app/client/common/augmentationGrid.component.js | 37 | ||||
| -rw-r--r-- | app/client/common/buttonGrid.component.js | 32 | ||||
| -rw-r--r-- | app/client/common/index.js | 3 | ||||
| -rw-r--r-- | app/client/modules/pix2pixhd/views/pix2pixhd.train.js | 104 |
4 files changed, 106 insertions, 70 deletions
diff --git a/app/client/common/augmentationGrid.component.js b/app/client/common/augmentationGrid.component.js new file mode 100644 index 0000000..69bdc8a --- /dev/null +++ b/app/client/common/augmentationGrid.component.js @@ -0,0 +1,37 @@ +import { h, Component } from 'preact' + +import Group from './group.component' +import Param from './param.component' +import Button from './button.component' +import ButtonGrid from './buttonGrid.component' + +export default class AugmentationGrid extends Component { + state = { + x: 0, y: 0, sum: 0, + } + render() { + let rows = [] + return ( + <Group className='augmentationGrid'> + <ButtonGrid + x={this.props.make} + y={this.props.take} + max={5000} + onHover={(x, y) => this.setState({ x, y })} + onClick={(x, y) => { + this.setState({ sum: this.state.sum + x * y }) + this.props.onAugment(y, x) + }} + /> + <Param title='Take'>{this.state.y}</Param> + <Param title='Make'>{this.state.x}</Param> + <Param title='Will add to dataset'>{this.state.x * this.state.y}</Param> + <Param title='Total added this epoch'>{this.state.sum}</Param> + <Button onClick={() => { + this.setState({ sum: 0 }) + this.props.onTrain() + }}>Train</Button> + </Group> + ) + } +} diff --git a/app/client/common/buttonGrid.component.js b/app/client/common/buttonGrid.component.js new file mode 100644 index 0000000..4b86d62 --- /dev/null +++ b/app/client/common/buttonGrid.component.js @@ -0,0 +1,32 @@ +import { h, Component } from 'preact' + +export default function ButtonGrid(props) { + const max = props.max || Infinity + return ( + <table className='buttonGrid'> + <tr className='row'> + <th>{" "}</th> + {props.x.map(x => ( + <th>{x}</th> + ))} + </tr> + {props.y.map(y => ( + <tr className='row'> + <th>{y}</th> + {props.x.map(x => ( + <td> + {x * y > max ? " " : + <button + onClick={() => props.onClick(x, y)} + onMouseEnter={() => props.onHover(x, y)} + > + {" "} + </button> + } + </td> + ))} + </tr> + ))} + </table> + ) +} diff --git a/app/client/common/index.js b/app/client/common/index.js index 13b3189..7448104 100644 --- a/app/client/common/index.js +++ b/app/client/common/index.js @@ -1,4 +1,6 @@ +import AugmentationGrid from './augmentationGrid.component' import Button from './button.component' +import ButtonGrid from './buttonGrid.component' import Checkbox from './checkbox.component' import CurrentTask from './currentTask.component' import { FileList, FileRow } from './fileList.component' @@ -29,4 +31,5 @@ export { TextInput, NumberInput, Slider, Select, SelectGroup, Button, Checkbox, CurrentTask, TaskList, + ButtonGrid, AugmentationGrid, }
\ No newline at end of file diff --git a/app/client/modules/pix2pixhd/views/pix2pixhd.train.js b/app/client/modules/pix2pixhd/views/pix2pixhd.train.js index 9c8aacc..05ad638 100644 --- a/app/client/modules/pix2pixhd/views/pix2pixhd.train.js +++ b/app/client/modules/pix2pixhd/views/pix2pixhd.train.js @@ -11,7 +11,8 @@ import { FileList, FileRow, Select, SelectGroup, Group, Param, Button, TextInput, NumberInput, - CurrentTask, TaskList + CurrentTask, TaskList, + AugmentationGrid, } from '../../../common' import DatasetForm from '../../../dataset/dataset.form' import NewDatasetForm from '../../../dataset/dataset.new' @@ -32,30 +33,6 @@ class Pix2PixHDTrain extends Component { augment_take: 100, augment_make: 20, } - this.short_presets = [ - { augment_take: 100, augment_make: 5 }, - { augment_take: 200, augment_make: 5 }, - { augment_take: 200, augment_make: 3 }, - { augment_take: 50, augment_make: 10 }, - { augment_take: 100, augment_make: 10 }, - { augment_take: 1000, augment_make: 1 }, - ] - this.medium_presets = [ - { augment_take: 30, augment_make: 20 }, - { augment_take: 20, augment_make: 30 }, - { augment_take: 30, augment_make: 30 }, - { augment_take: 50, augment_make: 20 }, - { augment_take: 20, augment_make: 50 }, - { augment_take: 15, augment_make: 70 }, - ] - this.long_presets = [ - { augment_take: 2, augment_make: 100 }, - { augment_take: 2, augment_make: 200 }, - { augment_take: 5, augment_make: 100 }, - { augment_take: 5, augment_make: 200 }, - { augment_take: 10, augment_make: 100 }, - { augment_take: 10, augment_make: 200 }, - ] } componentWillMount(){ const id = this.props.match.params.id || localStorage.getItem('pix2pixhd.last_id') @@ -84,7 +61,6 @@ class Pix2PixHDTrain extends Component { interrupt(){ this.props.actions.queue.stop_task('gpu') } - render(){ if (this.props.pix2pixhd.loading) { return <Loading progress={this.props.pix2pixhd.progress} /> @@ -164,52 +140,40 @@ class Pix2PixHDTrain extends Component { <Button title="Make a movie without augmenting" value="Generate" - onClick={() => this.props.remote.augment_task(this.state.checkpoint_name, { ...this.state, no_symlinks: true, mov: true, folder_id: this.props.pix2pixhd.data.resultsFolder.id })} - /> - </Group> - <Group title='Augmentation Presets'> - <Param title="Short Recursion"> - <div> - {this.short_presets.map(p => ( - <button onClick={() => this.props.remote.augment_task(this.state.checkpoint_name, p)}> - {p.augment_take}{'x'}{p.augment_make} - </button> - ))} - </div> - </Param> - <Param title="Medium Recursion"> - <div> - {this.medium_presets.map(p => ( - <button onClick={() => this.props.remote.augment_task(this.state.checkpoint_name, p)}> - {p.augment_take}{'x'}{p.augment_make} - </button> - ))} - </div> - </Param> - <Param title="Long Recursion"> - <div> - {this.long_presets.map(p => ( - <button onClick={() => this.props.remote.augment_task(this.state.checkpoint_name, p)}> - {p.augment_take}{'x'}{p.augment_make} - </button> - ))} - </div> - </Param> - </Group> - - <Group title='Train'> - <Button - title="Train one epoch" - value="Train" - onClick={() => this.props.remote.train_task(this.state.checkpoint_name, pix2pixhd.folder_id, 1)} + onClick={() => { + this.props.remote.augment_task(this.state.checkpoint_name, { + ...this.state, + no_symlinks: true, + mov: true, + folder_id: this.props.pix2pixhd.data.resultsFolder.id + }) + }} /> </Group> - - <Group title='Clear'> - <Button - title="Delete recursive frames" - value="Clear" - onClick={() => this.props.remote.clear_recursive_task(this.state.checkpoint_name)} + <Group title='Augmentation Grid'> + <AugmentationGrid + take={[1,2,3,4,5,10,15,20,25,50,75,100,200,300,400,500,1000]} + make={[1,2,3,4,5,10,15,20,25,50,75,100,200,]} + onAugment={(augment_take, augment_make) => { + this.props.remote.augment_task(this.state.checkpoint_name, { + ...this.state, + augment_take, + augment_make, + }) + }} + onTrain={() => { + this.props.remote.train_task(this.state.checkpoint_name, pix2pixhd.folder_id, 1) + setTimeout(() => { // auto-generate epoch demo + this.props.remote.augment_task(this.state.checkpoint_name, { + ...this.state, + augment_take: 10, + augment_make: 150, + no_symlinks: true, + mov: true, + folder_id: this.props.pix2pixhd.data.resultsFolder.id + }) + }, 250) + }} /> </Group> |
