summaryrefslogtreecommitdiff
path: root/util/image_pool.py
diff options
context:
space:
mode:
authorTaesung Park <taesung_park@berkeley.edu>2017-12-10 23:04:41 -0800
committerTaesung Park <taesung_park@berkeley.edu>2017-12-10 23:04:41 -0800
commitf33f098be9b25c3b62523540c9c703af1db0b1c0 (patch)
tree9b51e547067b46ad8b55ddb34b207825550df867 /util/image_pool.py
parent3d2c534933b356dc313a620639a713cb940dc756 (diff)
parent2d96edbee5a488a7861833731a2cb71b23b55727 (diff)
merged conflicts
Diffstat (limited to 'util/image_pool.py')
-rw-r--r--util/image_pool.py6
1 files changed, 4 insertions, 2 deletions
diff --git a/util/image_pool.py b/util/image_pool.py
index 152ef5b..ada1627 100644
--- a/util/image_pool.py
+++ b/util/image_pool.py
@@ -2,6 +2,8 @@ import random
import numpy as np
import torch
from torch.autograd import Variable
+
+
class ImagePool():
def __init__(self, pool_size):
self.pool_size = pool_size
@@ -11,9 +13,9 @@ class ImagePool():
def query(self, images):
if self.pool_size == 0:
- return images
+ return Variable(images)
return_images = []
- for image in images.data:
+ for image in images:
image = torch.unsqueeze(image, 0)
if self.num_imgs < self.pool_size:
self.num_imgs = self.num_imgs + 1