summaryrefslogtreecommitdiff
path: root/become_yukarin/model/sr_model.py
diff options
context:
space:
mode:
Diffstat (limited to 'become_yukarin/model/sr_model.py')
-rw-r--r--become_yukarin/model/sr_model.py97
1 files changed, 64 insertions, 33 deletions
diff --git a/become_yukarin/model/sr_model.py b/become_yukarin/model/sr_model.py
index f8e55d6..12863a7 100644
--- a/become_yukarin/model/sr_model.py
+++ b/become_yukarin/model/sr_model.py
@@ -16,8 +16,10 @@ class CBR(chainer.Chain):
with self.init_scope():
if sample == 'down':
self.c = L.Convolution2D(ch0, ch1, 4, 2, 1, initialW=w)
- else:
+ elif sample == 'up':
self.c = L.Deconvolution2D(ch0, ch1, 4, 2, 1, initialW=w)
+ else:
+ self.c = L.Convolution2D(ch0, ch1, 1, 1, 0, initialW=w)
if bn:
self.batchnorm = L.BatchNormalization(ch1)
@@ -32,19 +34,24 @@ class CBR(chainer.Chain):
return h
-class Encoder(chainer.Chain):
- def __init__(self, in_ch) -> None:
+class SREncoder(chainer.Chain):
+ def __init__(self, in_ch, base=64, extensive_layers=8) -> None:
super().__init__()
w = chainer.initializers.Normal(0.02)
with self.init_scope():
- self.c0 = L.Convolution2D(in_ch, 64, 3, 1, 1, initialW=w)
- self.c1 = CBR(64, 128, bn=True, sample='down', activation=F.leaky_relu, dropout=False)
- self.c2 = CBR(128, 256, bn=True, sample='down', activation=F.leaky_relu, dropout=False)
- self.c3 = CBR(256, 512, bn=True, sample='down', activation=F.leaky_relu, dropout=False)
- self.c4 = CBR(512, 512, bn=True, sample='down', activation=F.leaky_relu, dropout=False)
- self.c5 = CBR(512, 512, bn=True, sample='down', activation=F.leaky_relu, dropout=False)
- self.c6 = CBR(512, 512, bn=True, sample='down', activation=F.leaky_relu, dropout=False)
- self.c7 = CBR(512, 512, bn=True, sample='down', activation=F.leaky_relu, dropout=False)
+ if extensive_layers > 0:
+ self.c0 = L.Convolution2D(in_ch, base * 1, 3, 1, 1, initialW=w)
+ else:
+ self.c0 = L.Convolution2D(in_ch, base * 1, 1, 1, 0, initialW=w)
+
+ _choose = lambda i: 'down' if i < extensive_layers else 'same'
+ self.c1 = CBR(base * 1, base * 2, bn=True, sample=_choose(1), activation=F.leaky_relu, dropout=False)
+ self.c2 = CBR(base * 2, base * 4, bn=True, sample=_choose(2), activation=F.leaky_relu, dropout=False)
+ self.c3 = CBR(base * 4, base * 8, bn=True, sample=_choose(3), activation=F.leaky_relu, dropout=False)
+ self.c4 = CBR(base * 8, base * 8, bn=True, sample=_choose(4), activation=F.leaky_relu, dropout=False)
+ self.c5 = CBR(base * 8, base * 8, bn=True, sample=_choose(5), activation=F.leaky_relu, dropout=False)
+ self.c6 = CBR(base * 8, base * 8, bn=True, sample=_choose(6), activation=F.leaky_relu, dropout=False)
+ self.c7 = CBR(base * 8, base * 8, bn=True, sample=_choose(7), activation=F.leaky_relu, dropout=False)
def __call__(self, x):
hs = [F.leaky_relu(self.c0(x))]
@@ -53,19 +60,24 @@ class Encoder(chainer.Chain):
return hs
-class Decoder(chainer.Chain):
- def __init__(self, out_ch) -> None:
+class SRDecoder(chainer.Chain):
+ def __init__(self, out_ch, base=64, extensive_layers=8) -> None:
super().__init__()
w = chainer.initializers.Normal(0.02)
with self.init_scope():
- self.c0 = CBR(512, 512, bn=True, sample='up', activation=F.relu, dropout=True)
- self.c1 = CBR(1024, 512, bn=True, sample='up', activation=F.relu, dropout=True)
- self.c2 = CBR(1024, 512, bn=True, sample='up', activation=F.relu, dropout=True)
- self.c3 = CBR(1024, 512, bn=True, sample='up', activation=F.relu, dropout=False)
- self.c4 = CBR(1024, 256, bn=True, sample='up', activation=F.relu, dropout=False)
- self.c5 = CBR(512, 128, bn=True, sample='up', activation=F.relu, dropout=False)
- self.c6 = CBR(256, 64, bn=True, sample='up', activation=F.relu, dropout=False)
- self.c7 = L.Convolution2D(128, out_ch, 3, 1, 1, initialW=w)
+ _choose = lambda i: 'up' if i >= 8 - extensive_layers else 'same'
+ self.c0 = CBR(base * 8, base * 8, bn=True, sample=_choose(0), activation=F.relu, dropout=True)
+ self.c1 = CBR(base * 16, base * 8, bn=True, sample=_choose(1), activation=F.relu, dropout=True)
+ self.c2 = CBR(base * 16, base * 8, bn=True, sample=_choose(2), activation=F.relu, dropout=True)
+ self.c3 = CBR(base * 16, base * 8, bn=True, sample=_choose(3), activation=F.relu, dropout=False)
+ self.c4 = CBR(base * 16, base * 4, bn=True, sample=_choose(4), activation=F.relu, dropout=False)
+ self.c5 = CBR(base * 8, base * 2, bn=True, sample=_choose(5), activation=F.relu, dropout=False)
+ self.c6 = CBR(base * 4, base * 1, bn=True, sample=_choose(6), activation=F.relu, dropout=False)
+
+ if extensive_layers > 0:
+ self.c7 = L.Convolution2D(base * 2, out_ch, 3, 1, 1, initialW=w)
+ else:
+ self.c7 = L.Convolution2D(base * 2, out_ch, 1, 1, 0, initialW=w)
def __call__(self, hs):
h = self.c0(hs[-1])
@@ -79,27 +91,32 @@ class Decoder(chainer.Chain):
class SRPredictor(chainer.Chain):
- def __init__(self, in_ch, out_ch) -> None:
+ def __init__(self, in_ch, out_ch, base, extensive_layers) -> None:
super().__init__()
with self.init_scope():
- self.encoder = Encoder(in_ch)
- self.decoder = Decoder(out_ch)
+ self.encoder = Encoder(in_ch, base=base, extensive_layers=extensive_layers)
+ self.decoder = Decoder(out_ch, base=base, extensive_layers=extensive_layers)
def __call__(self, x):
return self.decoder(self.encoder(x))
class SRDiscriminator(chainer.Chain):
- def __init__(self, in_ch, out_ch) -> None:
+ def __init__(self, in_ch, out_ch, base=32, extensive_layers=5) -> None:
super().__init__()
w = chainer.initializers.Normal(0.02)
with self.init_scope():
- self.c0_0 = CBR(in_ch, 32, bn=False, sample='down', activation=F.leaky_relu, dropout=False)
- self.c0_1 = CBR(out_ch, 32, bn=False, sample='down', activation=F.leaky_relu, dropout=False)
- self.c1 = CBR(64, 128, bn=True, sample='down', activation=F.leaky_relu, dropout=False)
- self.c2 = CBR(128, 256, bn=True, sample='down', activation=F.leaky_relu, dropout=False)
- self.c3 = CBR(256, 512, bn=True, sample='down', activation=F.leaky_relu, dropout=False)
- self.c4 = L.Convolution2D(512, 1, 3, 1, 1, initialW=w)
+ _choose = lambda i: 'down' if i < extensive_layers else 'same'
+ self.c0_0 = CBR(in_ch, base * 1, bn=False, sample=_choose(0), activation=F.leaky_relu, dropout=False)
+ self.c0_1 = CBR(out_ch, base * 1, bn=False, sample=_choose(0), activation=F.leaky_relu, dropout=False)
+ self.c1 = CBR(base * 2, base * 4, bn=True, sample=_choose(1), activation=F.leaky_relu, dropout=False)
+ self.c2 = CBR(base * 4, base * 8, bn=True, sample=_choose(2), activation=F.leaky_relu, dropout=False)
+ self.c3 = CBR(base * 8, base * 16, bn=True, sample=_choose(3), activation=F.leaky_relu, dropout=False)
+
+ if extensive_layers > 4:
+ self.c4 = L.Convolution2D(base * 16, 1, 3, 1, 1, initialW=w)
+ else:
+ self.c4 = L.Convolution2D(base * 16, 1, 1, 1, 0, initialW=w)
def __call__(self, x_0, x_1):
h = F.concat([self.c0_0(x_0), self.c0_1(x_1)])
@@ -112,10 +129,24 @@ class SRDiscriminator(chainer.Chain):
def create_predictor_sr(config: SRModelConfig):
- return SRPredictor(in_ch=1, out_ch=1)
+ return SRPredictor(
+ in_ch=1,
+ out_ch=1,
+ base=config.generator_base_channels,
+ extensive_layers=config.generator_extensive_layers,
+ )
+
+
+def create_discriminator_sr(config: SRModelConfig):
+ return SRDiscriminator(
+ in_ch=1,
+ out_ch=1,
+ base=config.discriminator_base_channels,
+ extensive_layers=config.discriminator_extensive_layers,
+ )
def create_sr(config: SRModelConfig):
predictor = create_predictor_sr(config)
- discriminator = SRDiscriminator(in_ch=1, out_ch=1)
+ discriminator = create_discriminator_sr(config)
return predictor, discriminator