summaryrefslogtreecommitdiff
path: root/app/client/modules/samplernn/samplernn.loss.js
diff options
context:
space:
mode:
Diffstat (limited to 'app/client/modules/samplernn/samplernn.loss.js')
-rw-r--r--app/client/modules/samplernn/samplernn.loss.js143
1 files changed, 143 insertions, 0 deletions
diff --git a/app/client/modules/samplernn/samplernn.loss.js b/app/client/modules/samplernn/samplernn.loss.js
new file mode 100644
index 0000000..3900c31
--- /dev/null
+++ b/app/client/modules/samplernn/samplernn.loss.js
@@ -0,0 +1,143 @@
+import { h, Component } from 'preact'
+import { bindActionCreators } from 'redux'
+import { connect } from 'react-redux'
+
+import { lerp, norm, randint, randrange } from '../../util'
+
+import * as samplernnActions from './samplernn.actions'
+
+import Dataset from '../../dataset/dataset.component'
+
+import Group from '../../common/group.component'
+import Slider from '../../common/slider.component'
+import Select from '../../common/select.component'
+import Button from '../../common/button.component'
+import { FileList } from '../../common/fileList.component'
+import TextInput from '../../common/textInput.component'
+
+class SampleRNNLoss extends Component {
+ constructor(props){
+ super()
+ props.actions.load_loss()
+ }
+ render(){
+ this.refs = {}
+ return (
+ <div className='app lossGraph'>
+ <div className='heading'>
+ <h3>SampleRNN Loss</h3>
+ <canvas ref={(ref) => this.refs['canvas'] = ref} />
+ </div>
+ </div>
+ )
+ }
+ componentDidUpdate(){
+ const { lossReport } = this.props.samplernn
+ if (! lossReport) return
+ const canvas = this.refs.canvas
+ canvas.width = window.innerWidth
+ canvas.height = window.innerHeight
+ canvas.style.width = canvas.width + 'px'
+ canvas.style.height = canvas.height + 'px'
+
+ const ctx = canvas.getContext('2d')
+ const w = canvas.width = canvas.width * devicePixelRatio
+ const h = canvas.height = canvas.height * devicePixelRatio
+
+ const keys = Object.keys(lossReport).sort().filter(k => !!lossReport[k].length)
+ let scaleMax = 0
+ let scaleMin = Infinity
+ let epochsMax = 0
+ keys.forEach(key => {
+ const loss = lossReport[key]
+ epochsMax = Math.max(loss.length, epochsMax)
+ loss.forEach((a) => {
+ const v = parseFloat(a.training_loss)
+ if (! v) return
+ scaleMax = Math.max(v, scaleMax)
+ scaleMin = Math.min(v, scaleMin)
+ })
+ })
+ // scaleMax *= 10
+ console.log(scaleMax, scaleMin, epochsMax)
+
+ scaleMax = 3
+ scaleMin = 0
+ const margin = 0
+ const wmin = 0
+ const wmax = w
+ const hmin = 0
+ const hmax = h
+ const epochsScaleFactor = 1 // 3/2
+
+ let X, Y
+ for (var ii = 0; ii < epochsMax; ii++) {
+ X = lerp((ii)/(epochsMax/(epochsScaleFactor))*(epochsScaleFactor), wmin, wmax)
+ ctx.strokeStyle = 'rgba(0,0,0,0.3)'
+ ctx.beginPath(0, 0)
+ ctx.moveTo(X, 0)
+ ctx.lineTo(X, h)
+ ctx.lineWidth = 1
+ // ctx.stroke()
+ if ( (ii % 5) === 0 ) {
+ ctx.lineWidth = 2
+ ctx.stroke()
+ }
+ }
+ for (var ii = scaleMin; ii < scaleMax; ii += 1) {
+ Y = lerp(ii/scaleMax, wmin, wmax)
+ // ctx.strokeStyle = 'rgba(255,255,255,1.0)'
+ ctx.beginPath(0, 0)
+ ctx.moveTo(0, (h-Y))
+ ctx.lineTo(w, (h-Y))
+ ctx.lineWidth = 1
+ // ctx.stroke()
+ // if ( (ii % 1) < 0.1) {
+ // ctx.strokeStyle = 'rgba(255,255,255,1.0)'
+ ctx.lineWidth = 2
+ ctx.stroke()
+ ctx.stroke()
+ ctx.stroke()
+ // }
+ }
+ ctx.lineWidth = 1
+
+ keys.forEach(key => {
+ const loss = lossReport[key]
+ const vf = parseFloat(loss[loss.length-1].training_loss) || 0
+ const vg = parseFloat(loss[0].training_loss) || 5
+ console.log(vf)
+ const vv = 1 - norm(vf, scaleMin, scaleMax/2)
+ ctx.lineWidth = (1-norm(vf, scaleMin, scaleMax)) * 5
+ ctx.strokeStyle = 'rgba(' + [randrange(30,190), randrange(30,150), randrange(60,120)].join(',') + ',' + 0.8+ ')'
+ let begun = false
+ loss.forEach((a, i) => {
+ const v = parseFloat(a.training_loss)
+ if (! v) return
+ const x = lerp((i-2)/(epochsMax/(epochsScaleFactor))*(epochsScaleFactor), wmin, wmax)
+ const y = lerp(norm(v, scaleMin, scaleMax), hmax, hmin)
+ if (i === 0) {
+ return
+ }
+ if (! begun) {
+ begun = true
+ ctx.beginPath(x,y)
+ } else {
+ ctx.lineTo(x,y)
+ // ctx.stroke()
+ }
+ })
+ ctx.stroke()
+ })
+ }
+}
+
+const mapStateToProps = state => ({
+ samplernn: state.module.samplernn,
+})
+
+const mapDispatchToProps = (dispatch, ownProps) => ({
+ actions: bindActionCreators(samplernnActions, dispatch),
+})
+
+export default connect(mapStateToProps, mapDispatchToProps)(SampleRNNLoss)