diff options
Diffstat (limited to 'app/client/modules')
| -rw-r--r-- | app/client/modules/pix2pixhd/pix2pixhd.actions.js | 26 | ||||
| -rw-r--r-- | app/client/modules/pix2pixhd/pix2pixhd.reducer.js | 10 | ||||
| -rw-r--r-- | app/client/modules/pix2pixhd/views/pix2pixhd.train.js | 106 |
3 files changed, 72 insertions, 70 deletions
diff --git a/app/client/modules/pix2pixhd/pix2pixhd.actions.js b/app/client/modules/pix2pixhd/pix2pixhd.actions.js index c1cd2b1..a17eeab 100644 --- a/app/client/modules/pix2pixhd/pix2pixhd.actions.js +++ b/app/client/modules/pix2pixhd/pix2pixhd.actions.js @@ -200,4 +200,30 @@ export const list_epochs = (checkpoint_name) => (dispatch) => { }, }) }) +} + +export const count_dataset = (checkpoint_name) => (dispatch) => { + const module = pix2pixhdModule.name + util.allProgress([ + actions.socket.count_directory({ module, dir: 'sequences/' + checkpoint_name + '/' }), + actions.socket.count_directory({ module, dir: 'datasets/' + checkpoint_name + '/train_A/' }), + ], (percent, i, n) => { + console.log('pix2pixhd load progress', i, n) + dispatch({ + type: types.app.load_progress, + progress: { i, n }, + data: { module: 'pix2pixhd' }, + }) + }).then(res => { + const [sequenceCount, datasetCount] = res //, datasets, results, output, datasetUsage, lossReport] = res + console.log(sequenceCount, datasetCount) + dispatch({ + type: types.pix2pixhd.load_dataset_count, + data: { + name: checkpoint_name, + sequenceCount, + datasetCount, + } + }) + }) }
\ No newline at end of file diff --git a/app/client/modules/pix2pixhd/pix2pixhd.reducer.js b/app/client/modules/pix2pixhd/pix2pixhd.reducer.js index c3d52a3..5a2afc0 100644 --- a/app/client/modules/pix2pixhd/pix2pixhd.reducer.js +++ b/app/client/modules/pix2pixhd/pix2pixhd.reducer.js @@ -8,6 +8,11 @@ const pix2pixhdInitialState = { folder_id: 0, data: null, results: null, + checkpoint: { + name: '', + sequenceCount: 0, + datasetCount: 0, + } } const pix2pixhdReducer = (state = pix2pixhdInitialState, action) => { @@ -21,6 +26,11 @@ const pix2pixhdReducer = (state = pix2pixhdInitialState, action) => { ...state, results: action.results, } + case types.pix2pixhd.load_dataset_count: + return { + ...state, + checkpoint: action.data, + } case types.file.destroy: console.log('file destroy', state.results) return { diff --git a/app/client/modules/pix2pixhd/views/pix2pixhd.train.js b/app/client/modules/pix2pixhd/views/pix2pixhd.train.js index 9c8aacc..06caa5a 100644 --- a/app/client/modules/pix2pixhd/views/pix2pixhd.train.js +++ b/app/client/modules/pix2pixhd/views/pix2pixhd.train.js @@ -11,7 +11,8 @@ import { FileList, FileRow, Select, SelectGroup, Group, Param, Button, TextInput, NumberInput, - CurrentTask, TaskList + CurrentTask, TaskList, + AugmentationGrid, } from '../../../common' import DatasetForm from '../../../dataset/dataset.form' import NewDatasetForm from '../../../dataset/dataset.new' @@ -32,30 +33,6 @@ class Pix2PixHDTrain extends Component { augment_take: 100, augment_make: 20, } - this.short_presets = [ - { augment_take: 100, augment_make: 5 }, - { augment_take: 200, augment_make: 5 }, - { augment_take: 200, augment_make: 3 }, - { augment_take: 50, augment_make: 10 }, - { augment_take: 100, augment_make: 10 }, - { augment_take: 1000, augment_make: 1 }, - ] - this.medium_presets = [ - { augment_take: 30, augment_make: 20 }, - { augment_take: 20, augment_make: 30 }, - { augment_take: 30, augment_make: 30 }, - { augment_take: 50, augment_make: 20 }, - { augment_take: 20, augment_make: 50 }, - { augment_take: 15, augment_make: 70 }, - ] - this.long_presets = [ - { augment_take: 2, augment_make: 100 }, - { augment_take: 2, augment_make: 200 }, - { augment_take: 5, augment_make: 100 }, - { augment_take: 5, augment_make: 200 }, - { augment_take: 10, augment_make: 100 }, - { augment_take: 10, augment_make: 200 }, - ] } componentWillMount(){ const id = this.props.match.params.id || localStorage.getItem('pix2pixhd.last_id') @@ -75,6 +52,7 @@ class Pix2PixHDTrain extends Component { if (prevState.checkpoint_name !== this.state.checkpoint_name) { this.setState({ epoch: 'latest' }) this.props.actions.list_epochs(this.state.checkpoint_name) + this.props.actions.count_dataset(this.state.checkpoint_name) } } handleChange(name, value){ @@ -84,7 +62,6 @@ class Pix2PixHDTrain extends Component { interrupt(){ this.props.actions.queue.stop_task('gpu') } - render(){ if (this.props.pix2pixhd.loading) { return <Loading progress={this.props.pix2pixhd.progress} /> @@ -164,52 +141,41 @@ class Pix2PixHDTrain extends Component { <Button title="Make a movie without augmenting" value="Generate" - onClick={() => this.props.remote.augment_task(this.state.checkpoint_name, { ...this.state, no_symlinks: true, mov: true, folder_id: this.props.pix2pixhd.data.resultsFolder.id })} - /> - </Group> - <Group title='Augmentation Presets'> - <Param title="Short Recursion"> - <div> - {this.short_presets.map(p => ( - <button onClick={() => this.props.remote.augment_task(this.state.checkpoint_name, p)}> - {p.augment_take}{'x'}{p.augment_make} - </button> - ))} - </div> - </Param> - <Param title="Medium Recursion"> - <div> - {this.medium_presets.map(p => ( - <button onClick={() => this.props.remote.augment_task(this.state.checkpoint_name, p)}> - {p.augment_take}{'x'}{p.augment_make} - </button> - ))} - </div> - </Param> - <Param title="Long Recursion"> - <div> - {this.long_presets.map(p => ( - <button onClick={() => this.props.remote.augment_task(this.state.checkpoint_name, p)}> - {p.augment_take}{'x'}{p.augment_make} - </button> - ))} - </div> - </Param> - </Group> - - <Group title='Train'> - <Button - title="Train one epoch" - value="Train" - onClick={() => this.props.remote.train_task(this.state.checkpoint_name, pix2pixhd.folder_id, 1)} + onClick={() => { + this.props.remote.augment_task(this.state.checkpoint_name, { + ...this.state, + no_symlinks: true, + mov: true, + folder_id: this.props.pix2pixhd.data.resultsFolder.id + }) + }} /> </Group> - - <Group title='Clear'> - <Button - title="Delete recursive frames" - value="Clear" - onClick={() => this.props.remote.clear_recursive_task(this.state.checkpoint_name)} + <Group title='Augmentation Grid'> + <AugmentationGrid + checkpoint={this.props.pix2pixhd.checkpoint} + take={[1,2,3,4,5,10,15,20,25,50,75,100,200,300,400,500,1000]} + make={[1,2,3,4,5,10,15,20,25,50,75,100,200,]} + onAugment={(augment_take, augment_make) => { + this.props.remote.augment_task(this.state.checkpoint_name, { + ...this.state, + augment_take, + augment_make, + }) + }} + onTrain={() => { + this.props.remote.train_task(this.state.checkpoint_name, pix2pixhd.folder_id, 1) + setTimeout(() => { // auto-generate epoch demo + this.props.remote.augment_task(this.state.checkpoint_name, { + ...this.state, + augment_take: 10, + augment_make: 150, + no_symlinks: true, + mov: true, + folder_id: this.props.pix2pixhd.data.resultsFolder.id + }) + }, 250) + }} /> </Group> |
