summaryrefslogtreecommitdiff
path: root/Codes/flownet2/src/ops/preprocessing/kernels/augmentation_base.h
diff options
context:
space:
mode:
Diffstat (limited to 'Codes/flownet2/src/ops/preprocessing/kernels/augmentation_base.h')
-rw-r--r--Codes/flownet2/src/ops/preprocessing/kernels/augmentation_base.h228
1 files changed, 228 insertions, 0 deletions
diff --git a/Codes/flownet2/src/ops/preprocessing/kernels/augmentation_base.h b/Codes/flownet2/src/ops/preprocessing/kernels/augmentation_base.h
new file mode 100644
index 0000000..d2aba2c
--- /dev/null
+++ b/Codes/flownet2/src/ops/preprocessing/kernels/augmentation_base.h
@@ -0,0 +1,228 @@
+#ifndef AUGMENTATION_LAYER_BASE_H_
+#define AUGMENTATION_LAYER_BASE_H_
+
+#include "tensorflow/core/framework/tensor_types.h"
+
+#include <iostream>
+#include <string>
+#include <vector>
+
+namespace tensorflow {
+template<typename T>
+class OptionalType {
+ public:
+ OptionalType(const T default_value) : default_value(default_value), has_value(false) {}
+
+ operator bool() const {
+ return has_value;
+ }
+
+ OptionalType& operator=(T val) {
+ has_value = true;
+ value = val;
+ return *this;
+ }
+
+ const T operator()() const {
+ return has_value ? value : default_value;
+ }
+
+ void clear() {
+ has_value = false;
+ }
+
+ const T get_default() {
+ return default_value;
+ }
+
+ private:
+ T value;
+ bool has_value;
+ const T default_value;
+};
+
+class AugmentationCoeff {
+ public:
+ // Spatial Types
+ OptionalType<float>dx;
+ OptionalType<float>dy;
+ OptionalType<float>angle;
+ OptionalType<float>zoom_x;
+ OptionalType<float>zoom_y;
+
+ // Chromatic Types
+ OptionalType<float>gamma;
+ OptionalType<float>brightness;
+ OptionalType<float>contrast;
+ OptionalType<float>color1;
+ OptionalType<float>color2;
+ OptionalType<float>color3;
+
+ AugmentationCoeff() : dx(0.0), dy(0.0), angle(0.0), zoom_x(1.0), zoom_y(1.0), gamma(1.0),
+ brightness(0.0), contrast(1.0), color1(1.0), color2(1.0), color3(1.0) {}
+
+ AugmentationCoeff(const AugmentationCoeff& coeff) : AugmentationCoeff() {
+ replace_with(coeff);
+ }
+
+ void clear();
+
+ void combine_with(const AugmentationCoeff& coeff);
+
+ void replace_with(const AugmentationCoeff& coeff);
+};
+
+typedef struct AugmentationParam {
+ std::string rand_type;
+ bool should_exp;
+ float mean;
+ float spread;
+ float prob;
+} AugmentationParam;
+
+class AugmentationParams {
+ public:
+ int crop_height;
+ int crop_width;
+
+ // Spatial options
+ OptionalType<struct AugmentationParam>translate;
+ OptionalType<struct AugmentationParam>rotate;
+ OptionalType<struct AugmentationParam>zoom;
+ OptionalType<struct AugmentationParam>squeeze;
+
+ // Chromatic options
+ OptionalType<struct AugmentationParam>gamma;
+ OptionalType<struct AugmentationParam>brightness;
+ OptionalType<struct AugmentationParam>contrast;
+ OptionalType<struct AugmentationParam>color;
+
+ inline AugmentationParams(int crop_height,
+ int crop_width,
+ std::vector<std::string>params_name,
+ std::vector<std::string>params_rand_type,
+ std::vector<bool> params_exp,
+ std::vector<float> params_mean,
+ std::vector<float> params_spread,
+ std::vector<float> params_prob) :
+ crop_height(crop_height),
+ crop_width(crop_width),
+ translate(AugmentationParam()),
+ rotate(AugmentationParam()),
+ zoom(AugmentationParam()),
+ squeeze(AugmentationParam()),
+ gamma(AugmentationParam()),
+ brightness(AugmentationParam()),
+ contrast(AugmentationParam()),
+ color(AugmentationParam()) {
+ for (int i = 0; i < params_name.size(); i++) {
+ const std::string name = params_name[i];
+ const std::string rand_type = params_rand_type[i];
+ const bool should_exp = params_exp[i];
+ const float mean = params_mean[i];
+ const float spread = params_spread[i];
+ const float prob = params_prob[i];
+
+ struct AugmentationParam param = { rand_type, should_exp, mean, spread, prob };
+
+ if (name == "translate") {
+ this->translate = param;
+ } else if (name == "rotate") {
+ this->rotate = param;
+ } else if (name == "zoom") {
+ this->zoom = param;
+ } else if (name == "squeeze") {
+ this->squeeze = param;
+ } else if (name == "noise") {
+ // NoOp: We handle noise on the Python side
+ } else if (name == "gamma") {
+ this->gamma = param;
+ } else if (name == "brightness") {
+ this->brightness = param;
+ } else if (name == "contrast") {
+ this->contrast = param;
+ } else if (name == "color") {
+ this->color = param;
+ } else {
+ std::cout << "Ignoring unknown augmentation parameter: " << name << std::endl;
+ }
+ }
+ }
+
+ bool should_do_spatial_transform() {
+ return this->translate || this->rotate || this->zoom || this->squeeze;
+ }
+
+ bool should_do_chromatic_transform() {
+ return this->gamma || this->brightness || this->contrast || this->color;
+ }
+};
+
+class AugmentationLayerBase {
+ public:
+ class TransMat {
+ /**
+ * Translation matrix class for spatial augmentation
+ * | 0 1 2 |
+ * | 3 4 5 |
+ */
+
+ public:
+ float t0, t1, t2;
+ float t3, t4, t5;
+
+
+ void fromCoeff(AugmentationCoeff *coeff,
+ int out_width,
+ int out_height,
+ int src_width,
+ int src_height);
+
+ void fromTensor(const float *tensor_data);
+
+ TransMat inverse();
+
+ void leftMultiply(float u0,
+ float u1,
+ float u2,
+ float u3,
+ float u4,
+ float u5);
+
+ void toIdentity();
+ };
+
+ // TODO: Class ChromaticCoeffs
+
+ static float rng_generate(const AugmentationParam& param,
+ float discount_coeff,
+ const float default_value);
+
+ static void clear_spatial_coeffs(AugmentationCoeff& coeff);
+ static void generate_chromatic_coeffs(float discount_coeff,
+ const AugmentationParams& aug,
+ AugmentationCoeff & coeff);
+ static void generate_spatial_coeffs(float discount_coeff,
+ const AugmentationParams& aug,
+ AugmentationCoeff & coeff);
+ static void generate_valid_spatial_coeffs(float discount_coeff,
+ const AugmentationParams& aug,
+ AugmentationCoeff & coeff,
+ int src_width,
+ int src_height,
+ int out_width,
+ int out_height);
+
+ static void copy_chromatic_coeffs_to_tensor(const std::vector<AugmentationCoeff>& coeff_arr,
+ typename TTypes<float, 2>::Tensor& out);
+ static void copy_spatial_coeffs_to_tensor(const std::vector<AugmentationCoeff>& coeff_arr,
+ const int out_width,
+ const int out_height,
+ const int src_width,
+ const int src_height,
+ typename TTypes<float, 2>::Tensor& out,
+ const bool invert = false);
+};
+} // namespace tensorflow
+
+#endif // AUGMENTATION_LAYER_BASE_H_