summaryrefslogtreecommitdiff
path: root/util/image_pool.py
diff options
context:
space:
mode:
Diffstat (limited to 'util/image_pool.py')
-rw-r--r--util/image_pool.py33
1 files changed, 33 insertions, 0 deletions
diff --git a/util/image_pool.py b/util/image_pool.py
new file mode 100644
index 0000000..b59e185
--- /dev/null
+++ b/util/image_pool.py
@@ -0,0 +1,33 @@
+import random
+import numpy as np
+import torch
+from pdb import set_trace as st
+from torch.autograd import Variable
+class ImagePool():
+ def __init__(self, pool_size):
+ self.pool_size = pool_size
+ if self.pool_size > 0:
+ self.num_imgs = 0
+ self.images = []
+
+ def query(self, images):
+ if self.pool_size == 0:
+ return images
+ return_images = []
+ for image in images.data:
+ image = torch.unsqueeze(image, 0)
+ if self.num_imgs < self.pool_size:
+ self.num_imgs = self.num_imgs + 1
+ self.images.append(image)
+ return_images.append(image)
+ else:
+ p = random.uniform(0, 1)
+ if p > 0.5:
+ random_id = random.randint(0, self.pool_size-1)
+ tmp = self.images[random_id].clone()
+ self.images[random_id] = image
+ return_images.append(tmp)
+ else:
+ return_images.append(image)
+ return_images = Variable(torch.cat(return_images, 0))
+ return return_images