summaryrefslogtreecommitdiff
path: root/become_yukarin/dataset/dataset.py
diff options
context:
space:
mode:
Diffstat (limited to 'become_yukarin/dataset/dataset.py')
-rw-r--r--become_yukarin/dataset/dataset.py25
1 files changed, 17 insertions, 8 deletions
diff --git a/become_yukarin/dataset/dataset.py b/become_yukarin/dataset/dataset.py
index b0f9807..38cf749 100644
--- a/become_yukarin/dataset/dataset.py
+++ b/become_yukarin/dataset/dataset.py
@@ -313,7 +313,6 @@ class ShapeAlignProcess(BaseDataProcess):
class RandomPaddingProcess(BaseDataProcess):
def __init__(self, min_size: int, time_axis: int = 1):
- assert time_axis == 1
self._min_size = min_size
self._time_axis = time_axis
@@ -328,7 +327,9 @@ class RandomPaddingProcess(BaseDataProcess):
pre = random.randint(self._min_size - data.shape[self._time_axis] + 1)
post = self._min_size - pre
- return numpy.pad(data, ((0, 0), (pre, post)), mode='constant')
+ pad = [(0, 0)] * data.ndim
+ pad[self._time_axis] = (pre, post)
+ return numpy.pad(data, pad, mode='constant')
class LastPaddingProcess(BaseDataProcess):
@@ -520,8 +521,8 @@ def create_sr(config: SRDatasetConfig):
data_process_base = ChainProcess([
LowHighSpectrogramFeatureLoadProcess(validate=True),
SplitProcess(dict(
- input=LambdaProcess(lambda d, test: numpy.log(d.low)),
- target=LambdaProcess(lambda d, test: numpy.log(d.high)),
+ input=LambdaProcess(lambda d, test: numpy.log(d.low[:, :-1])),
+ target=LambdaProcess(lambda d, test: numpy.log(d.high[:, :-1])),
)),
])
@@ -535,13 +536,13 @@ def create_sr(config: SRDatasetConfig):
def padding(s):
return ChainProcess([
LambdaProcess(lambda d, test: dict(data=d[s], seed=d['seed'])),
- RandomPaddingProcess(min_size=config.train_crop_size),
+ RandomPaddingProcess(min_size=config.train_crop_size, time_axis=0),
])
def crop(s):
return ChainProcess([
LambdaProcess(lambda d, test: dict(data=d[s], seed=d['seed'])),
- RandomCropProcess(crop_size=config.train_crop_size),
+ RandomCropProcess(crop_size=config.train_crop_size, time_axis=0),
])
data_process_train.append(ChainProcess([
@@ -550,6 +551,10 @@ def create_sr(config: SRDatasetConfig):
add_seed(),
SplitProcess(dict(input=crop('input'), target=crop('target'))),
]))
+ data_process_train.append(LambdaProcess(lambda d, test: {
+ 'input': d['input'][numpy.newaxis],
+ 'target': d['target'][numpy.newaxis],
+ }))
data_process_test = copy.deepcopy(data_process_base)
if config.train_crop_size is not None:
@@ -557,14 +562,18 @@ def create_sr(config: SRDatasetConfig):
input=ChainProcess([
LambdaProcess(lambda d, test: d['input']),
LastPaddingProcess(min_size=config.train_crop_size),
- FirstCropProcess(crop_size=config.train_crop_size),
+ FirstCropProcess(crop_size=config.train_crop_size, time_axis=0),
]),
target=ChainProcess([
LambdaProcess(lambda d, test: d['target']),
LastPaddingProcess(min_size=config.train_crop_size),
- FirstCropProcess(crop_size=config.train_crop_size),
+ FirstCropProcess(crop_size=config.train_crop_size, time_axis=0),
]),
)))
+ data_process_test.append(LambdaProcess(lambda d, test: {
+ 'input': d['input'][numpy.newaxis],
+ 'target': d['target'][numpy.newaxis],
+ }))
input_paths = list(sorted([Path(p) for p in glob.glob(str(config.input_glob))]))