diff options
Diffstat (limited to 'util/image_pool.py')
| -rw-r--r-- | util/image_pool.py | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/util/image_pool.py b/util/image_pool.py index 152ef5b..ada1627 100644 --- a/util/image_pool.py +++ b/util/image_pool.py @@ -2,6 +2,8 @@ import random import numpy as np import torch from torch.autograd import Variable + + class ImagePool(): def __init__(self, pool_size): self.pool_size = pool_size @@ -11,9 +13,9 @@ class ImagePool(): def query(self, images): if self.pool_size == 0: - return images + return Variable(images) return_images = [] - for image in images.data: + for image in images: image = torch.unsqueeze(image, 0) if self.num_imgs < self.pool_size: self.num_imgs = self.num_imgs + 1 |
