diff options
Diffstat (limited to 'become_yukarin/config.py')
| -rw-r--r-- | become_yukarin/config.py | 39 |
1 files changed, 30 insertions, 9 deletions
diff --git a/become_yukarin/config.py b/become_yukarin/config.py index 83f3597..4ba953e 100644 --- a/become_yukarin/config.py +++ b/become_yukarin/config.py @@ -1,7 +1,9 @@ import json from pathlib import Path +from typing import Dict from typing import List from typing import NamedTuple +from typing import Optional from typing import Union from .param import Param @@ -17,8 +19,10 @@ class DatasetConfig(NamedTuple): target_var_path: Path features: List[str] train_crop_size: int - global_noise: float - local_noise: float + input_global_noise: float + input_local_noise: float + target_global_noise: float + target_local_noise: float seed: int num_test: int @@ -40,7 +44,7 @@ class ModelConfig(NamedTuple): aligner_out_time_length: int disable_last_rnn: bool enable_aligner: bool - discriminator: DiscriminatorModelConfig + discriminator: Optional[DiscriminatorModelConfig] class LossConfig(NamedTuple): @@ -94,10 +98,15 @@ def create_from_json(s: Union[str, Path]): except TypeError: d = json.load(open(s)) - discriminator_model_config = DiscriminatorModelConfig( - in_channels=d['model']['discriminator']['in_channels'], - hidden_channels_list=d['model']['discriminator']['hidden_channels_list'], - ) + backward_compatible(d) + + if d['model']['discriminator'] is not None: + discriminator_model_config = DiscriminatorModelConfig( + in_channels=d['model']['discriminator']['in_channels'], + hidden_channels_list=d['model']['discriminator']['hidden_channels_list'], + ) + else: + discriminator_model_config = None return Config( dataset=DatasetConfig( @@ -110,8 +119,10 @@ def create_from_json(s: Union[str, Path]): target_var_path=Path(d['dataset']['target_var_path']), features=d['dataset']['features'], train_crop_size=d['dataset']['train_crop_size'], - global_noise=d['dataset']['global_noise'], - local_noise=d['dataset']['local_noise'], + input_global_noise=d['dataset']['input_global_noise'], + input_local_noise=d['dataset']['input_local_noise'], + target_global_noise=d['dataset']['target_global_noise'], + target_local_noise=d['dataset']['target_local_noise'], seed=d['dataset']['seed'], num_test=d['dataset']['num_test'], ), @@ -147,3 +158,13 @@ def create_from_json(s: Union[str, Path]): tags=d['project']['tags'], ) ) + + +def backward_compatible(d: Dict): + if 'input_global_noise' not in d['dataset']: + d['dataset']['input_global_noise'] = d['dataset']['global_noise'] + d['dataset']['input_local_noise'] = d['dataset']['local_noise'] + + if 'target_global_noise' not in d['dataset']: + d['dataset']['target_global_noise'] = d['dataset']['global_noise'] + d['dataset']['target_local_noise'] = d['dataset']['local_noise'] |
