diff options
Diffstat (limited to 'become_yukarin/config/config.py')
| -rw-r--r-- | become_yukarin/config/config.py | 47 |
1 files changed, 4 insertions, 43 deletions
diff --git a/become_yukarin/config/config.py b/become_yukarin/config/config.py index ee1d68f..f49b185 100644 --- a/become_yukarin/config/config.py +++ b/become_yukarin/config/config.py @@ -27,32 +27,14 @@ class DatasetConfig(NamedTuple): num_test: int -class DiscriminatorModelConfig(NamedTuple): - in_channels: int - hidden_channels_list: List[int] - - class ModelConfig(NamedTuple): in_channels: int - conv_bank_out_channels: int - conv_bank_k: int - max_pooling_k: int - conv_projections_hidden_channels: int - highway_layers: int out_channels: int - out_size: int - aligner_out_time_length: int - disable_last_rnn: bool - enable_aligner: bool - discriminator: Optional[DiscriminatorModelConfig] class LossConfig(NamedTuple): - l1: float - predictor_fake: float - discriminator_true: float - discriminator_fake: float - discriminator_grad: float + mse: float + adversarial: float class TrainConfig(NamedTuple): @@ -100,14 +82,6 @@ def create_from_json(s: Union[str, Path]): backward_compatible(d) - if d['model']['discriminator'] is not None: - discriminator_model_config = DiscriminatorModelConfig( - in_channels=d['model']['discriminator']['in_channels'], - hidden_channels_list=d['model']['discriminator']['hidden_channels_list'], - ) - else: - discriminator_model_config = None - return Config( dataset=DatasetConfig( param=Param(), @@ -128,24 +102,11 @@ def create_from_json(s: Union[str, Path]): ), model=ModelConfig( in_channels=d['model']['in_channels'], - conv_bank_out_channels=d['model']['conv_bank_out_channels'], - conv_bank_k=d['model']['conv_bank_k'], - max_pooling_k=d['model']['max_pooling_k'], - conv_projections_hidden_channels=d['model']['conv_projections_hidden_channels'], - highway_layers=d['model']['highway_layers'], out_channels=d['model']['out_channels'], - out_size=d['model']['out_size'], - aligner_out_time_length=d['model']['aligner_out_time_length'], - disable_last_rnn=d['model']['disable_last_rnn'], - enable_aligner=d['model']['enable_aligner'], - discriminator=discriminator_model_config, ), loss=LossConfig( - l1=d['loss']['l1'], - predictor_fake=d['loss']['predictor_fake'], - discriminator_true=d['loss']['discriminator_true'], - discriminator_fake=d['loss']['discriminator_fake'], - discriminator_grad=d['loss']['discriminator_grad'], + mse=d['loss']['mse'], + adversarial=d['loss']['adversarial'], ), train=TrainConfig( batchsize=d['train']['batchsize'], |
