diff options
| author | Hiroshiba <Hiroshiba@users.noreply.github.com> | 2018-02-27 21:54:30 +0900 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2018-02-27 21:54:30 +0900 |
| commit | a7606629eab569cf98029870c8414ffaa6d320b7 (patch) | |
| tree | 6a7a313d3c067c776d259167ae3945c6b75dded0 /become_yukarin/model/sr_model.py | |
| parent | 4741dc5b82563403db43546a1ee49ddcf0ebd1c8 (diff) | |
| parent | 6a84007c044c1664eb998f1b4fd95c6af9878fe2 (diff) | |
Merge pull request #1 from Hiroshiba/pix2pix-expand-size
pix2pixモデルのフィルタサイズを可変にした
Diffstat (limited to 'become_yukarin/model/sr_model.py')
| -rw-r--r-- | become_yukarin/model/sr_model.py | 97 |
1 files changed, 64 insertions, 33 deletions
diff --git a/become_yukarin/model/sr_model.py b/become_yukarin/model/sr_model.py index f8e55d6..12863a7 100644 --- a/become_yukarin/model/sr_model.py +++ b/become_yukarin/model/sr_model.py @@ -16,8 +16,10 @@ class CBR(chainer.Chain): with self.init_scope(): if sample == 'down': self.c = L.Convolution2D(ch0, ch1, 4, 2, 1, initialW=w) - else: + elif sample == 'up': self.c = L.Deconvolution2D(ch0, ch1, 4, 2, 1, initialW=w) + else: + self.c = L.Convolution2D(ch0, ch1, 1, 1, 0, initialW=w) if bn: self.batchnorm = L.BatchNormalization(ch1) @@ -32,19 +34,24 @@ class CBR(chainer.Chain): return h -class Encoder(chainer.Chain): - def __init__(self, in_ch) -> None: +class SREncoder(chainer.Chain): + def __init__(self, in_ch, base=64, extensive_layers=8) -> None: super().__init__() w = chainer.initializers.Normal(0.02) with self.init_scope(): - self.c0 = L.Convolution2D(in_ch, 64, 3, 1, 1, initialW=w) - self.c1 = CBR(64, 128, bn=True, sample='down', activation=F.leaky_relu, dropout=False) - self.c2 = CBR(128, 256, bn=True, sample='down', activation=F.leaky_relu, dropout=False) - self.c3 = CBR(256, 512, bn=True, sample='down', activation=F.leaky_relu, dropout=False) - self.c4 = CBR(512, 512, bn=True, sample='down', activation=F.leaky_relu, dropout=False) - self.c5 = CBR(512, 512, bn=True, sample='down', activation=F.leaky_relu, dropout=False) - self.c6 = CBR(512, 512, bn=True, sample='down', activation=F.leaky_relu, dropout=False) - self.c7 = CBR(512, 512, bn=True, sample='down', activation=F.leaky_relu, dropout=False) + if extensive_layers > 0: + self.c0 = L.Convolution2D(in_ch, base * 1, 3, 1, 1, initialW=w) + else: + self.c0 = L.Convolution2D(in_ch, base * 1, 1, 1, 0, initialW=w) + + _choose = lambda i: 'down' if i < extensive_layers else 'same' + self.c1 = CBR(base * 1, base * 2, bn=True, sample=_choose(1), activation=F.leaky_relu, dropout=False) + self.c2 = CBR(base * 2, base * 4, bn=True, sample=_choose(2), activation=F.leaky_relu, dropout=False) + self.c3 = CBR(base * 4, base * 8, bn=True, sample=_choose(3), activation=F.leaky_relu, dropout=False) + self.c4 = CBR(base * 8, base * 8, bn=True, sample=_choose(4), activation=F.leaky_relu, dropout=False) + self.c5 = CBR(base * 8, base * 8, bn=True, sample=_choose(5), activation=F.leaky_relu, dropout=False) + self.c6 = CBR(base * 8, base * 8, bn=True, sample=_choose(6), activation=F.leaky_relu, dropout=False) + self.c7 = CBR(base * 8, base * 8, bn=True, sample=_choose(7), activation=F.leaky_relu, dropout=False) def __call__(self, x): hs = [F.leaky_relu(self.c0(x))] @@ -53,19 +60,24 @@ class Encoder(chainer.Chain): return hs -class Decoder(chainer.Chain): - def __init__(self, out_ch) -> None: +class SRDecoder(chainer.Chain): + def __init__(self, out_ch, base=64, extensive_layers=8) -> None: super().__init__() w = chainer.initializers.Normal(0.02) with self.init_scope(): - self.c0 = CBR(512, 512, bn=True, sample='up', activation=F.relu, dropout=True) - self.c1 = CBR(1024, 512, bn=True, sample='up', activation=F.relu, dropout=True) - self.c2 = CBR(1024, 512, bn=True, sample='up', activation=F.relu, dropout=True) - self.c3 = CBR(1024, 512, bn=True, sample='up', activation=F.relu, dropout=False) - self.c4 = CBR(1024, 256, bn=True, sample='up', activation=F.relu, dropout=False) - self.c5 = CBR(512, 128, bn=True, sample='up', activation=F.relu, dropout=False) - self.c6 = CBR(256, 64, bn=True, sample='up', activation=F.relu, dropout=False) - self.c7 = L.Convolution2D(128, out_ch, 3, 1, 1, initialW=w) + _choose = lambda i: 'up' if i >= 8 - extensive_layers else 'same' + self.c0 = CBR(base * 8, base * 8, bn=True, sample=_choose(0), activation=F.relu, dropout=True) + self.c1 = CBR(base * 16, base * 8, bn=True, sample=_choose(1), activation=F.relu, dropout=True) + self.c2 = CBR(base * 16, base * 8, bn=True, sample=_choose(2), activation=F.relu, dropout=True) + self.c3 = CBR(base * 16, base * 8, bn=True, sample=_choose(3), activation=F.relu, dropout=False) + self.c4 = CBR(base * 16, base * 4, bn=True, sample=_choose(4), activation=F.relu, dropout=False) + self.c5 = CBR(base * 8, base * 2, bn=True, sample=_choose(5), activation=F.relu, dropout=False) + self.c6 = CBR(base * 4, base * 1, bn=True, sample=_choose(6), activation=F.relu, dropout=False) + + if extensive_layers > 0: + self.c7 = L.Convolution2D(base * 2, out_ch, 3, 1, 1, initialW=w) + else: + self.c7 = L.Convolution2D(base * 2, out_ch, 1, 1, 0, initialW=w) def __call__(self, hs): h = self.c0(hs[-1]) @@ -79,27 +91,32 @@ class Decoder(chainer.Chain): class SRPredictor(chainer.Chain): - def __init__(self, in_ch, out_ch) -> None: + def __init__(self, in_ch, out_ch, base, extensive_layers) -> None: super().__init__() with self.init_scope(): - self.encoder = Encoder(in_ch) - self.decoder = Decoder(out_ch) + self.encoder = Encoder(in_ch, base=base, extensive_layers=extensive_layers) + self.decoder = Decoder(out_ch, base=base, extensive_layers=extensive_layers) def __call__(self, x): return self.decoder(self.encoder(x)) class SRDiscriminator(chainer.Chain): - def __init__(self, in_ch, out_ch) -> None: + def __init__(self, in_ch, out_ch, base=32, extensive_layers=5) -> None: super().__init__() w = chainer.initializers.Normal(0.02) with self.init_scope(): - self.c0_0 = CBR(in_ch, 32, bn=False, sample='down', activation=F.leaky_relu, dropout=False) - self.c0_1 = CBR(out_ch, 32, bn=False, sample='down', activation=F.leaky_relu, dropout=False) - self.c1 = CBR(64, 128, bn=True, sample='down', activation=F.leaky_relu, dropout=False) - self.c2 = CBR(128, 256, bn=True, sample='down', activation=F.leaky_relu, dropout=False) - self.c3 = CBR(256, 512, bn=True, sample='down', activation=F.leaky_relu, dropout=False) - self.c4 = L.Convolution2D(512, 1, 3, 1, 1, initialW=w) + _choose = lambda i: 'down' if i < extensive_layers else 'same' + self.c0_0 = CBR(in_ch, base * 1, bn=False, sample=_choose(0), activation=F.leaky_relu, dropout=False) + self.c0_1 = CBR(out_ch, base * 1, bn=False, sample=_choose(0), activation=F.leaky_relu, dropout=False) + self.c1 = CBR(base * 2, base * 4, bn=True, sample=_choose(1), activation=F.leaky_relu, dropout=False) + self.c2 = CBR(base * 4, base * 8, bn=True, sample=_choose(2), activation=F.leaky_relu, dropout=False) + self.c3 = CBR(base * 8, base * 16, bn=True, sample=_choose(3), activation=F.leaky_relu, dropout=False) + + if extensive_layers > 4: + self.c4 = L.Convolution2D(base * 16, 1, 3, 1, 1, initialW=w) + else: + self.c4 = L.Convolution2D(base * 16, 1, 1, 1, 0, initialW=w) def __call__(self, x_0, x_1): h = F.concat([self.c0_0(x_0), self.c0_1(x_1)]) @@ -112,10 +129,24 @@ class SRDiscriminator(chainer.Chain): def create_predictor_sr(config: SRModelConfig): - return SRPredictor(in_ch=1, out_ch=1) + return SRPredictor( + in_ch=1, + out_ch=1, + base=config.generator_base_channels, + extensive_layers=config.generator_extensive_layers, + ) + + +def create_discriminator_sr(config: SRModelConfig): + return SRDiscriminator( + in_ch=1, + out_ch=1, + base=config.discriminator_base_channels, + extensive_layers=config.discriminator_extensive_layers, + ) def create_sr(config: SRModelConfig): predictor = create_predictor_sr(config) - discriminator = SRDiscriminator(in_ch=1, out_ch=1) + discriminator = create_discriminator_sr(config) return predictor, discriminator |
