diff options
| author | Hiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp> | 2017-12-24 20:26:32 +0900 |
|---|---|---|
| committer | Hiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp> | 2017-12-24 20:27:25 +0900 |
| commit | 3b38bf420774f2a7f718be927689b67446e680c9 (patch) | |
| tree | 8c18c84042e500e5ff78729a10d21481f7bd4903 | |
| parent | 93df4c160b8332a4ef41190860b5056905143def (diff) | |
separate-noise-level
| -rw-r--r-- | become_yukarin/config.py | 39 | ||||
| -rw-r--r-- | become_yukarin/dataset/dataset.py | 4 | ||||
| -rw-r--r-- | become_yukarin/model.py | 5 |
3 files changed, 36 insertions, 12 deletions
diff --git a/become_yukarin/config.py b/become_yukarin/config.py index 83f3597..4ba953e 100644 --- a/become_yukarin/config.py +++ b/become_yukarin/config.py @@ -1,7 +1,9 @@ import json from pathlib import Path +from typing import Dict from typing import List from typing import NamedTuple +from typing import Optional from typing import Union from .param import Param @@ -17,8 +19,10 @@ class DatasetConfig(NamedTuple): target_var_path: Path features: List[str] train_crop_size: int - global_noise: float - local_noise: float + input_global_noise: float + input_local_noise: float + target_global_noise: float + target_local_noise: float seed: int num_test: int @@ -40,7 +44,7 @@ class ModelConfig(NamedTuple): aligner_out_time_length: int disable_last_rnn: bool enable_aligner: bool - discriminator: DiscriminatorModelConfig + discriminator: Optional[DiscriminatorModelConfig] class LossConfig(NamedTuple): @@ -94,10 +98,15 @@ def create_from_json(s: Union[str, Path]): except TypeError: d = json.load(open(s)) - discriminator_model_config = DiscriminatorModelConfig( - in_channels=d['model']['discriminator']['in_channels'], - hidden_channels_list=d['model']['discriminator']['hidden_channels_list'], - ) + backward_compatible(d) + + if d['model']['discriminator'] is not None: + discriminator_model_config = DiscriminatorModelConfig( + in_channels=d['model']['discriminator']['in_channels'], + hidden_channels_list=d['model']['discriminator']['hidden_channels_list'], + ) + else: + discriminator_model_config = None return Config( dataset=DatasetConfig( @@ -110,8 +119,10 @@ def create_from_json(s: Union[str, Path]): target_var_path=Path(d['dataset']['target_var_path']), features=d['dataset']['features'], train_crop_size=d['dataset']['train_crop_size'], - global_noise=d['dataset']['global_noise'], - local_noise=d['dataset']['local_noise'], + input_global_noise=d['dataset']['input_global_noise'], + input_local_noise=d['dataset']['input_local_noise'], + target_global_noise=d['dataset']['target_global_noise'], + target_local_noise=d['dataset']['target_local_noise'], seed=d['dataset']['seed'], num_test=d['dataset']['num_test'], ), @@ -147,3 +158,13 @@ def create_from_json(s: Union[str, Path]): tags=d['project']['tags'], ) ) + + +def backward_compatible(d: Dict): + if 'input_global_noise' not in d['dataset']: + d['dataset']['input_global_noise'] = d['dataset']['global_noise'] + d['dataset']['input_local_noise'] = d['dataset']['local_noise'] + + if 'target_global_noise' not in d['dataset']: + d['dataset']['target_global_noise'] = d['dataset']['global_noise'] + d['dataset']['target_local_noise'] = d['dataset']['local_noise'] diff --git a/become_yukarin/dataset/dataset.py b/become_yukarin/dataset/dataset.py index fa68a78..5ad7a80 100644 --- a/become_yukarin/dataset/dataset.py +++ b/become_yukarin/dataset/dataset.py @@ -420,11 +420,11 @@ def create(config: DatasetConfig): data_process_train.append(SplitProcess(dict( input=ChainProcess([ LambdaProcess(lambda d, test: d['input']), - AddNoiseProcess(p_global=config.global_noise, p_local=config.local_noise), + AddNoiseProcess(p_global=config.input_global_noise, p_local=config.input_local_noise), ]), target=ChainProcess([ LambdaProcess(lambda d, test: d['target']), - AddNoiseProcess(p_global=config.global_noise, p_local=config.local_noise), + AddNoiseProcess(p_global=config.target_global_noise, p_local=config.target_local_noise), ]), mask=ChainProcess([ LambdaProcess(lambda d, test: d['mask']), diff --git a/become_yukarin/model.py b/become_yukarin/model.py index 8879f11..8a6af14 100644 --- a/become_yukarin/model.py +++ b/become_yukarin/model.py @@ -285,5 +285,8 @@ def create(config: ModelConfig): aligner = create_aligner(config) else: aligner = None - discriminator = create_discriminator(config.discriminator) + if config.discriminator is not None: + discriminator = create_discriminator(config.discriminator) + else: + discriminator = None return predictor, aligner, discriminator |
