summaryrefslogtreecommitdiff
path: root/app/client/modules/pix2pixhd/views/pix2pixhd.train.js
diff options
context:
space:
mode:
Diffstat (limited to 'app/client/modules/pix2pixhd/views/pix2pixhd.train.js')
-rw-r--r--app/client/modules/pix2pixhd/views/pix2pixhd.train.js156
1 files changed, 156 insertions, 0 deletions
diff --git a/app/client/modules/pix2pixhd/views/pix2pixhd.train.js b/app/client/modules/pix2pixhd/views/pix2pixhd.train.js
new file mode 100644
index 0000000..8901ee8
--- /dev/null
+++ b/app/client/modules/pix2pixhd/views/pix2pixhd.train.js
@@ -0,0 +1,156 @@
+import { h, Component } from 'preact'
+import { bindActionCreators } from 'redux'
+import { connect } from 'react-redux'
+import util from '../../../util'
+
+import * as pix2pixhdActions from '../pix2pixhd.actions'
+import * as pix2pixhdTasks from '../pix2pixhd.tasks'
+
+import {
+ Loading,
+ FileList, FileRow,
+ Select, SelectGroup, Group, Button,
+ TextInput,
+ CurrentTask, TaskList
+} from '../../../common'
+import DatasetForm from '../../../dataset/dataset.form'
+import NewDatasetForm from '../../../dataset/dataset.new'
+import UploadStatus from '../../../dataset/upload.status'
+
+import DatasetComponent from '../../../dataset/dataset.component'
+
+import pix2pixhdModule from '../pix2pixhd.module'
+
+class Pix2PixHDTrain extends Component {
+ constructor(props){
+ super(props)
+ this.handleChange = this.handleChange.bind(this)
+ this.state = {
+ sequence: '',
+ epoch: 'latest',
+ augment_take: 100,
+ augment_make: 20,
+ }
+ }
+ componentWillMount(){
+ const id = this.props.match.params.id || localStorage.getItem('pix2pixhd.last_id')
+ console.log('load dataset:', id)
+ const { match, pix2pixhd, actions } = this.props
+ if (id === 'new') return
+ if (id) {
+ if (parseInt(id)) localStorage.setItem('pix2pixhd.last_id', id)
+ if (! pix2pixhd.folder || pix2pixhd.folder.id !== id) {
+ actions.load_directories(id)
+ }
+ } else {
+ this.props.history.push('/pix2pixhd/new/')
+ }
+ }
+ handleChange(value, name){
+ this.setState({ [name]: value })
+ }
+ changeCheckpoint(name){
+ // this.props.actions.list_epochs('pix2pixhd', nextProps.opt.checkpoint_name)
+ }
+ render(){
+ if (this.props.pix2pixhd.loading) {
+ return <Loading progress={this.props.pix2pixhd.progress} />
+ }
+ const { pix2pixhd, match, history, queue } = this.props
+ const { folderLookup, datasetLookup } = (pix2pixhd.data || {})
+ const folder = (folderLookup || {})[pix2pixhd.folder_id] || {}
+ console.log(pix2pixhd)
+
+ const checkpointGroups = Object.keys(folderLookup).map(id => {
+ const folder = this.props.pix2pixhd.data.folderLookup[id]
+ if (folder.name === 'results') return
+ const datasets = folder.datasets.map(name => {
+ const dataset = datasetLookup[name]
+ if (dataset.checkpoints.length) {
+ return name
+ }
+ return null
+ }).filter(n => !!n)
+ return {
+ name: folder.name,
+ options: datasets.sort(),
+ }
+ }).filter(n => !!n && !!n.options.length).sort((a,b) => a.name.localeCompare(b.name))
+
+ return (
+ <div className='app pix2pixhd'>
+ <h1>pix2pixhd training</h1>
+ <div class='heading'>
+ <SelectGroup live
+ name='checkpoint_name'
+ title='Checkpoint'
+ options={checkpointGroups}
+ onChange={this.changeCheckpoint}
+ />
+ <Select
+ title="Epoch"
+ value={this.state.epoch}
+ />
+ <br/>
+ <Group title='Augment'>
+ <TextInput
+ type="number"
+ name="augment_take"
+ title="Pick N random frames"
+ value={this.state.augment_take}
+ onInput={this.handleChange}
+ />
+ <TextInput
+ type="number"
+ name="augment_make"
+ title="Generate N recursively"
+ value={this.state.augment_make}
+ onInput={this.handleChange}
+ />
+ <Button
+ title="Augment dataset"
+ value="Augment"
+ onClick={() => remote.augment_task(dataset, pix2pixhd.folder_id, 1)}
+ />
+ </Group>
+
+ <Group title='Train'>
+ <Button
+ title="Train one epoch"
+ value="Train"
+ onClick={() => remote.train_task(dataset, pix2pixhd.folder_id, 1)}
+ />
+ </Group>
+
+ <Group title='Clear'>
+ <Button
+ title="Delete recursive frames"
+ value="Clear"
+ onClick={() => remote.clear_recursive_task(dataset, pix2pixhd.folder_id, 1)}
+ />
+ </Group>
+ </div>
+ <div>
+ <CurrentTask />
+ {!!queue.queue.length &&
+ <Group title='Upcoming Tasks'>
+ <TaskList tasks={queue.queue.map(id => queue.tasks[id])} />
+ </Group>
+ }
+ </div>
+ </div>
+ )
+ }
+}
+
+const mapStateToProps = state => ({
+ pix2pixhd: state.module.pix2pixhd,
+ queue: state.queue,
+})
+
+const mapDispatchToProps = (dispatch, ownProps) => ({
+ actions: bindActionCreators(pix2pixhdActions, dispatch),
+ remote: bindActionCreators(pix2pixhdTasks, dispatch),
+})
+
+export default connect(mapStateToProps, mapDispatchToProps)(Pix2PixHDTrain)