diff options
| author | Hiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp> | 2017-11-22 23:50:31 +0900 |
|---|---|---|
| committer | Hiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp> | 2017-11-22 23:50:31 +0900 |
| commit | 9f87a74de09e38f9d8f3e7ebb5fd26fac44a3b0e (patch) | |
| tree | ae662b5319256e3864877cacbd21c527f33448f0 /become_yukarin | |
| parent | d6af2a851644afe253b97461b35138011a479a95 (diff) | |
can remove aligner
Diffstat (limited to 'become_yukarin')
| -rw-r--r-- | become_yukarin/config.py | 2 | ||||
| -rw-r--r-- | become_yukarin/loss.py | 5 | ||||
| -rw-r--r-- | become_yukarin/model.py | 10 |
3 files changed, 15 insertions, 2 deletions
diff --git a/become_yukarin/config.py b/become_yukarin/config.py index 07f35fd..0efbf04 100644 --- a/become_yukarin/config.py +++ b/become_yukarin/config.py @@ -34,6 +34,7 @@ class ModelConfig(NamedTuple): out_size: int aligner_out_time_length: int disable_last_rnn: bool + enable_aligner: bool class LossConfig(NamedTuple): @@ -104,6 +105,7 @@ def create_from_json(s: Union[str, Path]): out_size=d['model']['out_size'], aligner_out_time_length=d['model']['aligner_out_time_length'], disable_last_rnn=d['model']['disable_last_rnn'], + enable_aligner=d['model']['enable_aligner'], ), loss=LossConfig( l1=d['loss']['l1'], diff --git a/become_yukarin/loss.py b/become_yukarin/loss.py index c59747a..b2b03fc 100644 --- a/become_yukarin/loss.py +++ b/become_yukarin/loss.py @@ -7,7 +7,7 @@ from .model import Predictor class Loss(chainer.link.Chain): - def __init__(self, config: LossConfig, predictor: Predictor, aligner: Aligner): + def __init__(self, config: LossConfig, predictor: Predictor, aligner: Aligner = None): super().__init__() self.config = config @@ -21,7 +21,8 @@ class Loss(chainer.link.Chain): mask = chainer.as_variable(mask) h = input - h = self.aligner(h) + if self.aligner is not None: + h = self.aligner(h) y = self.predictor(h) loss = chainer.functions.sum(chainer.functions.absolute_error(y, target) * mask) diff --git a/become_yukarin/model.py b/become_yukarin/model.py index 3b5102e..c475685 100644 --- a/become_yukarin/model.py +++ b/become_yukarin/model.py @@ -212,8 +212,18 @@ def create_predictor(config: ModelConfig): def create_aligner(config: ModelConfig): + assert config.enable_aligner aligner = Aligner( in_size=config.in_channels, out_time_length=config.aligner_out_time_length, ) return aligner + + +def create(config: ModelConfig): + predictor = create_predictor(config) + if config.enable_aligner: + aligner = create_aligner(config) + else: + aligner = None + return predictor, aligner |
