summaryrefslogtreecommitdiff
path: root/Codes/flownet2/src/ops/preprocessing
diff options
context:
space:
mode:
authorStevenLiuWen <liuwen@shanghaitech.edu.cn>2018-03-13 03:28:06 -0400
committerStevenLiuWen <liuwen@shanghaitech.edu.cn>2018-03-13 03:28:06 -0400
commitfede6ca1dd0077ff509d84bd24028cc7a93bb119 (patch)
treeaf7f6e759b5dec4fc2964daed09e903958b919ed /Codes/flownet2/src/ops/preprocessing
first commit
Diffstat (limited to 'Codes/flownet2/src/ops/preprocessing')
-rw-r--r--Codes/flownet2/src/ops/preprocessing/kernels/augmentation_base.cc420
-rw-r--r--Codes/flownet2/src/ops/preprocessing/kernels/augmentation_base.h228
-rw-r--r--Codes/flownet2/src/ops/preprocessing/kernels/data_augmentation.cc461
-rw-r--r--Codes/flownet2/src/ops/preprocessing/kernels/data_augmentation.cu.cc348
-rw-r--r--Codes/flownet2/src/ops/preprocessing/kernels/data_augmentation.h22
-rw-r--r--Codes/flownet2/src/ops/preprocessing/kernels/flow_augmentation.cc129
-rw-r--r--Codes/flownet2/src/ops/preprocessing/kernels/flow_augmentation.h19
-rw-r--r--Codes/flownet2/src/ops/preprocessing/kernels/flow_augmentation_gpu.cu.cc95
-rw-r--r--Codes/flownet2/src/ops/preprocessing/preprocessing.cc96
9 files changed, 1818 insertions, 0 deletions
diff --git a/Codes/flownet2/src/ops/preprocessing/kernels/augmentation_base.cc b/Codes/flownet2/src/ops/preprocessing/kernels/augmentation_base.cc
new file mode 100644
index 0000000..b93dfa6
--- /dev/null
+++ b/Codes/flownet2/src/ops/preprocessing/kernels/augmentation_base.cc
@@ -0,0 +1,420 @@
+#include "augmentation_base.h"
+
+#include <math.h>
+#include <random>
+
+namespace tensorflow {
+/** TransMat Functions **/
+void AugmentationLayerBase::TransMat::fromCoeff(AugmentationCoeff *coeff,
+ int out_width,
+ int out_height,
+ int src_width,
+ int src_height) {
+ leftMultiply(1, 0, -0.5 * out_width,
+ 0, 1, -0.5 * out_height);
+
+ if (coeff->angle) {
+ leftMultiply(cos(coeff->angle()), -sin(coeff->angle()), 0,
+ sin(coeff->angle()), cos(coeff->angle()), 0);
+ }
+
+ if (coeff->dx || coeff->dy) {
+ leftMultiply(1, 0, coeff->dx() * out_width,
+ 0, 1, coeff->dy() * out_height);
+ }
+
+ if (coeff->zoom_x || coeff->zoom_y) {
+ leftMultiply(1.0 / coeff->zoom_x(), 0, 0,
+ 0, 1.0 / coeff->zoom_y(), 0);
+ }
+
+ leftMultiply(1, 0, 0.5 * src_width,
+ 0, 1, 0.5 * src_height);
+}
+
+void AugmentationLayerBase::TransMat::fromTensor(const float *tensor_data) {
+ t0 = tensor_data[0];
+ t1 = tensor_data[1];
+ t2 = tensor_data[2];
+ t3 = tensor_data[3];
+ t4 = tensor_data[4];
+ t5 = tensor_data[5];
+}
+
+AugmentationLayerBase::TransMat AugmentationLayerBase::TransMat::inverse() {
+ float a = this->t0, b = this->t1, c = this->t2;
+ float d = this->t3, e = this->t4, f = this->t5;
+
+ float denom = a * e - b * d;
+
+ TransMat result;
+
+ result.t0 = e / denom;
+ result.t1 = b / -denom;
+ result.t2 = (c * e - b * f) / -denom;
+ result.t3 = d / -denom;
+ result.t4 = a / denom;
+ result.t5 = (c * d - a * f) / denom;
+
+ return result;
+}
+
+void AugmentationLayerBase::TransMat::leftMultiply(float u0,
+ float u1,
+ float u2,
+ float u3,
+ float u4,
+ float u5) {
+ float t0 = this->t0, t1 = this->t1, t2 = this->t2;
+ float t3 = this->t3, t4 = this->t4, t5 = this->t5;
+
+ this->t0 = t0 * u0 + t3 * u1;
+ this->t1 = t1 * u0 + t4 * u1;
+ this->t2 = t2 * u0 + t5 * u1 + u2;
+ this->t3 = t0 * u3 + t3 * u4;
+ this->t4 = t1 * u3 + t4 * u4;
+ this->t5 = t2 * u3 + t5 * u4 + u5;
+}
+
+void AugmentationLayerBase::TransMat::toIdentity() {
+ t0 = 1; t1 = 0; t2 = 0;
+ t3 = 0; t4 = 1; t5 = 0;
+}
+
+/** AugmentationCoeff Functions **/
+void AugmentationCoeff::clear() {
+ // Spatial variables
+ dx.clear();
+ dy.clear();
+ angle.clear();
+ zoom_x.clear();
+ zoom_y.clear();
+
+ // Chromatic variables
+ gamma.clear();
+ brightness.clear();
+ contrast.clear();
+ color1.clear();
+ color2.clear();
+ color3.clear();
+}
+
+void AugmentationCoeff::combine_with(const AugmentationCoeff& coeff) {
+ // Spatial types
+ if (coeff.dx) {
+ dx = dx() * coeff.dx();
+ }
+
+ if (coeff.dy) {
+ dy = dy() * coeff.dy();
+ }
+
+ if (coeff.angle) {
+ angle = angle() * coeff.angle();
+ }
+
+ if (coeff.zoom_x) {
+ zoom_x = zoom_x() * coeff.zoom_x();
+ }
+
+ if (coeff.zoom_y) {
+ zoom_y = zoom_y() * coeff.zoom_y();
+ }
+
+ // Chromatic types
+ if (coeff.gamma) {
+ gamma = gamma() * coeff.gamma();
+ }
+
+ if (coeff.brightness) {
+ brightness = brightness() * coeff.brightness();
+ }
+
+ if (coeff.contrast) {
+ contrast = contrast() * coeff.contrast();
+ }
+
+ if (coeff.color1) {
+ color1 = color1() * coeff.color1();
+ }
+
+ if (coeff.color2) {
+ color2 = color2() * coeff.color2();
+ }
+
+ if (coeff.color3) {
+ color3 = color3() * coeff.color3();
+ }
+}
+
+void AugmentationCoeff::replace_with(const AugmentationCoeff& coeff) {
+ // Spatial types
+ if (coeff.dx) {
+ dx = coeff.dx();
+ }
+
+ if (coeff.dy) {
+ dy = coeff.dy();
+ }
+
+ if (coeff.angle) {
+ angle = coeff.angle();
+ }
+
+ if (coeff.zoom_x) {
+ zoom_x = coeff.zoom_x();
+ }
+
+ if (coeff.zoom_y) {
+ zoom_y = coeff.zoom_y();
+ }
+
+ // Chromatic types
+ if (coeff.gamma) {
+ gamma = gamma() * coeff.gamma();
+ }
+
+ if (coeff.brightness) {
+ brightness = coeff.brightness();
+ }
+
+ if (coeff.contrast) {
+ contrast = coeff.contrast();
+ }
+
+ if (coeff.color1) {
+ color1 = coeff.color1();
+ }
+
+ if (coeff.color2) {
+ color2 = coeff.color2();
+ }
+
+ if (coeff.color3) {
+ color3 = coeff.color3();
+ }
+}
+
+/** AugmentationLayerBase Functions **/
+float AugmentationLayerBase::rng_generate(const AugmentationParam& param,
+ float discount_coeff,
+ const float default_value) {
+ std::random_device rd; // Will be used to obtain a seed for the random number
+ // engine
+ std::mt19937 gen(rd()); // Standard mersenne_twister_engine seeded with rd()
+
+ float spread = param.spread * discount_coeff;
+
+ if (param.rand_type == "uniform_bernoulli") {
+ float tmp1 = 0.0;
+ bool tmp2 = false;
+
+ if (param.prob > 0.0) {
+ std::bernoulli_distribution bernoulli(param.prob);
+ tmp2 = bernoulli(gen);
+ }
+
+ if (!tmp2) {
+ return default_value;
+ }
+
+ if (param.spread > 0.0) {
+ std::uniform_real_distribution<> uniform(param.mean - spread,
+ param.mean + spread);
+ tmp1 = uniform(gen);
+ } else {
+ tmp1 = param.mean;
+ }
+
+ if (param.should_exp) {
+ tmp1 = exp(tmp1);
+ }
+
+ return tmp1;
+ } else if (param.rand_type == "gaussian_bernoulli") {
+ float tmp1 = 0.0;
+ bool tmp2 = false;
+
+ if (param.prob > 0.0) {
+ std::bernoulli_distribution bernoulli(param.prob);
+ tmp2 = bernoulli(gen);
+ }
+
+ if (!tmp2) {
+ return default_value;
+ }
+
+ if (spread > 0.0) {
+ std::normal_distribution<> normal(param.mean, spread);
+ tmp1 = normal(gen);
+ } else {
+ tmp1 = param.mean;
+ }
+
+ if (param.should_exp) {
+ tmp1 = exp(tmp1);
+ }
+
+ return tmp1;
+ } else {
+ throw "Unknown random type: " + param.rand_type;
+ }
+}
+
+void AugmentationLayerBase::generate_chromatic_coeffs(float discount_coeff,
+ const AugmentationParams& aug,
+ AugmentationCoeff & coeff) {
+ if (aug.gamma) {
+ coeff.gamma = rng_generate(aug.gamma(), discount_coeff, coeff.gamma.get_default());
+ }
+
+ if (aug.brightness) {
+ coeff.brightness =
+ rng_generate(aug.brightness(), discount_coeff, coeff.brightness.get_default());
+ }
+
+ if (aug.contrast) {
+ coeff.contrast = rng_generate(aug.contrast(), discount_coeff, coeff.contrast.get_default());
+ }
+
+ if (aug.color) {
+ coeff.color1 = rng_generate(aug.color(), discount_coeff, coeff.color1.get_default());
+ coeff.color2 = rng_generate(aug.color(), discount_coeff, coeff.color2.get_default());
+ coeff.color3 = rng_generate(aug.color(), discount_coeff, coeff.color3.get_default());
+ }
+}
+
+void AugmentationLayerBase::generate_spatial_coeffs(float discount_coeff,
+ const AugmentationParams& aug,
+ AugmentationCoeff & coeff) {
+ if (aug.translate) {
+ coeff.dx = rng_generate(aug.translate(), discount_coeff, coeff.dx.get_default());
+ coeff.dy = rng_generate(aug.translate(), discount_coeff, coeff.dy.get_default());
+ }
+
+ if (aug.rotate) {
+ coeff.angle = rng_generate(aug.rotate(), discount_coeff, coeff.angle.get_default());
+ }
+
+ if (aug.zoom) {
+ coeff.zoom_x = rng_generate(aug.zoom(), discount_coeff, coeff.zoom_x.get_default());
+ coeff.zoom_y = coeff.zoom_x();
+ }
+
+ if (aug.squeeze) {
+ float squeeze_coeff = rng_generate(aug.squeeze(), discount_coeff, 1.0);
+ coeff.zoom_x = coeff.zoom_x() * squeeze_coeff;
+ coeff.zoom_y = coeff.zoom_y() * squeeze_coeff;
+ }
+}
+
+void AugmentationLayerBase::generate_valid_spatial_coeffs(
+ float discount_coeff,
+ const AugmentationParams& aug,
+ AugmentationCoeff & coeff,
+ int src_width,
+ int src_height,
+ int out_width,
+ int out_height) {
+ int x, y;
+ float x1, y1, x2, y2;
+ int counter = 0;
+ int good_params = 0;
+ AugmentationCoeff incoming_coeff(coeff);
+
+ while (good_params < 4 && counter < 50) {
+ coeff.clear();
+ AugmentationLayerBase::generate_spatial_coeffs(discount_coeff, aug, coeff);
+ coeff.combine_with(incoming_coeff);
+
+ // Check if all 4 corners of the transformed image fit into the original
+ // image
+ good_params = 0;
+
+ for (x = 0; x < out_width; x += out_width - 1) {
+ for (y = 0; y < out_height; y += out_height - 1) {
+ // move the origin
+ x1 = x - 0.5 * out_width;
+ y1 = y - 0.5 * out_height;
+
+ // rotate
+ x2 = cos(coeff.angle()) * x1 - sin(coeff.angle()) * y1;
+ y2 = sin(coeff.angle()) * x1 + sin(coeff.angle()) * y1;
+
+ // translate
+ x2 = x2 + coeff.dx() * out_width;
+ y2 = y2 + coeff.dy() * out_height;
+
+ // zoom
+ x2 = x2 / coeff.zoom_x();
+ y2 = y2 / coeff.zoom_y();
+
+ // move the origin back
+ x2 = x2 + 0.5 * src_width;
+ y2 = y2 + 0.5 * src_height;
+
+ if (!((floor(x2) < 0) || (floor(x2) > src_width - 2.0) ||
+ (floor(y2) < 0) || (floor(y2) > src_height - 2.0))) {
+ good_params++;
+ }
+ }
+ }
+ counter++;
+ }
+
+ if (counter >= 50) {
+ printf("Warning: No suitable spatial transformation after %d attempts.\n", counter);
+ coeff.clear();
+ coeff.replace_with(incoming_coeff);
+ }
+}
+
+void AugmentationLayerBase::copy_chromatic_coeffs_to_tensor(
+ const std::vector<AugmentationCoeff>& coeff_arr,
+ typename TTypes<float, 2>::Tensor& out)
+{
+ float *out_ptr = out.data();
+ int counter = 0;
+
+ for (AugmentationCoeff coeff : coeff_arr) {
+ out_ptr[counter + 0] = coeff.gamma();
+ out_ptr[counter + 1] = coeff.brightness();
+ out_ptr[counter + 2] = coeff.contrast();
+ out_ptr[counter + 3] = coeff.color1();
+ out_ptr[counter + 4] = coeff.color2();
+ out_ptr[counter + 5] = coeff.color3();
+ counter += 6;
+ }
+}
+
+void AugmentationLayerBase::copy_spatial_coeffs_to_tensor(
+ const std::vector<AugmentationCoeff>& coeff_arr,
+ const int out_width,
+ const int out_height,
+ const int src_width,
+ const int src_height,
+ typename TTypes<float, 2>::Tensor& out,
+ const bool invert)
+{
+ float *out_ptr = out.data();
+ int counter = 0;
+ TransMat t;
+
+ for (AugmentationCoeff coeff : coeff_arr) {
+ t.toIdentity();
+ t.fromCoeff(&coeff, out_width, out_height, src_width, src_height);
+
+ if (invert) {
+ t = t.inverse();
+ }
+
+ out_ptr[counter + 0] = t.t0;
+ out_ptr[counter + 1] = t.t1;
+ out_ptr[counter + 2] = t.t2;
+ out_ptr[counter + 3] = t.t3;
+ out_ptr[counter + 4] = t.t4;
+ out_ptr[counter + 5] = t.t5;
+ counter += 6;
+ }
+}
+}
diff --git a/Codes/flownet2/src/ops/preprocessing/kernels/augmentation_base.h b/Codes/flownet2/src/ops/preprocessing/kernels/augmentation_base.h
new file mode 100644
index 0000000..d2aba2c
--- /dev/null
+++ b/Codes/flownet2/src/ops/preprocessing/kernels/augmentation_base.h
@@ -0,0 +1,228 @@
+#ifndef AUGMENTATION_LAYER_BASE_H_
+#define AUGMENTATION_LAYER_BASE_H_
+
+#include "tensorflow/core/framework/tensor_types.h"
+
+#include <iostream>
+#include <string>
+#include <vector>
+
+namespace tensorflow {
+template<typename T>
+class OptionalType {
+ public:
+ OptionalType(const T default_value) : default_value(default_value), has_value(false) {}
+
+ operator bool() const {
+ return has_value;
+ }
+
+ OptionalType& operator=(T val) {
+ has_value = true;
+ value = val;
+ return *this;
+ }
+
+ const T operator()() const {
+ return has_value ? value : default_value;
+ }
+
+ void clear() {
+ has_value = false;
+ }
+
+ const T get_default() {
+ return default_value;
+ }
+
+ private:
+ T value;
+ bool has_value;
+ const T default_value;
+};
+
+class AugmentationCoeff {
+ public:
+ // Spatial Types
+ OptionalType<float>dx;
+ OptionalType<float>dy;
+ OptionalType<float>angle;
+ OptionalType<float>zoom_x;
+ OptionalType<float>zoom_y;
+
+ // Chromatic Types
+ OptionalType<float>gamma;
+ OptionalType<float>brightness;
+ OptionalType<float>contrast;
+ OptionalType<float>color1;
+ OptionalType<float>color2;
+ OptionalType<float>color3;
+
+ AugmentationCoeff() : dx(0.0), dy(0.0), angle(0.0), zoom_x(1.0), zoom_y(1.0), gamma(1.0),
+ brightness(0.0), contrast(1.0), color1(1.0), color2(1.0), color3(1.0) {}
+
+ AugmentationCoeff(const AugmentationCoeff& coeff) : AugmentationCoeff() {
+ replace_with(coeff);
+ }
+
+ void clear();
+
+ void combine_with(const AugmentationCoeff& coeff);
+
+ void replace_with(const AugmentationCoeff& coeff);
+};
+
+typedef struct AugmentationParam {
+ std::string rand_type;
+ bool should_exp;
+ float mean;
+ float spread;
+ float prob;
+} AugmentationParam;
+
+class AugmentationParams {
+ public:
+ int crop_height;
+ int crop_width;
+
+ // Spatial options
+ OptionalType<struct AugmentationParam>translate;
+ OptionalType<struct AugmentationParam>rotate;
+ OptionalType<struct AugmentationParam>zoom;
+ OptionalType<struct AugmentationParam>squeeze;
+
+ // Chromatic options
+ OptionalType<struct AugmentationParam>gamma;
+ OptionalType<struct AugmentationParam>brightness;
+ OptionalType<struct AugmentationParam>contrast;
+ OptionalType<struct AugmentationParam>color;
+
+ inline AugmentationParams(int crop_height,
+ int crop_width,
+ std::vector<std::string>params_name,
+ std::vector<std::string>params_rand_type,
+ std::vector<bool> params_exp,
+ std::vector<float> params_mean,
+ std::vector<float> params_spread,
+ std::vector<float> params_prob) :
+ crop_height(crop_height),
+ crop_width(crop_width),
+ translate(AugmentationParam()),
+ rotate(AugmentationParam()),
+ zoom(AugmentationParam()),
+ squeeze(AugmentationParam()),
+ gamma(AugmentationParam()),
+ brightness(AugmentationParam()),
+ contrast(AugmentationParam()),
+ color(AugmentationParam()) {
+ for (int i = 0; i < params_name.size(); i++) {
+ const std::string name = params_name[i];
+ const std::string rand_type = params_rand_type[i];
+ const bool should_exp = params_exp[i];
+ const float mean = params_mean[i];
+ const float spread = params_spread[i];
+ const float prob = params_prob[i];
+
+ struct AugmentationParam param = { rand_type, should_exp, mean, spread, prob };
+
+ if (name == "translate") {
+ this->translate = param;
+ } else if (name == "rotate") {
+ this->rotate = param;
+ } else if (name == "zoom") {
+ this->zoom = param;
+ } else if (name == "squeeze") {
+ this->squeeze = param;
+ } else if (name == "noise") {
+ // NoOp: We handle noise on the Python side
+ } else if (name == "gamma") {
+ this->gamma = param;
+ } else if (name == "brightness") {
+ this->brightness = param;
+ } else if (name == "contrast") {
+ this->contrast = param;
+ } else if (name == "color") {
+ this->color = param;
+ } else {
+ std::cout << "Ignoring unknown augmentation parameter: " << name << std::endl;
+ }
+ }
+ }
+
+ bool should_do_spatial_transform() {
+ return this->translate || this->rotate || this->zoom || this->squeeze;
+ }
+
+ bool should_do_chromatic_transform() {
+ return this->gamma || this->brightness || this->contrast || this->color;
+ }
+};
+
+class AugmentationLayerBase {
+ public:
+ class TransMat {
+ /**
+ * Translation matrix class for spatial augmentation
+ * | 0 1 2 |
+ * | 3 4 5 |
+ */
+
+ public:
+ float t0, t1, t2;
+ float t3, t4, t5;
+
+
+ void fromCoeff(AugmentationCoeff *coeff,
+ int out_width,
+ int out_height,
+ int src_width,
+ int src_height);
+
+ void fromTensor(const float *tensor_data);
+
+ TransMat inverse();
+
+ void leftMultiply(float u0,
+ float u1,
+ float u2,
+ float u3,
+ float u4,
+ float u5);
+
+ void toIdentity();
+ };
+
+ // TODO: Class ChromaticCoeffs
+
+ static float rng_generate(const AugmentationParam& param,
+ float discount_coeff,
+ const float default_value);
+
+ static void clear_spatial_coeffs(AugmentationCoeff& coeff);
+ static void generate_chromatic_coeffs(float discount_coeff,
+ const AugmentationParams& aug,
+ AugmentationCoeff & coeff);
+ static void generate_spatial_coeffs(float discount_coeff,
+ const AugmentationParams& aug,
+ AugmentationCoeff & coeff);
+ static void generate_valid_spatial_coeffs(float discount_coeff,
+ const AugmentationParams& aug,
+ AugmentationCoeff & coeff,
+ int src_width,
+ int src_height,
+ int out_width,
+ int out_height);
+
+ static void copy_chromatic_coeffs_to_tensor(const std::vector<AugmentationCoeff>& coeff_arr,
+ typename TTypes<float, 2>::Tensor& out);
+ static void copy_spatial_coeffs_to_tensor(const std::vector<AugmentationCoeff>& coeff_arr,
+ const int out_width,
+ const int out_height,
+ const int src_width,
+ const int src_height,
+ typename TTypes<float, 2>::Tensor& out,
+ const bool invert = false);
+};
+} // namespace tensorflow
+
+#endif // AUGMENTATION_LAYER_BASE_H_
diff --git a/Codes/flownet2/src/ops/preprocessing/kernels/data_augmentation.cc b/Codes/flownet2/src/ops/preprocessing/kernels/data_augmentation.cc
new file mode 100644
index 0000000..77b8c83
--- /dev/null
+++ b/Codes/flownet2/src/ops/preprocessing/kernels/data_augmentation.cc
@@ -0,0 +1,461 @@
+#define EIGEN_USE_THREADS
+
+#include <algorithm>
+#include <iostream>
+#include <random>
+#include <vector>
+
+#include "augmentation_base.h"
+#include "data_augmentation.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/logging.h"
+
+#include "tensorflow/core/util/work_sharder.h"
+
+namespace tensorflow {
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+inline float clamp(float f, float a, float b) {
+ return fmaxf(a, fminf(f, b));
+}
+
+template<>
+void Augment(OpKernelContext *context,
+ const CPUDevice& d,
+ const int batch_size,
+ const int channels,
+ const int src_width,
+ const int src_height,
+ const int src_count,
+ const int out_width,
+ const int out_height,
+ const float *src_data,
+ float *out_data,
+ const float *transMats,
+ float *chromatic_coeffs) {
+ const int64 channel_count = batch_size * out_height * out_width;
+ const int kCostPerChannel = 10;
+ const DeviceBase::CpuWorkerThreads& worker_threads =
+ *context->device()->tensorflow_cpu_worker_threads();
+
+ Shard(worker_threads.num_threads,
+ worker_threads.workers,
+ channel_count,
+ kCostPerChannel,
+ [batch_size, channels, src_width,
+ src_height, src_count, out_width, out_height, src_data,
+ out_data, transMats, chromatic_coeffs](
+ int64 start_channel, int64 end_channel) {
+ // TF, NHWK: ((n * H + h) * W + w) * K + k at point (n, h, w, k)
+ for (int index = start_channel; index < end_channel; index++) {
+ int x = index % out_width;
+ int y = (index / out_width) % out_height;
+ int n = index / out_width / out_height;
+
+ const float *transMat = transMats + n * 6;
+
+ float gamma, brightness, contrast;
+
+ if (chromatic_coeffs) {
+ gamma = chromatic_coeffs[n * 6 + 0];
+ brightness = chromatic_coeffs[n * 6 + 1];
+ contrast = chromatic_coeffs[n * 6 + 2];
+ }
+
+ float xpos = x * transMat[0] + y * transMat[1] + transMat[2];
+ float ypos = x * transMat[3] + y * transMat[4] + transMat[5];
+
+ xpos = clamp(xpos, 0.0f, (float)(src_width) - 1.05f);
+ ypos = clamp(ypos, 0.0f, (float)(src_height) - 1.05f);
+
+ float tlx = floor(xpos);
+ float tly = floor(ypos);
+
+ float xdist = xpos - tlx;
+ float ydist = ypos - tly;
+
+ int srcTLIdxOffset = ((n * src_height + (int)tly) * src_width + (int)tlx) * channels;
+
+ // ((n * src_height + tly) * src_width + (tlx + 1)) * channels
+ int srcTRIdxOffset = srcTLIdxOffset + channels;
+
+ // ((n * src_height + (tly + 1)) * src_width + tlx) * channels
+ int srcBLIdxOffset = srcTLIdxOffset + channels * src_width;
+
+ // ((n * src_height + (tly + 1)) * src_width + (tlx + 1)) * channels
+ int srcBRIdxOffset = srcTLIdxOffset + channels + channels * src_width;
+
+ // Variables for chromatic transform
+ int data_index[3];
+ float rgb[3];
+ float mean_in = 0;
+ float mean_out = 0;
+
+ for (int c = 0; c < channels; c++) {
+ // Bilinear interpolation
+ int srcTLIdx = srcTLIdxOffset + c;
+ int srcTRIdx = std::min(srcTRIdxOffset + c, src_count);
+ int srcBLIdx = std::min(srcBLIdxOffset + c, src_count);
+ int srcBRIdx = std::min(srcBRIdxOffset + c, src_count);
+
+ float dest = (1 - xdist) * (1 - ydist) * src_data[srcTLIdx]
+ + (xdist) * (ydist) * src_data[srcBRIdx]
+ + (1 - xdist) * (ydist) * src_data[srcBLIdx]
+ + (xdist) * (1 - ydist) * src_data[srcTRIdx];
+
+ if (chromatic_coeffs) {
+ // Gather data for chromatic transform
+ data_index[c] = index * channels + c;
+ rgb[c] = dest;
+ mean_in += rgb[c];
+
+ // Note: coeff[3] == color1, coeff[4] == color2, ...
+ rgb[c] *= chromatic_coeffs[n * 6 + (3 + c)];
+
+ mean_out += rgb[c];
+ } else {
+ out_data[index * channels + c] = dest;
+ }
+ }
+
+ float brightness_coeff = mean_in / (mean_out + 0.01f);
+
+ if (chromatic_coeffs) {
+ // Chromatic transformation
+ for (int c = 0; c < channels; c++) {
+ // compensate brightness
+ rgb[c] = clamp(rgb[c] * brightness_coeff, 0.0f, 1.0f);
+
+ // gamma change
+ rgb[c] = pow(rgb[c], gamma);
+
+ // brightness change
+ rgb[c] = rgb[c] + brightness;
+
+ // contrast change
+ rgb[c] = 0.5f + (rgb[c] - 0.5f) * contrast;
+
+ out_data[data_index[c]] = clamp(rgb[c], 0.0f, 1.0f);
+ }
+ }
+ }
+ });
+}
+
+template<typename Device>
+class DataAugmentation : public OpKernel {
+ public:
+ explicit DataAugmentation(OpKernelConstruction *ctx) : OpKernel(ctx) {
+ // Get the crop [height, width] tensor and verify its dimensions
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("crop", &crop_));
+ OP_REQUIRES(ctx, crop_.size() == 2,
+ errors::InvalidArgument("crop must be 2 dimensions"));
+
+ // TODO: Verify params are all the same length
+
+ // Get the tensors for params_a and verify their dimensions
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("params_a_name", &params_a_name_));
+ OP_REQUIRES_OK(ctx,
+ ctx->GetAttr("params_a_rand_type", &params_a_rand_type_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("params_a_exp", &params_a_exp_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("params_a_mean", &params_a_mean_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("params_a_spread", &params_a_spread_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("params_a_prob", &params_a_prob_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("params_a_coeff_schedule", &params_a_coeff_schedule_));
+
+ // Get the tensors for params_b and verify their dimensions
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("params_b_name", &params_b_name_));
+ OP_REQUIRES_OK(ctx,
+ ctx->GetAttr("params_b_rand_type", &params_b_rand_type_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("params_b_exp", &params_b_exp_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("params_b_mean", &params_b_mean_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("params_b_spread", &params_b_spread_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("params_b_prob", &params_b_prob_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("params_b_coeff_schedule", &params_b_coeff_schedule_));
+ }
+
+ void Compute(OpKernelContext *ctx) override {
+ // Get the input images
+ const Tensor& input_a_t = ctx->input(0);
+ const Tensor& input_b_t = ctx->input(1);
+
+ // Get the global step value
+ const Tensor& global_step_t = ctx->input(2);
+ auto global_step_eigen = global_step_t.tensor<int64, 0>();
+ const int64 global_step = global_step_eigen.data()[0];
+
+ // Dimension constants
+ const int batch_size = input_a_t.dim_size(0);
+ const int src_height = input_a_t.dim_size(1);
+ const int src_width = input_a_t.dim_size(2);
+ const int channels = input_a_t.dim_size(3);
+ const int src_count = batch_size * src_height * src_width * channels;
+ const int out_height = crop_[0];
+ const int out_width = crop_[1];
+ const int out_count = batch_size * out_height * out_width * channels;
+
+ // All tensors for this op
+ Tensor chromatic_coeffs_a_t;
+ Tensor chromatic_coeffs_b_t;
+
+ // Allocate the memory for the output images
+ Tensor *output_a_t;
+ Tensor *output_b_t;
+
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_output(0, TensorShape({ batch_size, crop_[0], crop_[1],
+ channels }), &output_a_t));
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_output(1, TensorShape({ batch_size, crop_[0], crop_[1],
+ channels }), &output_b_t));
+
+ // Allocate the memory for the output spatial transforms
+ Tensor *spat_transform_a_t;
+ Tensor *spat_transform_b_t;
+
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_output(2, TensorShape({ batch_size, 6 }),
+ &spat_transform_a_t));
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_output(3, TensorShape({ batch_size, 6 }),
+ &spat_transform_b_t));
+
+ // Compute discount for coefficients if using a schedule
+ float discount_coeff_a = 1.0;
+ float discount_coeff_b = 1.0;
+
+ if (params_a_coeff_schedule_.size() == 3) {
+ float half_life = params_a_coeff_schedule_[0];
+ float initial_coeff = params_a_coeff_schedule_[1];
+ float final_coeff = params_a_coeff_schedule_[2];
+ discount_coeff_a = initial_coeff + (final_coeff - initial_coeff) *
+ (2.0 / (1.0 + exp(-1.0986 * global_step / half_life)) - 1.0);
+ }
+
+ if (params_b_coeff_schedule_.size() == 3) {
+ if (params_a_coeff_schedule_.size() == 3) {
+ discount_coeff_b = discount_coeff_a;
+ } else {
+ float half_life = params_b_coeff_schedule_[0];
+ float initial_coeff = params_b_coeff_schedule_[1];
+ float final_coeff = params_b_coeff_schedule_[2];
+ discount_coeff_b = initial_coeff + (final_coeff - initial_coeff) *
+ (2.0 / (1.0 + exp(-1.0986 * global_step / half_life)) - 1.0);
+ }
+ }
+
+ /*** BEGIN AUGMENTATION TO IMAGE A ***/
+ auto input_a = input_a_t.tensor<float, 4>();
+ auto output_a = output_a_t->tensor<float, 4>();
+
+ // Load augmentation parameters for image A
+ AugmentationParams aug_a = AugmentationParams(out_height, out_width,
+ params_a_name_,
+ params_a_rand_type_,
+ params_a_exp_,
+ params_a_mean_,
+ params_a_spread_,
+ params_a_prob_);
+
+ std::vector<AugmentationCoeff> coeffs_a;
+
+
+ bool gen_spatial_transform = aug_a.should_do_spatial_transform();
+ bool gen_chromatic_transform = aug_a.should_do_chromatic_transform();
+
+ for (int n = 0; n < batch_size; n++) {
+ AugmentationCoeff coeff;
+
+ if (gen_spatial_transform) {
+ AugmentationLayerBase::generate_valid_spatial_coeffs(discount_coeff_a, aug_a, coeff,
+ src_width, src_height,
+ out_width, out_height);
+ }
+
+ if (gen_chromatic_transform) {
+ AugmentationLayerBase::generate_chromatic_coeffs(discount_coeff_a, aug_a, coeff);
+ }
+
+ coeffs_a.push_back(coeff);
+ }
+
+ // Copy spatial coefficients A to the output Tensor on the CPU
+ // (output for FlowAugmentation)
+ auto spat_transform_a = spat_transform_a_t->tensor<float, 2>();
+ AugmentationLayerBase::copy_spatial_coeffs_to_tensor(coeffs_a,
+ out_width, out_height,
+ src_width, src_height,
+ spat_transform_a);
+
+ float *chromatic_coeffs_a_data = NULL;
+
+ if (gen_chromatic_transform) {
+ // Allocate a temporary tensor to hold the chromatic coefficients
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_temp(DataTypeToEnum<float>::value,
+ TensorShape({ batch_size, 6 }),
+ &chromatic_coeffs_a_t));
+
+ // Copy the chromatic coefficients A to a temporary Tensor on the CPU
+ auto chromatic_coeffs_a = chromatic_coeffs_a_t.tensor<float, 2>();
+ AugmentationLayerBase::copy_chromatic_coeffs_to_tensor(coeffs_a, chromatic_coeffs_a);
+ chromatic_coeffs_a_data = chromatic_coeffs_a.data();
+ }
+
+ // Perform augmentation either on CPU or GPU
+ Augment<Device>(
+ ctx,
+ ctx->eigen_device<Device>(),
+ batch_size,
+ channels,
+ src_width,
+ src_height,
+ src_count,
+ out_width,
+ out_height,
+ input_a.data(),
+ output_a.data(),
+ spat_transform_a.data(),
+ chromatic_coeffs_a_data);
+
+ /*** END AUGMENTATION TO IMAGE A ***/
+
+ /*** BEGIN GENERATE NEW COEFFICIENTS FOR IMAGE B ***/
+ AugmentationParams aug_b = AugmentationParams(out_height, out_width,
+ params_b_name_,
+ params_b_rand_type_,
+ params_b_exp_,
+ params_b_mean_,
+ params_b_spread_,
+ params_b_prob_);
+
+ std::vector<AugmentationCoeff> coeffs_b;
+
+ bool gen_spatial_transform_b = aug_b.should_do_spatial_transform();
+ bool gen_chromatic_transform_b = aug_b.should_do_chromatic_transform();
+
+ for (int n = 0; n < batch_size; n++) {
+ AugmentationCoeff coeff(coeffs_a[n]);
+
+ // If we did a spatial transform on image A, we need to do the same one
+ // (+ possibly more) on image B
+ if (gen_spatial_transform_b) {
+ AugmentationLayerBase::generate_valid_spatial_coeffs(discount_coeff_b, aug_b, coeff,
+ src_width, src_height,
+ out_width, out_height);
+ }
+
+ if (gen_chromatic_transform_b) {
+ AugmentationLayerBase::generate_chromatic_coeffs(discount_coeff_b, aug_b, coeff);
+ }
+
+ coeffs_b.push_back(coeff);
+ }
+
+ /*** END GENERATE NEW COEFFICIENTS FOR IMAGE B ***/
+
+ /*** BEGIN AUGMENTATION TO IMAGE B ***/
+ auto input_b = input_b_t.tensor<float, 4>();
+ auto output_b = output_b_t->tensor<float, 4>();
+
+ // Copy spatial coefficients B to the output Tensor on the CPU
+ auto spat_transform_b = spat_transform_b_t->tensor<float, 2>();
+ AugmentationLayerBase::copy_spatial_coeffs_to_tensor(coeffs_b,
+ out_width, out_height,
+ src_width, src_height,
+ spat_transform_b);
+
+ float *chromatic_coeffs_b_data = NULL;
+
+ if (gen_chromatic_transform || gen_chromatic_transform_b) {
+ // Allocate a temporary tensor to hold the chromatic coefficients
+ tensorflow::AllocatorAttributes pinned_allocator;
+ pinned_allocator.set_on_host(true);
+ pinned_allocator.set_gpu_compatible(true);
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_temp(DataTypeToEnum<float>::value,
+ TensorShape({ batch_size, 6 }),
+ &chromatic_coeffs_b_t, pinned_allocator));
+
+ // Copy the chromatic coefficients A to a temporary Tensor on the CPU
+ auto chromatic_coeffs_b = chromatic_coeffs_b_t.tensor<float, 2>();
+ AugmentationLayerBase::copy_chromatic_coeffs_to_tensor(coeffs_b, chromatic_coeffs_b);
+ chromatic_coeffs_b_data = chromatic_coeffs_b.data();
+ }
+
+ // Perform augmentation either on CPU or GPU
+ Augment<Device>(
+ ctx,
+ ctx->eigen_device<Device>(),
+ batch_size,
+ channels,
+ src_width,
+ src_height,
+ src_count,
+ out_width,
+ out_height,
+ input_b.data(),
+ output_b.data(),
+ spat_transform_b.data(),
+ chromatic_coeffs_b_data);
+
+ // FlowAugmentation needs the inverse
+ // TODO: To avoid rewriting, can we invert when we read on the
+ // FlowAugmentation side?
+ AugmentationLayerBase::copy_spatial_coeffs_to_tensor(coeffs_b,
+ out_width, out_height,
+ src_width, src_height,
+ spat_transform_b,
+ true);
+
+ /*** END AUGMENTATION TO IMAGE B ***/
+ }
+
+ private:
+ std::vector<int32>crop_;
+
+ // Params A
+ std::vector<string>params_a_name_;
+ std::vector<string>params_a_rand_type_;
+ std::vector<bool>params_a_exp_;
+ std::vector<float>params_a_mean_;
+ std::vector<float>params_a_spread_;
+ std::vector<float>params_a_prob_;
+ std::vector<float>params_a_coeff_schedule_;
+
+ // Params B
+ std::vector<string>params_b_name_;
+ std::vector<string>params_b_rand_type_;
+ std::vector<bool>params_b_exp_;
+ std::vector<float>params_b_mean_;
+ std::vector<float>params_b_spread_;
+ std::vector<float>params_b_prob_;
+ std::vector<float>params_b_coeff_schedule_;
+};
+
+
+REGISTER_KERNEL_BUILDER(Name("DataAugmentation")
+ .Device(DEVICE_CPU)
+ .HostMemory("global_step")
+ .HostMemory("transforms_from_a")
+ .HostMemory("transforms_from_b"),
+ DataAugmentation<CPUDevice>)
+
+#if GOOGLE_CUDA
+
+REGISTER_KERNEL_BUILDER(Name("DataAugmentation")
+ .Device(DEVICE_GPU)
+ .HostMemory("global_step")
+ .HostMemory("transforms_from_a")
+ .HostMemory("transforms_from_b"),
+ DataAugmentation<GPUDevice>)
+#endif // GOOGLE_CUDA
+} // namespace tensorflow
diff --git a/Codes/flownet2/src/ops/preprocessing/kernels/data_augmentation.cu.cc b/Codes/flownet2/src/ops/preprocessing/kernels/data_augmentation.cu.cc
new file mode 100644
index 0000000..7a2101d
--- /dev/null
+++ b/Codes/flownet2/src/ops/preprocessing/kernels/data_augmentation.cu.cc
@@ -0,0 +1,348 @@
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "augmentation_base.h"
+#include "data_augmentation.h"
+#include "tensorflow/core/util/cuda_kernel_helper.h"
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+inline __device__ __host__ float clamp(float f, float a, float b) {
+ return fmaxf(a, fminf(f, b));
+}
+
+__global__ void SpatialAugmentation(
+ const int32 nthreads,
+ const int src_width,
+ const int src_height,
+ const int channels,
+ const int src_count,
+ const int out_width,
+ const int out_height,
+ const float *src_data,
+ float *out_data,
+ const float *transMats) {
+ CUDA_1D_KERNEL_LOOP(index, nthreads) {
+ // Caffe, NKHW: ((n * K + k) * H + h) * W + w at point (n, k, h, w)
+ // TF, NHWK: ((n * H + h) * W + w) * K + k at point (n, h, w, k)
+ int c = index % channels;
+ int x = (index / channels) % out_width;
+ int y = (index / channels / out_width) % out_height;
+ int n = index / channels / out_width / out_height;
+
+ const float *transMat = transMats + n * 6;
+ float xpos = x * transMat[0] + y * transMat[1] + transMat[2];
+ float ypos = x * transMat[3] + y * transMat[4] + transMat[5];
+
+ xpos = clamp(xpos, 0.0f, (float)(src_width) - 1.05f);
+ ypos = clamp(ypos, 0.0f, (float)(src_height) - 1.05f);
+
+ float tlx = floor(xpos);
+ float tly = floor(ypos);
+
+ // Bilinear interpolation
+ int srcTLIdx = ((n * src_height + tly) * src_width + tlx) * channels + c;
+ int srcTRIdx = min((int)(((n * src_height + tly) * src_width + (tlx + 1)) * channels + c),
+ src_count);
+ int srcBLIdx = min((int)(((n * src_height + (tly + 1)) * src_width + tlx) * channels + c),
+ src_count);
+ int srcBRIdx = min((int)(((n * src_height + (tly + 1)) * src_width + (tlx + 1)) * channels + c),
+ src_count);
+
+ float xdist = xpos - tlx;
+ float ydist = ypos - tly;
+
+ float dest = (1 - xdist) * (1 - ydist) * src_data[srcTLIdx]
+ + (xdist) * (ydist) * src_data[srcBRIdx]
+ + (1 - xdist) * (ydist) * src_data[srcBLIdx]
+ + (xdist) * (1 - ydist) * src_data[srcTRIdx];
+
+ out_data[index] = dest;
+ }
+}
+
+typedef Eigen::GpuDevice GPUDevice;
+
+template<>
+void Augment(OpKernelContext *context,
+ const GPUDevice& d,
+ const int batch_size,
+ const int channels,
+ const int src_width,
+ const int src_height,
+ const int src_count,
+ const int out_width,
+ const int out_height,
+ const float *src_data,
+ float *out_data,
+ const float *transMats,
+ float *chromatic_coeffs) {
+ const int out_count = batch_size * out_height * out_width * channels;
+ CudaLaunchConfig config = GetCudaLaunchConfig(out_count, d);
+
+ printf("Chromatic transform not yet implemented on GPU, ignoring.");
+
+ SpatialAugmentation << < config.block_count, config.thread_per_block, 0, d.stream() >> > (
+ config.virtual_thread_count, src_width, src_height, channels, src_count,
+ out_width, out_height,
+ src_data, out_data, transMats);
+}
+
+//
+// template<typename Device>
+// class DataAugmentation : public OpKernel {
+// public:
+// explicit DataAugmentation(OpKernelConstruction *ctx) : OpKernel(ctx) {
+// // Get the crop [height, width] tensor and verify its dimensions
+// OP_REQUIRES_OK(ctx, ctx->GetAttr("crop", &crop_));
+// OP_REQUIRES(ctx, crop_.size() == 2,
+// errors::InvalidArgument("crop must be 2 dimensions"));
+//
+// // TODO: Verify params are all the same length
+//
+// // Get the tensors for params_a and verify their dimensions
+// OP_REQUIRES_OK(ctx, ctx->GetAttr("params_a_name", &params_a_name_));
+// OP_REQUIRES_OK(ctx,
+// ctx->GetAttr("params_a_rand_type",
+// &params_a_rand_type_));
+// OP_REQUIRES_OK(ctx, ctx->GetAttr("params_a_exp", &params_a_exp_));
+// OP_REQUIRES_OK(ctx, ctx->GetAttr("params_a_mean", &params_a_mean_));
+// OP_REQUIRES_OK(ctx, ctx->GetAttr("params_a_spread",
+// &params_a_spread_));
+// OP_REQUIRES_OK(ctx, ctx->GetAttr("params_a_prob", &params_a_prob_));
+//
+// // Get the tensors for params_b and verify their dimensions
+// OP_REQUIRES_OK(ctx, ctx->GetAttr("params_b_name", &params_b_name_));
+// OP_REQUIRES_OK(ctx,
+// ctx->GetAttr("params_b_rand_type",
+// &params_b_rand_type_));
+// OP_REQUIRES_OK(ctx, ctx->GetAttr("params_b_exp", &params_b_exp_));
+// OP_REQUIRES_OK(ctx, ctx->GetAttr("params_b_mean", &params_b_mean_));
+// OP_REQUIRES_OK(ctx, ctx->GetAttr("params_b_spread",
+// &params_b_spread_));
+// OP_REQUIRES_OK(ctx, ctx->GetAttr("params_b_prob", &params_b_prob_));
+// }
+//
+// void Compute(OpKernelContext *ctx) override {
+// const GPUDevice& device = ctx->eigen_gpu_device();
+//
+// // Get the input images
+// const Tensor& input_a_t = ctx->input(0);
+// const Tensor& input_b_t = ctx->input(1);
+//
+// // Dimension constants
+// const int batch_size = input_a_t.dim_size(0);
+// const int src_height = input_a_t.dim_size(1);
+// const int src_width = input_a_t.dim_size(2);
+// const int channels = input_a_t.dim_size(3);
+// const int src_count = batch_size * src_height * src_width * channels;
+// const int out_height = crop_[0];
+// const int out_width = crop_[1];
+// const int out_count = batch_size * out_height * out_width * channels;
+//
+// // Allocate the memory for the output images
+// Tensor *output_a_t;
+// Tensor *output_b_t;
+//
+// OP_REQUIRES_OK(ctx,
+// ctx->allocate_output(0, TensorShape({ batch_size,
+// crop_[0], crop_[1],
+// channels }),
+// &output_a_t));
+// OP_REQUIRES_OK(ctx,
+// ctx->allocate_output(1, TensorShape({ batch_size,
+// crop_[0], crop_[1],
+// channels }),
+// &output_b_t));
+//
+// // Allocate the memory for the output spatial transforms
+// Tensor *spat_transform_a_t;
+// Tensor *spat_transform_b_t;
+//
+// OP_REQUIRES_OK(ctx,
+// ctx->allocate_output(2, TensorShape({ batch_size, 6 }),
+// &spat_transform_a_t));
+// OP_REQUIRES_OK(ctx,
+// ctx->allocate_output(3, TensorShape({ batch_size, 6 }),
+// &spat_transform_b_t));
+//
+// // Allocate temporary pinned memory for the spatial transforms to be
+// used
+// // on the GPU
+// tensorflow::AllocatorAttributes pinned_allocator;
+// pinned_allocator.set_on_host(true);
+// pinned_allocator.set_gpu_compatible(true);
+//
+// Tensor spat_transform_a_pinned_t;
+// Tensor spat_transform_b_pinned_t;
+// OP_REQUIRES_OK(ctx,
+// ctx->allocate_temp(DataTypeToEnum<float>::value,
+// TensorShape({ batch_size, 6 }),
+// &spat_transform_a_pinned_t,
+// pinned_allocator));
+// OP_REQUIRES_OK(ctx,
+// ctx->allocate_temp(DataTypeToEnum<float>::value,
+// TensorShape({ batch_size, 6 }),
+// &spat_transform_b_pinned_t,
+// pinned_allocator));
+// auto spat_transform_a_pinned = spat_transform_a_pinned_t.tensor<float,
+// 2>();
+// auto spat_transform_b_pinned = spat_transform_b_pinned_t.tensor<float,
+// 2>();
+//
+// /*** BEGIN AUGMENTATION TO IMAGE A ***/
+// auto input_a = input_a_t.tensor<float, 4>();
+// auto output_a = output_a_t->tensor<float, 4>();
+//
+// // Load augmentation parameters for image A
+// AugmentationParams aug_a = AugmentationParams(out_height, out_width,
+// params_a_name_,
+// params_a_rand_type_,
+// params_a_exp_,
+// params_a_mean_,
+// params_a_spread_,
+// params_a_prob_);
+//
+// std::vector<AugmentationCoeff> coeffs_a;
+//
+// bool gen_spatial_transform = aug_a.should_do_spatial_transform();
+//
+// for (int n = 0; n < batch_size; n++) {
+// AugmentationCoeff coeff;
+//
+// if (gen_spatial_transform) {
+// AugmentationLayerBase::generate_valid_spatial_coeffs(aug_a, coeff,
+// src_width,
+// src_height,
+// out_width,
+// out_height);
+// }
+//
+// coeffs_a.push_back(coeff);
+// }
+//
+// // Copy spatial coefficients A to the output Tensor on the CPU (output
+// for
+// // FlowAugmentation)
+// auto spat_transform_a = spat_transform_a_t->tensor<float, 2>();
+// AugmentationLayerBase::copy_spatial_coeffs_to_tensor(coeffs_a,
+// out_width,
+// out_height,
+// src_width,
+// src_height,
+// spat_transform_a);
+//
+// // ...as well as a Tensor going to the GPU
+// AugmentationLayerBase::copy_spatial_coeffs_to_tensor(coeffs_a,
+// out_width,
+// out_height,
+// src_width,
+// src_height,
+//
+//
+//
+// spat_transform_a_pinned);
+//
+// CudaLaunchConfig config = GetCudaLaunchConfig(out_count, device);
+// SpatialAugmentation << < config.block_count, config.thread_per_block,
+// 0,
+// device.stream() >> > (
+// config.virtual_thread_count, src_width, src_height, channels,
+// src_count,
+// out_width, out_height,
+// input_a.data(), output_a.data(), spat_transform_a_pinned.data());
+//
+// /*** END AUGMENTATION TO IMAGE A ***/
+//
+// /*** BEGIN GENERATE NEW COEFFICIENTS FOR IMAGE B ***/
+// AugmentationParams aug_b = AugmentationParams(out_height, out_width,
+// params_b_name_,
+// params_b_rand_type_,
+// params_b_exp_,
+// params_b_mean_,
+// params_b_spread_,
+// params_b_prob_);
+//
+// std::vector<AugmentationCoeff> coeffs_b;
+//
+// gen_spatial_transform = aug_b.should_do_spatial_transform();
+//
+// for (int n = 0; n < batch_size; n++) {
+// AugmentationCoeff coeff;
+//
+// if (gen_spatial_transform) {
+// AugmentationLayerBase::generate_valid_spatial_coeffs(aug_b, coeff,
+// src_width,
+// src_height,
+// out_width,
+// out_height);
+// }
+//
+// coeffs_b.push_back(coeff);
+// }
+//
+// /*** END GENERATE NEW COEFFICIENTS FOR IMAGE B ***/
+//
+// /*** BEGIN AUGMENTATION TO IMAGE B ***/
+// auto input_b = input_b_t.tensor<float, 4>();
+// auto output_b = output_b_t->tensor<float, 4>();
+//
+// // Copy spatial coefficients B to the output Tensor on the CPU
+// auto spat_transform_b = spat_transform_b_t->tensor<float, 2>();
+// AugmentationLayerBase::copy_spatial_coeffs_to_tensor(coeffs_b,
+// out_width,
+// out_height,
+// src_width,
+// src_height,
+// spat_transform_b,
+// true);
+// AugmentationLayerBase::copy_spatial_coeffs_to_tensor(coeffs_b,
+// out_width,
+// out_height,
+// src_width,
+// src_height,
+//
+//
+//
+// spat_transform_b_pinned);
+//
+// SpatialAugmentation << < config.block_count, config.thread_per_block,
+// 0,
+// device.stream() >> > (
+// config.virtual_thread_count, src_width, src_height, channels,
+// src_count,
+// out_width, out_height,
+// input_b.data(), output_b.data(), spat_transform_b_pinned.data());
+//
+// /*** END AUGMENTATION TO IMAGE B ***/
+// }
+//
+// private:
+// std::vector<int32>crop_;
+//
+// // Params A
+// std::vector<string>params_a_name_;
+// std::vector<string>params_a_rand_type_;
+// std::vector<bool>params_a_exp_;
+// std::vector<float>params_a_mean_;
+// std::vector<float>params_a_spread_;
+// std::vector<float>params_a_prob_;
+//
+// // Params B
+// std::vector<string>params_b_name_;
+// std::vector<string>params_b_rand_type_;
+// std::vector<bool>params_b_exp_;
+// std::vector<float>params_b_mean_;
+// std::vector<float>params_b_spread_;
+// std::vector<float>params_b_prob_;
+// };
+} // namespace tensorflow
+#endif // GOOGLE_CUDA
diff --git a/Codes/flownet2/src/ops/preprocessing/kernels/data_augmentation.h b/Codes/flownet2/src/ops/preprocessing/kernels/data_augmentation.h
new file mode 100644
index 0000000..545b8a0
--- /dev/null
+++ b/Codes/flownet2/src/ops/preprocessing/kernels/data_augmentation.h
@@ -0,0 +1,22 @@
+#ifndef FLOWNET_DATA_AUGMENTATION_H_
+#define FLOWNET_DATA_AUGMENTATION_H_
+
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+template<class Device>
+void Augment(OpKernelContext *context,
+ const Device & d,
+ const int batch_size,
+ const int channels,
+ const int src_width,
+ const int src_height,
+ const int src_count,
+ const int out_width,
+ const int out_height,
+ const float *src_data,
+ float *out_data,
+ const float *transMats,
+ float *chromatic_coeffs);
+} // namespace tensorflow
+#endif // FLOWNET_DATA_AUGMENTATION_H_
diff --git a/Codes/flownet2/src/ops/preprocessing/kernels/flow_augmentation.cc b/Codes/flownet2/src/ops/preprocessing/kernels/flow_augmentation.cc
new file mode 100644
index 0000000..b5cc11f
--- /dev/null
+++ b/Codes/flownet2/src/ops/preprocessing/kernels/flow_augmentation.cc
@@ -0,0 +1,129 @@
+#define EIGEN_USE_THREADS
+
+#include "flow_augmentation.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+inline int clamp(int f, int a, int b) {
+ return std::max(a, std::min(f, b));
+}
+
+template<>
+void FillFlowAugmentation(const CPUDevice& device,
+ typename TTypes<float, 4>::Tensor output,
+ typename TTypes<float, 4>::ConstTensor flows,
+ typename TTypes<float, 2>::ConstTensor transforms_from_a,
+ typename TTypes<float, 2>::ConstTensor transforms_from_b) {
+ const int batch_size = output.dimension(0);
+ const int out_height = output.dimension(1);
+ const int out_width = output.dimension(2);
+ const int src_height = flows.dimension(1);
+ const int src_width = flows.dimension(2);
+ const int src_total_count = flows.dimension(0) * flows.dimension(1) *
+ flows.dimension(2) * flows.dimension(3);
+ float *output_ptr = output.data();
+ const float *flow_ptr = flows.data();
+
+ for (int n = 0; n < batch_size; n++) {
+ const float *transMatA = transforms_from_a.data() + n * 6;
+ const float *transMatB = transforms_from_b.data() + n * 6;
+
+ for (int y = 0; y < out_height; y++) {
+ int outputIdxOffset = (n * out_height + y) * out_width;
+
+ for (int x = 0; x < out_width; x++) {
+ // Apply transformation matrix applied to first image
+ const float xpos1 = x * transMatA[0] + y * transMatA[1] + transMatA[2];
+ const float ypos1 = x * transMatA[3] + y * transMatA[4] + transMatA[5];
+
+ const int srcXIdx =
+ ((n * src_height + (int)(ypos1 + 0.5)) * src_width + (int)(xpos1 + 0.5)) * 2 + 0;
+ const int srcYIdx = srcXIdx + 1;
+
+ const float xpos2 = xpos1 + flow_ptr[clamp(srcXIdx, 0, src_total_count - 1)];
+ const float ypos2 = ypos1 + flow_ptr[clamp(srcYIdx, 0, src_total_count - 1)];
+
+ // Apply inverse of the transformation matrix applied to second image
+ const float xpos3 = xpos2 * transMatB[0] + ypos2 * transMatB[1] + transMatB[2];
+ const float ypos3 = xpos2 * transMatB[3] + ypos2 * transMatB[4] + transMatB[5];
+
+ output_ptr[(outputIdxOffset + x) * 2 + 0] = xpos3 - (float)x;
+ output_ptr[(outputIdxOffset + x) * 2 + 1] = ypos3 - (float)y;
+ }
+ }
+ }
+}
+
+template<typename Device>
+class FlowAugmentation : public OpKernel {
+ public:
+ explicit FlowAugmentation(OpKernelConstruction *ctx) : OpKernel(ctx) {
+ // Get the crop [height, width] tensor and verify its dimensions
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("crop", &crop_));
+ OP_REQUIRES(ctx, crop_.size() == 2,
+ errors::InvalidArgument("crop must be 2 dimensions"));
+ }
+
+ void Compute(OpKernelContext *ctx) override {
+ // Get the input images and transforms and verify their dimensions
+ const Tensor& flows_t = ctx->input(0);
+ const Tensor& transforms_from_a_t = ctx->input(1);
+ const Tensor& transforms_from_b_t = ctx->input(2);
+
+ OP_REQUIRES(ctx, flows_t.dims() == 4,
+ errors::InvalidArgument("Input images must have rank 4"));
+ OP_REQUIRES(ctx,
+ (TensorShapeUtils::IsMatrix(transforms_from_a_t.shape()) &&
+ transforms_from_a_t.dim_size(0) ==
+ flows_t.dim_size(0) &&
+ transforms_from_a_t.dim_size(1) == 6),
+ errors::InvalidArgument(
+ "Input transforms_from_a should be num_images x 6"));
+ OP_REQUIRES(ctx,
+ (TensorShapeUtils::IsMatrix(transforms_from_b_t.shape()) &&
+ transforms_from_b_t.dim_size(0) ==
+ flows_t.dim_size(0) &&
+ transforms_from_b_t.dim_size(1) == 6),
+ errors::InvalidArgument(
+ "Input transforms_from_b should be num_images x 6"));
+
+ // Allocate the memory for the output
+ Tensor *output_t;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(
+ 0,
+ TensorShape({ flows_t.dim_size(0), crop_[0], crop_[1],
+ flows_t.dim_size(3) }), &output_t));
+
+ // Perform flow augmentation
+ auto flows = flows_t.tensor<float, 4>();
+ auto transforms_from_a = transforms_from_a_t.tensor<float, 2>();
+ auto transforms_from_b = transforms_from_b_t.tensor<float, 2>();
+ auto output = output_t->tensor<float, 4>();
+
+ FillFlowAugmentation(ctx->eigen_device<Device>(),
+ output,
+ flows,
+ transforms_from_a,
+ transforms_from_b);
+ }
+
+ private:
+ std::vector<int32>crop_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("FlowAugmentation")
+ .Device(DEVICE_CPU),
+ FlowAugmentation<CPUDevice>)
+
+#if GOOGLE_CUDA
+REGISTER_KERNEL_BUILDER(Name("FlowAugmentation")
+ .Device(DEVICE_GPU),
+ FlowAugmentation<GPUDevice>)
+#endif // GOOGLE_CUDA
+} // end namespace tensorflow
diff --git a/Codes/flownet2/src/ops/preprocessing/kernels/flow_augmentation.h b/Codes/flownet2/src/ops/preprocessing/kernels/flow_augmentation.h
new file mode 100644
index 0000000..7795991
--- /dev/null
+++ b/Codes/flownet2/src/ops/preprocessing/kernels/flow_augmentation.h
@@ -0,0 +1,19 @@
+#ifndef FLOWNET_FLOW_AUG_H_
+#define FLOWNET_FLOW_AUG_H_
+
+// See docs in ../ops/image_ops.cc.
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+template<class Device>
+void FillFlowAugmentation(const Device& device,
+ typename TTypes<float, 4>::Tensor output,
+ typename TTypes<float, 4>::ConstTensor flows,
+ typename TTypes<float, 2>::ConstTensor transforms_from_a,
+ typename TTypes<float, 2>::ConstTensor transforms_from_b);
+} // end namespace tensorflow
+
+#endif // FLOWNET_FLOW_AUG_H_
diff --git a/Codes/flownet2/src/ops/preprocessing/kernels/flow_augmentation_gpu.cu.cc b/Codes/flownet2/src/ops/preprocessing/kernels/flow_augmentation_gpu.cu.cc
new file mode 100644
index 0000000..7e10864
--- /dev/null
+++ b/Codes/flownet2/src/ops/preprocessing/kernels/flow_augmentation_gpu.cu.cc
@@ -0,0 +1,95 @@
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include <stdio.h>
+#include <iostream>
+
+#include "flow_augmentation.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/cuda_kernel_helper.h"
+
+namespace tensorflow {
+typedef Eigen::GpuDevice GPUDevice;
+
+inline __device__ __host__ int clamp(int f, int a, int b) {
+ return max(a, min(f, b));
+}
+
+__global__ void FillFlowAugmentationKernel(
+ const int32 nthreads,
+ const float *flow_ptr,
+ const float *transforms_from_a,
+ const float *inv_transforms_from_b,
+ const int src_total_count, const int src_height, const int src_width,
+ const int batch_size, const int out_height,
+ const int out_width, float *output_ptr) {
+ CUDA_1D_KERNEL_LOOP(index, nthreads) {
+ const float x = (float)(index % out_width);
+ const float y = (float)((index / out_width) % out_height);
+ const int n = (index / out_width / out_height);
+
+ const int transformIdx = n * 6;
+
+ // Apply transformation matrix applied to second image
+ const float xpos1 = x * transforms_from_a[transformIdx + 0]
+ + y * transforms_from_a[transformIdx + 1]
+ + transforms_from_a[transformIdx + 2];
+ const float ypos1 = x * transforms_from_a[transformIdx + 3]
+ + y * transforms_from_a[transformIdx + 4]
+ + transforms_from_a[transformIdx + 5];
+
+ // Caffe, NKHW: ((n * K + k) * H + h) * W + w at point (n, k, h, w)
+ // TF, NHWK: ((n * H + h) * W + w) * K + k at point (n, h, w, k)
+ const int srcXIdx =
+ ((n * src_height + (int)(ypos1 + 0.5)) * src_width + (int)(xpos1 + 0.5)) *
+ 2 + 0;
+ const int srcYIdx = srcXIdx + 1;
+
+ const float xpos2 = xpos1 + flow_ptr[clamp(srcXIdx, 0, src_total_count - 1)];
+ const float ypos2 = ypos1 + flow_ptr[clamp(srcYIdx, 0, src_total_count - 1)];
+
+ // Apply inverse of the transformation matrix applied to first image
+ const float xpos3 = xpos2 * inv_transforms_from_b[transformIdx + 0]
+ + ypos2 * inv_transforms_from_b[transformIdx + 1]
+ + inv_transforms_from_b[transformIdx + 2];
+ const float ypos3 = xpos2 * inv_transforms_from_b[transformIdx + 3]
+ + ypos2 * inv_transforms_from_b[transformIdx + 4]
+ + inv_transforms_from_b[transformIdx + 5];
+
+ output_ptr[((n * out_height + (int)y) * out_width + (int)x) * 2 + 0] = xpos3 -
+ x;
+ output_ptr[((n * out_height + (int)y) * out_width + (int)x) * 2 + 1] = ypos3 -
+ y;
+ }
+}
+
+template<>
+void FillFlowAugmentation(const GPUDevice& device,
+ typename TTypes<float, 4>::Tensor output,
+ typename TTypes<float, 4>::ConstTensor flows,
+ typename TTypes<const float, 2>::ConstTensor transforms_from_a,
+ typename TTypes<const float, 2>::ConstTensor transforms_from_b) {
+ const int batch_size = output.dimension(0);
+ const int out_height = output.dimension(1);
+ const int out_width = output.dimension(2);
+ const int depth = 2;
+ const int total_count = batch_size * out_height * out_width * depth;
+ const int src_total_count = flows.dimension(0) * flows.dimension(1) *
+ flows.dimension(2) * flows.dimension(3);
+
+ CudaLaunchConfig config = GetCudaLaunchConfig(total_count / 2, device);
+
+ FillFlowAugmentationKernel << < config.block_count, config.thread_per_block, 0,
+ device.stream() >> > (
+ total_count / 2, flows.data(), transforms_from_a.data(),
+ transforms_from_b.data(),
+ src_total_count, flows.dimension(1), flows.dimension(2), batch_size,
+ out_height, out_width, output.data());
+}
+} // end namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/Codes/flownet2/src/ops/preprocessing/preprocessing.cc b/Codes/flownet2/src/ops/preprocessing/preprocessing.cc
new file mode 100644
index 0000000..086a0d0
--- /dev/null
+++ b/Codes/flownet2/src/ops/preprocessing/preprocessing.cc
@@ -0,0 +1,96 @@
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+using shape_inference::InferenceContext;
+using shape_inference::ShapeHandle;
+using shape_inference::DimensionHandle;
+
+Status SetOutputToSizedImage(InferenceContext *c) {
+ ShapeHandle input;
+
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
+ DimensionHandle batch = c->Dim(input, 0);
+ DimensionHandle depth = c->Dim(input, 3);
+ std::vector<int32> crop_;
+ c->GetAttr("crop", &crop_);
+ DimensionHandle height = c->MakeDim(crop_[0]);
+ DimensionHandle width = c->MakeDim(crop_[1]);
+ c->set_output(0, c->MakeShape({ batch, height, width, depth }));
+ return Status::OK();
+}
+
+REGISTER_OP("DataAugmentation")
+.Input("image_a: float32")
+.Input("image_b: float32")
+.Input("global_step: int64")
+.Attr("crop: list(int) >= 2")
+.Attr("params_a_name: list(string)")
+.Attr("params_a_rand_type: list(string)")
+.Attr("params_a_exp: list(bool)")
+.Attr("params_a_mean: list(float)")
+.Attr("params_a_spread: list(float)")
+.Attr("params_a_prob: list(float)")
+.Attr("params_a_coeff_schedule: list(float)")
+.Attr("params_b_name: list(string)")
+.Attr("params_b_rand_type: list(string)")
+.Attr("params_b_exp: list(bool)")
+.Attr("params_b_mean: list(float)")
+.Attr("params_b_spread: list(float)")
+.Attr("params_b_prob: list(float)")
+.Attr("params_b_coeff_schedule: list(float)")
+.Output("aug_image_a: float32")
+.Output("aug_image_b: float32")
+.Output("transforms_from_a: float32")
+.Output("transforms_from_b: float32")
+.SetShapeFn([](InferenceContext *c) {
+ // Verify input A and input B both have 4 dimensions
+ ShapeHandle input_shape_a, input_shape_b;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape_a));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &input_shape_b));
+
+ // TODO: Verify params vectors all have the same length
+
+ // TODO: Move this out of here and into Compute
+ // Verify input A and input B are the same shape
+ DimensionHandle batch_size, unused;
+ TF_RETURN_IF_ERROR(c->WithValue(c->Dim(input_shape_a, 0),
+ c->Value(c->Dim(input_shape_b, 0)),
+ &batch_size));
+ TF_RETURN_IF_ERROR(c->WithValue(c->Dim(input_shape_a, 1),
+ c->Value(c->Dim(input_shape_b, 1)), &unused));
+ TF_RETURN_IF_ERROR(c->WithValue(c->Dim(input_shape_a, 2),
+ c->Value(c->Dim(input_shape_b, 2)), &unused));
+ TF_RETURN_IF_ERROR(c->WithValue(c->Dim(input_shape_a, 3),
+ c->Value(c->Dim(input_shape_b, 3)), &unused));
+
+ // Get cropping dimensions
+ std::vector<int32>crop_;
+ TF_RETURN_IF_ERROR(c->GetAttr("crop", &crop_));
+
+ // Reshape input shape to cropped shape
+ TF_RETURN_IF_ERROR(c->ReplaceDim(input_shape_a, 1, c->MakeDim(crop_[0]),
+ &input_shape_a));
+ TF_RETURN_IF_ERROR(c->ReplaceDim(input_shape_a, 2, c->MakeDim(crop_[1]),
+ &input_shape_a));
+
+ // Set output images shapes
+ c->set_output(0, input_shape_a);
+ c->set_output(1, input_shape_a);
+
+ // Set output spatial transforms shapes
+ c->set_output(2, c->MakeShape({ batch_size, 6 }));
+ c->set_output(3, c->MakeShape({ batch_size, 6 }));
+
+ return Status::OK();
+ });
+
+REGISTER_OP("FlowAugmentation")
+.Input("flows: float32")
+.Input("transforms_from_a: float32")
+.Input("transforms_from_b: float32")
+.Attr("crop: list(int) >= 2")
+.Output("transformed_flows: float32")
+.SetShapeFn(SetOutputToSizedImage);
+} // namespace tensorflow