From c0f3eacabde5d41992a5ae1d8d8f0f170f6b155e Mon Sep 17 00:00:00 2001 From: Hiroshiba Kazuyuki Date: Mon, 15 Jan 2018 08:24:27 +0900 Subject: 超解像クラス MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- become_yukarin/super_resolution.py | 82 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) create mode 100644 become_yukarin/super_resolution.py (limited to 'become_yukarin/super_resolution.py') diff --git a/become_yukarin/super_resolution.py b/become_yukarin/super_resolution.py new file mode 100644 index 0000000..535af57 --- /dev/null +++ b/become_yukarin/super_resolution.py @@ -0,0 +1,82 @@ +from functools import partial +from pathlib import Path +from typing import Optional + +import chainer +import numpy +import pyworld + +from become_yukarin.config.sr_config import SRConfig +from become_yukarin.data_struct import AcousticFeature +from become_yukarin.data_struct import Wave +from become_yukarin.dataset.dataset import LowHighSpectrogramFeatureLoadProcess +from become_yukarin.dataset.dataset import LowHighSpectrogramFeatureProcess +from become_yukarin.dataset.dataset import WaveFileLoadProcess +from become_yukarin.model.sr_model import create_predictor_sr + + +class SuperResolution(object): + def __init__(self, config: SRConfig, model_path: Path): + self.config = config + self.model_path = model_path + + self.model = model = create_predictor_sr(config.model) + chainer.serializers.load_npz(str(model_path), model) + + self._param = param = config.dataset.param + self._wave_process = WaveFileLoadProcess( + sample_rate=param.voice_param.sample_rate, + top_db=None, + ) + self._low_high_spectrogram_process = LowHighSpectrogramFeatureProcess( + frame_period=param.acoustic_feature_param.frame_period, + order=param.acoustic_feature_param.order, + alpha=param.acoustic_feature_param.alpha, + ) + self._low_high_spectrogram_load_process = LowHighSpectrogramFeatureLoadProcess( + validate=True, + ) + + def convert(self, input: numpy.ndarray) -> numpy.ndarray: + converter = partial(chainer.dataset.convert.concat_examples, padding=0) + inputs = converter([numpy.log(input)[:, :-1]]) + + with chainer.using_config('train', False): + out = self.model(inputs).data[0] + + out = out[0] + out[:, out.shape[1]] = out[:, -1] + return out + + def convert_to_audio( + self, + input: numpy.ndarray, + acoustic_feature: AcousticFeature, + sampling_rate: Optional[int] = None, + ): + out = pyworld.synthesize( + f0=acoustic_feature.f0.ravel(), + spectrogram=input.astype(numpy.float64), + aperiodicity=acoustic_feature.aperiodicity, + fs=sampling_rate, + frame_period=self._param.acoustic_feature_param.frame_period, + ) + return Wave(out, sampling_rate=sampling_rate) + + def convert_from_audio_path(self, input: Path): + input = self._wave_process(str(input), test=True) + input = self._low_high_spectrogram_process(input, test=True) + return self.convert(input.low) + + def convert_from_feature_path(self, input: Path): + input = self._low_high_spectrogram_load_process(input, test=True) + return self.convert(input.low) + + def __call__( + self, + input: numpy.ndarray, + acoustic_feature: AcousticFeature, + sampling_rate: Optional[int] = None, + ): + high = self.convert(input) + return self.convert_to_audio(high, acoustic_feature=acoustic_feature, sampling_rate=sampling_rate) -- cgit v1.2.3-70-g09d2