diff options
| author | Hiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp> | 2017-12-05 13:56:32 +0900 |
|---|---|---|
| committer | Hiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp> | 2017-12-05 13:56:32 +0900 |
| commit | b80026f4773723dc7b2f8927ece16fce61fea267 (patch) | |
| tree | 5fec8cbf811585bb002cea6d92d48e2b5d54efb0 | |
| parent | 421d6aab0953a47fff33062118c45d08ca95b00b (diff) | |
remove last linaer of discriminator
| -rw-r--r-- | become_yukarin/config.py | 2 | ||||
| -rw-r--r-- | become_yukarin/model.py | 8 |
2 files changed, 3 insertions, 7 deletions
diff --git a/become_yukarin/config.py b/become_yukarin/config.py index a65a72c..83f3597 100644 --- a/become_yukarin/config.py +++ b/become_yukarin/config.py @@ -26,7 +26,6 @@ class DatasetConfig(NamedTuple): class DiscriminatorModelConfig(NamedTuple): in_channels: int hidden_channels_list: List[int] - last_channels: int class ModelConfig(NamedTuple): @@ -98,7 +97,6 @@ def create_from_json(s: Union[str, Path]): discriminator_model_config = DiscriminatorModelConfig( in_channels=d['model']['discriminator']['in_channels'], hidden_channels_list=d['model']['discriminator']['hidden_channels_list'], - last_channels=d['model']['discriminator']['last_channels'], ) return Config( diff --git a/become_yukarin/model.py b/become_yukarin/model.py index d4fa369..8879f11 100644 --- a/become_yukarin/model.py +++ b/become_yukarin/model.py @@ -222,14 +222,14 @@ class Aligner(chainer.link.Chain): class Discriminator(chainer.link.Chain): - def __init__(self, in_channels: int, hidden_channels_list: List[int], last_channels: int): + def __init__(self, in_channels: int, hidden_channels_list: List[int]): super().__init__() with self.init_scope(): self.convs = chainer.link.ChainList(*( LegacyConvolution1D(i_c, o_c, ksize=2, stride=2) for i_c, o_c in zip([in_channels] + hidden_channels_list[:-1], hidden_channels_list) )) - self.last_linear = chainer.links.Linear(None, 1) + self.last_conv = LegacyConvolution1D(hidden_channels_list[-1], 1, ksize=1) def __call__(self, x): """ @@ -239,9 +239,8 @@ class Discriminator(chainer.link.Chain): h = chainer.functions.reshape(h, h.shape + (1,)) for conv in self.convs.children(): h = chainer.functions.relu(conv(h)) + h = self.last_conv(h) h = chainer.functions.reshape(h, h.shape[:-1]) - - h = self.last_linear(h) return h @@ -276,7 +275,6 @@ def create_discriminator(config: DiscriminatorModelConfig): discriminator = Discriminator( in_channels=config.in_channels, hidden_channels_list=config.hidden_channels_list, - last_channels=config.last_channels, ) return discriminator |
