summaryrefslogtreecommitdiff
path: root/become_yukarin/config/config.py
diff options
context:
space:
mode:
Diffstat (limited to 'become_yukarin/config/config.py')
-rw-r--r--become_yukarin/config/config.py170
1 files changed, 170 insertions, 0 deletions
diff --git a/become_yukarin/config/config.py b/become_yukarin/config/config.py
new file mode 100644
index 0000000..ee1d68f
--- /dev/null
+++ b/become_yukarin/config/config.py
@@ -0,0 +1,170 @@
+import json
+from pathlib import Path
+from typing import Dict
+from typing import List
+from typing import NamedTuple
+from typing import Optional
+from typing import Union
+
+from become_yukarin.param import Param
+
+
+class DatasetConfig(NamedTuple):
+ param: Param
+ input_glob: Path
+ target_glob: Path
+ input_mean_path: Path
+ input_var_path: Path
+ target_mean_path: Path
+ target_var_path: Path
+ features: List[str]
+ train_crop_size: int
+ input_global_noise: float
+ input_local_noise: float
+ target_global_noise: float
+ target_local_noise: float
+ seed: int
+ num_test: int
+
+
+class DiscriminatorModelConfig(NamedTuple):
+ in_channels: int
+ hidden_channels_list: List[int]
+
+
+class ModelConfig(NamedTuple):
+ in_channels: int
+ conv_bank_out_channels: int
+ conv_bank_k: int
+ max_pooling_k: int
+ conv_projections_hidden_channels: int
+ highway_layers: int
+ out_channels: int
+ out_size: int
+ aligner_out_time_length: int
+ disable_last_rnn: bool
+ enable_aligner: bool
+ discriminator: Optional[DiscriminatorModelConfig]
+
+
+class LossConfig(NamedTuple):
+ l1: float
+ predictor_fake: float
+ discriminator_true: float
+ discriminator_fake: float
+ discriminator_grad: float
+
+
+class TrainConfig(NamedTuple):
+ batchsize: int
+ gpu: int
+ log_iteration: int
+ snapshot_iteration: int
+
+
+class ProjectConfig(NamedTuple):
+ name: str
+ tags: List[str]
+
+
+class Config(NamedTuple):
+ dataset: DatasetConfig
+ model: ModelConfig
+ loss: LossConfig
+ train: TrainConfig
+ project: ProjectConfig
+
+ def save_as_json(self, path):
+ d = _namedtuple_to_dict(self)
+ json.dump(d, open(path, 'w'), indent=2, sort_keys=True, default=_default_path)
+
+
+def _default_path(o):
+ if isinstance(o, Path):
+ return str(o)
+ raise TypeError(repr(o) + " is not JSON serializable")
+
+
+def _namedtuple_to_dict(o: NamedTuple):
+ return {
+ k: v if not hasattr(v, '_asdict') else _namedtuple_to_dict(v)
+ for k, v in o._asdict().items()
+ }
+
+
+def create_from_json(s: Union[str, Path]):
+ try:
+ d = json.loads(s)
+ except TypeError:
+ d = json.load(open(s))
+
+ backward_compatible(d)
+
+ if d['model']['discriminator'] is not None:
+ discriminator_model_config = DiscriminatorModelConfig(
+ in_channels=d['model']['discriminator']['in_channels'],
+ hidden_channels_list=d['model']['discriminator']['hidden_channels_list'],
+ )
+ else:
+ discriminator_model_config = None
+
+ return Config(
+ dataset=DatasetConfig(
+ param=Param(),
+ input_glob=Path(d['dataset']['input_glob']),
+ target_glob=Path(d['dataset']['target_glob']),
+ input_mean_path=Path(d['dataset']['input_mean_path']),
+ input_var_path=Path(d['dataset']['input_var_path']),
+ target_mean_path=Path(d['dataset']['target_mean_path']),
+ target_var_path=Path(d['dataset']['target_var_path']),
+ features=d['dataset']['features'],
+ train_crop_size=d['dataset']['train_crop_size'],
+ input_global_noise=d['dataset']['input_global_noise'],
+ input_local_noise=d['dataset']['input_local_noise'],
+ target_global_noise=d['dataset']['target_global_noise'],
+ target_local_noise=d['dataset']['target_local_noise'],
+ seed=d['dataset']['seed'],
+ num_test=d['dataset']['num_test'],
+ ),
+ model=ModelConfig(
+ in_channels=d['model']['in_channels'],
+ conv_bank_out_channels=d['model']['conv_bank_out_channels'],
+ conv_bank_k=d['model']['conv_bank_k'],
+ max_pooling_k=d['model']['max_pooling_k'],
+ conv_projections_hidden_channels=d['model']['conv_projections_hidden_channels'],
+ highway_layers=d['model']['highway_layers'],
+ out_channels=d['model']['out_channels'],
+ out_size=d['model']['out_size'],
+ aligner_out_time_length=d['model']['aligner_out_time_length'],
+ disable_last_rnn=d['model']['disable_last_rnn'],
+ enable_aligner=d['model']['enable_aligner'],
+ discriminator=discriminator_model_config,
+ ),
+ loss=LossConfig(
+ l1=d['loss']['l1'],
+ predictor_fake=d['loss']['predictor_fake'],
+ discriminator_true=d['loss']['discriminator_true'],
+ discriminator_fake=d['loss']['discriminator_fake'],
+ discriminator_grad=d['loss']['discriminator_grad'],
+ ),
+ train=TrainConfig(
+ batchsize=d['train']['batchsize'],
+ gpu=d['train']['gpu'],
+ log_iteration=d['train']['log_iteration'],
+ snapshot_iteration=d['train']['snapshot_iteration'],
+ ),
+ project=ProjectConfig(
+ name=d['project']['name'],
+ tags=d['project']['tags'],
+ )
+ )
+
+
+def backward_compatible(d: Dict):
+ if 'input_global_noise' not in d['dataset']:
+ d['dataset']['input_global_noise'] = d['dataset']['global_noise']
+ d['dataset']['input_local_noise'] = d['dataset']['local_noise']
+
+ if 'target_global_noise' not in d['dataset']:
+ d['dataset']['target_global_noise'] = d['dataset']['global_noise']
+ d['dataset']['target_local_noise'] = d['dataset']['local_noise']