summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp>2017-11-15 02:27:33 +0900
committerHiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp>2017-11-15 02:27:33 +0900
commit4003ae8e457905070b789b75c5972ca93cc5756b (patch)
treed79bd6ba959ee673adcc89e8d5c69c0ec8cf0d93
parenta4f60ab4cd44d1fc89e83bb662fe430e3824d0dc (diff)
little modify
-rw-r--r--become_yukarin/config.py10
-rw-r--r--become_yukarin/dataset/dataset.py6
-rw-r--r--train.py7
3 files changed, 11 insertions, 12 deletions
diff --git a/become_yukarin/config.py b/become_yukarin/config.py
index f74c83e..424e598 100644
--- a/become_yukarin/config.py
+++ b/become_yukarin/config.py
@@ -9,8 +9,8 @@ from .param import Param
class DatasetConfig(NamedTuple):
param: Param
- input_glob: str
- target_glob: str
+ input_glob: Path
+ target_glob: Path
input_mean_path: Path
input_var_path: Path
target_mean_path: Path
@@ -40,7 +40,6 @@ class TrainConfig(NamedTuple):
gpu: int
log_iteration: int
snapshot_iteration: int
- output: Path
class Config(NamedTuple):
@@ -76,8 +75,8 @@ def create_from_json(s: Union[str, Path]):
return Config(
dataset=DatasetConfig(
param=Param(),
- input_glob=d['dataset']['input_glob'],
- target_glob=d['dataset']['target_glob'],
+ input_glob=Path(d['dataset']['input_glob']).expanduser(),
+ target_glob=Path(d['dataset']['target_glob']).expanduser(),
input_mean_path=Path(d['dataset']['input_mean_path']).expanduser(),
input_var_path=Path(d['dataset']['input_var_path']).expanduser(),
target_mean_path=Path(d['dataset']['target_mean_path']).expanduser(),
@@ -104,6 +103,5 @@ def create_from_json(s: Union[str, Path]):
gpu=d['train']['gpu'],
log_iteration=d['train']['log_iteration'],
snapshot_iteration=d['train']['snapshot_iteration'],
- output=Path(d['train']['output']).expanduser(),
),
)
diff --git a/become_yukarin/dataset/dataset.py b/become_yukarin/dataset/dataset.py
index dc5bc74..633843e 100644
--- a/become_yukarin/dataset/dataset.py
+++ b/become_yukarin/dataset/dataset.py
@@ -97,7 +97,7 @@ class AcousticFeatureProcess(BaseDataProcess):
spectrogram=spectrogram.astype(self._dtype),
aperiodicity=aperiodicity.astype(self._dtype),
mfcc=mfcc.astype(self._dtype),
- voiced=voiced[:, None].astype(self._dtype),
+ voiced=voiced[:, None],
)
feature.validate()
return feature
@@ -232,8 +232,8 @@ class DataProcessDataset(chainer.dataset.DatasetMixin):
def create(config: DatasetConfig):
import glob
- input_paths = list(sorted([Path(p) for p in glob.glob(config.input_glob)]))
- target_paths = list(sorted([Path(p) for p in glob.glob(config.target_glob)]))
+ input_paths = list(sorted([Path(p) for p in glob.glob(str(config.input_glob))]))
+ target_paths = list(sorted([Path(p) for p in glob.glob(str(config.target_glob))]))
assert len(input_paths) == len(target_paths)
acoustic_feature_load_process = AcousticFeatureLoadProcess()
diff --git a/train.py b/train.py
index b1a2213..27fd1fb 100644
--- a/train.py
+++ b/train.py
@@ -16,11 +16,12 @@ from functools import partial
parser = argparse.ArgumentParser()
parser.add_argument('config_json_path', type=Path)
+parser.add_argument('output', type=Path)
arguments = parser.parse_args()
config = create_from_json(arguments.config_json_path)
-config.train.output.mkdir(exist_ok=True)
-config.save_as_json((config.train.output / 'config.json').absolute())
+arguments.output.mkdir(exist_ok=True)
+config.save_as_json((arguments.output / 'config.json').absolute())
# model
predictor = create_model(config.model)
@@ -42,7 +43,7 @@ trigger_snapshot = (config.train.snapshot_iteration, 'iteration')
converter = partial(convert.concat_examples, padding=0)
updater = training.StandardUpdater(train_iter, optimizer, device=config.train.gpu, converter=converter)
-trainer = training.Trainer(updater, out=config.train.output)
+trainer = training.Trainer(updater, out=arguments.output)
ext = extensions.Evaluator(test_iter, model, converter, device=config.train.gpu)
trainer.extend(ext, name='test', trigger=trigger_log)