summaryrefslogtreecommitdiff
path: root/scripts/super_resolution_test.py
diff options
context:
space:
mode:
authorHiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp>2018-01-31 05:09:07 +0900
committerHiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp>2018-01-31 05:09:07 +0900
commit48addd22a87f248bb8041bca47e9c209a16175a4 (patch)
tree3c2386adafdea434483106a646f33c6f6a7e10cb /scripts/super_resolution_test.py
parentb432502ccc924bb10bee0cf8fe11afd0a5f4757d (diff)
RealtimeVocoderを試せるコード追加
Diffstat (limited to 'scripts/super_resolution_test.py')
-rw-r--r--scripts/super_resolution_test.py12
1 files changed, 9 insertions, 3 deletions
diff --git a/scripts/super_resolution_test.py b/scripts/super_resolution_test.py
index 8b04ce0..4f34632 100644
--- a/scripts/super_resolution_test.py
+++ b/scripts/super_resolution_test.py
@@ -18,10 +18,12 @@ parser.add_argument('model_names', nargs='+')
parser.add_argument('-md', '--model_directory', type=Path, default=Path('/mnt/dwango/hiroshiba/become-yukarin/'))
parser.add_argument('-iwd', '--input_wave_directory', type=Path,
default=Path('/mnt/dwango/hiroshiba/become-yukarin/dataset/yukari-wave/yukari-news/'))
+parser.add_argument('-g', '--gpu', type=int)
args = parser.parse_args()
model_directory = args.model_directory # type: Path
input_wave_directory = args.input_wave_directory # type: Path
+gpu = args.gpu
paths_test = list(Path('./test_data_sr/').glob('*.wav'))
@@ -41,6 +43,7 @@ def process(p: Path, super_resolution: SuperResolution):
frame_period=param.acoustic_feature_param.frame_period,
order=param.acoustic_feature_param.order,
alpha=param.acoustic_feature_param.alpha,
+ f0_estimating_method=param.acoustic_feature_param.f0_estimating_method,
)
try:
@@ -68,7 +71,7 @@ for model_name in args.model_names:
model_paths = base_model.glob('predictor*.npz')
model_path = list(sorted(model_paths, key=extract_number))[-1]
print(model_path)
- super_resolution = SuperResolution(config, model_path)
+ super_resolution = SuperResolution(config, model_path, gpu=gpu)
output = Path('./output').absolute() / base_model.name
output.mkdir(exist_ok=True)
@@ -76,5 +79,8 @@ for model_name in args.model_names:
paths = [path_train, path_test] + paths_test
process_partial = partial(process, super_resolution=super_resolution)
- pool = multiprocessing.Pool()
- pool.map(process_partial, paths)
+ if gpu is None:
+ pool = multiprocessing.Pool()
+ pool.map(process_partial, paths)
+ else:
+ list(map(process_partial, paths))