diff options
| author | Hiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp> | 2017-11-22 23:50:31 +0900 |
|---|---|---|
| committer | Hiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp> | 2017-11-22 23:50:31 +0900 |
| commit | 9f87a74de09e38f9d8f3e7ebb5fd26fac44a3b0e (patch) | |
| tree | ae662b5319256e3864877cacbd21c527f33448f0 | |
| parent | d6af2a851644afe253b97461b35138011a479a95 (diff) | |
can remove aligner
| -rw-r--r-- | become_yukarin/config.py | 2 | ||||
| -rw-r--r-- | become_yukarin/loss.py | 5 | ||||
| -rw-r--r-- | become_yukarin/model.py | 10 | ||||
| -rw-r--r-- | train.py | 6 |
4 files changed, 17 insertions, 6 deletions
diff --git a/become_yukarin/config.py b/become_yukarin/config.py index 07f35fd..0efbf04 100644 --- a/become_yukarin/config.py +++ b/become_yukarin/config.py @@ -34,6 +34,7 @@ class ModelConfig(NamedTuple): out_size: int aligner_out_time_length: int disable_last_rnn: bool + enable_aligner: bool class LossConfig(NamedTuple): @@ -104,6 +105,7 @@ def create_from_json(s: Union[str, Path]): out_size=d['model']['out_size'], aligner_out_time_length=d['model']['aligner_out_time_length'], disable_last_rnn=d['model']['disable_last_rnn'], + enable_aligner=d['model']['enable_aligner'], ), loss=LossConfig( l1=d['loss']['l1'], diff --git a/become_yukarin/loss.py b/become_yukarin/loss.py index c59747a..b2b03fc 100644 --- a/become_yukarin/loss.py +++ b/become_yukarin/loss.py @@ -7,7 +7,7 @@ from .model import Predictor class Loss(chainer.link.Chain): - def __init__(self, config: LossConfig, predictor: Predictor, aligner: Aligner): + def __init__(self, config: LossConfig, predictor: Predictor, aligner: Aligner = None): super().__init__() self.config = config @@ -21,7 +21,8 @@ class Loss(chainer.link.Chain): mask = chainer.as_variable(mask) h = input - h = self.aligner(h) + if self.aligner is not None: + h = self.aligner(h) y = self.predictor(h) loss = chainer.functions.sum(chainer.functions.absolute_error(y, target) * mask) diff --git a/become_yukarin/model.py b/become_yukarin/model.py index 3b5102e..c475685 100644 --- a/become_yukarin/model.py +++ b/become_yukarin/model.py @@ -212,8 +212,18 @@ def create_predictor(config: ModelConfig): def create_aligner(config: ModelConfig): + assert config.enable_aligner aligner = Aligner( in_size=config.in_channels, out_time_length=config.aligner_out_time_length, ) return aligner + + +def create(config: ModelConfig): + predictor = create_predictor(config) + if config.enable_aligner: + aligner = create_aligner(config) + else: + aligner = None + return predictor, aligner @@ -12,8 +12,7 @@ from chainer.training import extensions from become_yukarin.config import create_from_json from become_yukarin.dataset import create as create_dataset from become_yukarin.loss import Loss -from become_yukarin.model import create_aligner -from become_yukarin.model import create_predictor +from become_yukarin.model import create parser = argparse.ArgumentParser() parser.add_argument('config_json_path', type=Path) @@ -27,8 +26,7 @@ config.save_as_json((arguments.output / 'config.json').absolute()) # model if config.train.gpu >= 0: cuda.get_device_from_id(config.train.gpu).use() -predictor = create_predictor(config.model) -aligner = create_aligner(config.model) +predictor, aligner = create(config.model) model = Loss(config.loss, predictor=predictor, aligner=aligner) # dataset |
