summaryrefslogtreecommitdiff
path: root/become_yukarin/model.py
diff options
context:
space:
mode:
Diffstat (limited to 'become_yukarin/model.py')
-rw-r--r--become_yukarin/model.py8
1 files changed, 3 insertions, 5 deletions
diff --git a/become_yukarin/model.py b/become_yukarin/model.py
index d4fa369..8879f11 100644
--- a/become_yukarin/model.py
+++ b/become_yukarin/model.py
@@ -222,14 +222,14 @@ class Aligner(chainer.link.Chain):
class Discriminator(chainer.link.Chain):
- def __init__(self, in_channels: int, hidden_channels_list: List[int], last_channels: int):
+ def __init__(self, in_channels: int, hidden_channels_list: List[int]):
super().__init__()
with self.init_scope():
self.convs = chainer.link.ChainList(*(
LegacyConvolution1D(i_c, o_c, ksize=2, stride=2)
for i_c, o_c in zip([in_channels] + hidden_channels_list[:-1], hidden_channels_list)
))
- self.last_linear = chainer.links.Linear(None, 1)
+ self.last_conv = LegacyConvolution1D(hidden_channels_list[-1], 1, ksize=1)
def __call__(self, x):
"""
@@ -239,9 +239,8 @@ class Discriminator(chainer.link.Chain):
h = chainer.functions.reshape(h, h.shape + (1,))
for conv in self.convs.children():
h = chainer.functions.relu(conv(h))
+ h = self.last_conv(h)
h = chainer.functions.reshape(h, h.shape[:-1])
-
- h = self.last_linear(h)
return h
@@ -276,7 +275,6 @@ def create_discriminator(config: DiscriminatorModelConfig):
discriminator = Discriminator(
in_channels=config.in_channels,
hidden_channels_list=config.hidden_channels_list,
- last_channels=config.last_channels,
)
return discriminator