diff options
Diffstat (limited to 'become_yukarin/config.py')
| -rw-r--r-- | become_yukarin/config.py | 42 |
1 files changed, 36 insertions, 6 deletions
diff --git a/become_yukarin/config.py b/become_yukarin/config.py index 0efbf04..80212b6 100644 --- a/become_yukarin/config.py +++ b/become_yukarin/config.py @@ -23,6 +23,12 @@ class DatasetConfig(NamedTuple): num_test: int +class DiscriminatorModelConfig(NamedTuple): + in_channels: int + hidden_channels_list: List[int] + last_channels: int + + class ModelConfig(NamedTuple): in_channels: int conv_bank_out_channels: int @@ -35,10 +41,14 @@ class ModelConfig(NamedTuple): aligner_out_time_length: int disable_last_rnn: bool enable_aligner: bool + discriminator: DiscriminatorModelConfig class LossConfig(NamedTuple): l1: float + predictor_fake: float + discriminator_true: float + discriminator_fake: float class TrainConfig(NamedTuple): @@ -48,11 +58,17 @@ class TrainConfig(NamedTuple): snapshot_iteration: int +class ProjectConfig(NamedTuple): + name: str + tags: List[str] + + class Config(NamedTuple): dataset: DatasetConfig model: ModelConfig loss: LossConfig train: TrainConfig + project: ProjectConfig def save_as_json(self, path): d = _namedtuple_to_dict(self) @@ -78,15 +94,21 @@ def create_from_json(s: Union[str, Path]): except TypeError: d = json.load(open(s)) + discriminator_model_config = DiscriminatorModelConfig( + in_channels=d['model']['discriminator']['in_channels'], + hidden_channels_list=d['model']['discriminator']['hidden_channels_list'], + last_channels=d['model']['discriminator']['last_channels'], + ) + return Config( dataset=DatasetConfig( param=Param(), - input_glob=Path(d['dataset']['input_glob']).expanduser(), - target_glob=Path(d['dataset']['target_glob']).expanduser(), - input_mean_path=Path(d['dataset']['input_mean_path']).expanduser(), - input_var_path=Path(d['dataset']['input_var_path']).expanduser(), - target_mean_path=Path(d['dataset']['target_mean_path']).expanduser(), - target_var_path=Path(d['dataset']['target_var_path']).expanduser(), + input_glob=Path(d['dataset']['input_glob']), + target_glob=Path(d['dataset']['target_glob']), + input_mean_path=Path(d['dataset']['input_mean_path']), + input_var_path=Path(d['dataset']['input_var_path']), + target_mean_path=Path(d['dataset']['target_mean_path']), + target_var_path=Path(d['dataset']['target_var_path']), features=d['dataset']['features'], train_crop_size=d['dataset']['train_crop_size'], global_noise=d['dataset']['global_noise'], @@ -106,9 +128,13 @@ def create_from_json(s: Union[str, Path]): aligner_out_time_length=d['model']['aligner_out_time_length'], disable_last_rnn=d['model']['disable_last_rnn'], enable_aligner=d['model']['enable_aligner'], + discriminator=discriminator_model_config, ), loss=LossConfig( l1=d['loss']['l1'], + predictor_fake=d['loss']['predictor_fake'], + discriminator_true=d['loss']['discriminator_true'], + discriminator_fake=d['loss']['discriminator_fake'], ), train=TrainConfig( batchsize=d['train']['batchsize'], @@ -116,4 +142,8 @@ def create_from_json(s: Union[str, Path]): log_iteration=d['train']['log_iteration'], snapshot_iteration=d['train']['snapshot_iteration'], ), + project=ProjectConfig( + name=d['project']['name'], + tags=d['project']['tags'], + ) ) |
