summaryrefslogtreecommitdiff
path: root/app/relay
diff options
context:
space:
mode:
authorJules Laplace <julescarbon@gmail.com>2018-06-16 21:23:21 +0200
committerJules Laplace <julescarbon@gmail.com>2018-06-16 21:23:21 +0200
commit0d3580fe118e9187dbc88bb28bd3ed9078d96fa9 (patch)
tree7953ab749fcc406f26b9e0e9d2d7d9d34089317e /app/relay
parent5c1441fd70f2967ae978a0a83ffd73ed4ac51ea0 (diff)
module
Diffstat (limited to 'app/relay')
-rw-r--r--app/relay/interpreters.js5
-rw-r--r--app/relay/modules/pix2pix.js8
2 files changed, 9 insertions, 4 deletions
diff --git a/app/relay/interpreters.js b/app/relay/interpreters.js
index 90dfcaa..4671988 100644
--- a/app/relay/interpreters.js
+++ b/app/relay/interpreters.js
@@ -19,6 +19,11 @@ export default {
params: ['-u'],
gpu: true,
},
+ pytorch_cpu: {
+ cmd: process.env.PYTORCH_BIN,
+ params: ['-u'],
+ gpu: false,
+ },
tensorflow: {
cmd: process.env.TENSORFLOW_BIN,
params: ['-u'],
diff --git a/app/relay/modules/pix2pix.js b/app/relay/modules/pix2pix.js
index 7c204dc..3169fce 100644
--- a/app/relay/modules/pix2pix.js
+++ b/app/relay/modules/pix2pix.js
@@ -53,7 +53,7 @@ const make_folds = {
after: 'combine_folds',
}
const combine_folds = {
- type: 'pytorch',
+ type: 'pytorch_cpu',
script: 'datasets/combine_A_and_B.py',
params: (task) => {
return [
@@ -81,7 +81,7 @@ const train = {
}
let args = [
'--dataroot', datasets_path,
- '--module-name', task.module,
+ '--module_name', task.module,
'--name', task.dataset,
'--model', 'pix2pix',
'--loadSize', task.opt.load_size || 264,
@@ -112,7 +112,7 @@ const generate = {
params: (task) => {
return [
'--dataroot', '/sequences/' + task.module + '/' + task.dataset,
- '--module-name', task.module,
+ '--module_name', task.module,
'--name', task.dataset,
'--start_img', '/sequences/' + task.module + '/' + task.dataset + '/frame_00001.png',
'--how_many', 1000,
@@ -139,7 +139,7 @@ const live = {
'--checkpoint-name', task.checkpoint,
'--experiment', task.checkpoint,
'--name', task.checkpoint,
- '--module-name', task.module,
+ '--module_name', task.module,
'--sequence-name', task.dataset,
'--recursive', '--recursive-frac', 0.1,
'--sequence', '--sequence-frac', 0.3,