summaryrefslogtreecommitdiff
path: root/become_yukarin/super_resolution.py
diff options
context:
space:
mode:
authorHiroshiba Kazuyuki <hihokaruta@gmail.com>2018-01-19 22:34:45 +0900
committerHiroshiba Kazuyuki <hihokaruta@gmail.com>2018-01-19 22:34:45 +0900
commit4b581ca1c7552094221d236d596e7488aa69d0de (patch)
treea7019ea6085c06bc42d5e62ae2c08a6de7e56de4 /become_yukarin/super_resolution.py
parent86079f0cea1f79beb7cbbec08f6c19191929207a (diff)
on PUG
Diffstat (limited to 'become_yukarin/super_resolution.py')
-rw-r--r--become_yukarin/super_resolution.py10
1 files changed, 8 insertions, 2 deletions
diff --git a/become_yukarin/super_resolution.py b/become_yukarin/super_resolution.py
index bdb2e61..163057d 100644
--- a/become_yukarin/super_resolution.py
+++ b/become_yukarin/super_resolution.py
@@ -15,12 +15,15 @@ from become_yukarin.model.sr_model import create_predictor_sr
class SuperResolution(object):
- def __init__(self, config: SRConfig, model_path: Path):
+ def __init__(self, config: SRConfig, model_path: Path, gpu: int = None):
self.config = config
self.model_path = model_path
+ self.gpu = gpu
self.model = model = create_predictor_sr(config.model)
chainer.serializers.load_npz(str(model_path), model)
+ if self.gpu is not None:
+ model.to_gpu(self.gpu)
self._param = param = config.dataset.param
self._wave_process = WaveFileLoadProcess(
@@ -37,7 +40,7 @@ class SuperResolution(object):
)
def convert(self, input: numpy.ndarray) -> numpy.ndarray:
- converter = partial(chainer.dataset.convert.concat_examples, padding=0)
+ converter = partial(chainer.dataset.convert.concat_examples, device=self.gpu, padding=0)
pad = 128 - len(input) % 128
input = numpy.pad(input, [(0, pad), (0, 0)], mode='minimum')
input = numpy.log(input)[:, :-1]
@@ -47,6 +50,9 @@ class SuperResolution(object):
with chainer.using_config('train', False):
out = self.model(inputs).data[0]
+ if self.gpu is not None:
+ out = chainer.cuda.to_cpu(out)
+
out = out[0]
out = numpy.pad(out, [(0, 0), (0, 1)], mode='edge')
out = numpy.exp(out)