summaryrefslogtreecommitdiff
path: root/scripts/voice_conversion_test.py
blob: 24982ea41aff7dc892a980f64e50e78016c55212 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import argparse
import glob
import multiprocessing
import re
from functools import partial
from pathlib import Path

import librosa
import numpy

from become_yukarin import VoiceChanger
from become_yukarin.config.config import create_from_json as create_config

parser = argparse.ArgumentParser()
parser.add_argument('model_names', nargs='+')
parser.add_argument('-md', '--model_directory', type=Path, default=Path('/mnt/dwango/hiroshiba/become-yukarin/'))
parser.add_argument('-iwd', '--input_wave_directory', type=Path,
                    default=Path('/mnt/dwango/hiroshiba/become-yukarin/dataset/hiho-wave/hiho-pause-atr503-subset/'))
args = parser.parse_args()

model_directory = args.model_directory  # type: Path
input_wave_directory = args.input_wave_directory  # type: Path

paths_test = list(Path('./test_data/').glob('*.wav'))


def extract_number(f):
    s = re.findall("\d+", str(f))
    return int(s[-1]) if s else -1


def process(p: Path, voice_changer: VoiceChanger):
    try:
        if p.suffix in ['.npy', '.npz']:
            p = glob.glob(str(input_wave_directory / p.stem) + '.*')[0]
            p = Path(p)
        wave = voice_changer(p)
        librosa.output.write_wav(str(output / p.stem) + '.wav', wave.wave, wave.sampling_rate, norm=True)
    except:
        import traceback
        print('error!', str(p))
        traceback.format_exc()


for model_name in args.model_names:
    base_model = model_directory / model_name
    config = create_config(base_model / 'config.json')

    input_paths = list(sorted([Path(p) for p in glob.glob(str(config.dataset.input_glob))]))
    numpy.random.RandomState(config.dataset.seed).shuffle(input_paths)
    path_train = input_paths[0]
    path_test = input_paths[-1]

    model_paths = base_model.glob('predictor*.npz')
    model_path = list(sorted(model_paths, key=extract_number))[-1]
    print(model_path)
    voice_changer = VoiceChanger(config, model_path)

    output = Path('./output').absolute() / base_model.name
    output.mkdir(exist_ok=True)

    paths = [path_train, path_test] + paths_test

    process_partial = partial(process, voice_changer=voice_changer)
    pool = multiprocessing.Pool()
    pool.map(process_partial, paths)