summaryrefslogtreecommitdiff
path: root/app/client/modules
diff options
context:
space:
mode:
Diffstat (limited to 'app/client/modules')
-rw-r--r--app/client/modules/pix2pixhd/pix2pixhd.actions.js26
-rw-r--r--app/client/modules/pix2pixhd/pix2pixhd.reducer.js10
-rw-r--r--app/client/modules/pix2pixhd/views/pix2pixhd.train.js106
3 files changed, 72 insertions, 70 deletions
diff --git a/app/client/modules/pix2pixhd/pix2pixhd.actions.js b/app/client/modules/pix2pixhd/pix2pixhd.actions.js
index c1cd2b1..a17eeab 100644
--- a/app/client/modules/pix2pixhd/pix2pixhd.actions.js
+++ b/app/client/modules/pix2pixhd/pix2pixhd.actions.js
@@ -200,4 +200,30 @@ export const list_epochs = (checkpoint_name) => (dispatch) => {
},
})
})
+}
+
+export const count_dataset = (checkpoint_name) => (dispatch) => {
+ const module = pix2pixhdModule.name
+ util.allProgress([
+ actions.socket.count_directory({ module, dir: 'sequences/' + checkpoint_name + '/' }),
+ actions.socket.count_directory({ module, dir: 'datasets/' + checkpoint_name + '/train_A/' }),
+ ], (percent, i, n) => {
+ console.log('pix2pixhd load progress', i, n)
+ dispatch({
+ type: types.app.load_progress,
+ progress: { i, n },
+ data: { module: 'pix2pixhd' },
+ })
+ }).then(res => {
+ const [sequenceCount, datasetCount] = res //, datasets, results, output, datasetUsage, lossReport] = res
+ console.log(sequenceCount, datasetCount)
+ dispatch({
+ type: types.pix2pixhd.load_dataset_count,
+ data: {
+ name: checkpoint_name,
+ sequenceCount,
+ datasetCount,
+ }
+ })
+ })
} \ No newline at end of file
diff --git a/app/client/modules/pix2pixhd/pix2pixhd.reducer.js b/app/client/modules/pix2pixhd/pix2pixhd.reducer.js
index c3d52a3..5a2afc0 100644
--- a/app/client/modules/pix2pixhd/pix2pixhd.reducer.js
+++ b/app/client/modules/pix2pixhd/pix2pixhd.reducer.js
@@ -8,6 +8,11 @@ const pix2pixhdInitialState = {
folder_id: 0,
data: null,
results: null,
+ checkpoint: {
+ name: '',
+ sequenceCount: 0,
+ datasetCount: 0,
+ }
}
const pix2pixhdReducer = (state = pix2pixhdInitialState, action) => {
@@ -21,6 +26,11 @@ const pix2pixhdReducer = (state = pix2pixhdInitialState, action) => {
...state,
results: action.results,
}
+ case types.pix2pixhd.load_dataset_count:
+ return {
+ ...state,
+ checkpoint: action.data,
+ }
case types.file.destroy:
console.log('file destroy', state.results)
return {
diff --git a/app/client/modules/pix2pixhd/views/pix2pixhd.train.js b/app/client/modules/pix2pixhd/views/pix2pixhd.train.js
index 9c8aacc..06caa5a 100644
--- a/app/client/modules/pix2pixhd/views/pix2pixhd.train.js
+++ b/app/client/modules/pix2pixhd/views/pix2pixhd.train.js
@@ -11,7 +11,8 @@ import {
FileList, FileRow,
Select, SelectGroup, Group, Param, Button,
TextInput, NumberInput,
- CurrentTask, TaskList
+ CurrentTask, TaskList,
+ AugmentationGrid,
} from '../../../common'
import DatasetForm from '../../../dataset/dataset.form'
import NewDatasetForm from '../../../dataset/dataset.new'
@@ -32,30 +33,6 @@ class Pix2PixHDTrain extends Component {
augment_take: 100,
augment_make: 20,
}
- this.short_presets = [
- { augment_take: 100, augment_make: 5 },
- { augment_take: 200, augment_make: 5 },
- { augment_take: 200, augment_make: 3 },
- { augment_take: 50, augment_make: 10 },
- { augment_take: 100, augment_make: 10 },
- { augment_take: 1000, augment_make: 1 },
- ]
- this.medium_presets = [
- { augment_take: 30, augment_make: 20 },
- { augment_take: 20, augment_make: 30 },
- { augment_take: 30, augment_make: 30 },
- { augment_take: 50, augment_make: 20 },
- { augment_take: 20, augment_make: 50 },
- { augment_take: 15, augment_make: 70 },
- ]
- this.long_presets = [
- { augment_take: 2, augment_make: 100 },
- { augment_take: 2, augment_make: 200 },
- { augment_take: 5, augment_make: 100 },
- { augment_take: 5, augment_make: 200 },
- { augment_take: 10, augment_make: 100 },
- { augment_take: 10, augment_make: 200 },
- ]
}
componentWillMount(){
const id = this.props.match.params.id || localStorage.getItem('pix2pixhd.last_id')
@@ -75,6 +52,7 @@ class Pix2PixHDTrain extends Component {
if (prevState.checkpoint_name !== this.state.checkpoint_name) {
this.setState({ epoch: 'latest' })
this.props.actions.list_epochs(this.state.checkpoint_name)
+ this.props.actions.count_dataset(this.state.checkpoint_name)
}
}
handleChange(name, value){
@@ -84,7 +62,6 @@ class Pix2PixHDTrain extends Component {
interrupt(){
this.props.actions.queue.stop_task('gpu')
}
-
render(){
if (this.props.pix2pixhd.loading) {
return <Loading progress={this.props.pix2pixhd.progress} />
@@ -164,52 +141,41 @@ class Pix2PixHDTrain extends Component {
<Button
title="Make a movie without augmenting"
value="Generate"
- onClick={() => this.props.remote.augment_task(this.state.checkpoint_name, { ...this.state, no_symlinks: true, mov: true, folder_id: this.props.pix2pixhd.data.resultsFolder.id })}
- />
- </Group>
- <Group title='Augmentation Presets'>
- <Param title="Short Recursion">
- <div>
- {this.short_presets.map(p => (
- <button onClick={() => this.props.remote.augment_task(this.state.checkpoint_name, p)}>
- {p.augment_take}{'x'}{p.augment_make}
- </button>
- ))}
- </div>
- </Param>
- <Param title="Medium Recursion">
- <div>
- {this.medium_presets.map(p => (
- <button onClick={() => this.props.remote.augment_task(this.state.checkpoint_name, p)}>
- {p.augment_take}{'x'}{p.augment_make}
- </button>
- ))}
- </div>
- </Param>
- <Param title="Long Recursion">
- <div>
- {this.long_presets.map(p => (
- <button onClick={() => this.props.remote.augment_task(this.state.checkpoint_name, p)}>
- {p.augment_take}{'x'}{p.augment_make}
- </button>
- ))}
- </div>
- </Param>
- </Group>
-
- <Group title='Train'>
- <Button
- title="Train one epoch"
- value="Train"
- onClick={() => this.props.remote.train_task(this.state.checkpoint_name, pix2pixhd.folder_id, 1)}
+ onClick={() => {
+ this.props.remote.augment_task(this.state.checkpoint_name, {
+ ...this.state,
+ no_symlinks: true,
+ mov: true,
+ folder_id: this.props.pix2pixhd.data.resultsFolder.id
+ })
+ }}
/>
</Group>
-
- <Group title='Clear'>
- <Button
- title="Delete recursive frames"
- value="Clear"
- onClick={() => this.props.remote.clear_recursive_task(this.state.checkpoint_name)}
+ <Group title='Augmentation Grid'>
+ <AugmentationGrid
+ checkpoint={this.props.pix2pixhd.checkpoint}
+ take={[1,2,3,4,5,10,15,20,25,50,75,100,200,300,400,500,1000]}
+ make={[1,2,3,4,5,10,15,20,25,50,75,100,200,]}
+ onAugment={(augment_take, augment_make) => {
+ this.props.remote.augment_task(this.state.checkpoint_name, {
+ ...this.state,
+ augment_take,
+ augment_make,
+ })
+ }}
+ onTrain={() => {
+ this.props.remote.train_task(this.state.checkpoint_name, pix2pixhd.folder_id, 1)
+ setTimeout(() => { // auto-generate epoch demo
+ this.props.remote.augment_task(this.state.checkpoint_name, {
+ ...this.state,
+ augment_take: 10,
+ augment_make: 150,
+ no_symlinks: true,
+ mov: true,
+ folder_id: this.props.pix2pixhd.data.resultsFolder.id
+ })
+ }, 250)
+ }}
/>
</Group>