diff options
| -rw-r--r-- | become_yukarin/config.py | 4 | ||||
| -rw-r--r-- | become_yukarin/dataset/dataset.py | 30 |
2 files changed, 34 insertions, 0 deletions
diff --git a/become_yukarin/config.py b/become_yukarin/config.py index 864de4e..05b0790 100644 --- a/become_yukarin/config.py +++ b/become_yukarin/config.py @@ -17,6 +17,8 @@ class DatasetConfig(NamedTuple): target_var_path: Path features: List[str] train_crop_size: int + global_noise: float + local_noise: float seed: int num_test: int @@ -84,6 +86,8 @@ def create_from_json(s: Union[str, Path]): target_var_path=Path(d['dataset']['target_var_path']).expanduser(), features=d['dataset']['features'], train_crop_size=d['dataset']['train_crop_size'], + global_noise=d['dataset']['global_noise'], + local_noise=d['dataset']['local_noise'], seed=d['dataset']['seed'], num_test=d['dataset']['num_test'], ), diff --git a/become_yukarin/dataset/dataset.py b/become_yukarin/dataset/dataset.py index a6550b7..83936b1 100644 --- a/become_yukarin/dataset/dataset.py +++ b/become_yukarin/dataset/dataset.py @@ -270,6 +270,21 @@ class CropProcess(BaseDataProcess): return numpy.split(data, [start, start + self._crop_size], axis=self._time_axis)[1] +class AddNoiseProcess(BaseDataProcess): + def __init__(self, p_global: float = None, p_local: float = None): + assert p_global is None or 0 <= p_global + assert p_local is None or 0 <= p_local + self._p_global = p_global + self._p_local = p_local + + def __call__(self, data: numpy.ndarray, test): + assert not test + + g = numpy.random.randn() * self._p_global + l = numpy.random.randn(*data.shape).astype(data.dtype) * self._p_local + return data + g + l + + class DataProcessDataset(chainer.dataset.DatasetMixin): def __init__(self, data: typing.List, data_process: BaseDataProcess): self._data = data @@ -342,6 +357,21 @@ def create(config: DatasetConfig): )), ])) + # add noise + data_process_train.append(SplitProcess(dict( + input=ChainProcess([ + LambdaProcess(lambda d, test: d['input']), + AddNoiseProcess(p_global=config.global_noise, p_local=config.local_noise), + ]), + target=ChainProcess([ + LambdaProcess(lambda d, test: d['target']), + AddNoiseProcess(p_global=config.global_noise, p_local=config.local_noise), + ]), + mask=ChainProcess([ + LambdaProcess(lambda d, test: d['mask']), + ]), + ))) + data_process_test = data_process_base num_test = config.num_test |
