diff options
Diffstat (limited to 'become_yukarin/model.py')
| -rw-r--r-- | become_yukarin/model.py | 44 |
1 files changed, 43 insertions, 1 deletions
diff --git a/become_yukarin/model.py b/become_yukarin/model.py index c475685..8a727ae 100644 --- a/become_yukarin/model.py +++ b/become_yukarin/model.py @@ -1,5 +1,8 @@ +from typing import List + import chainer +from .config import DiscriminatorModelConfig from .config import ModelConfig @@ -193,6 +196,35 @@ class Aligner(chainer.link.Chain): return h +class Discriminator(chainer.link.Chain): + def __init__(self, in_channels: int, hidden_channels_list: List[int], last_channels: int): + super().__init__() + with self.init_scope(): + self.convs = chainer.link.ChainList(*( + Convolution1D(i_c, o_c, ksize=2, stride=2, nobias=True) + for i_c, o_c in zip([in_channels] + hidden_channels_list[:-1], hidden_channels_list) + )) + self.lstm_cell = chainer.links.StatelessLSTM(hidden_channels_list[-1], last_channels) + self.last_linear = chainer.links.Linear(last_channels, 1) + + def __call__(self, x): + """ + :param x: (batch, channel, time) + """ + h = x + for conv in self.convs.children(): + h = chainer.functions.relu(conv(h)) + + hs = chainer.functions.separate(h, axis=2) + c_next = h_next = None + for h in reversed(hs): + c_next, h_next = self.lstm_cell(c_next, h_next, h) + h = h_next + + h = self.last_linear(h) + return h + + def create_predictor(config: ModelConfig): network = CBHG( in_channels=config.in_channels, @@ -220,10 +252,20 @@ def create_aligner(config: ModelConfig): return aligner +def create_discriminator(config: DiscriminatorModelConfig): + discriminator = Discriminator( + in_channels=config.in_channels, + hidden_channels_list=config.hidden_channels_list, + last_channels=config.last_channels, + ) + return discriminator + + def create(config: ModelConfig): predictor = create_predictor(config) if config.enable_aligner: aligner = create_aligner(config) else: aligner = None - return predictor, aligner + discriminator = create_discriminator(config.discriminator) + return predictor, aligner, discriminator |
