diff options
Diffstat (limited to 'app/client/modules/pix2pixhd/views')
| -rw-r--r-- | app/client/modules/pix2pixhd/views/pix2pixhd.live.js | 52 | ||||
| -rw-r--r-- | app/client/modules/pix2pixhd/views/pix2pixhd.show.js | 1 | ||||
| -rw-r--r-- | app/client/modules/pix2pixhd/views/pix2pixhd.train.js | 187 |
3 files changed, 233 insertions, 7 deletions
diff --git a/app/client/modules/pix2pixhd/views/pix2pixhd.live.js b/app/client/modules/pix2pixhd/views/pix2pixhd.live.js index b127e23..52b4b61 100644 --- a/app/client/modules/pix2pixhd/views/pix2pixhd.live.js +++ b/app/client/modules/pix2pixhd/views/pix2pixhd.live.js @@ -4,7 +4,7 @@ import { connect } from 'react-redux' import { ParamGroup, Param, Player, Group, - Slider, Select, TextInput, Button, Loading + Slider, SelectGroup, Select, TextInput, Button, Loading } from '../../../common/' import { startRecording, stopRecording, saveFrame, toggleFPS } from '../../../live/player' @@ -39,6 +39,7 @@ class Pix2PixHDLive extends Component { } componentWillUpdate(nextProps) { if (nextProps.opt.checkpoint_name && nextProps.opt.checkpoint_name !== this.props.opt.checkpoint_name) { + console.log('listing epochs') this.props.actions.live.list_epochs('pix2pixhd', nextProps.opt.checkpoint_name) } } @@ -99,8 +100,47 @@ class Pix2PixHDLive extends Component { render(){ // console.log(this.props) if (this.props.pix2pixhd.loading) { - return <Loading /> + return <Loading progress={this.props.pix2pixhd.progress} /> } + const { folderLookup, datasetLookup, sequences } = this.props.pix2pixhd.data + + const sequenceLookup = sequences.reduce((a,b) => { + a[b.name] = true + return a + }, {}) + + const sequenceGroups = Object.keys(folderLookup).map(id => { + const folder = this.props.pix2pixhd.data.folderLookup[id] + if (folder.name === 'results') return + const datasets = folder.datasets.map(name => { + const sequence = sequenceLookup[name] + if (sequence) { + return name + } + return null + }).filter(n => !!n) + return { + name: folder.name, + options: datasets.sort(), + } + }).filter(n => !!n && !!n.options.length).sort((a,b) => a.name.localeCompare(b.name)) + + const checkpointGroups = Object.keys(folderLookup).map(id => { + const folder = this.props.pix2pixhd.data.folderLookup[id] + if (folder.name === 'results') return + const datasets = folder.datasets.map(name => { + const dataset = datasetLookup[name] + if (dataset.checkpoints.length) { + return name + } + return null + }).filter(n => !!n) + return { + name: folder.name, + options: datasets.sort(), + } + }).filter(n => !!n && !!n.options.length).sort((a,b) => a.name.localeCompare(b.name)) + return ( <div className='app live centered'> <Player width={424} height={256} fullscreen={this.props.fullscreen} /> @@ -116,16 +156,16 @@ class Pix2PixHDLive extends Component { options={['a','b','sequence','recursive']} onChange={this.props.actions.live.set_param} /> - <Select live + <SelectGroup live name='sequence_name' title='sequence' - options={this.props.pix2pixhd.data.sequences.map(file => file.name)} + options={sequenceGroups} onChange={this.changeSequence} /> - <Select live + <SelectGroup live name='checkpoint_name' title='checkpoint' - options={this.props.pix2pixhd.data.checkpoints.map(file => file.name)} + options={checkpointGroups} onChange={this.changeCheckpoint} /> <Select live diff --git a/app/client/modules/pix2pixhd/views/pix2pixhd.show.js b/app/client/modules/pix2pixhd/views/pix2pixhd.show.js index d58ee80..3266d59 100644 --- a/app/client/modules/pix2pixhd/views/pix2pixhd.show.js +++ b/app/client/modules/pix2pixhd/views/pix2pixhd.show.js @@ -62,7 +62,6 @@ class Pix2PixHDShow extends Component { </div> </div> - <DatasetComponent loading={pix2pixhd.loading} progress={pix2pixhd.progress} diff --git a/app/client/modules/pix2pixhd/views/pix2pixhd.train.js b/app/client/modules/pix2pixhd/views/pix2pixhd.train.js new file mode 100644 index 0000000..df3a1f2 --- /dev/null +++ b/app/client/modules/pix2pixhd/views/pix2pixhd.train.js @@ -0,0 +1,187 @@ +import { h, Component } from 'preact' +import { bindActionCreators } from 'redux' +import { connect } from 'react-redux' +import util from '../../../util' + +import * as pix2pixhdActions from '../pix2pixhd.actions' +import * as pix2pixhdTasks from '../pix2pixhd.tasks' + +import { + Loading, + FileList, FileRow, + Select, SelectGroup, Group, Button, + TextInput, NumberInput, + CurrentTask, TaskList +} from '../../../common' +import DatasetForm from '../../../dataset/dataset.form' +import NewDatasetForm from '../../../dataset/dataset.new' +import UploadStatus from '../../../dataset/upload.status' + +import DatasetComponent from '../../../dataset/dataset.component' + +import pix2pixhdModule from '../pix2pixhd.module' + +class Pix2PixHDTrain extends Component { + constructor(props){ + super(props) + this.handleChange = this.handleChange.bind(this) + this.state = { + checkpoint_name: '', + epoch: 'latest', + augment_name: '', + augment_take: 100, + augment_make: 20, + } + } + componentWillMount(){ + const id = this.props.match.params.id || localStorage.getItem('pix2pixhd.last_id') + console.log('load dataset:', id) + const { match, pix2pixhd, actions } = this.props + if (id === 'new') return + if (id) { + if (parseInt(id)) localStorage.setItem('pix2pixhd.last_id', id) + if (! pix2pixhd.folder || pix2pixhd.folder.id !== id) { + actions.load_directories(id) + } + } else { + this.props.history.push('/pix2pixhd/new/') + } + } + componentDidUpdate(prevProps, prevState){ + if (! prevProps.pix2pixhd.data && this.props.pix2pixhd.data) { + console.log('set checkpoint_name') + this.setState({ + checkpoint_name: this.props.pix2pixhd.data.sequences[0].name, + epoch: 'latest' + }) + } + else if (prevState.checkpoint_name !== this.state.checkpoint_name) { + this.setState({ epoch: 'latest' }) + this.props.actions.list_epochs(this.state.checkpoint_name) + } + } + handleChange(name, value){ + console.log('name', name, 'value', value) + this.setState({ [name]: value }) + } + render(){ + if (this.props.pix2pixhd.loading) { + return <Loading progress={this.props.pix2pixhd.progress} /> + } + const { pix2pixhd, match, history, queue } = this.props + const { folderLookup, datasetLookup } = (pix2pixhd.data || {}) + const folder = (folderLookup || {})[pix2pixhd.folder_id] || {} + // console.log(pix2pixhd) + + const checkpointGroups = Object.keys(folderLookup).map(id => { + const folder = this.props.pix2pixhd.data.folderLookup[id] + if (folder.name === 'results') return + const datasets = folder.datasets.map(name => { + const dataset = datasetLookup[name] + if (dataset.checkpoints.length) { + return name + } + return null + }).filter(n => !!n) + return { + name: folder.name, + options: datasets.sort(), + } + }).filter(n => !!n && !!n.options.length).sort((a,b) => a.name.localeCompare(b.name)) + + console.log('state', this.props.pix2pixhd.data.epochs) + // console.log(this.state.checkpoint_name, this.state.epoch) + + return ( + <div className='app pix2pixhd'> + <div class='heading'> + <h1>pix2pixhd training</h1> + </div> + <div class='heading'> + <SelectGroup + name='checkpoint_name' + title='Checkpoint' + options={checkpointGroups} + onChange={this.handleChange} + value={this.state.checkpoint_name} + /> + <Select + title="Epoch" + name="epoch" + options={this.props.pix2pixhd.data.epochs} + onChange={this.handleChange} + value={this.state.epoch} + /> + <br/> + <Group title='Augment'> + <NumberInput + name="augment_take" + title="Pick N random frames" + value={this.state.augment_take} + onChange={this.handleChange} + type="int" + min="1" + max="1000" + /> + <NumberInput + name="augment_make" + title="Generate N recursively" + value={this.state.augment_make} + onChange={this.handleChange} + type="int" + min="1" + max="1000" + /> + <TextInput + name="augment_name" + title="Tag this epoch" + value={this.state.augment_name} + onChange={this.handleChange} + /> + <Button + title="Augment dataset" + value="Augment" + onClick={() => remote.augment_task(this.state.checkpoint_name, this.state)} + /> + </Group> + + <Group title='Train'> + <Button + title="Train one epoch" + value="Train" + onClick={() => remote.train_task(this.state.checkpoint_name, pix2pixhd.folder_id, 1)} + /> + </Group> + + <Group title='Clear'> + <Button + title="Delete recursive frames" + value="Clear" + onClick={() => remote.clear_recursive_task(this.state.checkpoint_name)} + /> + </Group> + </div> + <div> + <CurrentTask /> + {!!queue.queue.length && + <Group title='Upcoming Tasks'> + <TaskList tasks={queue.queue.map(id => queue.tasks[id])} /> + </Group> + } + </div> + </div> + ) + } +} + +const mapStateToProps = state => ({ + pix2pixhd: state.module.pix2pixhd, + queue: state.queue, +}) + +const mapDispatchToProps = (dispatch, ownProps) => ({ + actions: bindActionCreators(pix2pixhdActions, dispatch), + remote: bindActionCreators(pix2pixhdTasks, dispatch), +}) + +export default connect(mapStateToProps, mapDispatchToProps)(Pix2PixHDTrain) |
