summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp>2017-11-19 23:44:01 +0900
committerHiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp>2017-11-19 23:44:01 +0900
commit437a869590c989c184d33990b1d788149d073ee9 (patch)
treecc8dbc06096e0323fede3e985505acc605bd9c20
parent3a69b426bfaed71f9ba1ca02fe0767300ac05a06 (diff)
add noise
-rw-r--r--become_yukarin/config.py4
-rw-r--r--become_yukarin/dataset/dataset.py30
2 files changed, 34 insertions, 0 deletions
diff --git a/become_yukarin/config.py b/become_yukarin/config.py
index 864de4e..05b0790 100644
--- a/become_yukarin/config.py
+++ b/become_yukarin/config.py
@@ -17,6 +17,8 @@ class DatasetConfig(NamedTuple):
target_var_path: Path
features: List[str]
train_crop_size: int
+ global_noise: float
+ local_noise: float
seed: int
num_test: int
@@ -84,6 +86,8 @@ def create_from_json(s: Union[str, Path]):
target_var_path=Path(d['dataset']['target_var_path']).expanduser(),
features=d['dataset']['features'],
train_crop_size=d['dataset']['train_crop_size'],
+ global_noise=d['dataset']['global_noise'],
+ local_noise=d['dataset']['local_noise'],
seed=d['dataset']['seed'],
num_test=d['dataset']['num_test'],
),
diff --git a/become_yukarin/dataset/dataset.py b/become_yukarin/dataset/dataset.py
index a6550b7..83936b1 100644
--- a/become_yukarin/dataset/dataset.py
+++ b/become_yukarin/dataset/dataset.py
@@ -270,6 +270,21 @@ class CropProcess(BaseDataProcess):
return numpy.split(data, [start, start + self._crop_size], axis=self._time_axis)[1]
+class AddNoiseProcess(BaseDataProcess):
+ def __init__(self, p_global: float = None, p_local: float = None):
+ assert p_global is None or 0 <= p_global
+ assert p_local is None or 0 <= p_local
+ self._p_global = p_global
+ self._p_local = p_local
+
+ def __call__(self, data: numpy.ndarray, test):
+ assert not test
+
+ g = numpy.random.randn() * self._p_global
+ l = numpy.random.randn(*data.shape).astype(data.dtype) * self._p_local
+ return data + g + l
+
+
class DataProcessDataset(chainer.dataset.DatasetMixin):
def __init__(self, data: typing.List, data_process: BaseDataProcess):
self._data = data
@@ -342,6 +357,21 @@ def create(config: DatasetConfig):
)),
]))
+ # add noise
+ data_process_train.append(SplitProcess(dict(
+ input=ChainProcess([
+ LambdaProcess(lambda d, test: d['input']),
+ AddNoiseProcess(p_global=config.global_noise, p_local=config.local_noise),
+ ]),
+ target=ChainProcess([
+ LambdaProcess(lambda d, test: d['target']),
+ AddNoiseProcess(p_global=config.global_noise, p_local=config.local_noise),
+ ]),
+ mask=ChainProcess([
+ LambdaProcess(lambda d, test: d['mask']),
+ ]),
+ )))
+
data_process_test = data_process_base
num_test = config.num_test