summaryrefslogtreecommitdiff
path: root/become_yukarin/model.py
diff options
context:
space:
mode:
Diffstat (limited to 'become_yukarin/model.py')
-rw-r--r--become_yukarin/model.py44
1 files changed, 43 insertions, 1 deletions
diff --git a/become_yukarin/model.py b/become_yukarin/model.py
index c475685..8a727ae 100644
--- a/become_yukarin/model.py
+++ b/become_yukarin/model.py
@@ -1,5 +1,8 @@
+from typing import List
+
import chainer
+from .config import DiscriminatorModelConfig
from .config import ModelConfig
@@ -193,6 +196,35 @@ class Aligner(chainer.link.Chain):
return h
+class Discriminator(chainer.link.Chain):
+ def __init__(self, in_channels: int, hidden_channels_list: List[int], last_channels: int):
+ super().__init__()
+ with self.init_scope():
+ self.convs = chainer.link.ChainList(*(
+ Convolution1D(i_c, o_c, ksize=2, stride=2, nobias=True)
+ for i_c, o_c in zip([in_channels] + hidden_channels_list[:-1], hidden_channels_list)
+ ))
+ self.lstm_cell = chainer.links.StatelessLSTM(hidden_channels_list[-1], last_channels)
+ self.last_linear = chainer.links.Linear(last_channels, 1)
+
+ def __call__(self, x):
+ """
+ :param x: (batch, channel, time)
+ """
+ h = x
+ for conv in self.convs.children():
+ h = chainer.functions.relu(conv(h))
+
+ hs = chainer.functions.separate(h, axis=2)
+ c_next = h_next = None
+ for h in reversed(hs):
+ c_next, h_next = self.lstm_cell(c_next, h_next, h)
+ h = h_next
+
+ h = self.last_linear(h)
+ return h
+
+
def create_predictor(config: ModelConfig):
network = CBHG(
in_channels=config.in_channels,
@@ -220,10 +252,20 @@ def create_aligner(config: ModelConfig):
return aligner
+def create_discriminator(config: DiscriminatorModelConfig):
+ discriminator = Discriminator(
+ in_channels=config.in_channels,
+ hidden_channels_list=config.hidden_channels_list,
+ last_channels=config.last_channels,
+ )
+ return discriminator
+
+
def create(config: ModelConfig):
predictor = create_predictor(config)
if config.enable_aligner:
aligner = create_aligner(config)
else:
aligner = None
- return predictor, aligner
+ discriminator = create_discriminator(config.discriminator)
+ return predictor, aligner, discriminator