diff options
| author | Hiroshiba Kazuyuki <hihokaruta@gmail.com> | 2018-01-14 07:40:07 +0900 |
|---|---|---|
| committer | Hiroshiba Kazuyuki <hihokaruta@gmail.com> | 2018-01-14 07:40:07 +0900 |
| commit | 2be3f03adc5695f82c6ab86da780108f786ed014 (patch) | |
| tree | ae4b95aa3e45706598e66cc00ff5ad9f00ef97a9 | |
| parent | f9185301a22f1632b16dd5266197bb40cb7c302e (diff) | |
超解像
| -rw-r--r-- | become_yukarin/config/__init__.py | 2 | ||||
| -rw-r--r-- | become_yukarin/config/config.py (renamed from become_yukarin/config.py) | 2 | ||||
| -rw-r--r-- | become_yukarin/config/sr_config.py | 122 | ||||
| -rw-r--r-- | become_yukarin/data_struct.py | 10 | ||||
| -rw-r--r-- | become_yukarin/dataset/__init__.py | 1 | ||||
| -rw-r--r-- | become_yukarin/dataset/dataset.py | 125 | ||||
| -rw-r--r-- | become_yukarin/model/__init__.py | 2 | ||||
| -rw-r--r-- | become_yukarin/model/model.py (renamed from become_yukarin/model.py) | 4 | ||||
| -rw-r--r-- | become_yukarin/model/sr_model.py | 119 | ||||
| -rw-r--r-- | become_yukarin/param.py | 2 | ||||
| -rw-r--r-- | become_yukarin/updater/__init__.py | 2 | ||||
| -rw-r--r-- | become_yukarin/updater/sr_updater.py | 69 | ||||
| -rw-r--r-- | become_yukarin/updater/updater.py (renamed from become_yukarin/updater.py) | 11 | ||||
| -rw-r--r-- | become_yukarin/voice_changer.py | 4 | ||||
| -rw-r--r-- | scripts/extract_acoustic_feature.py | 13 | ||||
| -rw-r--r-- | train.py | 7 | ||||
| -rw-r--r-- | train_sr.py | 98 |
17 files changed, 556 insertions, 37 deletions
diff --git a/become_yukarin/config/__init__.py b/become_yukarin/config/__init__.py new file mode 100644 index 0000000..fb73169 --- /dev/null +++ b/become_yukarin/config/__init__.py @@ -0,0 +1,2 @@ +from . import config +from . import sr_config diff --git a/become_yukarin/config.py b/become_yukarin/config/config.py index 4ba953e..ee1d68f 100644 --- a/become_yukarin/config.py +++ b/become_yukarin/config/config.py @@ -6,7 +6,7 @@ from typing import NamedTuple from typing import Optional from typing import Union -from .param import Param +from become_yukarin.param import Param class DatasetConfig(NamedTuple): diff --git a/become_yukarin/config/sr_config.py b/become_yukarin/config/sr_config.py new file mode 100644 index 0000000..93db424 --- /dev/null +++ b/become_yukarin/config/sr_config.py @@ -0,0 +1,122 @@ +import json +from pathlib import Path +from typing import Dict +from typing import List +from typing import NamedTuple +from typing import Union + +from become_yukarin.param import Param + + +class SRDatasetConfig(NamedTuple): + param: Param + input_glob: Path + train_crop_size: int + seed: int + num_test: int + + +class SRModelConfig(NamedTuple): + in_channels: int + conv_bank_out_channels: int + conv_bank_k: int + max_pooling_k: int + conv_projections_hidden_channels: int + highway_layers: int + out_channels: int + out_size: int + aligner_out_time_length: int + disable_last_rnn: bool + enable_aligner: bool + + +class SRLossConfig(NamedTuple): + mse: float + adversarial: float + + +class SRTrainConfig(NamedTuple): + batchsize: int + gpu: int + log_iteration: int + snapshot_iteration: int + + +class SRProjectConfig(NamedTuple): + name: str + tags: List[str] + + +class SRConfig(NamedTuple): + dataset: SRDatasetConfig + model: SRModelConfig + loss: SRLossConfig + train: SRTrainConfig + project: SRProjectConfig + + def save_as_json(self, path): + d = _namedtuple_to_dict(self) + json.dump(d, open(path, 'w'), indent=2, sort_keys=True, default=_default_path) + + +def _default_path(o): + if isinstance(o, Path): + return str(o) + raise TypeError(repr(o) + " is not JSON serializable") + + +def _namedtuple_to_dict(o: NamedTuple): + return { + k: v if not hasattr(v, '_asdict') else _namedtuple_to_dict(v) + for k, v in o._asdict().items() + } + + +def create_from_json(s: Union[str, Path]): + try: + d = json.loads(s) + except TypeError: + d = json.load(open(s)) + + backward_compatible(d) + + return SRConfig( + dataset=SRDatasetConfig( + param=Param(), + input_glob=Path(d['dataset']['input_glob']), + train_crop_size=d['dataset']['train_crop_size'], + seed=d['dataset']['seed'], + num_test=d['dataset']['num_test'], + ), + model=SRModelConfig( + in_channels=d['model']['in_channels'], + conv_bank_out_channels=d['model']['conv_bank_out_channels'], + conv_bank_k=d['model']['conv_bank_k'], + max_pooling_k=d['model']['max_pooling_k'], + conv_projections_hidden_channels=d['model']['conv_projections_hidden_channels'], + highway_layers=d['model']['highway_layers'], + out_channels=d['model']['out_channels'], + out_size=d['model']['out_size'], + aligner_out_time_length=d['model']['aligner_out_time_length'], + disable_last_rnn=d['model']['disable_last_rnn'], + enable_aligner=d['model']['enable_aligner'], + ), + loss=SRLossConfig( + mse=d['loss']['mse'], + adversarial=d['loss']['adversarial'], + ), + train=SRTrainConfig( + batchsize=d['train']['batchsize'], + gpu=d['train']['gpu'], + log_iteration=d['train']['log_iteration'], + snapshot_iteration=d['train']['snapshot_iteration'], + ), + project=SRProjectConfig( + name=d['project']['name'], + tags=d['project']['tags'], + ) + ) + + +def backward_compatible(d: Dict): + pass diff --git a/become_yukarin/data_struct.py b/become_yukarin/data_struct.py index 73b9b3b..78c8cf3 100644 --- a/become_yukarin/data_struct.py +++ b/become_yukarin/data_struct.py @@ -60,3 +60,13 @@ class AcousticFeature(NamedTuple): mfcc=order + 1, voiced=1, ) + + +class LowHighSpectrogramFeature(NamedTuple): + low: numpy.ndarray + high: numpy.ndarray + + def validate(self): + assert self.low.ndim == 2 + assert self.high.ndim == 2 + assert self.low.shape == self.high.shape diff --git a/become_yukarin/dataset/__init__.py b/become_yukarin/dataset/__init__.py index 4606e7b..591eaa4 100644 --- a/become_yukarin/dataset/__init__.py +++ b/become_yukarin/dataset/__init__.py @@ -1,3 +1,4 @@ from . import dataset from . import utility from .dataset import create +from .dataset import create_sr diff --git a/become_yukarin/dataset/dataset.py b/become_yukarin/dataset/dataset.py index 368073d..b0f9807 100644 --- a/become_yukarin/dataset/dataset.py +++ b/become_yukarin/dataset/dataset.py @@ -1,4 +1,5 @@ import copy +import glob import typing from abc import ABCMeta, abstractmethod from collections import defaultdict @@ -13,8 +14,10 @@ import numpy import pysptk import pyworld -from ..config import DatasetConfig +from ..config.config import DatasetConfig +from ..config.sr_config import SRDatasetConfig from ..data_struct import AcousticFeature +from ..data_struct import LowHighSpectrogramFeature from ..data_struct import Wave @@ -112,6 +115,35 @@ class AcousticFeatureProcess(BaseDataProcess): return feature +class LowHighSpectrogramFeatureProcess(BaseDataProcess): + def __init__(self, frame_period, order, alpha, dtype=numpy.float32): + self._acoustic_feature_process = AcousticFeatureProcess( + frame_period=frame_period, + order=order, + alpha=alpha, + ) + self._dtype = dtype + self._alpha = alpha + + def __call__(self, data: Wave, test): + acoustic_feature = self._acoustic_feature_process(data, test=True).astype_only_float(self._dtype) + high_spectrogram = acoustic_feature.spectrogram + + fftlen = pyworld.get_cheaptrick_fft_size(data.sampling_rate) + low_spectrogram = pysptk.mc2sp( + acoustic_feature.mfcc, + alpha=self._alpha, + fftlen=fftlen, + ) + + feature = LowHighSpectrogramFeature( + low=low_spectrogram, + high=high_spectrogram, + ) + feature.validate() + return feature + + class AcousticFeatureLoadProcess(BaseDataProcess): def __init__(self, validate=False): self._validate = validate @@ -130,6 +162,21 @@ class AcousticFeatureLoadProcess(BaseDataProcess): return feature +class LowHighSpectrogramFeatureLoadProcess(BaseDataProcess): + def __init__(self, validate=False): + self._validate = validate + + def __call__(self, path: Path, test=None): + d = numpy.load(path.expanduser()).item() # type: dict + feature = LowHighSpectrogramFeature( + low=d['low'], + high=d['high'], + ) + if self._validate: + feature.validate() + return feature + + class AcousticFeatureSaveProcess(BaseDataProcess): def __init__(self, validate=False, ignore: List[str] = None): self._validate = validate @@ -353,11 +400,6 @@ class DataProcessDataset(chainer.dataset.DatasetMixin): def create(config: DatasetConfig): - import glob - input_paths = list(sorted([Path(p) for p in glob.glob(str(config.input_glob))])) - target_paths = list(sorted([Path(p) for p in glob.glob(str(config.target_glob))])) - assert len(input_paths) == len(target_paths) - acoustic_feature_load_process = AcousticFeatureLoadProcess() input_mean = acoustic_feature_load_process(config.input_mean_path, test=True) input_var = acoustic_feature_load_process(config.input_var_path, test=True) @@ -433,7 +475,7 @@ def create(config: DatasetConfig): ]), ))) - data_process_test = data_process_base + data_process_test = copy.deepcopy(data_process_base) if config.train_crop_size is not None: data_process_test.append(SplitProcess(dict( input=ChainProcess([ @@ -453,6 +495,10 @@ def create(config: DatasetConfig): ]), ))) + input_paths = list(sorted([Path(p) for p in glob.glob(str(config.input_glob))])) + target_paths = list(sorted([Path(p) for p in glob.glob(str(config.target_glob))])) + assert len(input_paths) == len(target_paths) + num_test = config.num_test pairs = [ dict(input_path=input_path, target_path=target_path) @@ -468,3 +514,68 @@ def create(config: DatasetConfig): 'test': DataProcessDataset(test_paths, data_process_test), 'train_eval': DataProcessDataset(train_for_evaluate_paths, data_process_test), } + + +def create_sr(config: SRDatasetConfig): + data_process_base = ChainProcess([ + LowHighSpectrogramFeatureLoadProcess(validate=True), + SplitProcess(dict( + input=LambdaProcess(lambda d, test: numpy.log(d.low)), + target=LambdaProcess(lambda d, test: numpy.log(d.high)), + )), + ]) + + data_process_train = copy.deepcopy(data_process_base) + + # cropping + if config.train_crop_size is not None: + def add_seed(): + return LambdaProcess(lambda d, test: dict(seed=numpy.random.randint(2 ** 32), **d)) + + def padding(s): + return ChainProcess([ + LambdaProcess(lambda d, test: dict(data=d[s], seed=d['seed'])), + RandomPaddingProcess(min_size=config.train_crop_size), + ]) + + def crop(s): + return ChainProcess([ + LambdaProcess(lambda d, test: dict(data=d[s], seed=d['seed'])), + RandomCropProcess(crop_size=config.train_crop_size), + ]) + + data_process_train.append(ChainProcess([ + add_seed(), + SplitProcess(dict(input=padding('input'), target=padding('target'))), + add_seed(), + SplitProcess(dict(input=crop('input'), target=crop('target'))), + ])) + + data_process_test = copy.deepcopy(data_process_base) + if config.train_crop_size is not None: + data_process_test.append(SplitProcess(dict( + input=ChainProcess([ + LambdaProcess(lambda d, test: d['input']), + LastPaddingProcess(min_size=config.train_crop_size), + FirstCropProcess(crop_size=config.train_crop_size), + ]), + target=ChainProcess([ + LambdaProcess(lambda d, test: d['target']), + LastPaddingProcess(min_size=config.train_crop_size), + FirstCropProcess(crop_size=config.train_crop_size), + ]), + ))) + + input_paths = list(sorted([Path(p) for p in glob.glob(str(config.input_glob))])) + + num_test = config.num_test + numpy.random.RandomState(config.seed).shuffle(input_paths) + train_paths = input_paths[num_test:] + test_paths = input_paths[:num_test] + train_for_evaluate_paths = train_paths[:num_test] + + return { + 'train': DataProcessDataset(train_paths, data_process_train), + 'test': DataProcessDataset(test_paths, data_process_test), + 'train_eval': DataProcessDataset(train_for_evaluate_paths, data_process_test), + } diff --git a/become_yukarin/model/__init__.py b/become_yukarin/model/__init__.py new file mode 100644 index 0000000..6e6c130 --- /dev/null +++ b/become_yukarin/model/__init__.py @@ -0,0 +1,2 @@ +from . import model +from . import sr_model diff --git a/become_yukarin/model.py b/become_yukarin/model/model.py index 8a6af14..fc2d722 100644 --- a/become_yukarin/model.py +++ b/become_yukarin/model/model.py @@ -3,8 +3,8 @@ from typing import List import chainer -from .config import DiscriminatorModelConfig -from .config import ModelConfig +from become_yukarin.config.config import DiscriminatorModelConfig +from become_yukarin.config.config import ModelConfig class Convolution1D(chainer.links.ConvolutionND): diff --git a/become_yukarin/model/sr_model.py b/become_yukarin/model/sr_model.py new file mode 100644 index 0000000..74119a4 --- /dev/null +++ b/become_yukarin/model/sr_model.py @@ -0,0 +1,119 @@ +import chainer +import chainer.functions as F +import chainer.links as L + +from become_yukarin.config.sr_config import SRModelConfig + + +class CBR(chainer.Chain): + def __init__(self, ch0, ch1, bn=True, sample='down', activation=F.relu, dropout=False): + super().__init__() + self.bn = bn + self.activation = activation + self.dropout = dropout + + w = chainer.initializers.Normal(0.02) + with self.init_scope(): + if sample == 'down': + self.c = L.Convolution2D(ch0, ch1, 4, 2, 1, initialW=w) + else: + self.c = L.Deconvolution2D(ch0, ch1, 4, 2, 1, initialW=w) + if bn: + self.batchnorm = L.BatchNormalization(ch1) + + def __call__(self, x): + h = self.c(x) + if self.bn: + h = self.batchnorm(h) + if self.dropout: + h = F.dropout(h) + if self.activation is not None: + h = self.activation(h) + return h + + +class Encoder(chainer.Chain): + def __init__(self, in_ch): + super().__init__() + w = chainer.initializers.Normal(0.02) + with self.init_scope(): + self.c0 = L.Convolution2D(in_ch, 64, 3, 1, 1, initialW=w) + self.c1 = CBR(64, 128, bn=True, sample='down', activation=F.leaky_relu, dropout=False) + self.c2 = CBR(128, 256, bn=True, sample='down', activation=F.leaky_relu, dropout=False) + self.c3 = CBR(256, 512, bn=True, sample='down', activation=F.leaky_relu, dropout=False) + self.c4 = CBR(512, 512, bn=True, sample='down', activation=F.leaky_relu, dropout=False) + self.c5 = CBR(512, 512, bn=True, sample='down', activation=F.leaky_relu, dropout=False) + self.c6 = CBR(512, 512, bn=True, sample='down', activation=F.leaky_relu, dropout=False) + self.c7 = CBR(512, 512, bn=True, sample='down', activation=F.leaky_relu, dropout=False) + + def __call__(self, x): + x = F.reshape(x, (len(x), 1) + x.shape[1:]) + hs = [F.leaky_relu(self.c0(x))] + for i in range(1, 8): + hs.append(self['c%d' % i](hs[i - 1])) + return hs + + +class Decoder(chainer.Chain): + def __init__(self, out_ch): + super().__init__() + w = chainer.initializers.Normal(0.02) + with self.init_scope(): + self.c0 = CBR(512, 512, bn=True, sample='up', activation=F.relu, dropout=True) + self.c1 = CBR(1024, 512, bn=True, sample='up', activation=F.relu, dropout=True) + self.c2 = CBR(1024, 512, bn=True, sample='up', activation=F.relu, dropout=True) + self.c3 = CBR(1024, 512, bn=True, sample='up', activation=F.relu, dropout=False) + self.c4 = CBR(1024, 256, bn=True, sample='up', activation=F.relu, dropout=False) + self.c5 = CBR(512, 128, bn=True, sample='up', activation=F.relu, dropout=False) + self.c6 = CBR(256, 64, bn=True, sample='up', activation=F.relu, dropout=False) + self.c7 = L.Convolution2D(128, out_ch, 3, 1, 1, initialW=w) + + def __call__(self, hs): + h = self.c0(hs[-1]) + for i in range(1, 8): + h = F.concat([h, hs[-i - 1]]) + if i < 7: + h = self['c%d' % i](h) + else: + h = self.c7(h) + return h + + +class SRPredictor(chainer.Chain): + def __init__(self, in_ch, out_ch): + super().__init__() + with self.init_scope(): + self.encoder = Encoder(in_ch) + self.decoder = Decoder(out_ch) + + def __call__(self, x): + return self.decoder(self.encoder(x)) + + +class SRDiscriminator(chainer.Chain): + def __init__(self, in_ch, out_ch): + super().__init__() + w = chainer.initializers.Normal(0.02) + with self.init_scope(): + self.c0_0 = CBR(in_ch, 32, bn=False, sample='down', activation=F.leaky_relu, dropout=False) + self.c0_1 = CBR(out_ch, 32, bn=False, sample='down', activation=F.leaky_relu, dropout=False) + self.c1 = CBR(64, 128, bn=True, sample='down', activation=F.leaky_relu, dropout=False) + self.c2 = CBR(128, 256, bn=True, sample='down', activation=F.leaky_relu, dropout=False) + self.c3 = CBR(256, 512, bn=True, sample='down', activation=F.leaky_relu, dropout=False) + self.c4 = L.Convolution2D(512, 1, 3, 1, 1, initialW=w) + + def __call__(self, x_0, x_1): + x_0 = F.reshape(x_0, (len(x_0), 1) + x_0.shape[1:]) + h = F.concat([self.c0_0(x_0), self.c0_1(x_1)]) + h = self.c1(h) + h = self.c2(h) + h = self.c3(h) + h = self.c4(h) + # h = F.average_pooling_2d(h, h.data.shape[2], 1, 0) + return h + + +def create_sr(config: SRModelConfig): + predictor = SRPredictor(in_ch=1, out_ch=3) + discriminator = SRDiscriminator(in_ch=1, out_ch=3) + return predictor, discriminator diff --git a/become_yukarin/param.py b/become_yukarin/param.py index e6f46bc..5a43d74 100644 --- a/become_yukarin/param.py +++ b/become_yukarin/param.py @@ -9,7 +9,7 @@ class VoiceParam(NamedTuple): class AcousticFeatureParam(NamedTuple): frame_period: int = 5 - order: int = 25 + order: int = 8 alpha: float = 0.466 diff --git a/become_yukarin/updater/__init__.py b/become_yukarin/updater/__init__.py new file mode 100644 index 0000000..d85003a --- /dev/null +++ b/become_yukarin/updater/__init__.py @@ -0,0 +1,2 @@ +from . import sr_updater +from . import updater diff --git a/become_yukarin/updater/sr_updater.py b/become_yukarin/updater/sr_updater.py new file mode 100644 index 0000000..a6b1d22 --- /dev/null +++ b/become_yukarin/updater/sr_updater.py @@ -0,0 +1,69 @@ +import chainer +import chainer.functions as F +from become_yukarin.config.sr_config import SRLossConfig + +from become_yukarin.model.sr_model import SRDiscriminator +from become_yukarin.model.sr_model import SRPredictor + + +class SRUpdater(chainer.training.StandardUpdater): + def __init__( + self, + loss_config: SRLossConfig, + predictor: SRPredictor, + discriminator: SRDiscriminator, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.loss_config = loss_config + self.predictor = predictor + self.discriminator = discriminator + + def _loss_predictor(self, predictor, output, target, d_fake): + b, _, w, h = d_fake.data.shape + + loss_mse = (F.mean_absolute_error(output, target)) + chainer.report({'mse': loss_mse}, predictor) + + loss_adv = F.sum(F.softplus(-d_fake)) / (b * w * h) + chainer.report({'adversarial': loss_adv}, predictor) + + loss = self.loss_config.mse * loss_mse + self.loss_config.adversarial * loss_adv + chainer.report({'loss': loss}, predictor) + return loss + + def _loss_discriminator(self, discriminator, y_in, y_out): + b, _, w, h = y_in.data.shape + + loss_real = F.sum(F.softplus(-y_in)) / (b * w * h) + chainer.report({'real': loss_real}, discriminator) + + loss_fake = F.sum(F.softplus(y_out)) / (b * w * h) + chainer.report({'fake': loss_fake}, discriminator) + + loss = loss_real + loss_fake + chainer.report({'loss': loss}, discriminator) + return loss + + def forward(self, input, target): + output = self.predictor(input) + d_fake = self.discriminator(input, output) + d_real = self.discriminator(input, target) + + loss = { + 'predictor': self._loss_predictor(self.predictor, output, target, d_fake), + 'discriminator': self._loss_discriminator(self.discriminator, d_real, d_fake), + } + return loss + + def update_core(self): + opt_predictor = self.get_optimizer('predictor') + opt_discriminator = self.get_optimizer('discriminator') + + batch = self.get_iterator('main').next() + batch = self.converter(batch, self.device) + loss = self.forward(**batch) + + opt_predictor.update(loss.get, 'predictor') + opt_discriminator.update(loss.get, 'discriminator') diff --git a/become_yukarin/updater.py b/become_yukarin/updater/updater.py index f6444d0..ef77e77 100644 --- a/become_yukarin/updater.py +++ b/become_yukarin/updater/updater.py @@ -2,18 +2,16 @@ import chainer import numpy from chainer import reporter -from .config import LossConfig -from .config import ModelConfig -from .model import Aligner -from .model import Discriminator -from .model import Predictor +from become_yukarin.config.config import LossConfig +from become_yukarin.model.model import Aligner +from become_yukarin.model.model import Discriminator +from become_yukarin.model.model import Predictor class Updater(chainer.training.StandardUpdater): def __init__( self, loss_config: LossConfig, - model_config: ModelConfig, predictor: Predictor, aligner: Aligner = None, discriminator: Discriminator = None, @@ -22,7 +20,6 @@ class Updater(chainer.training.StandardUpdater): ): super().__init__(*args, **kwargs) self.loss_config = loss_config - self.model_config = model_config self.predictor = predictor self.aligner = aligner self.discriminator = discriminator diff --git a/become_yukarin/voice_changer.py b/become_yukarin/voice_changer.py index 140a2a0..a8a207a 100644 --- a/become_yukarin/voice_changer.py +++ b/become_yukarin/voice_changer.py @@ -7,7 +7,7 @@ import numpy import pysptk import pyworld -from become_yukarin.config import Config +from become_yukarin.config.config import Config from become_yukarin.data_struct import AcousticFeature from become_yukarin.data_struct import Wave from become_yukarin.dataset.dataset import AcousticFeatureDenormalizeProcess @@ -17,7 +17,7 @@ from become_yukarin.dataset.dataset import AcousticFeatureProcess from become_yukarin.dataset.dataset import DecodeFeatureProcess from become_yukarin.dataset.dataset import EncodeFeatureProcess from become_yukarin.dataset.dataset import WaveFileLoadProcess -from become_yukarin.model import create_predictor +from become_yukarin.model.model import create_predictor class VoiceChanger(object): diff --git a/scripts/extract_acoustic_feature.py b/scripts/extract_acoustic_feature.py index 297c10b..7943639 100644 --- a/scripts/extract_acoustic_feature.py +++ b/scripts/extract_acoustic_feature.py @@ -40,19 +40,6 @@ arguments = parser.parse_args() pprint(dir(arguments)) -def make_feature( - path, - sample_rate, - top_db, - frame_period, - order, - alpha, -): - wave = WaveFileLoadProcess(sample_rate=sample_rate, top_db=top_db)(path, test=True) - feature = AcousticFeatureProcess(frame_period=frame_period, order=order, alpha=alpha)(wave, test=True) - return feature - - def generate_feature(path1, path2): out1 = Path(arguments.output1_directory, path1.stem + '.npy') out2 = Path(arguments.output2_directory, path2.stem + '.npy') @@ -10,10 +10,10 @@ from chainer.iterators import MultiprocessIterator from chainer.training import extensions from chainerui.utils import save_args -from become_yukarin.config import create_from_json +from become_yukarin.config.config import create_from_json from become_yukarin.dataset import create as create_dataset -from become_yukarin.model import create -from become_yukarin.updater import Updater +from become_yukarin.model.model import create +from become_yukarin.updater.updater import Updater parser = argparse.ArgumentParser() parser.add_argument('config_json_path', type=Path) @@ -54,7 +54,6 @@ opts = {key: create_optimizer(model) for key, model in models.items()} converter = partial(convert.concat_examples, padding=0) updater = Updater( loss_config=config.loss, - model_config=config.model, predictor=predictor, aligner=aligner, discriminator=discriminator, diff --git a/train_sr.py b/train_sr.py new file mode 100644 index 0000000..c714aa0 --- /dev/null +++ b/train_sr.py @@ -0,0 +1,98 @@ +import argparse +from functools import partial +from pathlib import Path + +from chainer import cuda +from chainer import optimizers +from chainer import training +from chainer.dataset import convert +from chainer.iterators import MultiprocessIterator +from chainer.training import extensions +from chainerui.utils import save_args + +from become_yukarin.config.sr_config import create_from_json +from become_yukarin.dataset import create_sr as create_sr_dataset +from become_yukarin.model.sr_model import create_sr as create_sr_model +from become_yukarin.updater.sr_updater import SRUpdater + +parser = argparse.ArgumentParser() +parser.add_argument('config_json_path', type=Path) +parser.add_argument('output', type=Path) +arguments = parser.parse_args() + +config = create_from_json(arguments.config_json_path) +arguments.output.mkdir(exist_ok=True) +config.save_as_json((arguments.output / 'config.json').absolute()) + +# model +if config.train.gpu >= 0: + cuda.get_device_from_id(config.train.gpu).use() +predictor, discriminator = create_sr_model(config.model) +models = { + 'predictor': predictor, + 'discriminator': discriminator, +} + +# dataset +dataset = create_sr_dataset(config.dataset) +train_iter = MultiprocessIterator(dataset['train'], config.train.batchsize) +test_iter = MultiprocessIterator(dataset['test'], config.train.batchsize, repeat=False, shuffle=False) +train_eval_iter = MultiprocessIterator(dataset['train_eval'], config.train.batchsize, repeat=False, shuffle=False) + + +# optimizer +def create_optimizer(model): + optimizer = optimizers.Adam(alpha=0.0002, beta1=0.5, beta2=0.999) + optimizer.setup(model) + return optimizer + + +opts = {key: create_optimizer(model) for key, model in models.items()} + +# updater +converter = partial(convert.concat_examples, padding=0) +updater = SRUpdater( + loss_config=config.loss, + predictor=predictor, + discriminator=discriminator, + device=config.train.gpu, + iterator=train_iter, + optimizer=opts, + converter=converter, +) + +# trainer +trigger_log = (config.train.log_iteration, 'iteration') +trigger_snapshot = (config.train.snapshot_iteration, 'iteration') + +trainer = training.Trainer(updater, out=arguments.output) + +ext = extensions.Evaluator(test_iter, models, converter, device=config.train.gpu, eval_func=updater.forward) +trainer.extend(ext, name='test', trigger=trigger_log) +ext = extensions.Evaluator(train_eval_iter, models, converter, device=config.train.gpu, eval_func=updater.forward) +trainer.extend(ext, name='train', trigger=trigger_log) + +trainer.extend(extensions.dump_graph('predictor/loss')) + +ext = extensions.snapshot_object(predictor, filename='predictor_{.updater.iteration}.npz') +trainer.extend(ext, trigger=trigger_snapshot) + +trainer.extend(extensions.LogReport(trigger=trigger_log)) + +if extensions.PlotReport.available(): + trainer.extend(extensions.PlotReport( + y_keys=[ + 'predictor/loss', + 'predictor/mse', + 'predictor/adversarial', + 'discriminator/accuracy', + 'discriminator/fake', + 'discriminator/real', + ], + x_key='iteration', + file_name='loss.png', + trigger=trigger_log, + )) + +save_args(arguments, arguments.output) +trainer.run() |
