summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp>2017-12-24 20:26:32 +0900
committerHiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp>2017-12-24 20:27:25 +0900
commit3b38bf420774f2a7f718be927689b67446e680c9 (patch)
tree8c18c84042e500e5ff78729a10d21481f7bd4903
parent93df4c160b8332a4ef41190860b5056905143def (diff)
separate-noise-level
-rw-r--r--become_yukarin/config.py39
-rw-r--r--become_yukarin/dataset/dataset.py4
-rw-r--r--become_yukarin/model.py5
3 files changed, 36 insertions, 12 deletions
diff --git a/become_yukarin/config.py b/become_yukarin/config.py
index 83f3597..4ba953e 100644
--- a/become_yukarin/config.py
+++ b/become_yukarin/config.py
@@ -1,7 +1,9 @@
import json
from pathlib import Path
+from typing import Dict
from typing import List
from typing import NamedTuple
+from typing import Optional
from typing import Union
from .param import Param
@@ -17,8 +19,10 @@ class DatasetConfig(NamedTuple):
target_var_path: Path
features: List[str]
train_crop_size: int
- global_noise: float
- local_noise: float
+ input_global_noise: float
+ input_local_noise: float
+ target_global_noise: float
+ target_local_noise: float
seed: int
num_test: int
@@ -40,7 +44,7 @@ class ModelConfig(NamedTuple):
aligner_out_time_length: int
disable_last_rnn: bool
enable_aligner: bool
- discriminator: DiscriminatorModelConfig
+ discriminator: Optional[DiscriminatorModelConfig]
class LossConfig(NamedTuple):
@@ -94,10 +98,15 @@ def create_from_json(s: Union[str, Path]):
except TypeError:
d = json.load(open(s))
- discriminator_model_config = DiscriminatorModelConfig(
- in_channels=d['model']['discriminator']['in_channels'],
- hidden_channels_list=d['model']['discriminator']['hidden_channels_list'],
- )
+ backward_compatible(d)
+
+ if d['model']['discriminator'] is not None:
+ discriminator_model_config = DiscriminatorModelConfig(
+ in_channels=d['model']['discriminator']['in_channels'],
+ hidden_channels_list=d['model']['discriminator']['hidden_channels_list'],
+ )
+ else:
+ discriminator_model_config = None
return Config(
dataset=DatasetConfig(
@@ -110,8 +119,10 @@ def create_from_json(s: Union[str, Path]):
target_var_path=Path(d['dataset']['target_var_path']),
features=d['dataset']['features'],
train_crop_size=d['dataset']['train_crop_size'],
- global_noise=d['dataset']['global_noise'],
- local_noise=d['dataset']['local_noise'],
+ input_global_noise=d['dataset']['input_global_noise'],
+ input_local_noise=d['dataset']['input_local_noise'],
+ target_global_noise=d['dataset']['target_global_noise'],
+ target_local_noise=d['dataset']['target_local_noise'],
seed=d['dataset']['seed'],
num_test=d['dataset']['num_test'],
),
@@ -147,3 +158,13 @@ def create_from_json(s: Union[str, Path]):
tags=d['project']['tags'],
)
)
+
+
+def backward_compatible(d: Dict):
+ if 'input_global_noise' not in d['dataset']:
+ d['dataset']['input_global_noise'] = d['dataset']['global_noise']
+ d['dataset']['input_local_noise'] = d['dataset']['local_noise']
+
+ if 'target_global_noise' not in d['dataset']:
+ d['dataset']['target_global_noise'] = d['dataset']['global_noise']
+ d['dataset']['target_local_noise'] = d['dataset']['local_noise']
diff --git a/become_yukarin/dataset/dataset.py b/become_yukarin/dataset/dataset.py
index fa68a78..5ad7a80 100644
--- a/become_yukarin/dataset/dataset.py
+++ b/become_yukarin/dataset/dataset.py
@@ -420,11 +420,11 @@ def create(config: DatasetConfig):
data_process_train.append(SplitProcess(dict(
input=ChainProcess([
LambdaProcess(lambda d, test: d['input']),
- AddNoiseProcess(p_global=config.global_noise, p_local=config.local_noise),
+ AddNoiseProcess(p_global=config.input_global_noise, p_local=config.input_local_noise),
]),
target=ChainProcess([
LambdaProcess(lambda d, test: d['target']),
- AddNoiseProcess(p_global=config.global_noise, p_local=config.local_noise),
+ AddNoiseProcess(p_global=config.target_global_noise, p_local=config.target_local_noise),
]),
mask=ChainProcess([
LambdaProcess(lambda d, test: d['mask']),
diff --git a/become_yukarin/model.py b/become_yukarin/model.py
index 8879f11..8a6af14 100644
--- a/become_yukarin/model.py
+++ b/become_yukarin/model.py
@@ -285,5 +285,8 @@ def create(config: ModelConfig):
aligner = create_aligner(config)
else:
aligner = None
- discriminator = create_discriminator(config.discriminator)
+ if config.discriminator is not None:
+ discriminator = create_discriminator(config.discriminator)
+ else:
+ discriminator = None
return predictor, aligner, discriminator