summaryrefslogtreecommitdiff
path: root/Code/tfutils_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'Code/tfutils_test.py')
-rw-r--r--Code/tfutils_test.py102
1 files changed, 102 insertions, 0 deletions
diff --git a/Code/tfutils_test.py b/Code/tfutils_test.py
new file mode 100644
index 0000000..4e2b490
--- /dev/null
+++ b/Code/tfutils_test.py
@@ -0,0 +1,102 @@
+from tfutils import *
+
+imgs = tf.constant(np.ones([2, 2, 2, 3]))
+sess = tf.Session()
+
+
+# noinspection PyClassHasNoInit,PyMethodMayBeStatic
+class TestPad:
+ def test_rb(self):
+ res = sess.run(batch_pad_to_bounding_box(imgs, 0, 0, 4, 4))
+ assert np.array_equal(res, np.array([[[[1, 1, 1],
+ [1, 1, 1],
+ [0, 0, 0],
+ [0, 0, 0]],
+ [[1, 1, 1],
+ [1, 1, 1],
+ [0, 0, 0],
+ [0, 0, 0]],
+ [[0, 0, 0],
+ [0, 0, 0],
+ [0, 0, 0],
+ [0, 0, 0]],
+ [[0, 0, 0],
+ [0, 0, 0],
+ [0, 0, 0],
+ [0, 0, 0]]
+ ],
+ [[[1, 1, 1],
+ [1, 1, 1],
+ [0, 0, 0],
+ [0, 0, 0]],
+ [[1, 1, 1],
+ [1, 1, 1],
+ [0, 0, 0],
+ [0, 0, 0]],
+ [[0, 0, 0],
+ [0, 0, 0],
+ [0, 0, 0],
+ [0, 0, 0]],
+ [[0, 0, 0],
+ [0, 0, 0],
+ [0, 0, 0],
+ [0, 0, 0]]
+ ]], dtype=float))
+
+ def test_center(self):
+ res = sess.run(batch_pad_to_bounding_box(imgs, 1, 1, 4, 4))
+ assert np.array_equal(res, np.array([[[[0, 0, 0],
+ [0, 0, 0],
+ [0, 0, 0],
+ [0, 0, 0]],
+ [[0, 0, 0],
+ [1, 1, 1],
+ [1, 1, 1],
+ [0, 0, 0]],
+ [[0, 0, 0],
+ [1, 1, 1],
+ [1, 1, 1],
+ [0, 0, 0]],
+ [[0, 0, 0],
+ [0, 0, 0],
+ [0, 0, 0],
+ [0, 0, 0]]
+ ],
+ [[[0, 0, 0],
+ [0, 0, 0],
+ [0, 0, 0],
+ [0, 0, 0]],
+ [[0, 0, 0],
+ [1, 1, 1],
+ [1, 1, 1],
+ [0, 0, 0]],
+ [[0, 0, 0],
+ [1, 1, 1],
+ [1, 1, 1],
+ [0, 0, 0]],
+ [[0, 0, 0],
+ [0, 0, 0],
+ [0, 0, 0],
+ [0, 0, 0]]
+ ]], dtype=float))
+
+
+padded = batch_pad_to_bounding_box(imgs, 1, 1, 4, 4)
+
+
+# noinspection PyClassHasNoInit
+class TestCrop:
+ def test_rb(self):
+ res = sess.run(batch_crop_to_bounding_box(padded, 0, 0, 2, 2))
+ assert np.array_equal(res, np.array([[[[0, 0, 0],
+ [0, 0, 0]],
+ [[0, 0, 0],
+ [1, 1, 1]]],
+ [[[0, 0, 0],
+ [0, 0, 0]],
+ [[0, 0, 0],
+ [1, 1, 1]]]]))
+
+ def test_center(self):
+ res = sess.run(batch_crop_to_bounding_box(padded, 1, 1, 2, 2))
+ assert np.array_equal(res, np.ones([2, 2, 2, 3]))