summaryrefslogtreecommitdiff
path: root/util/image_pool.py
diff options
context:
space:
mode:
Diffstat (limited to 'util/image_pool.py')
-rw-r--r--util/image_pool.py6
1 files changed, 4 insertions, 2 deletions
diff --git a/util/image_pool.py b/util/image_pool.py
index 152ef5b..ada1627 100644
--- a/util/image_pool.py
+++ b/util/image_pool.py
@@ -2,6 +2,8 @@ import random
import numpy as np
import torch
from torch.autograd import Variable
+
+
class ImagePool():
def __init__(self, pool_size):
self.pool_size = pool_size
@@ -11,9 +13,9 @@ class ImagePool():
def query(self, images):
if self.pool_size == 0:
- return images
+ return Variable(images)
return_images = []
- for image in images.data:
+ for image in images:
image = torch.unsqueeze(image, 0)
if self.num_imgs < self.pool_size:
self.num_imgs = self.num_imgs + 1