diff options
Diffstat (limited to 'become_yukarin/config')
| -rw-r--r-- | become_yukarin/config/config.py | 10 | ||||
| -rw-r--r-- | become_yukarin/config/sr_config.py | 9 |
2 files changed, 18 insertions, 1 deletions
diff --git a/become_yukarin/config/config.py b/become_yukarin/config/config.py index f49b185..68ba1bd 100644 --- a/become_yukarin/config/config.py +++ b/become_yukarin/config/config.py @@ -30,6 +30,11 @@ class DatasetConfig(NamedTuple): class ModelConfig(NamedTuple): in_channels: int out_channels: int + generator_base_channels: int + generator_extensive_layers: int + discriminator_base_channels: int + discriminator_extensive_layers: int + weak_discriminator: bool class LossConfig(NamedTuple): @@ -103,6 +108,11 @@ def create_from_json(s: Union[str, Path]): model=ModelConfig( in_channels=d['model']['in_channels'], out_channels=d['model']['out_channels'], + generator_base_channels=d['model']['generator_base_channels'], + generator_extensive_layers=d['model']['generator_extensive_layers'], + discriminator_base_channels=d['model']['discriminator_base_channels'], + discriminator_extensive_layers=d['model']['discriminator_extensive_layers'], + weak_discriminator=d['model']['weak_discriminator'], ), loss=LossConfig( mse=d['loss']['mse'], diff --git a/become_yukarin/config/sr_config.py b/become_yukarin/config/sr_config.py index 4f980a2..75cf6ff 100644 --- a/become_yukarin/config/sr_config.py +++ b/become_yukarin/config/sr_config.py @@ -20,7 +20,10 @@ class SRDatasetConfig(NamedTuple): class SRModelConfig(NamedTuple): - pass + generator_base_channels: int + generator_extensive_layers: int + discriminator_base_channels: int + discriminator_extensive_layers: int class SRLossConfig(NamedTuple): @@ -85,6 +88,10 @@ def create_from_json(s: Union[str, Path]): num_test=d['dataset']['num_test'], ), model=SRModelConfig( + generator_base_channels=d['model']['generator_base_channels'], + generator_extensive_layers=d['model']['generator_extensive_layers'], + discriminator_base_channels=d['model']['discriminator_base_channels'], + discriminator_extensive_layers=d['model']['discriminator_extensive_layers'], ), loss=SRLossConfig( mse=d['loss']['mse'], |
