summaryrefslogtreecommitdiff
path: root/become_yukarin/config.py
diff options
context:
space:
mode:
authorHiroshiba Kazuyuki <hihokaruta@gmail.com>2018-01-14 07:40:07 +0900
committerHiroshiba Kazuyuki <hihokaruta@gmail.com>2018-01-14 07:40:07 +0900
commit2be3f03adc5695f82c6ab86da780108f786ed014 (patch)
treeae4b95aa3e45706598e66cc00ff5ad9f00ef97a9 /become_yukarin/config.py
parentf9185301a22f1632b16dd5266197bb40cb7c302e (diff)
超解像
Diffstat (limited to 'become_yukarin/config.py')
-rw-r--r--become_yukarin/config.py170
1 files changed, 0 insertions, 170 deletions
diff --git a/become_yukarin/config.py b/become_yukarin/config.py
deleted file mode 100644
index 4ba953e..0000000
--- a/become_yukarin/config.py
+++ /dev/null
@@ -1,170 +0,0 @@
-import json
-from pathlib import Path
-from typing import Dict
-from typing import List
-from typing import NamedTuple
-from typing import Optional
-from typing import Union
-
-from .param import Param
-
-
-class DatasetConfig(NamedTuple):
- param: Param
- input_glob: Path
- target_glob: Path
- input_mean_path: Path
- input_var_path: Path
- target_mean_path: Path
- target_var_path: Path
- features: List[str]
- train_crop_size: int
- input_global_noise: float
- input_local_noise: float
- target_global_noise: float
- target_local_noise: float
- seed: int
- num_test: int
-
-
-class DiscriminatorModelConfig(NamedTuple):
- in_channels: int
- hidden_channels_list: List[int]
-
-
-class ModelConfig(NamedTuple):
- in_channels: int
- conv_bank_out_channels: int
- conv_bank_k: int
- max_pooling_k: int
- conv_projections_hidden_channels: int
- highway_layers: int
- out_channels: int
- out_size: int
- aligner_out_time_length: int
- disable_last_rnn: bool
- enable_aligner: bool
- discriminator: Optional[DiscriminatorModelConfig]
-
-
-class LossConfig(NamedTuple):
- l1: float
- predictor_fake: float
- discriminator_true: float
- discriminator_fake: float
- discriminator_grad: float
-
-
-class TrainConfig(NamedTuple):
- batchsize: int
- gpu: int
- log_iteration: int
- snapshot_iteration: int
-
-
-class ProjectConfig(NamedTuple):
- name: str
- tags: List[str]
-
-
-class Config(NamedTuple):
- dataset: DatasetConfig
- model: ModelConfig
- loss: LossConfig
- train: TrainConfig
- project: ProjectConfig
-
- def save_as_json(self, path):
- d = _namedtuple_to_dict(self)
- json.dump(d, open(path, 'w'), indent=2, sort_keys=True, default=_default_path)
-
-
-def _default_path(o):
- if isinstance(o, Path):
- return str(o)
- raise TypeError(repr(o) + " is not JSON serializable")
-
-
-def _namedtuple_to_dict(o: NamedTuple):
- return {
- k: v if not hasattr(v, '_asdict') else _namedtuple_to_dict(v)
- for k, v in o._asdict().items()
- }
-
-
-def create_from_json(s: Union[str, Path]):
- try:
- d = json.loads(s)
- except TypeError:
- d = json.load(open(s))
-
- backward_compatible(d)
-
- if d['model']['discriminator'] is not None:
- discriminator_model_config = DiscriminatorModelConfig(
- in_channels=d['model']['discriminator']['in_channels'],
- hidden_channels_list=d['model']['discriminator']['hidden_channels_list'],
- )
- else:
- discriminator_model_config = None
-
- return Config(
- dataset=DatasetConfig(
- param=Param(),
- input_glob=Path(d['dataset']['input_glob']),
- target_glob=Path(d['dataset']['target_glob']),
- input_mean_path=Path(d['dataset']['input_mean_path']),
- input_var_path=Path(d['dataset']['input_var_path']),
- target_mean_path=Path(d['dataset']['target_mean_path']),
- target_var_path=Path(d['dataset']['target_var_path']),
- features=d['dataset']['features'],
- train_crop_size=d['dataset']['train_crop_size'],
- input_global_noise=d['dataset']['input_global_noise'],
- input_local_noise=d['dataset']['input_local_noise'],
- target_global_noise=d['dataset']['target_global_noise'],
- target_local_noise=d['dataset']['target_local_noise'],
- seed=d['dataset']['seed'],
- num_test=d['dataset']['num_test'],
- ),
- model=ModelConfig(
- in_channels=d['model']['in_channels'],
- conv_bank_out_channels=d['model']['conv_bank_out_channels'],
- conv_bank_k=d['model']['conv_bank_k'],
- max_pooling_k=d['model']['max_pooling_k'],
- conv_projections_hidden_channels=d['model']['conv_projections_hidden_channels'],
- highway_layers=d['model']['highway_layers'],
- out_channels=d['model']['out_channels'],
- out_size=d['model']['out_size'],
- aligner_out_time_length=d['model']['aligner_out_time_length'],
- disable_last_rnn=d['model']['disable_last_rnn'],
- enable_aligner=d['model']['enable_aligner'],
- discriminator=discriminator_model_config,
- ),
- loss=LossConfig(
- l1=d['loss']['l1'],
- predictor_fake=d['loss']['predictor_fake'],
- discriminator_true=d['loss']['discriminator_true'],
- discriminator_fake=d['loss']['discriminator_fake'],
- discriminator_grad=d['loss']['discriminator_grad'],
- ),
- train=TrainConfig(
- batchsize=d['train']['batchsize'],
- gpu=d['train']['gpu'],
- log_iteration=d['train']['log_iteration'],
- snapshot_iteration=d['train']['snapshot_iteration'],
- ),
- project=ProjectConfig(
- name=d['project']['name'],
- tags=d['project']['tags'],
- )
- )
-
-
-def backward_compatible(d: Dict):
- if 'input_global_noise' not in d['dataset']:
- d['dataset']['input_global_noise'] = d['dataset']['global_noise']
- d['dataset']['input_local_noise'] = d['dataset']['local_noise']
-
- if 'target_global_noise' not in d['dataset']:
- d['dataset']['target_global_noise'] = d['dataset']['global_noise']
- d['dataset']['target_local_noise'] = d['dataset']['local_noise']