summaryrefslogtreecommitdiff
path: root/become_yukarin/model
diff options
context:
space:
mode:
Diffstat (limited to 'become_yukarin/model')
-rw-r--r--become_yukarin/model/sr_model.py6
1 files changed, 2 insertions, 4 deletions
diff --git a/become_yukarin/model/sr_model.py b/become_yukarin/model/sr_model.py
index 74119a4..64158ca 100644
--- a/become_yukarin/model/sr_model.py
+++ b/become_yukarin/model/sr_model.py
@@ -47,7 +47,6 @@ class Encoder(chainer.Chain):
self.c7 = CBR(512, 512, bn=True, sample='down', activation=F.leaky_relu, dropout=False)
def __call__(self, x):
- x = F.reshape(x, (len(x), 1) + x.shape[1:])
hs = [F.leaky_relu(self.c0(x))]
for i in range(1, 8):
hs.append(self['c%d' % i](hs[i - 1]))
@@ -103,7 +102,6 @@ class SRDiscriminator(chainer.Chain):
self.c4 = L.Convolution2D(512, 1, 3, 1, 1, initialW=w)
def __call__(self, x_0, x_1):
- x_0 = F.reshape(x_0, (len(x_0), 1) + x_0.shape[1:])
h = F.concat([self.c0_0(x_0), self.c0_1(x_1)])
h = self.c1(h)
h = self.c2(h)
@@ -114,6 +112,6 @@ class SRDiscriminator(chainer.Chain):
def create_sr(config: SRModelConfig):
- predictor = SRPredictor(in_ch=1, out_ch=3)
- discriminator = SRDiscriminator(in_ch=1, out_ch=3)
+ predictor = SRPredictor(in_ch=1, out_ch=1)
+ discriminator = SRDiscriminator(in_ch=1, out_ch=1)
return predictor, discriminator