summaryrefslogtreecommitdiff
path: root/become_yukarin/config.py
diff options
context:
space:
mode:
Diffstat (limited to 'become_yukarin/config.py')
-rw-r--r--become_yukarin/config.py39
1 files changed, 30 insertions, 9 deletions
diff --git a/become_yukarin/config.py b/become_yukarin/config.py
index 83f3597..4ba953e 100644
--- a/become_yukarin/config.py
+++ b/become_yukarin/config.py
@@ -1,7 +1,9 @@
import json
from pathlib import Path
+from typing import Dict
from typing import List
from typing import NamedTuple
+from typing import Optional
from typing import Union
from .param import Param
@@ -17,8 +19,10 @@ class DatasetConfig(NamedTuple):
target_var_path: Path
features: List[str]
train_crop_size: int
- global_noise: float
- local_noise: float
+ input_global_noise: float
+ input_local_noise: float
+ target_global_noise: float
+ target_local_noise: float
seed: int
num_test: int
@@ -40,7 +44,7 @@ class ModelConfig(NamedTuple):
aligner_out_time_length: int
disable_last_rnn: bool
enable_aligner: bool
- discriminator: DiscriminatorModelConfig
+ discriminator: Optional[DiscriminatorModelConfig]
class LossConfig(NamedTuple):
@@ -94,10 +98,15 @@ def create_from_json(s: Union[str, Path]):
except TypeError:
d = json.load(open(s))
- discriminator_model_config = DiscriminatorModelConfig(
- in_channels=d['model']['discriminator']['in_channels'],
- hidden_channels_list=d['model']['discriminator']['hidden_channels_list'],
- )
+ backward_compatible(d)
+
+ if d['model']['discriminator'] is not None:
+ discriminator_model_config = DiscriminatorModelConfig(
+ in_channels=d['model']['discriminator']['in_channels'],
+ hidden_channels_list=d['model']['discriminator']['hidden_channels_list'],
+ )
+ else:
+ discriminator_model_config = None
return Config(
dataset=DatasetConfig(
@@ -110,8 +119,10 @@ def create_from_json(s: Union[str, Path]):
target_var_path=Path(d['dataset']['target_var_path']),
features=d['dataset']['features'],
train_crop_size=d['dataset']['train_crop_size'],
- global_noise=d['dataset']['global_noise'],
- local_noise=d['dataset']['local_noise'],
+ input_global_noise=d['dataset']['input_global_noise'],
+ input_local_noise=d['dataset']['input_local_noise'],
+ target_global_noise=d['dataset']['target_global_noise'],
+ target_local_noise=d['dataset']['target_local_noise'],
seed=d['dataset']['seed'],
num_test=d['dataset']['num_test'],
),
@@ -147,3 +158,13 @@ def create_from_json(s: Union[str, Path]):
tags=d['project']['tags'],
)
)
+
+
+def backward_compatible(d: Dict):
+ if 'input_global_noise' not in d['dataset']:
+ d['dataset']['input_global_noise'] = d['dataset']['global_noise']
+ d['dataset']['input_local_noise'] = d['dataset']['local_noise']
+
+ if 'target_global_noise' not in d['dataset']:
+ d['dataset']['target_global_noise'] = d['dataset']['global_noise']
+ d['dataset']['target_local_noise'] = d['dataset']['local_noise']