summaryrefslogtreecommitdiff
path: root/app
diff options
context:
space:
mode:
authorjules@lens <julescarbon@gmail.com>2018-09-05 12:00:28 +0200
committerjules@lens <julescarbon@gmail.com>2018-09-05 12:00:28 +0200
commit9abfa16dc059d042c21f1636ecc8797ef29a030d (patch)
treed0583cb5dae01de1abc57ed8f7587d23242ed6f0 /app
parent0a3c6743543dd3dfcb876f5ce735b72d050e981d (diff)
parent15eb6806b6e216255f33abcb885f6cdbc38a7663 (diff)
Merge branch 'master' of asdf.us:live-cortex
Diffstat (limited to 'app')
-rw-r--r--app/client/common/index.js5
-rw-r--r--app/client/common/numberInput.component.js50
-rw-r--r--app/client/common/selectGroup.component.js67
-rw-r--r--app/client/common/textInput.component.js7
-rw-r--r--app/client/dashboard/dashboard.actions.js6
-rw-r--r--app/client/dashboard/dashboard.component.js12
-rw-r--r--app/client/dashboard/dashboard.reducer.js10
-rw-r--r--app/client/dashboard/dashboardHeader.component.js1
-rw-r--r--app/client/dataset/dataset.reducer.js10
-rw-r--r--app/client/live/live.reducer.js9
-rw-r--r--app/client/modules/morph/morph.actions.js6
-rw-r--r--app/client/modules/pix2pixhd/index.js5
-rw-r--r--app/client/modules/pix2pixhd/pix2pixhd.actions.js33
-rw-r--r--app/client/modules/pix2pixhd/pix2pixhd.tasks.js26
-rw-r--r--app/client/modules/pix2pixhd/views/pix2pixhd.live.js52
-rw-r--r--app/client/modules/pix2pixhd/views/pix2pixhd.show.js1
-rw-r--r--app/client/modules/pix2pixhd/views/pix2pixhd.train.js187
-rw-r--r--app/client/modules/samplernn/samplernn.actions.js12
-rw-r--r--app/client/socket/socket.actions.js25
-rw-r--r--app/client/socket/socket.live.js14
-rw-r--r--app/client/types.js1
-rw-r--r--app/relay/modules/pix2pixhd.js73
22 files changed, 556 insertions, 56 deletions
diff --git a/app/client/common/index.js b/app/client/common/index.js
index 3981fa7..eeb8bfc 100644
--- a/app/client/common/index.js
+++ b/app/client/common/index.js
@@ -8,11 +8,13 @@ import Gallery from './gallery.component'
import Group from './group.component'
import Header from './header.component'
import Loading from './loading.component'
+import NumberInput from './numberInput.component'
import Param from './param.component'
import ParamGroup from './paramGroup.component'
import Player from './player.component'
import Progress from './progress.component'
import Select from './select.component'
+import SelectGroup from './selectGroup.component'
import Slider from './slider.component'
import TextInput from './textInput.component'
import * as Views from './views'
@@ -23,6 +25,7 @@ export {
FolderList, FileList, FileRow, FileUpload,
Gallery, Player,
Group, ParamGroup, Param,
- TextInput, Slider, Select, Button, Checkbox,
+ TextInput, NumberInput,
+ Slider, Select, SelectGroup, Button, Checkbox,
CurrentTask,
} \ No newline at end of file
diff --git a/app/client/common/numberInput.component.js b/app/client/common/numberInput.component.js
new file mode 100644
index 0000000..43f9878
--- /dev/null
+++ b/app/client/common/numberInput.component.js
@@ -0,0 +1,50 @@
+import { h, Component } from 'preact'
+
+class NumberInput extends Component {
+ constructor(props){
+ super(props)
+ this.state = { value: null, changed: false }
+ this.handleInput = this.handleInput.bind(this)
+ this.handleKeydown = this.handleKeydown.bind(this)
+ }
+ handleInput(e){
+ this.setState({
+ value: e.target.value,
+ changed: true,
+ })
+ this.props.onInput && this.props.onInput(e.target.value, e.target.name)
+ this.props.onChange && this.props.onInput(e.target.name, e.target.value)
+ }
+ handleKeydown(e){
+ if (e.keyCode === 13) {
+ this.setState({
+ value: e.target.value,
+ changed: false,
+ })
+ this.props.onSave && this.props.onSave(e.target.value, e.target.name)
+ }
+ }
+ render() {
+ return (
+ <div className='numberInput param'>
+ <label>
+ <span>{this.props.title}</span>
+ <input
+ type={'number'}
+ name={this.props.name || 'number'}
+ value={this.state.changed ? this.state.value : this.props.value}
+ onInput={this.handleInput}
+ onKeydown={this.handleKeydown}
+ placeholder={this.props.placeholder}
+ autofocus={this.props.autofocus}
+ min={this.props.min}
+ max={this.props.max}
+ step={this.props.step || this.props.type === 'int' ? 1 : 0.01}
+ />
+ </label>
+ </div>
+ )
+ }
+}
+
+export default NumberInput
diff --git a/app/client/common/selectGroup.component.js b/app/client/common/selectGroup.component.js
new file mode 100644
index 0000000..5c1af51
--- /dev/null
+++ b/app/client/common/selectGroup.component.js
@@ -0,0 +1,67 @@
+import { h, Component } from 'preact'
+import { connect } from 'react-redux'
+import { bindActionCreators } from 'redux'
+
+class SelectGroup extends Component {
+ constructor(props){
+ super(props)
+ this.handleChange = this.handleChange.bind(this)
+ }
+ handleChange(e){
+ clearTimeout(this.timeout)
+ let new_value = e.target.value
+ if (new_value === 'PLACEHOLDER') return
+ this.props.onChange && this.props.onChange(this.props.name, new_value)
+ }
+ render() {
+ const currentValue = this.props.live ? this.props.opt[this.props.name] : this.props.value
+ let lastValue
+ const options = (this.props.options || []).map((group, i) => {
+ const groupName = group.name
+ const children = group.options.map(key => {
+ let name = key.length < 2 ? key.toUpperCase() : key
+ name = name.replace(/_/g, ' ')
+ let value = key
+ lastValue = value
+ return (
+ <option value={value} key={value}>
+ {name}
+ </option>
+ )
+ })
+ return (
+ <optgroup label={groupName} key={groupName}>
+ {children}
+ </optgroup>
+ )
+ })
+ return (
+ <div className='select param'>
+ <label>
+ <span>{this.props.title}</span>
+ <select
+ onChange={this.handleChange}
+ value={currentValue || lastValue}
+ >
+ {this.props.placeholder && <option value="PLACEHOLDER">{this.props.placeholder}</option>}
+ {options}
+ </select>
+ </label>
+ {this.props.children}
+ </div>
+ )
+ }
+}
+
+function capitalize(s){
+ return (s || "").replace(/(?:^|\s)\S/g, function(a) { return a.toUpperCase(); });
+}
+
+const mapStateToProps = (state, props) => ({
+ opt: props.opt || state.live.opt,
+})
+
+const mapDispatchToProps = (dispatch, ownProps) => ({
+})
+
+export default connect(mapStateToProps, mapDispatchToProps)(SelectGroup)
diff --git a/app/client/common/textInput.component.js b/app/client/common/textInput.component.js
index a3739d4..44e1349 100644
--- a/app/client/common/textInput.component.js
+++ b/app/client/common/textInput.component.js
@@ -12,7 +12,7 @@ class TextInput extends Component {
value: e.target.value,
changed: true,
})
- this.props.onInput && this.props.onInput(e.target.value)
+ this.props.onInput && this.props.onInput(e.target.value, e.target.name)
}
handleKeydown(e){
if (e.keyCode === 13) {
@@ -20,7 +20,7 @@ class TextInput extends Component {
value: e.target.value,
changed: false,
})
- this.props.onSave && this.props.onSave(e.target.value)
+ this.props.onSave && this.props.onSave(e.target.value, e.target.name)
}
}
render() {
@@ -29,7 +29,8 @@ class TextInput extends Component {
<label>
<span>{this.props.title}</span>
<input
- type='text'
+ type={this.props.type || 'text'}
+ name={this.props.name || 'text'}
value={this.state.changed ? this.state.value : this.props.value}
onInput={this.handleInput}
onKeydown={this.handleKeydown}
diff --git a/app/client/dashboard/dashboard.actions.js b/app/client/dashboard/dashboard.actions.js
index 8b5502a..c428d0b 100644
--- a/app/client/dashboard/dashboard.actions.js
+++ b/app/client/dashboard/dashboard.actions.js
@@ -11,7 +11,11 @@ export const load = () => (dispatch) => {
actions.file.index({ module: 'morph', generated: 1, limit: 15, orderBy: 'created_at desc', }),
], (percent, i, n) => {
// console.log('dashboard load progress', i, n)
- dispatch({ type: types.app.load_progress, progress: { i, n }})
+ dispatch({
+ type: types.app.load_progress,
+ progress: { i, n },
+ data: { module: 'dashboard' }
+ })
}).then(res => {
const [ tasks, folders, samplernn, pix2pixhd, morph ] = res
const { mapFn, sortFn } = util.sort.orderByFn('date desc')
diff --git a/app/client/dashboard/dashboard.component.js b/app/client/dashboard/dashboard.component.js
index 3c9b2de..0c15f99 100644
--- a/app/client/dashboard/dashboard.component.js
+++ b/app/client/dashboard/dashboard.component.js
@@ -11,7 +11,7 @@ import Button from '../common/button.component'
import DashboardHeader from './dashboardheader.component'
import TaskList from './tasklist.component'
-import { FolderList, FileList } from '../common'
+import { Loading, FolderList, FileList } from '../common'
import Gallery from '../common/gallery.component'
import * as dashboardActions from './dashboard.actions'
@@ -22,7 +22,7 @@ import actions from '../actions'
class Dashboard extends Component {
constructor(props){
super()
- console.log(props)
+ // console.log(props)
props.actions.load()
}
componentWillUpdate(nextProps) {
@@ -30,9 +30,11 @@ class Dashboard extends Component {
// this.props.actions.list_epochs(nextProps.opt.checkpoint_name)
}
render(){
- const { site, foldersByModule, renders, queue, images } = this.props
+ const { loading, progress, site, foldersByModule, renders, queue, images } = this.props
+ if (loading) {
+ return <Loading progress={progress} />
+ }
const { tasks } = queue
- console.log(foldersByModule)
const folders = foldersByModule && Object.keys(modules).sort().map(key => {
let path = key === 'samplernn' ? '/samplernn/datasets/' : '/' + key + '/sequences/'
let folder_list = (foldersByModule[key] || []).map(folder => {
@@ -99,6 +101,8 @@ class Dashboard extends Component {
}
}
const mapStateToProps = state => ({
+ loading: state.dashboard.loading,
+ progress: state.dashboard.progress,
site: state.system.site,
foldersByModule: state.dashboard.data.foldersByModule,
renders: state.dashboard.data.renders,
diff --git a/app/client/dashboard/dashboard.reducer.js b/app/client/dashboard/dashboard.reducer.js
index 812a501..f10f3fa 100644
--- a/app/client/dashboard/dashboard.reducer.js
+++ b/app/client/dashboard/dashboard.reducer.js
@@ -5,6 +5,7 @@ let FileSaver = require('file-saver')
const dashboardInitialState = {
loading: false,
+ progress: null,
error: null,
data: {},
images: [
@@ -15,6 +16,15 @@ const dashboardInitialState = {
const dashboardReducer = (state = dashboardInitialState, action) => {
switch(action.type) {
+ case types.app.load_progress:
+ if (!action.data || action.data.module !== 'dashboard') {
+ return state
+ }
+ return {
+ ...state,
+ loading: true,
+ progress: action.progress,
+ }
case types.dashboard.load:
return {
...state,
diff --git a/app/client/dashboard/dashboardHeader.component.js b/app/client/dashboard/dashboardHeader.component.js
index 5f1306c..063cd47 100644
--- a/app/client/dashboard/dashboardHeader.component.js
+++ b/app/client/dashboard/dashboardHeader.component.js
@@ -23,7 +23,6 @@ class DashboardHeader extends Component {
)
}
renderStatus(name, gpu){
- console.log(gpu)
if (gpu.status === 'IDLE') {
return <div>{name} idle</div>
}
diff --git a/app/client/dataset/dataset.reducer.js b/app/client/dataset/dataset.reducer.js
index 10b4a94..f303a7f 100644
--- a/app/client/dataset/dataset.reducer.js
+++ b/app/client/dataset/dataset.reducer.js
@@ -8,6 +8,7 @@ import types from '../types'
const datasetInitialState = () => ({
loading: true,
+ progress: null,
error: null,
data: null,
folder_id: 0,
@@ -57,6 +58,15 @@ const datasetReducer = (state = datasetInitialState(), action) => {
case types.file.destroy:
return handleFileDestroy(state, action)
+ case types.dataset.list_epochs:
+ return {
+ ...state,
+ data: {
+ ...state.data,
+ epochs: action.data.epochs,
+ }
+ }
+
default:
return state
}
diff --git a/app/client/live/live.reducer.js b/app/client/live/live.reducer.js
index 7fd7667..dfdf7ee 100644
--- a/app/client/live/live.reducer.js
+++ b/app/client/live/live.reducer.js
@@ -11,6 +11,7 @@ const liveInitialState = {
recurse_roll: 0, rotate: 0, scale: 0, process_frac: 0.5,
view_mode: 'b',
},
+ all_checkpoints: [],
checkpoints: [],
epochs: ['latest'],
sequences: [],
@@ -57,7 +58,15 @@ const liveReducer = (state = liveInitialState, action) => {
epochs: [],
}
+ case types.socket.list_all_checkpoints:
+ return {
+ ...state,
+ all_checkpoints: action.all_checkpoints,
+ epochs: [],
+ }
+
case types.socket.list_epochs:
+ console.log(action)
if (action.epochs === "not found") return { ...state, epochs: [] }
return {
...state,
diff --git a/app/client/modules/morph/morph.actions.js b/app/client/modules/morph/morph.actions.js
index 04f452d..6586778 100644
--- a/app/client/modules/morph/morph.actions.js
+++ b/app/client/modules/morph/morph.actions.js
@@ -19,7 +19,11 @@ export const load_data = (id) => (dispatch) => {
actions.socket.list_directory({ module, dir: 'renders' }),
], (percent, i, n) => {
console.log('morph load progress', i, n)
- dispatch({ type: types.app.load_progress, progress: { i, n }})
+ dispatch({
+ type: types.app.load_progress,
+ progress: { i, n },
+ data: { module: 'morph' },
+ })
}).then(res => {
const [datasetApiReport, sequences, renders] = res
const {
diff --git a/app/client/modules/pix2pixhd/index.js b/app/client/modules/pix2pixhd/index.js
index ea224f3..cbd3136 100644
--- a/app/client/modules/pix2pixhd/index.js
+++ b/app/client/modules/pix2pixhd/index.js
@@ -8,6 +8,7 @@ import util from '../../util'
import Pix2PixHDNew from './views/pix2pixhd.new'
import Pix2PixHDShow from './views/pix2pixhd.show'
import Pix2PixHDResults from './views/pix2pixhd.results'
+import Pix2PixHDTrain from './views/pix2pixhd.train'
import Pix2PixHDLive from './views/pix2pixhd.live'
class router {
@@ -25,6 +26,7 @@ class router {
<Route exact path='/pix2pixhd/new/' component={Pix2PixHDNew} />
<Route exact path='/pix2pixhd/sequences/' component={Pix2PixHDShow} />
<Route exact path='/pix2pixhd/sequences/:id/' component={Pix2PixHDShow} />
+ <Route exact path='/pix2pixhd/train/' component={Pix2PixHDTrain} />
<Route exact path='/pix2pixhd/results/' component={Pix2PixHDResults} />
<Route exact path='/pix2pixhd/live/' component={Pix2PixHDLive} />
</section>
@@ -34,8 +36,9 @@ class router {
function links(){
return [
- { url: '/pix2pixhd/new/', name: 'new' },
+ { url: '/pix2pixhd/new/', name: 'folders' },
{ url: '/pix2pixhd/sequences/', name: 'sequences' },
+ { url: '/pix2pixhd/train/', name: 'checkpoints' },
{ url: '/pix2pixhd/results/', name: 'results' },
{ url: '/pix2pixhd/live/', name: 'live' },
]
diff --git a/app/client/modules/pix2pixhd/pix2pixhd.actions.js b/app/client/modules/pix2pixhd/pix2pixhd.actions.js
index 8e481d3..c1cd2b1 100644
--- a/app/client/modules/pix2pixhd/pix2pixhd.actions.js
+++ b/app/client/modules/pix2pixhd/pix2pixhd.actions.js
@@ -21,7 +21,11 @@ export const load_directories = (id) => (dispatch) => {
// actions.socket.disk_usage({ module, dir: 'datasets' }),
], (percent, i, n) => {
console.log('pix2pixhd load progress', i, n)
- dispatch({ type: types.app.load_progress, progress: { i, n }})
+ dispatch({
+ type: types.app.load_progress,
+ progress: { i, n },
+ data: { module: 'pix2pixhd' },
+ })
}).then(res => {
const [datasetApiReport, sequences, datasets, checkpoints] = res //, datasets, results, output, datasetUsage, lossReport] = res
const {
@@ -157,7 +161,11 @@ export const load_results = (id) => (dispatch) => {
actions.socket.list_directory({ module, dir: 'renders' }),
], (percent, i, n) => {
console.log('pix2pixhd load progress', i, n)
- dispatch({ type: types.app.load_progress, progress: { i, n }})
+ dispatch({
+ type: types.app.load_progress,
+ progress: { i, n },
+ data: { module: 'pix2pixhd' },
+ })
}).then(res => {
const [folders, files, results, renders] = res //, datasets, results, output, datasetUsage, lossReport] = res
console.log(files, results, renders)
@@ -171,4 +179,25 @@ export const load_results = (id) => (dispatch) => {
}
})
})
+}
+
+const G_NET_REGEXP = new RegExp('_net_G.pth$')
+
+export const list_epochs = (checkpoint_name) => (dispatch) => {
+ const module = pix2pixhdModule.name
+ actions.socket.list_directory({ module, dir: 'checkpoints/' + checkpoint_name }).then(files => {
+ // console.log(files)
+ const epochs = files.map(f => {
+ if (!f.name.match(G_NET_REGEXP)) return null
+ return f.name.replace(G_NET_REGEXP, '')
+ }).filter(f => !!f)
+ // console.log(epochs)
+ dispatch({
+ type: types.dataset.list_epochs,
+ data: {
+ epochs,
+ module
+ },
+ })
+ })
} \ No newline at end of file
diff --git a/app/client/modules/pix2pixhd/pix2pixhd.tasks.js b/app/client/modules/pix2pixhd/pix2pixhd.tasks.js
index f3c5342..bd51f2b 100644
--- a/app/client/modules/pix2pixhd/pix2pixhd.tasks.js
+++ b/app/client/modules/pix2pixhd/pix2pixhd.tasks.js
@@ -52,3 +52,29 @@ export const live_task = (sequence, checkpoint, opt) => dispatch => {
console.log('add live task')
return actions.queue.add_task(task)
}
+
+export const augment_task = (dataset, opt) => dispatch => {
+ const task = {
+ module: module.name,
+ activity: 'augment',
+ dataset,
+ opt: {
+ ...opt,
+ }
+ }
+ console.log(task)
+ console.log('add augment task')
+ return actions.queue.add_task(task)
+}
+
+export const clear_recursive_task = (dataset) => dispatch => {
+ const task = {
+ module: module.name,
+ activity: 'clear_recursive',
+ dataset,
+ }
+ console.log(task)
+ console.log('add clear recursive task')
+ return actions.queue.add_task(task)
+}
+
diff --git a/app/client/modules/pix2pixhd/views/pix2pixhd.live.js b/app/client/modules/pix2pixhd/views/pix2pixhd.live.js
index b127e23..52b4b61 100644
--- a/app/client/modules/pix2pixhd/views/pix2pixhd.live.js
+++ b/app/client/modules/pix2pixhd/views/pix2pixhd.live.js
@@ -4,7 +4,7 @@ import { connect } from 'react-redux'
import {
ParamGroup, Param, Player, Group,
- Slider, Select, TextInput, Button, Loading
+ Slider, SelectGroup, Select, TextInput, Button, Loading
} from '../../../common/'
import { startRecording, stopRecording, saveFrame, toggleFPS } from '../../../live/player'
@@ -39,6 +39,7 @@ class Pix2PixHDLive extends Component {
}
componentWillUpdate(nextProps) {
if (nextProps.opt.checkpoint_name && nextProps.opt.checkpoint_name !== this.props.opt.checkpoint_name) {
+ console.log('listing epochs')
this.props.actions.live.list_epochs('pix2pixhd', nextProps.opt.checkpoint_name)
}
}
@@ -99,8 +100,47 @@ class Pix2PixHDLive extends Component {
render(){
// console.log(this.props)
if (this.props.pix2pixhd.loading) {
- return <Loading />
+ return <Loading progress={this.props.pix2pixhd.progress} />
}
+ const { folderLookup, datasetLookup, sequences } = this.props.pix2pixhd.data
+
+ const sequenceLookup = sequences.reduce((a,b) => {
+ a[b.name] = true
+ return a
+ }, {})
+
+ const sequenceGroups = Object.keys(folderLookup).map(id => {
+ const folder = this.props.pix2pixhd.data.folderLookup[id]
+ if (folder.name === 'results') return
+ const datasets = folder.datasets.map(name => {
+ const sequence = sequenceLookup[name]
+ if (sequence) {
+ return name
+ }
+ return null
+ }).filter(n => !!n)
+ return {
+ name: folder.name,
+ options: datasets.sort(),
+ }
+ }).filter(n => !!n && !!n.options.length).sort((a,b) => a.name.localeCompare(b.name))
+
+ const checkpointGroups = Object.keys(folderLookup).map(id => {
+ const folder = this.props.pix2pixhd.data.folderLookup[id]
+ if (folder.name === 'results') return
+ const datasets = folder.datasets.map(name => {
+ const dataset = datasetLookup[name]
+ if (dataset.checkpoints.length) {
+ return name
+ }
+ return null
+ }).filter(n => !!n)
+ return {
+ name: folder.name,
+ options: datasets.sort(),
+ }
+ }).filter(n => !!n && !!n.options.length).sort((a,b) => a.name.localeCompare(b.name))
+
return (
<div className='app live centered'>
<Player width={424} height={256} fullscreen={this.props.fullscreen} />
@@ -116,16 +156,16 @@ class Pix2PixHDLive extends Component {
options={['a','b','sequence','recursive']}
onChange={this.props.actions.live.set_param}
/>
- <Select live
+ <SelectGroup live
name='sequence_name'
title='sequence'
- options={this.props.pix2pixhd.data.sequences.map(file => file.name)}
+ options={sequenceGroups}
onChange={this.changeSequence}
/>
- <Select live
+ <SelectGroup live
name='checkpoint_name'
title='checkpoint'
- options={this.props.pix2pixhd.data.checkpoints.map(file => file.name)}
+ options={checkpointGroups}
onChange={this.changeCheckpoint}
/>
<Select live
diff --git a/app/client/modules/pix2pixhd/views/pix2pixhd.show.js b/app/client/modules/pix2pixhd/views/pix2pixhd.show.js
index d58ee80..3266d59 100644
--- a/app/client/modules/pix2pixhd/views/pix2pixhd.show.js
+++ b/app/client/modules/pix2pixhd/views/pix2pixhd.show.js
@@ -62,7 +62,6 @@ class Pix2PixHDShow extends Component {
</div>
</div>
-
<DatasetComponent
loading={pix2pixhd.loading}
progress={pix2pixhd.progress}
diff --git a/app/client/modules/pix2pixhd/views/pix2pixhd.train.js b/app/client/modules/pix2pixhd/views/pix2pixhd.train.js
new file mode 100644
index 0000000..df3a1f2
--- /dev/null
+++ b/app/client/modules/pix2pixhd/views/pix2pixhd.train.js
@@ -0,0 +1,187 @@
+import { h, Component } from 'preact'
+import { bindActionCreators } from 'redux'
+import { connect } from 'react-redux'
+import util from '../../../util'
+
+import * as pix2pixhdActions from '../pix2pixhd.actions'
+import * as pix2pixhdTasks from '../pix2pixhd.tasks'
+
+import {
+ Loading,
+ FileList, FileRow,
+ Select, SelectGroup, Group, Button,
+ TextInput, NumberInput,
+ CurrentTask, TaskList
+} from '../../../common'
+import DatasetForm from '../../../dataset/dataset.form'
+import NewDatasetForm from '../../../dataset/dataset.new'
+import UploadStatus from '../../../dataset/upload.status'
+
+import DatasetComponent from '../../../dataset/dataset.component'
+
+import pix2pixhdModule from '../pix2pixhd.module'
+
+class Pix2PixHDTrain extends Component {
+ constructor(props){
+ super(props)
+ this.handleChange = this.handleChange.bind(this)
+ this.state = {
+ checkpoint_name: '',
+ epoch: 'latest',
+ augment_name: '',
+ augment_take: 100,
+ augment_make: 20,
+ }
+ }
+ componentWillMount(){
+ const id = this.props.match.params.id || localStorage.getItem('pix2pixhd.last_id')
+ console.log('load dataset:', id)
+ const { match, pix2pixhd, actions } = this.props
+ if (id === 'new') return
+ if (id) {
+ if (parseInt(id)) localStorage.setItem('pix2pixhd.last_id', id)
+ if (! pix2pixhd.folder || pix2pixhd.folder.id !== id) {
+ actions.load_directories(id)
+ }
+ } else {
+ this.props.history.push('/pix2pixhd/new/')
+ }
+ }
+ componentDidUpdate(prevProps, prevState){
+ if (! prevProps.pix2pixhd.data && this.props.pix2pixhd.data) {
+ console.log('set checkpoint_name')
+ this.setState({
+ checkpoint_name: this.props.pix2pixhd.data.sequences[0].name,
+ epoch: 'latest'
+ })
+ }
+ else if (prevState.checkpoint_name !== this.state.checkpoint_name) {
+ this.setState({ epoch: 'latest' })
+ this.props.actions.list_epochs(this.state.checkpoint_name)
+ }
+ }
+ handleChange(name, value){
+ console.log('name', name, 'value', value)
+ this.setState({ [name]: value })
+ }
+ render(){
+ if (this.props.pix2pixhd.loading) {
+ return <Loading progress={this.props.pix2pixhd.progress} />
+ }
+ const { pix2pixhd, match, history, queue } = this.props
+ const { folderLookup, datasetLookup } = (pix2pixhd.data || {})
+ const folder = (folderLookup || {})[pix2pixhd.folder_id] || {}
+ // console.log(pix2pixhd)
+
+ const checkpointGroups = Object.keys(folderLookup).map(id => {
+ const folder = this.props.pix2pixhd.data.folderLookup[id]
+ if (folder.name === 'results') return
+ const datasets = folder.datasets.map(name => {
+ const dataset = datasetLookup[name]
+ if (dataset.checkpoints.length) {
+ return name
+ }
+ return null
+ }).filter(n => !!n)
+ return {
+ name: folder.name,
+ options: datasets.sort(),
+ }
+ }).filter(n => !!n && !!n.options.length).sort((a,b) => a.name.localeCompare(b.name))
+
+ console.log('state', this.props.pix2pixhd.data.epochs)
+ // console.log(this.state.checkpoint_name, this.state.epoch)
+
+ return (
+ <div className='app pix2pixhd'>
+ <div class='heading'>
+ <h1>pix2pixhd training</h1>
+ </div>
+ <div class='heading'>
+ <SelectGroup
+ name='checkpoint_name'
+ title='Checkpoint'
+ options={checkpointGroups}
+ onChange={this.handleChange}
+ value={this.state.checkpoint_name}
+ />
+ <Select
+ title="Epoch"
+ name="epoch"
+ options={this.props.pix2pixhd.data.epochs}
+ onChange={this.handleChange}
+ value={this.state.epoch}
+ />
+ <br/>
+ <Group title='Augment'>
+ <NumberInput
+ name="augment_take"
+ title="Pick N random frames"
+ value={this.state.augment_take}
+ onChange={this.handleChange}
+ type="int"
+ min="1"
+ max="1000"
+ />
+ <NumberInput
+ name="augment_make"
+ title="Generate N recursively"
+ value={this.state.augment_make}
+ onChange={this.handleChange}
+ type="int"
+ min="1"
+ max="1000"
+ />
+ <TextInput
+ name="augment_name"
+ title="Tag this epoch"
+ value={this.state.augment_name}
+ onChange={this.handleChange}
+ />
+ <Button
+ title="Augment dataset"
+ value="Augment"
+ onClick={() => remote.augment_task(this.state.checkpoint_name, this.state)}
+ />
+ </Group>
+
+ <Group title='Train'>
+ <Button
+ title="Train one epoch"
+ value="Train"
+ onClick={() => remote.train_task(this.state.checkpoint_name, pix2pixhd.folder_id, 1)}
+ />
+ </Group>
+
+ <Group title='Clear'>
+ <Button
+ title="Delete recursive frames"
+ value="Clear"
+ onClick={() => remote.clear_recursive_task(this.state.checkpoint_name)}
+ />
+ </Group>
+ </div>
+ <div>
+ <CurrentTask />
+ {!!queue.queue.length &&
+ <Group title='Upcoming Tasks'>
+ <TaskList tasks={queue.queue.map(id => queue.tasks[id])} />
+ </Group>
+ }
+ </div>
+ </div>
+ )
+ }
+}
+
+const mapStateToProps = state => ({
+ pix2pixhd: state.module.pix2pixhd,
+ queue: state.queue,
+})
+
+const mapDispatchToProps = (dispatch, ownProps) => ({
+ actions: bindActionCreators(pix2pixhdActions, dispatch),
+ remote: bindActionCreators(pix2pixhdTasks, dispatch),
+})
+
+export default connect(mapStateToProps, mapDispatchToProps)(Pix2PixHDTrain)
diff --git a/app/client/modules/samplernn/samplernn.actions.js b/app/client/modules/samplernn/samplernn.actions.js
index f4f7a6d..e4de1dd 100644
--- a/app/client/modules/samplernn/samplernn.actions.js
+++ b/app/client/modules/samplernn/samplernn.actions.js
@@ -21,7 +21,11 @@ export const load_directories = (id) => (dispatch) => {
actions.socket.disk_usage({ module, dir: 'datasets' }),
load_loss()(dispatch),
], (percent, i, n) => {
- dispatch({ type: types.app.load_progress, progress: { i, n }})
+ dispatch({
+ type: types.app.load_progress,
+ progress: { i, n },
+ data: { module: 'samplernn' },
+ })
}).then(res => {
// console.log(res)
const [datasetApiReport, datasets, results, output, datasetUsage, lossReport] = res
@@ -128,7 +132,11 @@ export const load_graph = () => dispatch => {
load_loss()(dispatch),
actions.socket.list_directory({ module, dir: 'results' }),
], (percent, i, n) => {
- dispatch({ type: types.app.load_progress, progress: { i, n }})
+ dispatch({
+ type: types.app.load_progress,
+ progress: { i, n },
+ data: { module: 'samplernn' },
+ })
}).then(res => {
const [lossReport, results] = res
dispatch({
diff --git a/app/client/socket/socket.actions.js b/app/client/socket/socket.actions.js
index e15dda2..78b0517 100644
--- a/app/client/socket/socket.actions.js
+++ b/app/client/socket/socket.actions.js
@@ -1,24 +1,13 @@
import uuidv1 from 'uuid/v1'
import { socket } from './socket.connection'
-export function run_system_command(opt) {
- return syscall_async('run_system_command', opt)
-}
-export function disk_usage(opt) {
- return syscall_async('run_system_command', { cmd: 'du', ...opt })
-}
-export function list_directory(opt) {
- return syscall_async('list_directory', opt).then(res => res.files)
-}
-export function list_sequences(opt) {
- return syscall_async('list_sequences', opt).then(res => res.sequences)
-}
-export function run_script(opt) {
- return syscall_async('run_script', opt)
-}
-export function upload_file(opt) {
- return syscall_async('upload_file', opt)
-}
+export const run_system_command = opt => syscall_async('run_system_command', opt)
+export const disk_usage = opt => syscall_async('run_system_command', { cmd: 'du', ...opt })
+export const list_directory = opt => syscall_async('list_directory', opt).then(res => res.files)
+export const list_sequences = opt => syscall_async('list_sequences', opt).then(res => res.sequences)
+export const run_script = opt => syscall_async('run_script', opt)
+export const upload_file = opt => syscall_async('upload_file', opt)
+
export const syscall_async = (tag, payload, ttl=10000) => {
ttl = payload.ttl || ttl
return new Promise( (resolve, reject) => {
diff --git a/app/client/socket/socket.live.js b/app/client/socket/socket.live.js
index fc53eb3..a1a7a3f 100644
--- a/app/client/socket/socket.live.js
+++ b/app/client/socket/socket.live.js
@@ -27,6 +27,12 @@ socket.on('res', (data) => {
checkpoints: data.res,
})
break
+ case 'list_all_checkpoints':
+ dispatch({
+ type: types.socket.list_all_checkpoints,
+ checkpoints: data.res,
+ })
+ break
case 'list_epochs':
dispatch({
type: types.socket.list_epochs,
@@ -53,10 +59,16 @@ export function list_checkpoints(module) {
payload: module,
})
}
+export function list_all_checkpoints(module) {
+ socket.emit('cmd', {
+ cmd: 'list_all_checkpoints',
+ payload: module,
+ })
+}
export function list_epochs(module, checkpoint_name) {
socket.emit('cmd', {
cmd: 'list_epochs',
- payload: module + '/' + checkpoint_name,
+ payload: (module === 'pix2pix' || module === 'pix2wav') ? module + '/' + checkpoint_name : checkpoint_name,
})
}
export function list_sequences(module) {
diff --git a/app/client/types.js b/app/client/types.js
index d12ac91..2df494b 100644
--- a/app/client/types.js
+++ b/app/client/types.js
@@ -92,6 +92,7 @@ export default {
file_uploaded: 'DATASET_FILE_UPLOADED',
fetch_url: 'DATASET_FETCH_URL',
fetch_progress: 'DATASET_FETCH_PROGRESS',
+ list_epochs: 'DATASET_LIST_EPOCHS',
},
samplernn: {
init: 'SAMPLERNN_INIT',
diff --git a/app/relay/modules/pix2pixhd.js b/app/relay/modules/pix2pixhd.js
index 821be71..a90fc15 100644
--- a/app/relay/modules/pix2pixhd.js
+++ b/app/relay/modules/pix2pixhd.js
@@ -81,20 +81,64 @@ const generate = {
type: 'pytorch',
script: 'test.py',
params: (task) => {
+ let epoch = 0
+ const dataset = task.dataset.toLowerCase()
+ const datasets_path = path.join(cwd, 'datasets', dataset)
+ const checkpoints_path = path.join(cwd, 'checkpoints', dataset)
+ const iter_txt = path.join(checkpoints_path, 'iter.txt')
+ console.log(dataset, iter_txt)
+ if (fs.existsSync(iter_txt)) {
+ const iter = fs.readFileSync(iter_txt).toString().split('\n');
+ console.log(iter)
+ epoch = iter[0] || 0
+ console.log(task.module, dataset, '=>', epoch, task.epochs)
+ } else {
+ console.log(task.module, dataset, '=>', 'starting new training')
+ }
return [
- '--dataroot', '/sequences/' + task.dataset,
+ '--dataroot', datasets_path,
'--module_name', task.module,
- '--name', task.dataset,
- '--start_img', '/sequences/' + task.dataset + '/frame_00001.png',
- '--how_many', 1000,
- '--model', 'test',
- '--aspect_ratio', 1.777777,
- '--which_model_netG', 'unet_256',
- '--which_direction', 'AtoB',
- '--dataset_mode', 'test',
- '--loadSize', 256,
- '--fineSize', 256,
- '--norm', 'batch'
+ '--name', dataset,
+ '--model', 'pix2pixHD',
+ '--label_nc', 0, '--no_instance',
+ '--niter', task.epochs,
+ '--niter_decay', 0,
+ '--save_epoch_freq', 1,
+ ]
+ },
+ after: 'render',
+}
+const augment = {
+ type: 'pytorch',
+ script: 'augment.py',
+ params: (task) => {
+ let epoch = 0
+ const dataset = task.dataset.toLowerCase()
+ const datasets_path = path.join(cwd, 'datasets', dataset)
+ const checkpoints_path = path.join(cwd, 'checkpoints', dataset)
+ // supply render_dir
+ return [
+ '--dataroot', datasets_path,
+ '--results_dir', './recursive',
+ '--module_name', task.module,
+ '--name', dataset,
+ '--model', 'pix2pixHD',
+ '--label_nc', 0, '--no_instance',
+ '--augment-take', task.opt.augment_take,
+ '--augment-make', task.opt.augment_make,
+ '--augment-name', task.opt.augment_name,
+ '--which_epoch', task.opt.epoch,
+ ]
+ },
+}
+const clear_recursive = {
+ type: 'pytorch',
+ script: 'clear_recursive.py',
+ params: (task) => {
+ const dataset = task.dataset.toLowerCase()
+ return [
+ '--name', dataset,
+ '--epoch', epoch,
]
},
}
@@ -113,7 +157,6 @@ const live = {
'--name', task.checkpoint,
'--module_name', 'pix2pixHD',
'--sequence-name', task.dataset,
- '--recursive', '--recursive-frac', 0.1,
'--sequence', '--sequence-frac', 0.3,
'--process-frac', 0.5,
'--label_nc', '0', '--no_instance',
@@ -121,7 +164,7 @@ const live = {
'--just-copy', '--poll_delay', opt.poll_delay || 0.09,
'--which_epoch', 'latest',
'--norm', 'batch',
- '--store_b', // comment this line to store all live output
+ '--store_b', // uncomment this line to store all live output
]
},
listen: (task, res, i) => {
@@ -159,6 +202,8 @@ export default {
build,
train,
generate,
+ augment,
+ clear_recursive,
live,
render,
}