diff options
Diffstat (limited to 'become_yukarin/model')
| -rw-r--r-- | become_yukarin/model/sr_model.py | 6 |
1 files changed, 2 insertions, 4 deletions
diff --git a/become_yukarin/model/sr_model.py b/become_yukarin/model/sr_model.py index 74119a4..64158ca 100644 --- a/become_yukarin/model/sr_model.py +++ b/become_yukarin/model/sr_model.py @@ -47,7 +47,6 @@ class Encoder(chainer.Chain): self.c7 = CBR(512, 512, bn=True, sample='down', activation=F.leaky_relu, dropout=False) def __call__(self, x): - x = F.reshape(x, (len(x), 1) + x.shape[1:]) hs = [F.leaky_relu(self.c0(x))] for i in range(1, 8): hs.append(self['c%d' % i](hs[i - 1])) @@ -103,7 +102,6 @@ class SRDiscriminator(chainer.Chain): self.c4 = L.Convolution2D(512, 1, 3, 1, 1, initialW=w) def __call__(self, x_0, x_1): - x_0 = F.reshape(x_0, (len(x_0), 1) + x_0.shape[1:]) h = F.concat([self.c0_0(x_0), self.c0_1(x_1)]) h = self.c1(h) h = self.c2(h) @@ -114,6 +112,6 @@ class SRDiscriminator(chainer.Chain): def create_sr(config: SRModelConfig): - predictor = SRPredictor(in_ch=1, out_ch=3) - discriminator = SRDiscriminator(in_ch=1, out_ch=3) + predictor = SRPredictor(in_ch=1, out_ch=1) + discriminator = SRDiscriminator(in_ch=1, out_ch=1) return predictor, discriminator |
