summaryrefslogtreecommitdiff
path: root/app/client/modules/samplernn/views/samplernn.graph.js
diff options
context:
space:
mode:
Diffstat (limited to 'app/client/modules/samplernn/views/samplernn.graph.js')
-rw-r--r--app/client/modules/samplernn/views/samplernn.graph.js27
1 files changed, 21 insertions, 6 deletions
diff --git a/app/client/modules/samplernn/views/samplernn.graph.js b/app/client/modules/samplernn/views/samplernn.graph.js
index 40e47fa..58f8d02 100644
--- a/app/client/modules/samplernn/views/samplernn.graph.js
+++ b/app/client/modules/samplernn/views/samplernn.graph.js
@@ -18,7 +18,7 @@ import TextInput from '../../../common/textInput.component'
class SampleRNNGraph extends Component {
constructor(props){
super()
- props.actions.load_loss()
+ props.actions.load_graph()
}
render(){
this.refs = {}
@@ -32,8 +32,8 @@ class SampleRNNGraph extends Component {
)
}
componentDidUpdate(){
- const { lossReport } = this.props.samplernn
- if (! lossReport) return
+ const { lossReport, results } = this.props.samplernn
+ if (! lossReport || ! results) return
const canvas = this.refs.canvas
canvas.width = window.innerWidth
canvas.height = window.innerHeight
@@ -45,7 +45,16 @@ class SampleRNNGraph extends Component {
const h = canvas.height
ctx.clearRect(0,0,w,h)
- const keys = Object.keys(lossReport).sort().filter(k => !!lossReport[k].length)
+ const resultsByDate = results.map(file => {
+ if (!file.name.match(/^exp:/)) return null
+ const dataset = file.name.split("-")[3].split(":")[1]
+ return [
+ file.date,
+ dataset
+ ]
+ }).filter(a => !!a).sort((a,b) => a[0]-a[1])
+
+ const keys = Object.keys(lossReport).filter(k => !!lossReport[k].length)
let scaleMax = 0
let scaleMin = Infinity
let epochsMax = 0
@@ -115,13 +124,19 @@ class SampleRNNGraph extends Component {
ctx.lineWidth = 1
ctx.restore()
- keys.forEach(key => {
+ const min_date = resultsByDate[0][0]
+ const max_date = resultsByDate[resultsByDate.length-1][0]
+ resultsByDate.forEach(pair => {
+ const date = pair[0]
+ const key = pair[1]
const loss = lossReport[key]
+ if (!key) return
const vf = parseFloat(loss[loss.length-1].training_loss) || 0
const vg = parseFloat(loss[0].training_loss) || 5
// console.log(vf)
const vv = 1 - norm(vf, scaleMin, scaleMax/2)
- ctx.lineWidth = (1-norm(vf, scaleMin, scaleMax)) * 4
+ // ctx.lineWidth = (1-norm(vf, scaleMin, scaleMax)) * 4
+ ctx.lineWidth = norm(date, min_date, max_date) * 4
ctx.strokeStyle = 'rgba(' + [randrange(30,190), randrange(30,150), randrange(60,120)].join(',') + ',' + 0.8+ ')'
let begun = false
loss.forEach((a, i) => {