import { h, Component } from 'preact' import { bindActionCreators } from 'redux' import { connect } from 'react-redux' import util from '../../../util' const { lerp, norm, randint, randrange } = util import * as samplernnActions from '../samplernn.actions' import Group from '../../../common/group.component' import Slider from '../../../common/slider.component' import Select from '../../../common/select.component' import Button from '../../../common/button.component' import { FileList } from '../../../common/fileList.component' import TextInput from '../../../common/textInput.component' class SampleRNNGraph extends Component { constructor(props){ super() props.actions.load_graph() } render(){ this.refs = {} return (

SampleRNN Loss Graph

this.refs['canvas'] = ref} />
) } componentDidUpdate(){ const { lossReport, results } = this.props.samplernn if (! lossReport || ! results) return const canvas = this.refs.canvas canvas.width = window.innerWidth canvas.height = window.innerHeight canvas.style.width = canvas.width + 'px' canvas.style.height = canvas.height + 'px' const ctx = canvas.getContext('2d') const w = canvas.width const h = canvas.height ctx.clearRect(0,0,w,h) const resultsByDate = results.map(file => { if (!file.name.match(/^exp:/)) return null const dataset = file.name.split("-")[3].split(":")[1] return [ +new Date(file.date), dataset ] }).filter(a => !!a).sort((a,b) => a[0]-a[1]) const keys = Object.keys(lossReport).filter(k => !!lossReport[k].length) let scaleMax = 0 let scaleMin = Infinity let epochsMax = 0 keys.forEach(key => { const loss = lossReport[key] epochsMax = Math.max(loss.length, epochsMax) loss.forEach((a) => { const v = parseFloat(a.training_loss) if (! v) return scaleMax = Math.max(v, scaleMax) scaleMin = Math.min(v, scaleMin) }) }) // scaleMax *= 10 // console.log(scaleMax, scaleMin, epochsMax) scaleMax = 3 scaleMin = 0 const margin = 0 const wmin = 0 const wmax = w/2 const hmin = 0 const hmax = h/2 const epochsScaleFactor = 1 // 3/2 ctx.save() let X, Y for (var ii = 0; ii < epochsMax; ii++) { X = lerp((ii)/(epochsMax/(epochsScaleFactor))*(epochsScaleFactor), wmin, wmax) ctx.strokeStyle = 'rgba(0,0,0,0.3)' ctx.beginPath(0, 0) ctx.moveTo(X, 0) ctx.lineTo(X, h) ctx.lineWidth = 0.5 // ctx.stroke() if ( ((ii+1) % 6) === 0 ) { ctx.lineWidth = 0.5 ctx.stroke() const fontSize = 12 ctx.font = 'italic ' + (fontSize) + 'px "Georgia"' ctx.fillStyle = 'rgba(0,12,28,0.6)' ctx.fillText(ii/5*6, X + (8), h - ((fontSize + 4))) } } for (var ii = scaleMin; ii < scaleMax; ii += 1) { Y = lerp(ii/scaleMax, hmin, hmax) // ctx.strokeStyle = 'rgba(255,255,255,1.0)' ctx.beginPath(0, 0) ctx.moveTo(0, (h-Y)) ctx.lineTo(w, (h-Y)) ctx.lineWidth = 1 // ctx.stroke() // if ( (ii % 1) < 0.1) { // ctx.strokeStyle = 'rgba(255,255,255,1.0)' ctx.lineWidth = 2 ctx.setLineDash([4, 4]) ctx.stroke() ctx.stroke() ctx.stroke() ctx.setLineDash([1,1]) const fontSize = 12 ctx.font = 'italic ' + (fontSize) + 'px "Georgia"' ctx.fillStyle = 'rgba(0,12,28,0.6)' ctx.fillText(ii.toFixed(1), w-50, (h-Y) + fontSize + (10)) // } } ctx.lineWidth = 1 ctx.restore() const min_date = resultsByDate[0][0] const max_date = resultsByDate[resultsByDate.length-1][0] resultsByDate.forEach(pair => { const date = pair[0] const key = pair[1] const loss = lossReport[key] if (!key || !loss || !loss.length) return const vf = parseFloat(loss[loss.length-1].training_loss) || 0 const vg = parseFloat(loss[0].training_loss) || 5 // console.log(vf) const vv = 1 - norm(vf, scaleMin, scaleMax/2) ctx.lineWidth = (1-norm(vf, scaleMin, scaleMax)) * 4 // ctx.lineWidth = norm(date, min_date, max_date) * 3 // console.log(date, min_date, max_date) ctx.strokeStyle = 'rgba(' + [randrange(30,190), randrange(30,150), randrange(60,120)].join(',') + ',' + 0.8+ ')' let begun = false loss.forEach((a, i) => { const v = parseFloat(a.training_loss) if (! v) return const x = lerp((i)/(epochsMax/(epochsScaleFactor))*(epochsScaleFactor), wmin, wmax) const y = lerp(norm(v, scaleMin, scaleMax), hmax, hmin) if (! begun) { begun = true ctx.beginPath(0,0) ctx.moveTo(x,y) } else { ctx.lineTo(x,y) // ctx.stroke() } }) ctx.stroke() const i = loss.length-1 const v = parseFloat(loss[i].training_loss) const x = lerp((i)/(epochsMax/(epochsScaleFactor))*(epochsScaleFactor), wmin, wmax) const y = lerp(norm(v, scaleMin, scaleMax), hmax, hmin) const fontSize = 9 ctx.font = 'italic ' + (fontSize) + 'px "Georgia"' ctx.fillStyle = 'rgba(0,12,28,0.6)' ctx.fillText(key, x + 4, y + fontSize/2) }) } } const mapStateToProps = state => ({ samplernn: state.module.samplernn, }) const mapDispatchToProps = (dispatch, ownProps) => ({ actions: bindActionCreators(samplernnActions, dispatch), }) export default connect(mapStateToProps, mapDispatchToProps)(SampleRNNGraph)