summaryrefslogtreecommitdiff
path: root/become_yukarin/model.py
diff options
context:
space:
mode:
authorHiroshiba Kazuyuki <hihokaruta@gmail.com>2017-11-07 10:20:04 +0900
committerHiroshiba Kazuyuki <hihokaruta@gmail.com>2017-11-07 10:29:28 +0900
commit6119849270c2aed117627d7d2b060f37d1c25de4 (patch)
tree178e8df7a0c3d33f9de776b85ee4ff545f2ecdc3 /become_yukarin/model.py
parent8e637c41a262373786b94d40a8f3559caf5cd44c (diff)
can train
Diffstat (limited to 'become_yukarin/model.py')
-rw-r--r--become_yukarin/model.py43
1 files changed, 35 insertions, 8 deletions
diff --git a/become_yukarin/model.py b/become_yukarin/model.py
index 087afcd..c67f351 100644
--- a/become_yukarin/model.py
+++ b/become_yukarin/model.py
@@ -1,24 +1,51 @@
import chainer
+from .config import ModelConfig
-class DeepConvolution(chainer.link.Chain):
- def __init__(self, num_scale: int, base_num_z: int, **kwargs):
+
+class DeepConvolution1D(chainer.link.Chain):
+ def __init__(self, in_size: int, num_scale: int, base_num_z: int, **kwargs):
super().__init__(**kwargs)
self.num_scale = num_scale
+ self.out_size = base_num_z * 2 ** (num_scale - 1)
for i in range(num_scale):
l = base_num_z * 2 ** i
- self.add_link('conv{}'.format(i + 1),
- chainer.links.Convolution2D(None, l, 4, 2, 1, nobias=True))
+ self.add_link('conv{}'.format(i + 1), chainer.links.ConvolutionND(1, in_size, l, 3, 1, 1, nobias=True))
self.add_link('bn{}'.format(i + 1), chainer.links.BatchNormalization(l))
-
- def get_scaled_width(self, base_width):
- return base_width // (2 ** self.num_scale)
+ in_size = l
def __call__(self, x):
h = x
for i in range(self.num_scale):
conv = getattr(self, 'conv{}'.format(i + 1))
bn = getattr(self, 'bn{}'.format(i + 1))
- chainer.functions.relu(bn(conv(h)))
+ h = chainer.functions.relu(bn(conv(h)))
+ return h
+
+
+class Model(chainer.link.Chain):
+ def __init__(self, convs: DeepConvolution1D, out_size: int):
+ super().__init__()
+ with self.init_scope():
+ self.convs = convs
+ self.last = chainer.links.ConvolutionND(1, convs.out_size, out_size, 1)
+
+ def __call__(self, x):
+ h = x
+ h = self.convs(h)
+ h = self.last(h)
return h
+
+
+def create(config: ModelConfig):
+ convs = DeepConvolution1D(
+ in_size=config.in_size,
+ num_scale=config.num_scale,
+ base_num_z=config.base_num_z,
+ )
+ model = Model(
+ convs=convs,
+ out_size=config.out_size,
+ )
+ return model