summaryrefslogtreecommitdiff
path: root/become_yukarin/config.py
diff options
context:
space:
mode:
Diffstat (limited to 'become_yukarin/config.py')
-rw-r--r--become_yukarin/config.py42
1 files changed, 36 insertions, 6 deletions
diff --git a/become_yukarin/config.py b/become_yukarin/config.py
index 0efbf04..80212b6 100644
--- a/become_yukarin/config.py
+++ b/become_yukarin/config.py
@@ -23,6 +23,12 @@ class DatasetConfig(NamedTuple):
num_test: int
+class DiscriminatorModelConfig(NamedTuple):
+ in_channels: int
+ hidden_channels_list: List[int]
+ last_channels: int
+
+
class ModelConfig(NamedTuple):
in_channels: int
conv_bank_out_channels: int
@@ -35,10 +41,14 @@ class ModelConfig(NamedTuple):
aligner_out_time_length: int
disable_last_rnn: bool
enable_aligner: bool
+ discriminator: DiscriminatorModelConfig
class LossConfig(NamedTuple):
l1: float
+ predictor_fake: float
+ discriminator_true: float
+ discriminator_fake: float
class TrainConfig(NamedTuple):
@@ -48,11 +58,17 @@ class TrainConfig(NamedTuple):
snapshot_iteration: int
+class ProjectConfig(NamedTuple):
+ name: str
+ tags: List[str]
+
+
class Config(NamedTuple):
dataset: DatasetConfig
model: ModelConfig
loss: LossConfig
train: TrainConfig
+ project: ProjectConfig
def save_as_json(self, path):
d = _namedtuple_to_dict(self)
@@ -78,15 +94,21 @@ def create_from_json(s: Union[str, Path]):
except TypeError:
d = json.load(open(s))
+ discriminator_model_config = DiscriminatorModelConfig(
+ in_channels=d['model']['discriminator']['in_channels'],
+ hidden_channels_list=d['model']['discriminator']['hidden_channels_list'],
+ last_channels=d['model']['discriminator']['last_channels'],
+ )
+
return Config(
dataset=DatasetConfig(
param=Param(),
- input_glob=Path(d['dataset']['input_glob']).expanduser(),
- target_glob=Path(d['dataset']['target_glob']).expanduser(),
- input_mean_path=Path(d['dataset']['input_mean_path']).expanduser(),
- input_var_path=Path(d['dataset']['input_var_path']).expanduser(),
- target_mean_path=Path(d['dataset']['target_mean_path']).expanduser(),
- target_var_path=Path(d['dataset']['target_var_path']).expanduser(),
+ input_glob=Path(d['dataset']['input_glob']),
+ target_glob=Path(d['dataset']['target_glob']),
+ input_mean_path=Path(d['dataset']['input_mean_path']),
+ input_var_path=Path(d['dataset']['input_var_path']),
+ target_mean_path=Path(d['dataset']['target_mean_path']),
+ target_var_path=Path(d['dataset']['target_var_path']),
features=d['dataset']['features'],
train_crop_size=d['dataset']['train_crop_size'],
global_noise=d['dataset']['global_noise'],
@@ -106,9 +128,13 @@ def create_from_json(s: Union[str, Path]):
aligner_out_time_length=d['model']['aligner_out_time_length'],
disable_last_rnn=d['model']['disable_last_rnn'],
enable_aligner=d['model']['enable_aligner'],
+ discriminator=discriminator_model_config,
),
loss=LossConfig(
l1=d['loss']['l1'],
+ predictor_fake=d['loss']['predictor_fake'],
+ discriminator_true=d['loss']['discriminator_true'],
+ discriminator_fake=d['loss']['discriminator_fake'],
),
train=TrainConfig(
batchsize=d['train']['batchsize'],
@@ -116,4 +142,8 @@ def create_from_json(s: Union[str, Path]):
log_iteration=d['train']['log_iteration'],
snapshot_iteration=d['train']['snapshot_iteration'],
),
+ project=ProjectConfig(
+ name=d['project']['name'],
+ tags=d['project']['tags'],
+ )
)