summaryrefslogtreecommitdiff
path: root/app/client/modules/samplernn
diff options
context:
space:
mode:
authorJules Laplace <julescarbon@gmail.com>2018-06-16 19:05:45 +0200
committerJules Laplace <julescarbon@gmail.com>2018-06-16 19:05:45 +0200
commit72485efe1d7e17c883beba80c7a6b65ff1dca841 (patch)
tree1be6cdf358c43db0e27a63a719a60686287a6575 /app/client/modules/samplernn
parent408086ca5fced2552461f2d00f6d4a95be7b2636 (diff)
deploy bundle
Diffstat (limited to 'app/client/modules/samplernn')
-rw-r--r--app/client/modules/samplernn/samplernn.actions.js16
-rw-r--r--app/client/modules/samplernn/samplernn.reducer.js8
-rw-r--r--app/client/modules/samplernn/views/samplernn.graph.js27
3 files changed, 44 insertions, 7 deletions
diff --git a/app/client/modules/samplernn/samplernn.actions.js b/app/client/modules/samplernn/samplernn.actions.js
index 807a3d0..3fb38cc 100644
--- a/app/client/modules/samplernn/samplernn.actions.js
+++ b/app/client/modules/samplernn/samplernn.actions.js
@@ -122,6 +122,22 @@ export const load_directories = (id) => (dispatch) => {
})
}
+export const load_graph = () => dispatch => {
+ util.allProgress([
+ load_loss()(dispatch),
+ actions.socket.list_directory({ module, dir: 'results' }),
+ ], (percent, i, n) => {
+ dispatch({ type: types.app.load_progress, progress: { i, n }})
+ }).then(res => {
+ const [lossReport, results] = res
+ dispatch({
+ type: types.samplernn.load_graph,
+ lossReport,
+ results,
+ })
+ })
+}
+
export const load_loss = () => dispatch => {
return actions.socket.run_script({ module: 'samplernn', activity: 'report' })
.then(report => {
diff --git a/app/client/modules/samplernn/samplernn.reducer.js b/app/client/modules/samplernn/samplernn.reducer.js
index ce3a549..4758f61 100644
--- a/app/client/modules/samplernn/samplernn.reducer.js
+++ b/app/client/modules/samplernn/samplernn.reducer.js
@@ -10,6 +10,7 @@ const samplernnInitialState = {
folder_id: 0,
data: null,
lossReport: null,
+ results: null,
}
const samplernnReducer = (state = samplernnInitialState, action) => {
@@ -23,7 +24,12 @@ const samplernnReducer = (state = samplernnInitialState, action) => {
...state,
lossReport: action.lossReport,
}
-
+ case types.samplernn.load_graph:
+ return {
+ ...state,
+ lossReport: action.lossReport,
+ results: action.results,
+ }
default:
return state
}
diff --git a/app/client/modules/samplernn/views/samplernn.graph.js b/app/client/modules/samplernn/views/samplernn.graph.js
index 40e47fa..58f8d02 100644
--- a/app/client/modules/samplernn/views/samplernn.graph.js
+++ b/app/client/modules/samplernn/views/samplernn.graph.js
@@ -18,7 +18,7 @@ import TextInput from '../../../common/textInput.component'
class SampleRNNGraph extends Component {
constructor(props){
super()
- props.actions.load_loss()
+ props.actions.load_graph()
}
render(){
this.refs = {}
@@ -32,8 +32,8 @@ class SampleRNNGraph extends Component {
)
}
componentDidUpdate(){
- const { lossReport } = this.props.samplernn
- if (! lossReport) return
+ const { lossReport, results } = this.props.samplernn
+ if (! lossReport || ! results) return
const canvas = this.refs.canvas
canvas.width = window.innerWidth
canvas.height = window.innerHeight
@@ -45,7 +45,16 @@ class SampleRNNGraph extends Component {
const h = canvas.height
ctx.clearRect(0,0,w,h)
- const keys = Object.keys(lossReport).sort().filter(k => !!lossReport[k].length)
+ const resultsByDate = results.map(file => {
+ if (!file.name.match(/^exp:/)) return null
+ const dataset = file.name.split("-")[3].split(":")[1]
+ return [
+ file.date,
+ dataset
+ ]
+ }).filter(a => !!a).sort((a,b) => a[0]-a[1])
+
+ const keys = Object.keys(lossReport).filter(k => !!lossReport[k].length)
let scaleMax = 0
let scaleMin = Infinity
let epochsMax = 0
@@ -115,13 +124,19 @@ class SampleRNNGraph extends Component {
ctx.lineWidth = 1
ctx.restore()
- keys.forEach(key => {
+ const min_date = resultsByDate[0][0]
+ const max_date = resultsByDate[resultsByDate.length-1][0]
+ resultsByDate.forEach(pair => {
+ const date = pair[0]
+ const key = pair[1]
const loss = lossReport[key]
+ if (!key) return
const vf = parseFloat(loss[loss.length-1].training_loss) || 0
const vg = parseFloat(loss[0].training_loss) || 5
// console.log(vf)
const vv = 1 - norm(vf, scaleMin, scaleMax/2)
- ctx.lineWidth = (1-norm(vf, scaleMin, scaleMax)) * 4
+ // ctx.lineWidth = (1-norm(vf, scaleMin, scaleMax)) * 4
+ ctx.lineWidth = norm(date, min_date, max_date) * 4
ctx.strokeStyle = 'rgba(' + [randrange(30,190), randrange(30,150), randrange(60,120)].join(',') + ',' + 0.8+ ')'
let begun = false
loss.forEach((a, i) => {