diff options
Diffstat (limited to 'app/client/modules/samplernn/views/samplernn.graph.js')
| -rw-r--r-- | app/client/modules/samplernn/views/samplernn.graph.js | 159 |
1 files changed, 159 insertions, 0 deletions
diff --git a/app/client/modules/samplernn/views/samplernn.graph.js b/app/client/modules/samplernn/views/samplernn.graph.js new file mode 100644 index 0000000..821f1cb --- /dev/null +++ b/app/client/modules/samplernn/views/samplernn.graph.js @@ -0,0 +1,159 @@ +import { h, Component } from 'preact' +import { bindActionCreators } from 'redux' +import { connect } from 'react-redux' + +import { lerp, norm, randint, randrange } from '../../../util' + +import * as samplernnActions from '../samplernn.actions' + +import Group from '../../../common/group.component' +import Slider from '../../../common/slider.component' +import Select from '../../../common/select.component' +import Button from '../../../common/button.component' +import { FileList } from '../../../common/fileList.component' +import TextInput from '../../../common/textInput.component' + +class SampleRNNGraph extends Component { + constructor(props){ + super() + props.actions.load_loss() + } + render(){ + this.refs = {} + return ( + <div className='app lossGraph'> + <div className='heading'> + <h3>SampleRNN Loss Graph</h3> + <canvas ref={(ref) => this.refs['canvas'] = ref} /> + </div> + </div> + ) + } + componentDidUpdate(){ + const { lossReport } = this.props.samplernn + if (! lossReport) return + const canvas = this.refs.canvas + canvas.width = window.innerWidth + canvas.height = window.innerHeight + canvas.style.width = canvas.width + 'px' + canvas.style.height = canvas.height + 'px' + + const ctx = canvas.getContext('2d') + const w = canvas.width = canvas.width * devicePixelRatio + const h = canvas.height = canvas.height * devicePixelRatio + ctx.clearRect(0,0,w,h) + + const keys = Object.keys(lossReport).sort().filter(k => !!lossReport[k].length) + let scaleMax = 0 + let scaleMin = Infinity + let epochsMax = 0 + keys.forEach(key => { + const loss = lossReport[key] + epochsMax = Math.max(loss.length, epochsMax) + loss.forEach((a) => { + const v = parseFloat(a.training_loss) + if (! v) return + scaleMax = Math.max(v, scaleMax) + scaleMin = Math.min(v, scaleMin) + }) + }) + // scaleMax *= 10 + console.log(scaleMax, scaleMin, epochsMax) + + scaleMax = 3 + scaleMin = 0 + const margin = 0 + const wmin = 0 + const wmax = w + const hmin = 0 + const hmax = h + const epochsScaleFactor = 1 // 3/2 + + let X, Y + for (var ii = 0; ii < epochsMax; ii++) { + X = lerp((ii)/(epochsMax/(epochsScaleFactor))*(epochsScaleFactor), wmin, wmax) + ctx.strokeStyle = 'rgba(0,0,0,0.3)' + ctx.beginPath(0, 0) + ctx.moveTo(X, 0) + ctx.lineTo(X, h) + ctx.lineWidth = 1 + // ctx.stroke() + if ( (ii % 5) === 0 ) { + ctx.lineWidth = 2 + ctx.stroke() + const fontSize = 12 + ctx.font = 'italic ' + (fontSize * devicePixelRatio) + 'px "Georgia"' + ctx.fillStyle = 'rgba(0,12,28,0.6)' + ctx.fillText(ii/5*6, X + (8 * devicePixelRatio), h - ((fontSize + 4) * devicePixelRatio)) + } + } + for (var ii = scaleMin; ii < scaleMax; ii += 1) { + Y = lerp(ii/scaleMax, hmin, hmax) + // ctx.strokeStyle = 'rgba(255,255,255,1.0)' + ctx.beginPath(0, 0) + ctx.moveTo(0, (h-Y)) + ctx.lineTo(w, (h-Y)) + ctx.lineWidth = 1 + // ctx.stroke() + // if ( (ii % 1) < 0.1) { + // ctx.strokeStyle = 'rgba(255,255,255,1.0)' + ctx.lineWidth = 2 + ctx.setLineDash([4, 4]) + ctx.stroke() + ctx.stroke() + ctx.stroke() + ctx.setLineDash([0,0]) + const fontSize = 12 + ctx.font = 'italic ' + (fontSize * devicePixelRatio) + 'px "Georgia"' + ctx.fillStyle = 'rgba(0,12,28,0.6)' + ctx.fillText(ii.toFixed(1), w-50, (h-Y) + fontSize + (10 * devicePixelRatio)) + // } + } + ctx.lineWidth = 1 + + keys.forEach(key => { + const loss = lossReport[key] + const vf = parseFloat(loss[loss.length-1].training_loss) || 0 + const vg = parseFloat(loss[0].training_loss) || 5 + // console.log(vf) + const vv = 1 - norm(vf, scaleMin, scaleMax/2) + ctx.lineWidth = (1-norm(vf, scaleMin, scaleMax)) * 5 + ctx.strokeStyle = 'rgba(' + [randrange(30,190), randrange(30,150), randrange(60,120)].join(',') + ',' + 0.8+ ')' + let begun = false + loss.forEach((a, i) => { + const v = parseFloat(a.training_loss) + if (! v) return + const x = lerp((i-2)/(epochsMax/(epochsScaleFactor))*(epochsScaleFactor), wmin, wmax) + const y = lerp(norm(v, scaleMin, scaleMax), hmax, hmin) + if (i === 0) { + return + } + if (! begun) { + begun = true + ctx.beginPath(x,y) + } else { + ctx.lineTo(x,y) + // ctx.stroke() + } + }) + ctx.stroke() + const i = loss.length-1 + const v = parseFloat(loss[i].training_loss) + const x = lerp((i-2)/(epochsMax/(epochsScaleFactor))*(epochsScaleFactor), wmin, wmax) + const y = lerp(norm(v, scaleMin, scaleMax), hmax, hmin) + const fontSize = 9 + ctx.font = 'italic ' + (fontSize * devicePixelRatio) + 'px "Georgia"' + ctx.fillText(key, x + fontSize, y + fontSize) + }) + } +} + +const mapStateToProps = state => ({ + samplernn: state.module.samplernn, +}) + +const mapDispatchToProps = (dispatch, ownProps) => ({ + actions: bindActionCreators(samplernnActions, dispatch), +}) + +export default connect(mapStateToProps, mapDispatchToProps)(SampleRNNGraph) |
