diff options
| author | tingchunw <tingchunw@nvidia.com> | 2017-12-09 01:02:39 +0000 |
|---|---|---|
| committer | tingchunw <tingchunw@nvidia.com> | 2017-12-09 01:02:39 +0000 |
| commit | db4a24df4a8482089bcb88d0e79c9b484307fc75 (patch) | |
| tree | 50f05acb700ff3b71e94d937ccd1b57da165cea0 /data | |
| parent | 99d031b469478434ea185e1da07f12b7b007c6b6 (diff) | |
add explanation for training with new dataset
Diffstat (limited to 'data')
| -rwxr-xr-x | data/aligned_dataset.py | 12 |
1 files changed, 8 insertions, 4 deletions
diff --git a/data/aligned_dataset.py b/data/aligned_dataset.py index a0c9a0a..a3cdc76 100755 --- a/data/aligned_dataset.py +++ b/data/aligned_dataset.py @@ -33,12 +33,16 @@ class AlignedDataset(BaseDataset): self.dataset_size = len(self.label_paths) def __getitem__(self, index): - ### label maps + ### label maps label_path = self.label_paths[index] label = Image.open(label_path) - params = get_params(self.opt, label.size) - transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False) - label_tensor = transform_label(label) * 255.0 + params = get_params(self.opt, label.size) + if self.opt.label_nc == 0: + transform_label = get_transform(self.opt, params) + label_tensor = transform_label(label.convert('RGB')) + else: + transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False) + label_tensor = transform_label(label) * 255.0 image_tensor = inst_tensor = feat_tensor = 0 ### real images |
