diff options
Diffstat (limited to 'become_yukarin/super_resolution.py')
| -rw-r--r-- | become_yukarin/super_resolution.py | 10 |
1 files changed, 8 insertions, 2 deletions
diff --git a/become_yukarin/super_resolution.py b/become_yukarin/super_resolution.py index bdb2e61..163057d 100644 --- a/become_yukarin/super_resolution.py +++ b/become_yukarin/super_resolution.py @@ -15,12 +15,15 @@ from become_yukarin.model.sr_model import create_predictor_sr class SuperResolution(object): - def __init__(self, config: SRConfig, model_path: Path): + def __init__(self, config: SRConfig, model_path: Path, gpu: int = None): self.config = config self.model_path = model_path + self.gpu = gpu self.model = model = create_predictor_sr(config.model) chainer.serializers.load_npz(str(model_path), model) + if self.gpu is not None: + model.to_gpu(self.gpu) self._param = param = config.dataset.param self._wave_process = WaveFileLoadProcess( @@ -37,7 +40,7 @@ class SuperResolution(object): ) def convert(self, input: numpy.ndarray) -> numpy.ndarray: - converter = partial(chainer.dataset.convert.concat_examples, padding=0) + converter = partial(chainer.dataset.convert.concat_examples, device=self.gpu, padding=0) pad = 128 - len(input) % 128 input = numpy.pad(input, [(0, pad), (0, 0)], mode='minimum') input = numpy.log(input)[:, :-1] @@ -47,6 +50,9 @@ class SuperResolution(object): with chainer.using_config('train', False): out = self.model(inputs).data[0] + if self.gpu is not None: + out = chainer.cuda.to_cpu(out) + out = out[0] out = numpy.pad(out, [(0, 0), (0, 1)], mode='edge') out = numpy.exp(out) |
