summaryrefslogtreecommitdiff
path: root/app/client/modules/pix2pixhd
diff options
context:
space:
mode:
Diffstat (limited to 'app/client/modules/pix2pixhd')
-rw-r--r--app/client/modules/pix2pixhd/index.js1
-rw-r--r--app/client/modules/pix2pixhd/pix2pixhd.actions.js12
-rw-r--r--app/client/modules/pix2pixhd/pix2pixhd.tasks.js32
-rw-r--r--app/client/modules/pix2pixhd/views/pix2pixhd.live.js3
-rw-r--r--app/client/modules/pix2pixhd/views/pix2pixhd.train.js156
5 files changed, 201 insertions, 3 deletions
diff --git a/app/client/modules/pix2pixhd/index.js b/app/client/modules/pix2pixhd/index.js
index b33ce00..cbd3136 100644
--- a/app/client/modules/pix2pixhd/index.js
+++ b/app/client/modules/pix2pixhd/index.js
@@ -26,6 +26,7 @@ class router {
<Route exact path='/pix2pixhd/new/' component={Pix2PixHDNew} />
<Route exact path='/pix2pixhd/sequences/' component={Pix2PixHDShow} />
<Route exact path='/pix2pixhd/sequences/:id/' component={Pix2PixHDShow} />
+ <Route exact path='/pix2pixhd/train/' component={Pix2PixHDTrain} />
<Route exact path='/pix2pixhd/results/' component={Pix2PixHDResults} />
<Route exact path='/pix2pixhd/live/' component={Pix2PixHDLive} />
</section>
diff --git a/app/client/modules/pix2pixhd/pix2pixhd.actions.js b/app/client/modules/pix2pixhd/pix2pixhd.actions.js
index 8e481d3..6459794 100644
--- a/app/client/modules/pix2pixhd/pix2pixhd.actions.js
+++ b/app/client/modules/pix2pixhd/pix2pixhd.actions.js
@@ -21,7 +21,11 @@ export const load_directories = (id) => (dispatch) => {
// actions.socket.disk_usage({ module, dir: 'datasets' }),
], (percent, i, n) => {
console.log('pix2pixhd load progress', i, n)
- dispatch({ type: types.app.load_progress, progress: { i, n }})
+ dispatch({
+ type: types.app.load_progress,
+ progress: { i, n },
+ data: { module: 'pix2pixhd' },
+ })
}).then(res => {
const [datasetApiReport, sequences, datasets, checkpoints] = res //, datasets, results, output, datasetUsage, lossReport] = res
const {
@@ -157,7 +161,11 @@ export const load_results = (id) => (dispatch) => {
actions.socket.list_directory({ module, dir: 'renders' }),
], (percent, i, n) => {
console.log('pix2pixhd load progress', i, n)
- dispatch({ type: types.app.load_progress, progress: { i, n }})
+ dispatch({
+ type: types.app.load_progress,
+ progress: { i, n },
+ data: { module: 'pix2pixhd' },
+ })
}).then(res => {
const [folders, files, results, renders] = res //, datasets, results, output, datasetUsage, lossReport] = res
console.log(files, results, renders)
diff --git a/app/client/modules/pix2pixhd/pix2pixhd.tasks.js b/app/client/modules/pix2pixhd/pix2pixhd.tasks.js
index f3c5342..92c0ff4 100644
--- a/app/client/modules/pix2pixhd/pix2pixhd.tasks.js
+++ b/app/client/modules/pix2pixhd/pix2pixhd.tasks.js
@@ -52,3 +52,35 @@ export const live_task = (sequence, checkpoint, opt) => dispatch => {
console.log('add live task')
return actions.queue.add_task(task)
}
+
+export const augment_task = (opt) => dispatch => {
+ const task = {
+ module: module.name,
+ activity: 'augment',
+ dataset: sequence,
+ checkpoint,
+ opt: {
+ ...opt,
+ poll_delay: 0.01,
+ }
+ }
+ console.log(task)
+ console.log('add live task')
+ return actions.queue.add_task(task)
+}
+
+export const clear_recursive_task = (opt) => dispatch => {
+ const task = {
+ module: module.name,
+ activity: 'clear_recursive',
+ dataset: sequence,
+ checkpoint,
+ opt: {
+ ...opt,
+ }
+ }
+ console.log(task)
+ console.log('add live task')
+ return actions.queue.add_task(task)
+}
+
diff --git a/app/client/modules/pix2pixhd/views/pix2pixhd.live.js b/app/client/modules/pix2pixhd/views/pix2pixhd.live.js
index 41ff7e5..52b4b61 100644
--- a/app/client/modules/pix2pixhd/views/pix2pixhd.live.js
+++ b/app/client/modules/pix2pixhd/views/pix2pixhd.live.js
@@ -100,7 +100,7 @@ class Pix2PixHDLive extends Component {
render(){
// console.log(this.props)
if (this.props.pix2pixhd.loading) {
- return <Loading />
+ return <Loading progress={this.props.pix2pixhd.progress} />
}
const { folderLookup, datasetLookup, sequences } = this.props.pix2pixhd.data
@@ -140,6 +140,7 @@ class Pix2PixHDLive extends Component {
options: datasets.sort(),
}
}).filter(n => !!n && !!n.options.length).sort((a,b) => a.name.localeCompare(b.name))
+
return (
<div className='app live centered'>
<Player width={424} height={256} fullscreen={this.props.fullscreen} />
diff --git a/app/client/modules/pix2pixhd/views/pix2pixhd.train.js b/app/client/modules/pix2pixhd/views/pix2pixhd.train.js
new file mode 100644
index 0000000..8901ee8
--- /dev/null
+++ b/app/client/modules/pix2pixhd/views/pix2pixhd.train.js
@@ -0,0 +1,156 @@
+import { h, Component } from 'preact'
+import { bindActionCreators } from 'redux'
+import { connect } from 'react-redux'
+import util from '../../../util'
+
+import * as pix2pixhdActions from '../pix2pixhd.actions'
+import * as pix2pixhdTasks from '../pix2pixhd.tasks'
+
+import {
+ Loading,
+ FileList, FileRow,
+ Select, SelectGroup, Group, Button,
+ TextInput,
+ CurrentTask, TaskList
+} from '../../../common'
+import DatasetForm from '../../../dataset/dataset.form'
+import NewDatasetForm from '../../../dataset/dataset.new'
+import UploadStatus from '../../../dataset/upload.status'
+
+import DatasetComponent from '../../../dataset/dataset.component'
+
+import pix2pixhdModule from '../pix2pixhd.module'
+
+class Pix2PixHDTrain extends Component {
+ constructor(props){
+ super(props)
+ this.handleChange = this.handleChange.bind(this)
+ this.state = {
+ sequence: '',
+ epoch: 'latest',
+ augment_take: 100,
+ augment_make: 20,
+ }
+ }
+ componentWillMount(){
+ const id = this.props.match.params.id || localStorage.getItem('pix2pixhd.last_id')
+ console.log('load dataset:', id)
+ const { match, pix2pixhd, actions } = this.props
+ if (id === 'new') return
+ if (id) {
+ if (parseInt(id)) localStorage.setItem('pix2pixhd.last_id', id)
+ if (! pix2pixhd.folder || pix2pixhd.folder.id !== id) {
+ actions.load_directories(id)
+ }
+ } else {
+ this.props.history.push('/pix2pixhd/new/')
+ }
+ }
+ handleChange(value, name){
+ this.setState({ [name]: value })
+ }
+ changeCheckpoint(name){
+ // this.props.actions.list_epochs('pix2pixhd', nextProps.opt.checkpoint_name)
+ }
+ render(){
+ if (this.props.pix2pixhd.loading) {
+ return <Loading progress={this.props.pix2pixhd.progress} />
+ }
+ const { pix2pixhd, match, history, queue } = this.props
+ const { folderLookup, datasetLookup } = (pix2pixhd.data || {})
+ const folder = (folderLookup || {})[pix2pixhd.folder_id] || {}
+ console.log(pix2pixhd)
+
+ const checkpointGroups = Object.keys(folderLookup).map(id => {
+ const folder = this.props.pix2pixhd.data.folderLookup[id]
+ if (folder.name === 'results') return
+ const datasets = folder.datasets.map(name => {
+ const dataset = datasetLookup[name]
+ if (dataset.checkpoints.length) {
+ return name
+ }
+ return null
+ }).filter(n => !!n)
+ return {
+ name: folder.name,
+ options: datasets.sort(),
+ }
+ }).filter(n => !!n && !!n.options.length).sort((a,b) => a.name.localeCompare(b.name))
+
+ return (
+ <div className='app pix2pixhd'>
+ <h1>pix2pixhd training</h1>
+ <div class='heading'>
+ <SelectGroup live
+ name='checkpoint_name'
+ title='Checkpoint'
+ options={checkpointGroups}
+ onChange={this.changeCheckpoint}
+ />
+ <Select
+ title="Epoch"
+ value={this.state.epoch}
+ />
+ <br/>
+ <Group title='Augment'>
+ <TextInput
+ type="number"
+ name="augment_take"
+ title="Pick N random frames"
+ value={this.state.augment_take}
+ onInput={this.handleChange}
+ />
+ <TextInput
+ type="number"
+ name="augment_make"
+ title="Generate N recursively"
+ value={this.state.augment_make}
+ onInput={this.handleChange}
+ />
+ <Button
+ title="Augment dataset"
+ value="Augment"
+ onClick={() => remote.augment_task(dataset, pix2pixhd.folder_id, 1)}
+ />
+ </Group>
+
+ <Group title='Train'>
+ <Button
+ title="Train one epoch"
+ value="Train"
+ onClick={() => remote.train_task(dataset, pix2pixhd.folder_id, 1)}
+ />
+ </Group>
+
+ <Group title='Clear'>
+ <Button
+ title="Delete recursive frames"
+ value="Clear"
+ onClick={() => remote.clear_recursive_task(dataset, pix2pixhd.folder_id, 1)}
+ />
+ </Group>
+ </div>
+ <div>
+ <CurrentTask />
+ {!!queue.queue.length &&
+ <Group title='Upcoming Tasks'>
+ <TaskList tasks={queue.queue.map(id => queue.tasks[id])} />
+ </Group>
+ }
+ </div>
+ </div>
+ )
+ }
+}
+
+const mapStateToProps = state => ({
+ pix2pixhd: state.module.pix2pixhd,
+ queue: state.queue,
+})
+
+const mapDispatchToProps = (dispatch, ownProps) => ({
+ actions: bindActionCreators(pix2pixhdActions, dispatch),
+ remote: bindActionCreators(pix2pixhdTasks, dispatch),
+})
+
+export default connect(mapStateToProps, mapDispatchToProps)(Pix2PixHDTrain)