summaryrefslogtreecommitdiff
path: root/app/client/modules/pix2pixhd
diff options
context:
space:
mode:
authorjules@lens <julescarbon@gmail.com>2018-09-05 12:00:28 +0200
committerjules@lens <julescarbon@gmail.com>2018-09-05 12:00:28 +0200
commit9abfa16dc059d042c21f1636ecc8797ef29a030d (patch)
treed0583cb5dae01de1abc57ed8f7587d23242ed6f0 /app/client/modules/pix2pixhd
parent0a3c6743543dd3dfcb876f5ce735b72d050e981d (diff)
parent15eb6806b6e216255f33abcb885f6cdbc38a7663 (diff)
Merge branch 'master' of asdf.us:live-cortex
Diffstat (limited to 'app/client/modules/pix2pixhd')
-rw-r--r--app/client/modules/pix2pixhd/index.js5
-rw-r--r--app/client/modules/pix2pixhd/pix2pixhd.actions.js33
-rw-r--r--app/client/modules/pix2pixhd/pix2pixhd.tasks.js26
-rw-r--r--app/client/modules/pix2pixhd/views/pix2pixhd.live.js52
-rw-r--r--app/client/modules/pix2pixhd/views/pix2pixhd.show.js1
-rw-r--r--app/client/modules/pix2pixhd/views/pix2pixhd.train.js187
6 files changed, 294 insertions, 10 deletions
diff --git a/app/client/modules/pix2pixhd/index.js b/app/client/modules/pix2pixhd/index.js
index ea224f3..cbd3136 100644
--- a/app/client/modules/pix2pixhd/index.js
+++ b/app/client/modules/pix2pixhd/index.js
@@ -8,6 +8,7 @@ import util from '../../util'
import Pix2PixHDNew from './views/pix2pixhd.new'
import Pix2PixHDShow from './views/pix2pixhd.show'
import Pix2PixHDResults from './views/pix2pixhd.results'
+import Pix2PixHDTrain from './views/pix2pixhd.train'
import Pix2PixHDLive from './views/pix2pixhd.live'
class router {
@@ -25,6 +26,7 @@ class router {
<Route exact path='/pix2pixhd/new/' component={Pix2PixHDNew} />
<Route exact path='/pix2pixhd/sequences/' component={Pix2PixHDShow} />
<Route exact path='/pix2pixhd/sequences/:id/' component={Pix2PixHDShow} />
+ <Route exact path='/pix2pixhd/train/' component={Pix2PixHDTrain} />
<Route exact path='/pix2pixhd/results/' component={Pix2PixHDResults} />
<Route exact path='/pix2pixhd/live/' component={Pix2PixHDLive} />
</section>
@@ -34,8 +36,9 @@ class router {
function links(){
return [
- { url: '/pix2pixhd/new/', name: 'new' },
+ { url: '/pix2pixhd/new/', name: 'folders' },
{ url: '/pix2pixhd/sequences/', name: 'sequences' },
+ { url: '/pix2pixhd/train/', name: 'checkpoints' },
{ url: '/pix2pixhd/results/', name: 'results' },
{ url: '/pix2pixhd/live/', name: 'live' },
]
diff --git a/app/client/modules/pix2pixhd/pix2pixhd.actions.js b/app/client/modules/pix2pixhd/pix2pixhd.actions.js
index 8e481d3..c1cd2b1 100644
--- a/app/client/modules/pix2pixhd/pix2pixhd.actions.js
+++ b/app/client/modules/pix2pixhd/pix2pixhd.actions.js
@@ -21,7 +21,11 @@ export const load_directories = (id) => (dispatch) => {
// actions.socket.disk_usage({ module, dir: 'datasets' }),
], (percent, i, n) => {
console.log('pix2pixhd load progress', i, n)
- dispatch({ type: types.app.load_progress, progress: { i, n }})
+ dispatch({
+ type: types.app.load_progress,
+ progress: { i, n },
+ data: { module: 'pix2pixhd' },
+ })
}).then(res => {
const [datasetApiReport, sequences, datasets, checkpoints] = res //, datasets, results, output, datasetUsage, lossReport] = res
const {
@@ -157,7 +161,11 @@ export const load_results = (id) => (dispatch) => {
actions.socket.list_directory({ module, dir: 'renders' }),
], (percent, i, n) => {
console.log('pix2pixhd load progress', i, n)
- dispatch({ type: types.app.load_progress, progress: { i, n }})
+ dispatch({
+ type: types.app.load_progress,
+ progress: { i, n },
+ data: { module: 'pix2pixhd' },
+ })
}).then(res => {
const [folders, files, results, renders] = res //, datasets, results, output, datasetUsage, lossReport] = res
console.log(files, results, renders)
@@ -171,4 +179,25 @@ export const load_results = (id) => (dispatch) => {
}
})
})
+}
+
+const G_NET_REGEXP = new RegExp('_net_G.pth$')
+
+export const list_epochs = (checkpoint_name) => (dispatch) => {
+ const module = pix2pixhdModule.name
+ actions.socket.list_directory({ module, dir: 'checkpoints/' + checkpoint_name }).then(files => {
+ // console.log(files)
+ const epochs = files.map(f => {
+ if (!f.name.match(G_NET_REGEXP)) return null
+ return f.name.replace(G_NET_REGEXP, '')
+ }).filter(f => !!f)
+ // console.log(epochs)
+ dispatch({
+ type: types.dataset.list_epochs,
+ data: {
+ epochs,
+ module
+ },
+ })
+ })
} \ No newline at end of file
diff --git a/app/client/modules/pix2pixhd/pix2pixhd.tasks.js b/app/client/modules/pix2pixhd/pix2pixhd.tasks.js
index f3c5342..bd51f2b 100644
--- a/app/client/modules/pix2pixhd/pix2pixhd.tasks.js
+++ b/app/client/modules/pix2pixhd/pix2pixhd.tasks.js
@@ -52,3 +52,29 @@ export const live_task = (sequence, checkpoint, opt) => dispatch => {
console.log('add live task')
return actions.queue.add_task(task)
}
+
+export const augment_task = (dataset, opt) => dispatch => {
+ const task = {
+ module: module.name,
+ activity: 'augment',
+ dataset,
+ opt: {
+ ...opt,
+ }
+ }
+ console.log(task)
+ console.log('add augment task')
+ return actions.queue.add_task(task)
+}
+
+export const clear_recursive_task = (dataset) => dispatch => {
+ const task = {
+ module: module.name,
+ activity: 'clear_recursive',
+ dataset,
+ }
+ console.log(task)
+ console.log('add clear recursive task')
+ return actions.queue.add_task(task)
+}
+
diff --git a/app/client/modules/pix2pixhd/views/pix2pixhd.live.js b/app/client/modules/pix2pixhd/views/pix2pixhd.live.js
index b127e23..52b4b61 100644
--- a/app/client/modules/pix2pixhd/views/pix2pixhd.live.js
+++ b/app/client/modules/pix2pixhd/views/pix2pixhd.live.js
@@ -4,7 +4,7 @@ import { connect } from 'react-redux'
import {
ParamGroup, Param, Player, Group,
- Slider, Select, TextInput, Button, Loading
+ Slider, SelectGroup, Select, TextInput, Button, Loading
} from '../../../common/'
import { startRecording, stopRecording, saveFrame, toggleFPS } from '../../../live/player'
@@ -39,6 +39,7 @@ class Pix2PixHDLive extends Component {
}
componentWillUpdate(nextProps) {
if (nextProps.opt.checkpoint_name && nextProps.opt.checkpoint_name !== this.props.opt.checkpoint_name) {
+ console.log('listing epochs')
this.props.actions.live.list_epochs('pix2pixhd', nextProps.opt.checkpoint_name)
}
}
@@ -99,8 +100,47 @@ class Pix2PixHDLive extends Component {
render(){
// console.log(this.props)
if (this.props.pix2pixhd.loading) {
- return <Loading />
+ return <Loading progress={this.props.pix2pixhd.progress} />
}
+ const { folderLookup, datasetLookup, sequences } = this.props.pix2pixhd.data
+
+ const sequenceLookup = sequences.reduce((a,b) => {
+ a[b.name] = true
+ return a
+ }, {})
+
+ const sequenceGroups = Object.keys(folderLookup).map(id => {
+ const folder = this.props.pix2pixhd.data.folderLookup[id]
+ if (folder.name === 'results') return
+ const datasets = folder.datasets.map(name => {
+ const sequence = sequenceLookup[name]
+ if (sequence) {
+ return name
+ }
+ return null
+ }).filter(n => !!n)
+ return {
+ name: folder.name,
+ options: datasets.sort(),
+ }
+ }).filter(n => !!n && !!n.options.length).sort((a,b) => a.name.localeCompare(b.name))
+
+ const checkpointGroups = Object.keys(folderLookup).map(id => {
+ const folder = this.props.pix2pixhd.data.folderLookup[id]
+ if (folder.name === 'results') return
+ const datasets = folder.datasets.map(name => {
+ const dataset = datasetLookup[name]
+ if (dataset.checkpoints.length) {
+ return name
+ }
+ return null
+ }).filter(n => !!n)
+ return {
+ name: folder.name,
+ options: datasets.sort(),
+ }
+ }).filter(n => !!n && !!n.options.length).sort((a,b) => a.name.localeCompare(b.name))
+
return (
<div className='app live centered'>
<Player width={424} height={256} fullscreen={this.props.fullscreen} />
@@ -116,16 +156,16 @@ class Pix2PixHDLive extends Component {
options={['a','b','sequence','recursive']}
onChange={this.props.actions.live.set_param}
/>
- <Select live
+ <SelectGroup live
name='sequence_name'
title='sequence'
- options={this.props.pix2pixhd.data.sequences.map(file => file.name)}
+ options={sequenceGroups}
onChange={this.changeSequence}
/>
- <Select live
+ <SelectGroup live
name='checkpoint_name'
title='checkpoint'
- options={this.props.pix2pixhd.data.checkpoints.map(file => file.name)}
+ options={checkpointGroups}
onChange={this.changeCheckpoint}
/>
<Select live
diff --git a/app/client/modules/pix2pixhd/views/pix2pixhd.show.js b/app/client/modules/pix2pixhd/views/pix2pixhd.show.js
index d58ee80..3266d59 100644
--- a/app/client/modules/pix2pixhd/views/pix2pixhd.show.js
+++ b/app/client/modules/pix2pixhd/views/pix2pixhd.show.js
@@ -62,7 +62,6 @@ class Pix2PixHDShow extends Component {
</div>
</div>
-
<DatasetComponent
loading={pix2pixhd.loading}
progress={pix2pixhd.progress}
diff --git a/app/client/modules/pix2pixhd/views/pix2pixhd.train.js b/app/client/modules/pix2pixhd/views/pix2pixhd.train.js
new file mode 100644
index 0000000..df3a1f2
--- /dev/null
+++ b/app/client/modules/pix2pixhd/views/pix2pixhd.train.js
@@ -0,0 +1,187 @@
+import { h, Component } from 'preact'
+import { bindActionCreators } from 'redux'
+import { connect } from 'react-redux'
+import util from '../../../util'
+
+import * as pix2pixhdActions from '../pix2pixhd.actions'
+import * as pix2pixhdTasks from '../pix2pixhd.tasks'
+
+import {
+ Loading,
+ FileList, FileRow,
+ Select, SelectGroup, Group, Button,
+ TextInput, NumberInput,
+ CurrentTask, TaskList
+} from '../../../common'
+import DatasetForm from '../../../dataset/dataset.form'
+import NewDatasetForm from '../../../dataset/dataset.new'
+import UploadStatus from '../../../dataset/upload.status'
+
+import DatasetComponent from '../../../dataset/dataset.component'
+
+import pix2pixhdModule from '../pix2pixhd.module'
+
+class Pix2PixHDTrain extends Component {
+ constructor(props){
+ super(props)
+ this.handleChange = this.handleChange.bind(this)
+ this.state = {
+ checkpoint_name: '',
+ epoch: 'latest',
+ augment_name: '',
+ augment_take: 100,
+ augment_make: 20,
+ }
+ }
+ componentWillMount(){
+ const id = this.props.match.params.id || localStorage.getItem('pix2pixhd.last_id')
+ console.log('load dataset:', id)
+ const { match, pix2pixhd, actions } = this.props
+ if (id === 'new') return
+ if (id) {
+ if (parseInt(id)) localStorage.setItem('pix2pixhd.last_id', id)
+ if (! pix2pixhd.folder || pix2pixhd.folder.id !== id) {
+ actions.load_directories(id)
+ }
+ } else {
+ this.props.history.push('/pix2pixhd/new/')
+ }
+ }
+ componentDidUpdate(prevProps, prevState){
+ if (! prevProps.pix2pixhd.data && this.props.pix2pixhd.data) {
+ console.log('set checkpoint_name')
+ this.setState({
+ checkpoint_name: this.props.pix2pixhd.data.sequences[0].name,
+ epoch: 'latest'
+ })
+ }
+ else if (prevState.checkpoint_name !== this.state.checkpoint_name) {
+ this.setState({ epoch: 'latest' })
+ this.props.actions.list_epochs(this.state.checkpoint_name)
+ }
+ }
+ handleChange(name, value){
+ console.log('name', name, 'value', value)
+ this.setState({ [name]: value })
+ }
+ render(){
+ if (this.props.pix2pixhd.loading) {
+ return <Loading progress={this.props.pix2pixhd.progress} />
+ }
+ const { pix2pixhd, match, history, queue } = this.props
+ const { folderLookup, datasetLookup } = (pix2pixhd.data || {})
+ const folder = (folderLookup || {})[pix2pixhd.folder_id] || {}
+ // console.log(pix2pixhd)
+
+ const checkpointGroups = Object.keys(folderLookup).map(id => {
+ const folder = this.props.pix2pixhd.data.folderLookup[id]
+ if (folder.name === 'results') return
+ const datasets = folder.datasets.map(name => {
+ const dataset = datasetLookup[name]
+ if (dataset.checkpoints.length) {
+ return name
+ }
+ return null
+ }).filter(n => !!n)
+ return {
+ name: folder.name,
+ options: datasets.sort(),
+ }
+ }).filter(n => !!n && !!n.options.length).sort((a,b) => a.name.localeCompare(b.name))
+
+ console.log('state', this.props.pix2pixhd.data.epochs)
+ // console.log(this.state.checkpoint_name, this.state.epoch)
+
+ return (
+ <div className='app pix2pixhd'>
+ <div class='heading'>
+ <h1>pix2pixhd training</h1>
+ </div>
+ <div class='heading'>
+ <SelectGroup
+ name='checkpoint_name'
+ title='Checkpoint'
+ options={checkpointGroups}
+ onChange={this.handleChange}
+ value={this.state.checkpoint_name}
+ />
+ <Select
+ title="Epoch"
+ name="epoch"
+ options={this.props.pix2pixhd.data.epochs}
+ onChange={this.handleChange}
+ value={this.state.epoch}
+ />
+ <br/>
+ <Group title='Augment'>
+ <NumberInput
+ name="augment_take"
+ title="Pick N random frames"
+ value={this.state.augment_take}
+ onChange={this.handleChange}
+ type="int"
+ min="1"
+ max="1000"
+ />
+ <NumberInput
+ name="augment_make"
+ title="Generate N recursively"
+ value={this.state.augment_make}
+ onChange={this.handleChange}
+ type="int"
+ min="1"
+ max="1000"
+ />
+ <TextInput
+ name="augment_name"
+ title="Tag this epoch"
+ value={this.state.augment_name}
+ onChange={this.handleChange}
+ />
+ <Button
+ title="Augment dataset"
+ value="Augment"
+ onClick={() => remote.augment_task(this.state.checkpoint_name, this.state)}
+ />
+ </Group>
+
+ <Group title='Train'>
+ <Button
+ title="Train one epoch"
+ value="Train"
+ onClick={() => remote.train_task(this.state.checkpoint_name, pix2pixhd.folder_id, 1)}
+ />
+ </Group>
+
+ <Group title='Clear'>
+ <Button
+ title="Delete recursive frames"
+ value="Clear"
+ onClick={() => remote.clear_recursive_task(this.state.checkpoint_name)}
+ />
+ </Group>
+ </div>
+ <div>
+ <CurrentTask />
+ {!!queue.queue.length &&
+ <Group title='Upcoming Tasks'>
+ <TaskList tasks={queue.queue.map(id => queue.tasks[id])} />
+ </Group>
+ }
+ </div>
+ </div>
+ )
+ }
+}
+
+const mapStateToProps = state => ({
+ pix2pixhd: state.module.pix2pixhd,
+ queue: state.queue,
+})
+
+const mapDispatchToProps = (dispatch, ownProps) => ({
+ actions: bindActionCreators(pix2pixhdActions, dispatch),
+ remote: bindActionCreators(pix2pixhdTasks, dispatch),
+})
+
+export default connect(mapStateToProps, mapDispatchToProps)(Pix2PixHDTrain)