From 2f69126b9112386b08e77958316ccf849b6f6504 Mon Sep 17 00:00:00 2001 From: Hiroshiba Kazuyuki Date: Fri, 1 Dec 2017 04:07:48 +0700 Subject: add gan --- become_yukarin/model.py | 44 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) (limited to 'become_yukarin/model.py') diff --git a/become_yukarin/model.py b/become_yukarin/model.py index c475685..8a727ae 100644 --- a/become_yukarin/model.py +++ b/become_yukarin/model.py @@ -1,5 +1,8 @@ +from typing import List + import chainer +from .config import DiscriminatorModelConfig from .config import ModelConfig @@ -193,6 +196,35 @@ class Aligner(chainer.link.Chain): return h +class Discriminator(chainer.link.Chain): + def __init__(self, in_channels: int, hidden_channels_list: List[int], last_channels: int): + super().__init__() + with self.init_scope(): + self.convs = chainer.link.ChainList(*( + Convolution1D(i_c, o_c, ksize=2, stride=2, nobias=True) + for i_c, o_c in zip([in_channels] + hidden_channels_list[:-1], hidden_channels_list) + )) + self.lstm_cell = chainer.links.StatelessLSTM(hidden_channels_list[-1], last_channels) + self.last_linear = chainer.links.Linear(last_channels, 1) + + def __call__(self, x): + """ + :param x: (batch, channel, time) + """ + h = x + for conv in self.convs.children(): + h = chainer.functions.relu(conv(h)) + + hs = chainer.functions.separate(h, axis=2) + c_next = h_next = None + for h in reversed(hs): + c_next, h_next = self.lstm_cell(c_next, h_next, h) + h = h_next + + h = self.last_linear(h) + return h + + def create_predictor(config: ModelConfig): network = CBHG( in_channels=config.in_channels, @@ -220,10 +252,20 @@ def create_aligner(config: ModelConfig): return aligner +def create_discriminator(config: DiscriminatorModelConfig): + discriminator = Discriminator( + in_channels=config.in_channels, + hidden_channels_list=config.hidden_channels_list, + last_channels=config.last_channels, + ) + return discriminator + + def create(config: ModelConfig): predictor = create_predictor(config) if config.enable_aligner: aligner = create_aligner(config) else: aligner = None - return predictor, aligner + discriminator = create_discriminator(config.discriminator) + return predictor, aligner, discriminator -- cgit v1.2.3-70-g09d2