diff options
Diffstat (limited to 'app')
| -rw-r--r-- | app/client/common/numberInput.component.js | 1 | ||||
| -rw-r--r-- | app/client/dataset/dataset.reducer.js | 9 | ||||
| -rw-r--r-- | app/client/modules/pix2pixhd/pix2pixhd.actions.js | 21 | ||||
| -rw-r--r-- | app/client/modules/pix2pixhd/views/pix2pixhd.train.js | 56 | ||||
| -rw-r--r-- | app/client/types.js | 1 |
5 files changed, 72 insertions, 16 deletions
diff --git a/app/client/common/numberInput.component.js b/app/client/common/numberInput.component.js index c3ad24c..43f9878 100644 --- a/app/client/common/numberInput.component.js +++ b/app/client/common/numberInput.component.js @@ -13,6 +13,7 @@ class NumberInput extends Component { changed: true, }) this.props.onInput && this.props.onInput(e.target.value, e.target.name) + this.props.onChange && this.props.onInput(e.target.name, e.target.value) } handleKeydown(e){ if (e.keyCode === 13) { diff --git a/app/client/dataset/dataset.reducer.js b/app/client/dataset/dataset.reducer.js index 065d3da..f303a7f 100644 --- a/app/client/dataset/dataset.reducer.js +++ b/app/client/dataset/dataset.reducer.js @@ -58,6 +58,15 @@ const datasetReducer = (state = datasetInitialState(), action) => { case types.file.destroy: return handleFileDestroy(state, action) + case types.dataset.list_epochs: + return { + ...state, + data: { + ...state.data, + epochs: action.data.epochs, + } + } + default: return state } diff --git a/app/client/modules/pix2pixhd/pix2pixhd.actions.js b/app/client/modules/pix2pixhd/pix2pixhd.actions.js index 6459794..c1cd2b1 100644 --- a/app/client/modules/pix2pixhd/pix2pixhd.actions.js +++ b/app/client/modules/pix2pixhd/pix2pixhd.actions.js @@ -179,4 +179,25 @@ export const load_results = (id) => (dispatch) => { } }) }) +} + +const G_NET_REGEXP = new RegExp('_net_G.pth$') + +export const list_epochs = (checkpoint_name) => (dispatch) => { + const module = pix2pixhdModule.name + actions.socket.list_directory({ module, dir: 'checkpoints/' + checkpoint_name }).then(files => { + // console.log(files) + const epochs = files.map(f => { + if (!f.name.match(G_NET_REGEXP)) return null + return f.name.replace(G_NET_REGEXP, '') + }).filter(f => !!f) + // console.log(epochs) + dispatch({ + type: types.dataset.list_epochs, + data: { + epochs, + module + }, + }) + }) }
\ No newline at end of file diff --git a/app/client/modules/pix2pixhd/views/pix2pixhd.train.js b/app/client/modules/pix2pixhd/views/pix2pixhd.train.js index 8901ee8..cd6507f 100644 --- a/app/client/modules/pix2pixhd/views/pix2pixhd.train.js +++ b/app/client/modules/pix2pixhd/views/pix2pixhd.train.js @@ -10,7 +10,7 @@ import { Loading, FileList, FileRow, Select, SelectGroup, Group, Button, - TextInput, + TextInput, NumberInput, CurrentTask, TaskList } from '../../../common' import DatasetForm from '../../../dataset/dataset.form' @@ -26,7 +26,7 @@ class Pix2PixHDTrain extends Component { super(props) this.handleChange = this.handleChange.bind(this) this.state = { - sequence: '', + checkpoint_name: '', epoch: 'latest', augment_take: 100, augment_make: 20, @@ -46,11 +46,22 @@ class Pix2PixHDTrain extends Component { this.props.history.push('/pix2pixhd/new/') } } - handleChange(value, name){ - this.setState({ [name]: value }) + componentDidUpdate(prevProps, prevState){ + if (! prevProps.pix2pixhd.data && this.props.pix2pixhd.data) { + console.log('set checkpoint_name') + this.setState({ + checkpoint_name: this.props.pix2pixhd.data.sequences[0].name, + epoch: 'latest' + }) + } + else if (prevState.checkpoint_name !== this.state.checkpoint_name) { + this.setState({ epoch: 'latest' }) + this.props.actions.list_epochs(this.state.checkpoint_name) + } } - changeCheckpoint(name){ - // this.props.actions.list_epochs('pix2pixhd', nextProps.opt.checkpoint_name) + handleChange(name, value){ + console.log('name', name, 'value', value) + this.setState({ [name]: value }) } render(){ if (this.props.pix2pixhd.loading) { @@ -59,7 +70,7 @@ class Pix2PixHDTrain extends Component { const { pix2pixhd, match, history, queue } = this.props const { folderLookup, datasetLookup } = (pix2pixhd.data || {}) const folder = (folderLookup || {})[pix2pixhd.folder_id] || {} - console.log(pix2pixhd) + // console.log(pix2pixhd) const checkpointGroups = Object.keys(folderLookup).map(id => { const folder = this.props.pix2pixhd.data.folderLookup[id] @@ -77,35 +88,48 @@ class Pix2PixHDTrain extends Component { } }).filter(n => !!n && !!n.options.length).sort((a,b) => a.name.localeCompare(b.name)) + console.log('state', this.props.pix2pixhd.data.epochs) + // console.log(this.state.checkpoint_name, this.state.epoch) + return ( <div className='app pix2pixhd'> - <h1>pix2pixhd training</h1> <div class='heading'> - <SelectGroup live + <h1>pix2pixhd training</h1> + </div> + <div class='heading'> + <SelectGroup name='checkpoint_name' title='Checkpoint' options={checkpointGroups} - onChange={this.changeCheckpoint} + onChange={this.handleChange} + value={this.state.checkpoint_name} /> <Select title="Epoch" + name="epoch" + options={this.props.pix2pixhd.data.epochs} + onChange={this.handleChange} value={this.state.epoch} /> <br/> <Group title='Augment'> - <TextInput - type="number" + <NumberInput name="augment_take" title="Pick N random frames" value={this.state.augment_take} - onInput={this.handleChange} + onChange={this.handleChange} + type="int" + min="1" + max="1000" /> - <TextInput - type="number" + <NumberInput name="augment_make" title="Generate N recursively" value={this.state.augment_make} - onInput={this.handleChange} + onChange={this.handleChange} + type="int" + min="1" + max="1000" /> <Button title="Augment dataset" diff --git a/app/client/types.js b/app/client/types.js index d12ac91..2df494b 100644 --- a/app/client/types.js +++ b/app/client/types.js @@ -92,6 +92,7 @@ export default { file_uploaded: 'DATASET_FILE_UPLOADED', fetch_url: 'DATASET_FETCH_URL', fetch_progress: 'DATASET_FETCH_PROGRESS', + list_epochs: 'DATASET_LIST_EPOCHS', }, samplernn: { init: 'SAMPLERNN_INIT', |
