summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--become_yukarin/config/config.py6
-rw-r--r--become_yukarin/config/sr_config.py6
-rw-r--r--become_yukarin/model/sr_model.py4
3 files changed, 14 insertions, 2 deletions
diff --git a/become_yukarin/config/config.py b/become_yukarin/config/config.py
index 68ba1bd..f1f24cf 100644
--- a/become_yukarin/config/config.py
+++ b/become_yukarin/config/config.py
@@ -139,3 +139,9 @@ def backward_compatible(d: Dict):
if 'target_global_noise' not in d['dataset']:
d['dataset']['target_global_noise'] = d['dataset']['global_noise']
d['dataset']['target_local_noise'] = d['dataset']['local_noise']
+
+ if 'generator_base_channels' not in d['model']:
+ d['model']['generator_base_channels'] = 64
+ d['model']['generator_extensive_layers'] = 8
+ d['model']['discriminator_base_channels'] = 32
+ d['model']['discriminator_extensive_layers'] = 5
diff --git a/become_yukarin/config/sr_config.py b/become_yukarin/config/sr_config.py
index 75cf6ff..b9a0ef2 100644
--- a/become_yukarin/config/sr_config.py
+++ b/become_yukarin/config/sr_config.py
@@ -113,3 +113,9 @@ def create_from_json(s: Union[str, Path]):
def backward_compatible(d: Dict):
if 'blur_size_factor' not in d['dataset']:
d['dataset']['blur_size_factor'] = 0
+
+ if 'generator_base_channels' not in d['model']:
+ d['model']['generator_base_channels'] = 64
+ d['model']['generator_extensive_layers'] = 8
+ d['model']['discriminator_base_channels'] = 32
+ d['model']['discriminator_extensive_layers'] = 5
diff --git a/become_yukarin/model/sr_model.py b/become_yukarin/model/sr_model.py
index 12863a7..28ba0c4 100644
--- a/become_yukarin/model/sr_model.py
+++ b/become_yukarin/model/sr_model.py
@@ -94,8 +94,8 @@ class SRPredictor(chainer.Chain):
def __init__(self, in_ch, out_ch, base, extensive_layers) -> None:
super().__init__()
with self.init_scope():
- self.encoder = Encoder(in_ch, base=base, extensive_layers=extensive_layers)
- self.decoder = Decoder(out_ch, base=base, extensive_layers=extensive_layers)
+ self.encoder = SREncoder(in_ch, base=base, extensive_layers=extensive_layers)
+ self.decoder = SRDecoder(out_ch, base=base, extensive_layers=extensive_layers)
def __call__(self, x):
return self.decoder(self.encoder(x))