diff options
| author | Hiroshiba Kazuyuki <hihokaruta@gmail.com> | 2018-01-14 07:40:07 +0900 |
|---|---|---|
| committer | Hiroshiba Kazuyuki <hihokaruta@gmail.com> | 2018-01-14 07:40:07 +0900 |
| commit | 2be3f03adc5695f82c6ab86da780108f786ed014 (patch) | |
| tree | ae4b95aa3e45706598e66cc00ff5ad9f00ef97a9 /become_yukarin/config.py | |
| parent | f9185301a22f1632b16dd5266197bb40cb7c302e (diff) | |
超解像
Diffstat (limited to 'become_yukarin/config.py')
| -rw-r--r-- | become_yukarin/config.py | 170 |
1 files changed, 0 insertions, 170 deletions
diff --git a/become_yukarin/config.py b/become_yukarin/config.py deleted file mode 100644 index 4ba953e..0000000 --- a/become_yukarin/config.py +++ /dev/null @@ -1,170 +0,0 @@ -import json -from pathlib import Path -from typing import Dict -from typing import List -from typing import NamedTuple -from typing import Optional -from typing import Union - -from .param import Param - - -class DatasetConfig(NamedTuple): - param: Param - input_glob: Path - target_glob: Path - input_mean_path: Path - input_var_path: Path - target_mean_path: Path - target_var_path: Path - features: List[str] - train_crop_size: int - input_global_noise: float - input_local_noise: float - target_global_noise: float - target_local_noise: float - seed: int - num_test: int - - -class DiscriminatorModelConfig(NamedTuple): - in_channels: int - hidden_channels_list: List[int] - - -class ModelConfig(NamedTuple): - in_channels: int - conv_bank_out_channels: int - conv_bank_k: int - max_pooling_k: int - conv_projections_hidden_channels: int - highway_layers: int - out_channels: int - out_size: int - aligner_out_time_length: int - disable_last_rnn: bool - enable_aligner: bool - discriminator: Optional[DiscriminatorModelConfig] - - -class LossConfig(NamedTuple): - l1: float - predictor_fake: float - discriminator_true: float - discriminator_fake: float - discriminator_grad: float - - -class TrainConfig(NamedTuple): - batchsize: int - gpu: int - log_iteration: int - snapshot_iteration: int - - -class ProjectConfig(NamedTuple): - name: str - tags: List[str] - - -class Config(NamedTuple): - dataset: DatasetConfig - model: ModelConfig - loss: LossConfig - train: TrainConfig - project: ProjectConfig - - def save_as_json(self, path): - d = _namedtuple_to_dict(self) - json.dump(d, open(path, 'w'), indent=2, sort_keys=True, default=_default_path) - - -def _default_path(o): - if isinstance(o, Path): - return str(o) - raise TypeError(repr(o) + " is not JSON serializable") - - -def _namedtuple_to_dict(o: NamedTuple): - return { - k: v if not hasattr(v, '_asdict') else _namedtuple_to_dict(v) - for k, v in o._asdict().items() - } - - -def create_from_json(s: Union[str, Path]): - try: - d = json.loads(s) - except TypeError: - d = json.load(open(s)) - - backward_compatible(d) - - if d['model']['discriminator'] is not None: - discriminator_model_config = DiscriminatorModelConfig( - in_channels=d['model']['discriminator']['in_channels'], - hidden_channels_list=d['model']['discriminator']['hidden_channels_list'], - ) - else: - discriminator_model_config = None - - return Config( - dataset=DatasetConfig( - param=Param(), - input_glob=Path(d['dataset']['input_glob']), - target_glob=Path(d['dataset']['target_glob']), - input_mean_path=Path(d['dataset']['input_mean_path']), - input_var_path=Path(d['dataset']['input_var_path']), - target_mean_path=Path(d['dataset']['target_mean_path']), - target_var_path=Path(d['dataset']['target_var_path']), - features=d['dataset']['features'], - train_crop_size=d['dataset']['train_crop_size'], - input_global_noise=d['dataset']['input_global_noise'], - input_local_noise=d['dataset']['input_local_noise'], - target_global_noise=d['dataset']['target_global_noise'], - target_local_noise=d['dataset']['target_local_noise'], - seed=d['dataset']['seed'], - num_test=d['dataset']['num_test'], - ), - model=ModelConfig( - in_channels=d['model']['in_channels'], - conv_bank_out_channels=d['model']['conv_bank_out_channels'], - conv_bank_k=d['model']['conv_bank_k'], - max_pooling_k=d['model']['max_pooling_k'], - conv_projections_hidden_channels=d['model']['conv_projections_hidden_channels'], - highway_layers=d['model']['highway_layers'], - out_channels=d['model']['out_channels'], - out_size=d['model']['out_size'], - aligner_out_time_length=d['model']['aligner_out_time_length'], - disable_last_rnn=d['model']['disable_last_rnn'], - enable_aligner=d['model']['enable_aligner'], - discriminator=discriminator_model_config, - ), - loss=LossConfig( - l1=d['loss']['l1'], - predictor_fake=d['loss']['predictor_fake'], - discriminator_true=d['loss']['discriminator_true'], - discriminator_fake=d['loss']['discriminator_fake'], - discriminator_grad=d['loss']['discriminator_grad'], - ), - train=TrainConfig( - batchsize=d['train']['batchsize'], - gpu=d['train']['gpu'], - log_iteration=d['train']['log_iteration'], - snapshot_iteration=d['train']['snapshot_iteration'], - ), - project=ProjectConfig( - name=d['project']['name'], - tags=d['project']['tags'], - ) - ) - - -def backward_compatible(d: Dict): - if 'input_global_noise' not in d['dataset']: - d['dataset']['input_global_noise'] = d['dataset']['global_noise'] - d['dataset']['input_local_noise'] = d['dataset']['local_noise'] - - if 'target_global_noise' not in d['dataset']: - d['dataset']['target_global_noise'] = d['dataset']['global_noise'] - d['dataset']['target_local_noise'] = d['dataset']['local_noise'] |
