diff options
Diffstat (limited to 'become_yukarin/config/config.py')
| -rw-r--r-- | become_yukarin/config/config.py | 170 |
1 files changed, 170 insertions, 0 deletions
diff --git a/become_yukarin/config/config.py b/become_yukarin/config/config.py new file mode 100644 index 0000000..ee1d68f --- /dev/null +++ b/become_yukarin/config/config.py @@ -0,0 +1,170 @@ +import json +from pathlib import Path +from typing import Dict +from typing import List +from typing import NamedTuple +from typing import Optional +from typing import Union + +from become_yukarin.param import Param + + +class DatasetConfig(NamedTuple): + param: Param + input_glob: Path + target_glob: Path + input_mean_path: Path + input_var_path: Path + target_mean_path: Path + target_var_path: Path + features: List[str] + train_crop_size: int + input_global_noise: float + input_local_noise: float + target_global_noise: float + target_local_noise: float + seed: int + num_test: int + + +class DiscriminatorModelConfig(NamedTuple): + in_channels: int + hidden_channels_list: List[int] + + +class ModelConfig(NamedTuple): + in_channels: int + conv_bank_out_channels: int + conv_bank_k: int + max_pooling_k: int + conv_projections_hidden_channels: int + highway_layers: int + out_channels: int + out_size: int + aligner_out_time_length: int + disable_last_rnn: bool + enable_aligner: bool + discriminator: Optional[DiscriminatorModelConfig] + + +class LossConfig(NamedTuple): + l1: float + predictor_fake: float + discriminator_true: float + discriminator_fake: float + discriminator_grad: float + + +class TrainConfig(NamedTuple): + batchsize: int + gpu: int + log_iteration: int + snapshot_iteration: int + + +class ProjectConfig(NamedTuple): + name: str + tags: List[str] + + +class Config(NamedTuple): + dataset: DatasetConfig + model: ModelConfig + loss: LossConfig + train: TrainConfig + project: ProjectConfig + + def save_as_json(self, path): + d = _namedtuple_to_dict(self) + json.dump(d, open(path, 'w'), indent=2, sort_keys=True, default=_default_path) + + +def _default_path(o): + if isinstance(o, Path): + return str(o) + raise TypeError(repr(o) + " is not JSON serializable") + + +def _namedtuple_to_dict(o: NamedTuple): + return { + k: v if not hasattr(v, '_asdict') else _namedtuple_to_dict(v) + for k, v in o._asdict().items() + } + + +def create_from_json(s: Union[str, Path]): + try: + d = json.loads(s) + except TypeError: + d = json.load(open(s)) + + backward_compatible(d) + + if d['model']['discriminator'] is not None: + discriminator_model_config = DiscriminatorModelConfig( + in_channels=d['model']['discriminator']['in_channels'], + hidden_channels_list=d['model']['discriminator']['hidden_channels_list'], + ) + else: + discriminator_model_config = None + + return Config( + dataset=DatasetConfig( + param=Param(), + input_glob=Path(d['dataset']['input_glob']), + target_glob=Path(d['dataset']['target_glob']), + input_mean_path=Path(d['dataset']['input_mean_path']), + input_var_path=Path(d['dataset']['input_var_path']), + target_mean_path=Path(d['dataset']['target_mean_path']), + target_var_path=Path(d['dataset']['target_var_path']), + features=d['dataset']['features'], + train_crop_size=d['dataset']['train_crop_size'], + input_global_noise=d['dataset']['input_global_noise'], + input_local_noise=d['dataset']['input_local_noise'], + target_global_noise=d['dataset']['target_global_noise'], + target_local_noise=d['dataset']['target_local_noise'], + seed=d['dataset']['seed'], + num_test=d['dataset']['num_test'], + ), + model=ModelConfig( + in_channels=d['model']['in_channels'], + conv_bank_out_channels=d['model']['conv_bank_out_channels'], + conv_bank_k=d['model']['conv_bank_k'], + max_pooling_k=d['model']['max_pooling_k'], + conv_projections_hidden_channels=d['model']['conv_projections_hidden_channels'], + highway_layers=d['model']['highway_layers'], + out_channels=d['model']['out_channels'], + out_size=d['model']['out_size'], + aligner_out_time_length=d['model']['aligner_out_time_length'], + disable_last_rnn=d['model']['disable_last_rnn'], + enable_aligner=d['model']['enable_aligner'], + discriminator=discriminator_model_config, + ), + loss=LossConfig( + l1=d['loss']['l1'], + predictor_fake=d['loss']['predictor_fake'], + discriminator_true=d['loss']['discriminator_true'], + discriminator_fake=d['loss']['discriminator_fake'], + discriminator_grad=d['loss']['discriminator_grad'], + ), + train=TrainConfig( + batchsize=d['train']['batchsize'], + gpu=d['train']['gpu'], + log_iteration=d['train']['log_iteration'], + snapshot_iteration=d['train']['snapshot_iteration'], + ), + project=ProjectConfig( + name=d['project']['name'], + tags=d['project']['tags'], + ) + ) + + +def backward_compatible(d: Dict): + if 'input_global_noise' not in d['dataset']: + d['dataset']['input_global_noise'] = d['dataset']['global_noise'] + d['dataset']['input_local_noise'] = d['dataset']['local_noise'] + + if 'target_global_noise' not in d['dataset']: + d['dataset']['target_global_noise'] = d['dataset']['global_noise'] + d['dataset']['target_local_noise'] = d['dataset']['local_noise'] |
