diff options
| author | Jules Laplace <julescarbon@gmail.com> | 2018-06-16 16:02:33 +0200 |
|---|---|---|
| committer | Jules Laplace <julescarbon@gmail.com> | 2018-06-16 16:02:33 +0200 |
| commit | 7917bca6c4da52f65de7f5ff139d66db2ac9ec51 (patch) | |
| tree | e6790843122e733aba70ccce356e066a0d25f397 /app/relay/modules | |
| parent | 3fbc955e814eb26e55fb083e688b49545b125f5e (diff) | |
navigationnnn
Diffstat (limited to 'app/relay/modules')
| -rw-r--r-- | app/relay/modules/pix2pix.js | 19 |
1 files changed, 17 insertions, 2 deletions
diff --git a/app/relay/modules/pix2pix.js b/app/relay/modules/pix2pix.js index 54fcdc3..6979be5 100644 --- a/app/relay/modules/pix2pix.js +++ b/app/relay/modules/pix2pix.js @@ -1,4 +1,5 @@ import path from 'path' +import fs from 'fs' const name = 'pix2pix' const cwd = process.env.PIX2PIX_CWD || path.join(process.env.HOME, 'code/' + name + '/') @@ -58,8 +59,20 @@ const train = { type: 'pytorch', script: 'train.py', params: (task) => { + let epoch = 0 + const datasets_path = path.join(cwd, 'datasets', task.module, task.dataset) + const checkpoints_path = path.join(cwd, 'checkpoints', task.module, task.dataset) + if (fs.existsSync(checkpoints_path)) { + try { + const checkpoints = fs.readdirSync(checkpoints_path) + checkpoints.forEach(c => { + epoch = Math.max(parseInt(c.name) || 0, epoch) + }) + console.log(task.module, task.dataset, epoch, task.epochs) + } catch (e) { } + } return [ - '--dataroot', path.join(cwd, 'datasets', task.module, task.dataset), + '--dataroot', datasets_path, '--module-name', task.module, '--name', task.dataset, '--model', 'pix2pix', @@ -69,7 +82,9 @@ const train = { '--which_direction', 'AtoB', '--lambda_B', 100, '--dataset_mode', 'aligned', - '--epoch_count', task.epochs, + '--epoch_count', task.epoch + task.epochs + 1, + '--niter', task.epochs, + '--niter_decay', 0, '--which_epoch', 'latest', '--continue_train', '--no_lsgan', |
