diff options
| author | jules@lens <julescarbon@gmail.com> | 2018-09-05 12:00:28 +0200 |
|---|---|---|
| committer | jules@lens <julescarbon@gmail.com> | 2018-09-05 12:00:28 +0200 |
| commit | 9abfa16dc059d042c21f1636ecc8797ef29a030d (patch) | |
| tree | d0583cb5dae01de1abc57ed8f7587d23242ed6f0 /app/client | |
| parent | 0a3c6743543dd3dfcb876f5ce735b72d050e981d (diff) | |
| parent | 15eb6806b6e216255f33abcb885f6cdbc38a7663 (diff) | |
Merge branch 'master' of asdf.us:live-cortex
Diffstat (limited to 'app/client')
21 files changed, 497 insertions, 42 deletions
diff --git a/app/client/common/index.js b/app/client/common/index.js index 3981fa7..eeb8bfc 100644 --- a/app/client/common/index.js +++ b/app/client/common/index.js @@ -8,11 +8,13 @@ import Gallery from './gallery.component' import Group from './group.component' import Header from './header.component' import Loading from './loading.component' +import NumberInput from './numberInput.component' import Param from './param.component' import ParamGroup from './paramGroup.component' import Player from './player.component' import Progress from './progress.component' import Select from './select.component' +import SelectGroup from './selectGroup.component' import Slider from './slider.component' import TextInput from './textInput.component' import * as Views from './views' @@ -23,6 +25,7 @@ export { FolderList, FileList, FileRow, FileUpload, Gallery, Player, Group, ParamGroup, Param, - TextInput, Slider, Select, Button, Checkbox, + TextInput, NumberInput, + Slider, Select, SelectGroup, Button, Checkbox, CurrentTask, }
\ No newline at end of file diff --git a/app/client/common/numberInput.component.js b/app/client/common/numberInput.component.js new file mode 100644 index 0000000..43f9878 --- /dev/null +++ b/app/client/common/numberInput.component.js @@ -0,0 +1,50 @@ +import { h, Component } from 'preact' + +class NumberInput extends Component { + constructor(props){ + super(props) + this.state = { value: null, changed: false } + this.handleInput = this.handleInput.bind(this) + this.handleKeydown = this.handleKeydown.bind(this) + } + handleInput(e){ + this.setState({ + value: e.target.value, + changed: true, + }) + this.props.onInput && this.props.onInput(e.target.value, e.target.name) + this.props.onChange && this.props.onInput(e.target.name, e.target.value) + } + handleKeydown(e){ + if (e.keyCode === 13) { + this.setState({ + value: e.target.value, + changed: false, + }) + this.props.onSave && this.props.onSave(e.target.value, e.target.name) + } + } + render() { + return ( + <div className='numberInput param'> + <label> + <span>{this.props.title}</span> + <input + type={'number'} + name={this.props.name || 'number'} + value={this.state.changed ? this.state.value : this.props.value} + onInput={this.handleInput} + onKeydown={this.handleKeydown} + placeholder={this.props.placeholder} + autofocus={this.props.autofocus} + min={this.props.min} + max={this.props.max} + step={this.props.step || this.props.type === 'int' ? 1 : 0.01} + /> + </label> + </div> + ) + } +} + +export default NumberInput diff --git a/app/client/common/selectGroup.component.js b/app/client/common/selectGroup.component.js new file mode 100644 index 0000000..5c1af51 --- /dev/null +++ b/app/client/common/selectGroup.component.js @@ -0,0 +1,67 @@ +import { h, Component } from 'preact' +import { connect } from 'react-redux' +import { bindActionCreators } from 'redux' + +class SelectGroup extends Component { + constructor(props){ + super(props) + this.handleChange = this.handleChange.bind(this) + } + handleChange(e){ + clearTimeout(this.timeout) + let new_value = e.target.value + if (new_value === 'PLACEHOLDER') return + this.props.onChange && this.props.onChange(this.props.name, new_value) + } + render() { + const currentValue = this.props.live ? this.props.opt[this.props.name] : this.props.value + let lastValue + const options = (this.props.options || []).map((group, i) => { + const groupName = group.name + const children = group.options.map(key => { + let name = key.length < 2 ? key.toUpperCase() : key + name = name.replace(/_/g, ' ') + let value = key + lastValue = value + return ( + <option value={value} key={value}> + {name} + </option> + ) + }) + return ( + <optgroup label={groupName} key={groupName}> + {children} + </optgroup> + ) + }) + return ( + <div className='select param'> + <label> + <span>{this.props.title}</span> + <select + onChange={this.handleChange} + value={currentValue || lastValue} + > + {this.props.placeholder && <option value="PLACEHOLDER">{this.props.placeholder}</option>} + {options} + </select> + </label> + {this.props.children} + </div> + ) + } +} + +function capitalize(s){ + return (s || "").replace(/(?:^|\s)\S/g, function(a) { return a.toUpperCase(); }); +} + +const mapStateToProps = (state, props) => ({ + opt: props.opt || state.live.opt, +}) + +const mapDispatchToProps = (dispatch, ownProps) => ({ +}) + +export default connect(mapStateToProps, mapDispatchToProps)(SelectGroup) diff --git a/app/client/common/textInput.component.js b/app/client/common/textInput.component.js index a3739d4..44e1349 100644 --- a/app/client/common/textInput.component.js +++ b/app/client/common/textInput.component.js @@ -12,7 +12,7 @@ class TextInput extends Component { value: e.target.value, changed: true, }) - this.props.onInput && this.props.onInput(e.target.value) + this.props.onInput && this.props.onInput(e.target.value, e.target.name) } handleKeydown(e){ if (e.keyCode === 13) { @@ -20,7 +20,7 @@ class TextInput extends Component { value: e.target.value, changed: false, }) - this.props.onSave && this.props.onSave(e.target.value) + this.props.onSave && this.props.onSave(e.target.value, e.target.name) } } render() { @@ -29,7 +29,8 @@ class TextInput extends Component { <label> <span>{this.props.title}</span> <input - type='text' + type={this.props.type || 'text'} + name={this.props.name || 'text'} value={this.state.changed ? this.state.value : this.props.value} onInput={this.handleInput} onKeydown={this.handleKeydown} diff --git a/app/client/dashboard/dashboard.actions.js b/app/client/dashboard/dashboard.actions.js index 8b5502a..c428d0b 100644 --- a/app/client/dashboard/dashboard.actions.js +++ b/app/client/dashboard/dashboard.actions.js @@ -11,7 +11,11 @@ export const load = () => (dispatch) => { actions.file.index({ module: 'morph', generated: 1, limit: 15, orderBy: 'created_at desc', }), ], (percent, i, n) => { // console.log('dashboard load progress', i, n) - dispatch({ type: types.app.load_progress, progress: { i, n }}) + dispatch({ + type: types.app.load_progress, + progress: { i, n }, + data: { module: 'dashboard' } + }) }).then(res => { const [ tasks, folders, samplernn, pix2pixhd, morph ] = res const { mapFn, sortFn } = util.sort.orderByFn('date desc') diff --git a/app/client/dashboard/dashboard.component.js b/app/client/dashboard/dashboard.component.js index 3c9b2de..0c15f99 100644 --- a/app/client/dashboard/dashboard.component.js +++ b/app/client/dashboard/dashboard.component.js @@ -11,7 +11,7 @@ import Button from '../common/button.component' import DashboardHeader from './dashboardheader.component' import TaskList from './tasklist.component' -import { FolderList, FileList } from '../common' +import { Loading, FolderList, FileList } from '../common' import Gallery from '../common/gallery.component' import * as dashboardActions from './dashboard.actions' @@ -22,7 +22,7 @@ import actions from '../actions' class Dashboard extends Component { constructor(props){ super() - console.log(props) + // console.log(props) props.actions.load() } componentWillUpdate(nextProps) { @@ -30,9 +30,11 @@ class Dashboard extends Component { // this.props.actions.list_epochs(nextProps.opt.checkpoint_name) } render(){ - const { site, foldersByModule, renders, queue, images } = this.props + const { loading, progress, site, foldersByModule, renders, queue, images } = this.props + if (loading) { + return <Loading progress={progress} /> + } const { tasks } = queue - console.log(foldersByModule) const folders = foldersByModule && Object.keys(modules).sort().map(key => { let path = key === 'samplernn' ? '/samplernn/datasets/' : '/' + key + '/sequences/' let folder_list = (foldersByModule[key] || []).map(folder => { @@ -99,6 +101,8 @@ class Dashboard extends Component { } } const mapStateToProps = state => ({ + loading: state.dashboard.loading, + progress: state.dashboard.progress, site: state.system.site, foldersByModule: state.dashboard.data.foldersByModule, renders: state.dashboard.data.renders, diff --git a/app/client/dashboard/dashboard.reducer.js b/app/client/dashboard/dashboard.reducer.js index 812a501..f10f3fa 100644 --- a/app/client/dashboard/dashboard.reducer.js +++ b/app/client/dashboard/dashboard.reducer.js @@ -5,6 +5,7 @@ let FileSaver = require('file-saver') const dashboardInitialState = { loading: false, + progress: null, error: null, data: {}, images: [ @@ -15,6 +16,15 @@ const dashboardInitialState = { const dashboardReducer = (state = dashboardInitialState, action) => { switch(action.type) { + case types.app.load_progress: + if (!action.data || action.data.module !== 'dashboard') { + return state + } + return { + ...state, + loading: true, + progress: action.progress, + } case types.dashboard.load: return { ...state, diff --git a/app/client/dashboard/dashboardHeader.component.js b/app/client/dashboard/dashboardHeader.component.js index 5f1306c..063cd47 100644 --- a/app/client/dashboard/dashboardHeader.component.js +++ b/app/client/dashboard/dashboardHeader.component.js @@ -23,7 +23,6 @@ class DashboardHeader extends Component { ) } renderStatus(name, gpu){ - console.log(gpu) if (gpu.status === 'IDLE') { return <div>{name} idle</div> } diff --git a/app/client/dataset/dataset.reducer.js b/app/client/dataset/dataset.reducer.js index 10b4a94..f303a7f 100644 --- a/app/client/dataset/dataset.reducer.js +++ b/app/client/dataset/dataset.reducer.js @@ -8,6 +8,7 @@ import types from '../types' const datasetInitialState = () => ({ loading: true, + progress: null, error: null, data: null, folder_id: 0, @@ -57,6 +58,15 @@ const datasetReducer = (state = datasetInitialState(), action) => { case types.file.destroy: return handleFileDestroy(state, action) + case types.dataset.list_epochs: + return { + ...state, + data: { + ...state.data, + epochs: action.data.epochs, + } + } + default: return state } diff --git a/app/client/live/live.reducer.js b/app/client/live/live.reducer.js index 7fd7667..dfdf7ee 100644 --- a/app/client/live/live.reducer.js +++ b/app/client/live/live.reducer.js @@ -11,6 +11,7 @@ const liveInitialState = { recurse_roll: 0, rotate: 0, scale: 0, process_frac: 0.5, view_mode: 'b', }, + all_checkpoints: [], checkpoints: [], epochs: ['latest'], sequences: [], @@ -57,7 +58,15 @@ const liveReducer = (state = liveInitialState, action) => { epochs: [], } + case types.socket.list_all_checkpoints: + return { + ...state, + all_checkpoints: action.all_checkpoints, + epochs: [], + } + case types.socket.list_epochs: + console.log(action) if (action.epochs === "not found") return { ...state, epochs: [] } return { ...state, diff --git a/app/client/modules/morph/morph.actions.js b/app/client/modules/morph/morph.actions.js index 04f452d..6586778 100644 --- a/app/client/modules/morph/morph.actions.js +++ b/app/client/modules/morph/morph.actions.js @@ -19,7 +19,11 @@ export const load_data = (id) => (dispatch) => { actions.socket.list_directory({ module, dir: 'renders' }), ], (percent, i, n) => { console.log('morph load progress', i, n) - dispatch({ type: types.app.load_progress, progress: { i, n }}) + dispatch({ + type: types.app.load_progress, + progress: { i, n }, + data: { module: 'morph' }, + }) }).then(res => { const [datasetApiReport, sequences, renders] = res const { diff --git a/app/client/modules/pix2pixhd/index.js b/app/client/modules/pix2pixhd/index.js index ea224f3..cbd3136 100644 --- a/app/client/modules/pix2pixhd/index.js +++ b/app/client/modules/pix2pixhd/index.js @@ -8,6 +8,7 @@ import util from '../../util' import Pix2PixHDNew from './views/pix2pixhd.new' import Pix2PixHDShow from './views/pix2pixhd.show' import Pix2PixHDResults from './views/pix2pixhd.results' +import Pix2PixHDTrain from './views/pix2pixhd.train' import Pix2PixHDLive from './views/pix2pixhd.live' class router { @@ -25,6 +26,7 @@ class router { <Route exact path='/pix2pixhd/new/' component={Pix2PixHDNew} /> <Route exact path='/pix2pixhd/sequences/' component={Pix2PixHDShow} /> <Route exact path='/pix2pixhd/sequences/:id/' component={Pix2PixHDShow} /> + <Route exact path='/pix2pixhd/train/' component={Pix2PixHDTrain} /> <Route exact path='/pix2pixhd/results/' component={Pix2PixHDResults} /> <Route exact path='/pix2pixhd/live/' component={Pix2PixHDLive} /> </section> @@ -34,8 +36,9 @@ class router { function links(){ return [ - { url: '/pix2pixhd/new/', name: 'new' }, + { url: '/pix2pixhd/new/', name: 'folders' }, { url: '/pix2pixhd/sequences/', name: 'sequences' }, + { url: '/pix2pixhd/train/', name: 'checkpoints' }, { url: '/pix2pixhd/results/', name: 'results' }, { url: '/pix2pixhd/live/', name: 'live' }, ] diff --git a/app/client/modules/pix2pixhd/pix2pixhd.actions.js b/app/client/modules/pix2pixhd/pix2pixhd.actions.js index 8e481d3..c1cd2b1 100644 --- a/app/client/modules/pix2pixhd/pix2pixhd.actions.js +++ b/app/client/modules/pix2pixhd/pix2pixhd.actions.js @@ -21,7 +21,11 @@ export const load_directories = (id) => (dispatch) => { // actions.socket.disk_usage({ module, dir: 'datasets' }), ], (percent, i, n) => { console.log('pix2pixhd load progress', i, n) - dispatch({ type: types.app.load_progress, progress: { i, n }}) + dispatch({ + type: types.app.load_progress, + progress: { i, n }, + data: { module: 'pix2pixhd' }, + }) }).then(res => { const [datasetApiReport, sequences, datasets, checkpoints] = res //, datasets, results, output, datasetUsage, lossReport] = res const { @@ -157,7 +161,11 @@ export const load_results = (id) => (dispatch) => { actions.socket.list_directory({ module, dir: 'renders' }), ], (percent, i, n) => { console.log('pix2pixhd load progress', i, n) - dispatch({ type: types.app.load_progress, progress: { i, n }}) + dispatch({ + type: types.app.load_progress, + progress: { i, n }, + data: { module: 'pix2pixhd' }, + }) }).then(res => { const [folders, files, results, renders] = res //, datasets, results, output, datasetUsage, lossReport] = res console.log(files, results, renders) @@ -171,4 +179,25 @@ export const load_results = (id) => (dispatch) => { } }) }) +} + +const G_NET_REGEXP = new RegExp('_net_G.pth$') + +export const list_epochs = (checkpoint_name) => (dispatch) => { + const module = pix2pixhdModule.name + actions.socket.list_directory({ module, dir: 'checkpoints/' + checkpoint_name }).then(files => { + // console.log(files) + const epochs = files.map(f => { + if (!f.name.match(G_NET_REGEXP)) return null + return f.name.replace(G_NET_REGEXP, '') + }).filter(f => !!f) + // console.log(epochs) + dispatch({ + type: types.dataset.list_epochs, + data: { + epochs, + module + }, + }) + }) }
\ No newline at end of file diff --git a/app/client/modules/pix2pixhd/pix2pixhd.tasks.js b/app/client/modules/pix2pixhd/pix2pixhd.tasks.js index f3c5342..bd51f2b 100644 --- a/app/client/modules/pix2pixhd/pix2pixhd.tasks.js +++ b/app/client/modules/pix2pixhd/pix2pixhd.tasks.js @@ -52,3 +52,29 @@ export const live_task = (sequence, checkpoint, opt) => dispatch => { console.log('add live task') return actions.queue.add_task(task) } + +export const augment_task = (dataset, opt) => dispatch => { + const task = { + module: module.name, + activity: 'augment', + dataset, + opt: { + ...opt, + } + } + console.log(task) + console.log('add augment task') + return actions.queue.add_task(task) +} + +export const clear_recursive_task = (dataset) => dispatch => { + const task = { + module: module.name, + activity: 'clear_recursive', + dataset, + } + console.log(task) + console.log('add clear recursive task') + return actions.queue.add_task(task) +} + diff --git a/app/client/modules/pix2pixhd/views/pix2pixhd.live.js b/app/client/modules/pix2pixhd/views/pix2pixhd.live.js index b127e23..52b4b61 100644 --- a/app/client/modules/pix2pixhd/views/pix2pixhd.live.js +++ b/app/client/modules/pix2pixhd/views/pix2pixhd.live.js @@ -4,7 +4,7 @@ import { connect } from 'react-redux' import { ParamGroup, Param, Player, Group, - Slider, Select, TextInput, Button, Loading + Slider, SelectGroup, Select, TextInput, Button, Loading } from '../../../common/' import { startRecording, stopRecording, saveFrame, toggleFPS } from '../../../live/player' @@ -39,6 +39,7 @@ class Pix2PixHDLive extends Component { } componentWillUpdate(nextProps) { if (nextProps.opt.checkpoint_name && nextProps.opt.checkpoint_name !== this.props.opt.checkpoint_name) { + console.log('listing epochs') this.props.actions.live.list_epochs('pix2pixhd', nextProps.opt.checkpoint_name) } } @@ -99,8 +100,47 @@ class Pix2PixHDLive extends Component { render(){ // console.log(this.props) if (this.props.pix2pixhd.loading) { - return <Loading /> + return <Loading progress={this.props.pix2pixhd.progress} /> } + const { folderLookup, datasetLookup, sequences } = this.props.pix2pixhd.data + + const sequenceLookup = sequences.reduce((a,b) => { + a[b.name] = true + return a + }, {}) + + const sequenceGroups = Object.keys(folderLookup).map(id => { + const folder = this.props.pix2pixhd.data.folderLookup[id] + if (folder.name === 'results') return + const datasets = folder.datasets.map(name => { + const sequence = sequenceLookup[name] + if (sequence) { + return name + } + return null + }).filter(n => !!n) + return { + name: folder.name, + options: datasets.sort(), + } + }).filter(n => !!n && !!n.options.length).sort((a,b) => a.name.localeCompare(b.name)) + + const checkpointGroups = Object.keys(folderLookup).map(id => { + const folder = this.props.pix2pixhd.data.folderLookup[id] + if (folder.name === 'results') return + const datasets = folder.datasets.map(name => { + const dataset = datasetLookup[name] + if (dataset.checkpoints.length) { + return name + } + return null + }).filter(n => !!n) + return { + name: folder.name, + options: datasets.sort(), + } + }).filter(n => !!n && !!n.options.length).sort((a,b) => a.name.localeCompare(b.name)) + return ( <div className='app live centered'> <Player width={424} height={256} fullscreen={this.props.fullscreen} /> @@ -116,16 +156,16 @@ class Pix2PixHDLive extends Component { options={['a','b','sequence','recursive']} onChange={this.props.actions.live.set_param} /> - <Select live + <SelectGroup live name='sequence_name' title='sequence' - options={this.props.pix2pixhd.data.sequences.map(file => file.name)} + options={sequenceGroups} onChange={this.changeSequence} /> - <Select live + <SelectGroup live name='checkpoint_name' title='checkpoint' - options={this.props.pix2pixhd.data.checkpoints.map(file => file.name)} + options={checkpointGroups} onChange={this.changeCheckpoint} /> <Select live diff --git a/app/client/modules/pix2pixhd/views/pix2pixhd.show.js b/app/client/modules/pix2pixhd/views/pix2pixhd.show.js index d58ee80..3266d59 100644 --- a/app/client/modules/pix2pixhd/views/pix2pixhd.show.js +++ b/app/client/modules/pix2pixhd/views/pix2pixhd.show.js @@ -62,7 +62,6 @@ class Pix2PixHDShow extends Component { </div> </div> - <DatasetComponent loading={pix2pixhd.loading} progress={pix2pixhd.progress} diff --git a/app/client/modules/pix2pixhd/views/pix2pixhd.train.js b/app/client/modules/pix2pixhd/views/pix2pixhd.train.js new file mode 100644 index 0000000..df3a1f2 --- /dev/null +++ b/app/client/modules/pix2pixhd/views/pix2pixhd.train.js @@ -0,0 +1,187 @@ +import { h, Component } from 'preact' +import { bindActionCreators } from 'redux' +import { connect } from 'react-redux' +import util from '../../../util' + +import * as pix2pixhdActions from '../pix2pixhd.actions' +import * as pix2pixhdTasks from '../pix2pixhd.tasks' + +import { + Loading, + FileList, FileRow, + Select, SelectGroup, Group, Button, + TextInput, NumberInput, + CurrentTask, TaskList +} from '../../../common' +import DatasetForm from '../../../dataset/dataset.form' +import NewDatasetForm from '../../../dataset/dataset.new' +import UploadStatus from '../../../dataset/upload.status' + +import DatasetComponent from '../../../dataset/dataset.component' + +import pix2pixhdModule from '../pix2pixhd.module' + +class Pix2PixHDTrain extends Component { + constructor(props){ + super(props) + this.handleChange = this.handleChange.bind(this) + this.state = { + checkpoint_name: '', + epoch: 'latest', + augment_name: '', + augment_take: 100, + augment_make: 20, + } + } + componentWillMount(){ + const id = this.props.match.params.id || localStorage.getItem('pix2pixhd.last_id') + console.log('load dataset:', id) + const { match, pix2pixhd, actions } = this.props + if (id === 'new') return + if (id) { + if (parseInt(id)) localStorage.setItem('pix2pixhd.last_id', id) + if (! pix2pixhd.folder || pix2pixhd.folder.id !== id) { + actions.load_directories(id) + } + } else { + this.props.history.push('/pix2pixhd/new/') + } + } + componentDidUpdate(prevProps, prevState){ + if (! prevProps.pix2pixhd.data && this.props.pix2pixhd.data) { + console.log('set checkpoint_name') + this.setState({ + checkpoint_name: this.props.pix2pixhd.data.sequences[0].name, + epoch: 'latest' + }) + } + else if (prevState.checkpoint_name !== this.state.checkpoint_name) { + this.setState({ epoch: 'latest' }) + this.props.actions.list_epochs(this.state.checkpoint_name) + } + } + handleChange(name, value){ + console.log('name', name, 'value', value) + this.setState({ [name]: value }) + } + render(){ + if (this.props.pix2pixhd.loading) { + return <Loading progress={this.props.pix2pixhd.progress} /> + } + const { pix2pixhd, match, history, queue } = this.props + const { folderLookup, datasetLookup } = (pix2pixhd.data || {}) + const folder = (folderLookup || {})[pix2pixhd.folder_id] || {} + // console.log(pix2pixhd) + + const checkpointGroups = Object.keys(folderLookup).map(id => { + const folder = this.props.pix2pixhd.data.folderLookup[id] + if (folder.name === 'results') return + const datasets = folder.datasets.map(name => { + const dataset = datasetLookup[name] + if (dataset.checkpoints.length) { + return name + } + return null + }).filter(n => !!n) + return { + name: folder.name, + options: datasets.sort(), + } + }).filter(n => !!n && !!n.options.length).sort((a,b) => a.name.localeCompare(b.name)) + + console.log('state', this.props.pix2pixhd.data.epochs) + // console.log(this.state.checkpoint_name, this.state.epoch) + + return ( + <div className='app pix2pixhd'> + <div class='heading'> + <h1>pix2pixhd training</h1> + </div> + <div class='heading'> + <SelectGroup + name='checkpoint_name' + title='Checkpoint' + options={checkpointGroups} + onChange={this.handleChange} + value={this.state.checkpoint_name} + /> + <Select + title="Epoch" + name="epoch" + options={this.props.pix2pixhd.data.epochs} + onChange={this.handleChange} + value={this.state.epoch} + /> + <br/> + <Group title='Augment'> + <NumberInput + name="augment_take" + title="Pick N random frames" + value={this.state.augment_take} + onChange={this.handleChange} + type="int" + min="1" + max="1000" + /> + <NumberInput + name="augment_make" + title="Generate N recursively" + value={this.state.augment_make} + onChange={this.handleChange} + type="int" + min="1" + max="1000" + /> + <TextInput + name="augment_name" + title="Tag this epoch" + value={this.state.augment_name} + onChange={this.handleChange} + /> + <Button + title="Augment dataset" + value="Augment" + onClick={() => remote.augment_task(this.state.checkpoint_name, this.state)} + /> + </Group> + + <Group title='Train'> + <Button + title="Train one epoch" + value="Train" + onClick={() => remote.train_task(this.state.checkpoint_name, pix2pixhd.folder_id, 1)} + /> + </Group> + + <Group title='Clear'> + <Button + title="Delete recursive frames" + value="Clear" + onClick={() => remote.clear_recursive_task(this.state.checkpoint_name)} + /> + </Group> + </div> + <div> + <CurrentTask /> + {!!queue.queue.length && + <Group title='Upcoming Tasks'> + <TaskList tasks={queue.queue.map(id => queue.tasks[id])} /> + </Group> + } + </div> + </div> + ) + } +} + +const mapStateToProps = state => ({ + pix2pixhd: state.module.pix2pixhd, + queue: state.queue, +}) + +const mapDispatchToProps = (dispatch, ownProps) => ({ + actions: bindActionCreators(pix2pixhdActions, dispatch), + remote: bindActionCreators(pix2pixhdTasks, dispatch), +}) + +export default connect(mapStateToProps, mapDispatchToProps)(Pix2PixHDTrain) diff --git a/app/client/modules/samplernn/samplernn.actions.js b/app/client/modules/samplernn/samplernn.actions.js index f4f7a6d..e4de1dd 100644 --- a/app/client/modules/samplernn/samplernn.actions.js +++ b/app/client/modules/samplernn/samplernn.actions.js @@ -21,7 +21,11 @@ export const load_directories = (id) => (dispatch) => { actions.socket.disk_usage({ module, dir: 'datasets' }), load_loss()(dispatch), ], (percent, i, n) => { - dispatch({ type: types.app.load_progress, progress: { i, n }}) + dispatch({ + type: types.app.load_progress, + progress: { i, n }, + data: { module: 'samplernn' }, + }) }).then(res => { // console.log(res) const [datasetApiReport, datasets, results, output, datasetUsage, lossReport] = res @@ -128,7 +132,11 @@ export const load_graph = () => dispatch => { load_loss()(dispatch), actions.socket.list_directory({ module, dir: 'results' }), ], (percent, i, n) => { - dispatch({ type: types.app.load_progress, progress: { i, n }}) + dispatch({ + type: types.app.load_progress, + progress: { i, n }, + data: { module: 'samplernn' }, + }) }).then(res => { const [lossReport, results] = res dispatch({ diff --git a/app/client/socket/socket.actions.js b/app/client/socket/socket.actions.js index e15dda2..78b0517 100644 --- a/app/client/socket/socket.actions.js +++ b/app/client/socket/socket.actions.js @@ -1,24 +1,13 @@ import uuidv1 from 'uuid/v1' import { socket } from './socket.connection' -export function run_system_command(opt) { - return syscall_async('run_system_command', opt) -} -export function disk_usage(opt) { - return syscall_async('run_system_command', { cmd: 'du', ...opt }) -} -export function list_directory(opt) { - return syscall_async('list_directory', opt).then(res => res.files) -} -export function list_sequences(opt) { - return syscall_async('list_sequences', opt).then(res => res.sequences) -} -export function run_script(opt) { - return syscall_async('run_script', opt) -} -export function upload_file(opt) { - return syscall_async('upload_file', opt) -} +export const run_system_command = opt => syscall_async('run_system_command', opt) +export const disk_usage = opt => syscall_async('run_system_command', { cmd: 'du', ...opt }) +export const list_directory = opt => syscall_async('list_directory', opt).then(res => res.files) +export const list_sequences = opt => syscall_async('list_sequences', opt).then(res => res.sequences) +export const run_script = opt => syscall_async('run_script', opt) +export const upload_file = opt => syscall_async('upload_file', opt) + export const syscall_async = (tag, payload, ttl=10000) => { ttl = payload.ttl || ttl return new Promise( (resolve, reject) => { diff --git a/app/client/socket/socket.live.js b/app/client/socket/socket.live.js index fc53eb3..a1a7a3f 100644 --- a/app/client/socket/socket.live.js +++ b/app/client/socket/socket.live.js @@ -27,6 +27,12 @@ socket.on('res', (data) => { checkpoints: data.res, }) break + case 'list_all_checkpoints': + dispatch({ + type: types.socket.list_all_checkpoints, + checkpoints: data.res, + }) + break case 'list_epochs': dispatch({ type: types.socket.list_epochs, @@ -53,10 +59,16 @@ export function list_checkpoints(module) { payload: module, }) } +export function list_all_checkpoints(module) { + socket.emit('cmd', { + cmd: 'list_all_checkpoints', + payload: module, + }) +} export function list_epochs(module, checkpoint_name) { socket.emit('cmd', { cmd: 'list_epochs', - payload: module + '/' + checkpoint_name, + payload: (module === 'pix2pix' || module === 'pix2wav') ? module + '/' + checkpoint_name : checkpoint_name, }) } export function list_sequences(module) { diff --git a/app/client/types.js b/app/client/types.js index d12ac91..2df494b 100644 --- a/app/client/types.js +++ b/app/client/types.js @@ -92,6 +92,7 @@ export default { file_uploaded: 'DATASET_FILE_UPLOADED', fetch_url: 'DATASET_FETCH_URL', fetch_progress: 'DATASET_FETCH_PROGRESS', + list_epochs: 'DATASET_LIST_EPOCHS', }, samplernn: { init: 'SAMPLERNN_INIT', |
