import { h, Component } from 'preact'
import { bindActionCreators } from 'redux'
import { connect } from 'react-redux'
import { lerp, norm, randint, randrange } from '../../../util'
import * as samplernnActions from '../samplernn.actions'
import Group from '../../../common/group.component'
import Slider from '../../../common/slider.component'
import Select from '../../../common/select.component'
import Button from '../../../common/button.component'
import { FileList } from '../../../common/fileList.component'
import TextInput from '../../../common/textInput.component'
class SampleRNNGraph extends Component {
constructor(props){
super()
props.actions.load_loss()
}
render(){
this.refs = {}
return (
SampleRNN Loss Graph
)
}
componentDidUpdate(){
const { lossReport } = this.props.samplernn
if (! lossReport) return
const canvas = this.refs.canvas
canvas.width = window.innerWidth
canvas.height = window.innerHeight
canvas.style.width = canvas.width + 'px'
canvas.style.height = canvas.height + 'px'
const ctx = canvas.getContext('2d')
const w = canvas.width = canvas.width * devicePixelRatio
const h = canvas.height = canvas.height * devicePixelRatio
ctx.clearRect(0,0,w,h)
const keys = Object.keys(lossReport).sort().filter(k => !!lossReport[k].length)
let scaleMax = 0
let scaleMin = Infinity
let epochsMax = 0
keys.forEach(key => {
const loss = lossReport[key]
epochsMax = Math.max(loss.length, epochsMax)
loss.forEach((a) => {
const v = parseFloat(a.training_loss)
if (! v) return
scaleMax = Math.max(v, scaleMax)
scaleMin = Math.min(v, scaleMin)
})
})
// scaleMax *= 10
console.log(scaleMax, scaleMin, epochsMax)
scaleMax = 3
scaleMin = 0
const margin = 0
const wmin = 0
const wmax = w
const hmin = 0
const hmax = h
const epochsScaleFactor = 1 // 3/2
let X, Y
for (var ii = 0; ii < epochsMax; ii++) {
X = lerp((ii)/(epochsMax/(epochsScaleFactor))*(epochsScaleFactor), wmin, wmax)
ctx.strokeStyle = 'rgba(0,0,0,0.3)'
ctx.beginPath(0, 0)
ctx.moveTo(X, 0)
ctx.lineTo(X, h)
ctx.lineWidth = 1
// ctx.stroke()
if ( (ii % 5) === 0 ) {
ctx.lineWidth = 2
ctx.stroke()
const fontSize = 12
ctx.font = 'italic ' + (fontSize * devicePixelRatio) + 'px "Georgia"'
ctx.fillStyle = 'rgba(0,12,28,0.6)'
ctx.fillText(ii/5*6, X + (8 * devicePixelRatio), h - ((fontSize + 4) * devicePixelRatio))
}
}
for (var ii = scaleMin; ii < scaleMax; ii += 1) {
Y = lerp(ii/scaleMax, hmin, hmax)
// ctx.strokeStyle = 'rgba(255,255,255,1.0)'
ctx.beginPath(0, 0)
ctx.moveTo(0, (h-Y))
ctx.lineTo(w, (h-Y))
ctx.lineWidth = 1
// ctx.stroke()
// if ( (ii % 1) < 0.1) {
// ctx.strokeStyle = 'rgba(255,255,255,1.0)'
ctx.lineWidth = 2
ctx.setLineDash([4, 4])
ctx.stroke()
ctx.stroke()
ctx.stroke()
ctx.setLineDash([0,0])
const fontSize = 12
ctx.font = 'italic ' + (fontSize * devicePixelRatio) + 'px "Georgia"'
ctx.fillStyle = 'rgba(0,12,28,0.6)'
ctx.fillText(ii.toFixed(1), w-50, (h-Y) + fontSize + (10 * devicePixelRatio))
// }
}
ctx.lineWidth = 1
keys.forEach(key => {
const loss = lossReport[key]
const vf = parseFloat(loss[loss.length-1].training_loss) || 0
const vg = parseFloat(loss[0].training_loss) || 5
// console.log(vf)
const vv = 1 - norm(vf, scaleMin, scaleMax/2)
ctx.lineWidth = (1-norm(vf, scaleMin, scaleMax)) * 5
ctx.strokeStyle = 'rgba(' + [randrange(30,190), randrange(30,150), randrange(60,120)].join(',') + ',' + 0.8+ ')'
let begun = false
loss.forEach((a, i) => {
const v = parseFloat(a.training_loss)
if (! v) return
const x = lerp((i-2)/(epochsMax/(epochsScaleFactor))*(epochsScaleFactor), wmin, wmax)
const y = lerp(norm(v, scaleMin, scaleMax), hmax, hmin)
if (i === 0) {
return
}
if (! begun) {
begun = true
ctx.beginPath(x,y)
} else {
ctx.lineTo(x,y)
// ctx.stroke()
}
})
ctx.stroke()
const i = loss.length-1
const v = parseFloat(loss[i].training_loss)
const x = lerp((i-2)/(epochsMax/(epochsScaleFactor))*(epochsScaleFactor), wmin, wmax)
const y = lerp(norm(v, scaleMin, scaleMax), hmax, hmin)
const fontSize = 9
ctx.font = 'italic ' + (fontSize * devicePixelRatio) + 'px "Georgia"'
ctx.fillText(key, x + fontSize, y + fontSize)
})
}
}
const mapStateToProps = state => ({
samplernn: state.module.samplernn,
})
const mapDispatchToProps = (dispatch, ownProps) => ({
actions: bindActionCreators(samplernnActions, dispatch),
})
export default connect(mapStateToProps, mapDispatchToProps)(SampleRNNGraph)