diff options
Diffstat (limited to 'become_yukarin/model.py')
| -rw-r--r-- | become_yukarin/model.py | 43 |
1 files changed, 35 insertions, 8 deletions
diff --git a/become_yukarin/model.py b/become_yukarin/model.py index 087afcd..c67f351 100644 --- a/become_yukarin/model.py +++ b/become_yukarin/model.py @@ -1,24 +1,51 @@ import chainer +from .config import ModelConfig -class DeepConvolution(chainer.link.Chain): - def __init__(self, num_scale: int, base_num_z: int, **kwargs): + +class DeepConvolution1D(chainer.link.Chain): + def __init__(self, in_size: int, num_scale: int, base_num_z: int, **kwargs): super().__init__(**kwargs) self.num_scale = num_scale + self.out_size = base_num_z * 2 ** (num_scale - 1) for i in range(num_scale): l = base_num_z * 2 ** i - self.add_link('conv{}'.format(i + 1), - chainer.links.Convolution2D(None, l, 4, 2, 1, nobias=True)) + self.add_link('conv{}'.format(i + 1), chainer.links.ConvolutionND(1, in_size, l, 3, 1, 1, nobias=True)) self.add_link('bn{}'.format(i + 1), chainer.links.BatchNormalization(l)) - - def get_scaled_width(self, base_width): - return base_width // (2 ** self.num_scale) + in_size = l def __call__(self, x): h = x for i in range(self.num_scale): conv = getattr(self, 'conv{}'.format(i + 1)) bn = getattr(self, 'bn{}'.format(i + 1)) - chainer.functions.relu(bn(conv(h))) + h = chainer.functions.relu(bn(conv(h))) + return h + + +class Model(chainer.link.Chain): + def __init__(self, convs: DeepConvolution1D, out_size: int): + super().__init__() + with self.init_scope(): + self.convs = convs + self.last = chainer.links.ConvolutionND(1, convs.out_size, out_size, 1) + + def __call__(self, x): + h = x + h = self.convs(h) + h = self.last(h) return h + + +def create(config: ModelConfig): + convs = DeepConvolution1D( + in_size=config.in_size, + num_scale=config.num_scale, + base_num_z=config.base_num_z, + ) + model = Model( + convs=convs, + out_size=config.out_size, + ) + return model |
