summaryrefslogtreecommitdiff
path: root/app/relay
diff options
context:
space:
mode:
Diffstat (limited to 'app/relay')
-rw-r--r--app/relay/modules/samplernn.js4
-rw-r--r--app/relay/runner.js6
2 files changed, 6 insertions, 4 deletions
diff --git a/app/relay/modules/samplernn.js b/app/relay/modules/samplernn.js
index 8b3b8e6..3962755 100644
--- a/app/relay/modules/samplernn.js
+++ b/app/relay/modules/samplernn.js
@@ -16,8 +16,8 @@ const train = {
script: 'train.py',
params: (task) => {
return [
- '--exp', task.dataset,
- '--dataset', task.dataset,
+ '--exp', task.dataset.name,
+ '--dataset', task.dataset.name,
'--frame_sizes', '8', '2',
'--n_rnn', '2',
'--epoch_limit', task.epochs || 4,
diff --git a/app/relay/runner.js b/app/relay/runner.js
index 811dff3..8231e3b 100644
--- a/app/relay/runner.js
+++ b/app/relay/runner.js
@@ -89,9 +89,10 @@ export function status () {
export function build_params(module, activity, task) {
const interpreter = interpreters[activity.type]
- let opt_params;
+ let opt_params, activity_params;
if (typeof activity.params === 'function') {
opt_params = activity.params(task)
+ activity_params = []
}
else {
const opt = task.opt || {}
@@ -103,10 +104,11 @@ export function build_params(module, activity, task) {
}
return [flag, value]
}).reduce((acc, cur) => acc.concat(cur), [])
+ activity_params = activity.params
}
const params = (interpreter.params || [])
.concat([ activity.script ])
- .concat(activity.params || [])
+ .concat(activity_params)
.concat(opt_params)
return {
activity,