summaryrefslogtreecommitdiff
path: root/become_yukarin/model.py
diff options
context:
space:
mode:
Diffstat (limited to 'become_yukarin/model.py')
-rw-r--r--become_yukarin/model.py43
1 files changed, 35 insertions, 8 deletions
diff --git a/become_yukarin/model.py b/become_yukarin/model.py
index 087afcd..c67f351 100644
--- a/become_yukarin/model.py
+++ b/become_yukarin/model.py
@@ -1,24 +1,51 @@
import chainer
+from .config import ModelConfig
-class DeepConvolution(chainer.link.Chain):
- def __init__(self, num_scale: int, base_num_z: int, **kwargs):
+
+class DeepConvolution1D(chainer.link.Chain):
+ def __init__(self, in_size: int, num_scale: int, base_num_z: int, **kwargs):
super().__init__(**kwargs)
self.num_scale = num_scale
+ self.out_size = base_num_z * 2 ** (num_scale - 1)
for i in range(num_scale):
l = base_num_z * 2 ** i
- self.add_link('conv{}'.format(i + 1),
- chainer.links.Convolution2D(None, l, 4, 2, 1, nobias=True))
+ self.add_link('conv{}'.format(i + 1), chainer.links.ConvolutionND(1, in_size, l, 3, 1, 1, nobias=True))
self.add_link('bn{}'.format(i + 1), chainer.links.BatchNormalization(l))
-
- def get_scaled_width(self, base_width):
- return base_width // (2 ** self.num_scale)
+ in_size = l
def __call__(self, x):
h = x
for i in range(self.num_scale):
conv = getattr(self, 'conv{}'.format(i + 1))
bn = getattr(self, 'bn{}'.format(i + 1))
- chainer.functions.relu(bn(conv(h)))
+ h = chainer.functions.relu(bn(conv(h)))
+ return h
+
+
+class Model(chainer.link.Chain):
+ def __init__(self, convs: DeepConvolution1D, out_size: int):
+ super().__init__()
+ with self.init_scope():
+ self.convs = convs
+ self.last = chainer.links.ConvolutionND(1, convs.out_size, out_size, 1)
+
+ def __call__(self, x):
+ h = x
+ h = self.convs(h)
+ h = self.last(h)
return h
+
+
+def create(config: ModelConfig):
+ convs = DeepConvolution1D(
+ in_size=config.in_size,
+ num_scale=config.num_scale,
+ base_num_z=config.base_num_z,
+ )
+ model = Model(
+ convs=convs,
+ out_size=config.out_size,
+ )
+ return model