summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp>2017-12-05 13:56:32 +0900
committerHiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp>2017-12-05 13:56:32 +0900
commitb80026f4773723dc7b2f8927ece16fce61fea267 (patch)
tree5fec8cbf811585bb002cea6d92d48e2b5d54efb0
parent421d6aab0953a47fff33062118c45d08ca95b00b (diff)
remove last linaer of discriminator
-rw-r--r--become_yukarin/config.py2
-rw-r--r--become_yukarin/model.py8
2 files changed, 3 insertions, 7 deletions
diff --git a/become_yukarin/config.py b/become_yukarin/config.py
index a65a72c..83f3597 100644
--- a/become_yukarin/config.py
+++ b/become_yukarin/config.py
@@ -26,7 +26,6 @@ class DatasetConfig(NamedTuple):
class DiscriminatorModelConfig(NamedTuple):
in_channels: int
hidden_channels_list: List[int]
- last_channels: int
class ModelConfig(NamedTuple):
@@ -98,7 +97,6 @@ def create_from_json(s: Union[str, Path]):
discriminator_model_config = DiscriminatorModelConfig(
in_channels=d['model']['discriminator']['in_channels'],
hidden_channels_list=d['model']['discriminator']['hidden_channels_list'],
- last_channels=d['model']['discriminator']['last_channels'],
)
return Config(
diff --git a/become_yukarin/model.py b/become_yukarin/model.py
index d4fa369..8879f11 100644
--- a/become_yukarin/model.py
+++ b/become_yukarin/model.py
@@ -222,14 +222,14 @@ class Aligner(chainer.link.Chain):
class Discriminator(chainer.link.Chain):
- def __init__(self, in_channels: int, hidden_channels_list: List[int], last_channels: int):
+ def __init__(self, in_channels: int, hidden_channels_list: List[int]):
super().__init__()
with self.init_scope():
self.convs = chainer.link.ChainList(*(
LegacyConvolution1D(i_c, o_c, ksize=2, stride=2)
for i_c, o_c in zip([in_channels] + hidden_channels_list[:-1], hidden_channels_list)
))
- self.last_linear = chainer.links.Linear(None, 1)
+ self.last_conv = LegacyConvolution1D(hidden_channels_list[-1], 1, ksize=1)
def __call__(self, x):
"""
@@ -239,9 +239,8 @@ class Discriminator(chainer.link.Chain):
h = chainer.functions.reshape(h, h.shape + (1,))
for conv in self.convs.children():
h = chainer.functions.relu(conv(h))
+ h = self.last_conv(h)
h = chainer.functions.reshape(h, h.shape[:-1])
-
- h = self.last_linear(h)
return h
@@ -276,7 +275,6 @@ def create_discriminator(config: DiscriminatorModelConfig):
discriminator = Discriminator(
in_channels=config.in_channels,
hidden_channels_list=config.hidden_channels_list,
- last_channels=config.last_channels,
)
return discriminator