diff options
Diffstat (limited to 'app/client/modules/samplernn/samplernn.loss.js')
| -rw-r--r-- | app/client/modules/samplernn/samplernn.loss.js | 143 |
1 files changed, 143 insertions, 0 deletions
diff --git a/app/client/modules/samplernn/samplernn.loss.js b/app/client/modules/samplernn/samplernn.loss.js new file mode 100644 index 0000000..3900c31 --- /dev/null +++ b/app/client/modules/samplernn/samplernn.loss.js @@ -0,0 +1,143 @@ +import { h, Component } from 'preact' +import { bindActionCreators } from 'redux' +import { connect } from 'react-redux' + +import { lerp, norm, randint, randrange } from '../../util' + +import * as samplernnActions from './samplernn.actions' + +import Dataset from '../../dataset/dataset.component' + +import Group from '../../common/group.component' +import Slider from '../../common/slider.component' +import Select from '../../common/select.component' +import Button from '../../common/button.component' +import { FileList } from '../../common/fileList.component' +import TextInput from '../../common/textInput.component' + +class SampleRNNLoss extends Component { + constructor(props){ + super() + props.actions.load_loss() + } + render(){ + this.refs = {} + return ( + <div className='app lossGraph'> + <div className='heading'> + <h3>SampleRNN Loss</h3> + <canvas ref={(ref) => this.refs['canvas'] = ref} /> + </div> + </div> + ) + } + componentDidUpdate(){ + const { lossReport } = this.props.samplernn + if (! lossReport) return + const canvas = this.refs.canvas + canvas.width = window.innerWidth + canvas.height = window.innerHeight + canvas.style.width = canvas.width + 'px' + canvas.style.height = canvas.height + 'px' + + const ctx = canvas.getContext('2d') + const w = canvas.width = canvas.width * devicePixelRatio + const h = canvas.height = canvas.height * devicePixelRatio + + const keys = Object.keys(lossReport).sort().filter(k => !!lossReport[k].length) + let scaleMax = 0 + let scaleMin = Infinity + let epochsMax = 0 + keys.forEach(key => { + const loss = lossReport[key] + epochsMax = Math.max(loss.length, epochsMax) + loss.forEach((a) => { + const v = parseFloat(a.training_loss) + if (! v) return + scaleMax = Math.max(v, scaleMax) + scaleMin = Math.min(v, scaleMin) + }) + }) + // scaleMax *= 10 + console.log(scaleMax, scaleMin, epochsMax) + + scaleMax = 3 + scaleMin = 0 + const margin = 0 + const wmin = 0 + const wmax = w + const hmin = 0 + const hmax = h + const epochsScaleFactor = 1 // 3/2 + + let X, Y + for (var ii = 0; ii < epochsMax; ii++) { + X = lerp((ii)/(epochsMax/(epochsScaleFactor))*(epochsScaleFactor), wmin, wmax) + ctx.strokeStyle = 'rgba(0,0,0,0.3)' + ctx.beginPath(0, 0) + ctx.moveTo(X, 0) + ctx.lineTo(X, h) + ctx.lineWidth = 1 + // ctx.stroke() + if ( (ii % 5) === 0 ) { + ctx.lineWidth = 2 + ctx.stroke() + } + } + for (var ii = scaleMin; ii < scaleMax; ii += 1) { + Y = lerp(ii/scaleMax, wmin, wmax) + // ctx.strokeStyle = 'rgba(255,255,255,1.0)' + ctx.beginPath(0, 0) + ctx.moveTo(0, (h-Y)) + ctx.lineTo(w, (h-Y)) + ctx.lineWidth = 1 + // ctx.stroke() + // if ( (ii % 1) < 0.1) { + // ctx.strokeStyle = 'rgba(255,255,255,1.0)' + ctx.lineWidth = 2 + ctx.stroke() + ctx.stroke() + ctx.stroke() + // } + } + ctx.lineWidth = 1 + + keys.forEach(key => { + const loss = lossReport[key] + const vf = parseFloat(loss[loss.length-1].training_loss) || 0 + const vg = parseFloat(loss[0].training_loss) || 5 + console.log(vf) + const vv = 1 - norm(vf, scaleMin, scaleMax/2) + ctx.lineWidth = (1-norm(vf, scaleMin, scaleMax)) * 5 + ctx.strokeStyle = 'rgba(' + [randrange(30,190), randrange(30,150), randrange(60,120)].join(',') + ',' + 0.8+ ')' + let begun = false + loss.forEach((a, i) => { + const v = parseFloat(a.training_loss) + if (! v) return + const x = lerp((i-2)/(epochsMax/(epochsScaleFactor))*(epochsScaleFactor), wmin, wmax) + const y = lerp(norm(v, scaleMin, scaleMax), hmax, hmin) + if (i === 0) { + return + } + if (! begun) { + begun = true + ctx.beginPath(x,y) + } else { + ctx.lineTo(x,y) + // ctx.stroke() + } + }) + ctx.stroke() + }) + } +} + +const mapStateToProps = state => ({ + samplernn: state.module.samplernn, +}) + +const mapDispatchToProps = (dispatch, ownProps) => ({ + actions: bindActionCreators(samplernnActions, dispatch), +}) + +export default connect(mapStateToProps, mapDispatchToProps)(SampleRNNLoss) |
