diff options
Diffstat (limited to 'become_yukarin/model.py')
| -rw-r--r-- | become_yukarin/model.py | 8 |
1 files changed, 3 insertions, 5 deletions
diff --git a/become_yukarin/model.py b/become_yukarin/model.py index d4fa369..8879f11 100644 --- a/become_yukarin/model.py +++ b/become_yukarin/model.py @@ -222,14 +222,14 @@ class Aligner(chainer.link.Chain): class Discriminator(chainer.link.Chain): - def __init__(self, in_channels: int, hidden_channels_list: List[int], last_channels: int): + def __init__(self, in_channels: int, hidden_channels_list: List[int]): super().__init__() with self.init_scope(): self.convs = chainer.link.ChainList(*( LegacyConvolution1D(i_c, o_c, ksize=2, stride=2) for i_c, o_c in zip([in_channels] + hidden_channels_list[:-1], hidden_channels_list) )) - self.last_linear = chainer.links.Linear(None, 1) + self.last_conv = LegacyConvolution1D(hidden_channels_list[-1], 1, ksize=1) def __call__(self, x): """ @@ -239,9 +239,8 @@ class Discriminator(chainer.link.Chain): h = chainer.functions.reshape(h, h.shape + (1,)) for conv in self.convs.children(): h = chainer.functions.relu(conv(h)) + h = self.last_conv(h) h = chainer.functions.reshape(h, h.shape[:-1]) - - h = self.last_linear(h) return h @@ -276,7 +275,6 @@ def create_discriminator(config: DiscriminatorModelConfig): discriminator = Discriminator( in_channels=config.in_channels, hidden_channels_list=config.hidden_channels_list, - last_channels=config.last_channels, ) return discriminator |
