diff options
Diffstat (limited to 'Codes/flownet2/src/ops/preprocessing')
9 files changed, 1818 insertions, 0 deletions
diff --git a/Codes/flownet2/src/ops/preprocessing/kernels/augmentation_base.cc b/Codes/flownet2/src/ops/preprocessing/kernels/augmentation_base.cc new file mode 100644 index 0000000..b93dfa6 --- /dev/null +++ b/Codes/flownet2/src/ops/preprocessing/kernels/augmentation_base.cc @@ -0,0 +1,420 @@ +#include "augmentation_base.h" + +#include <math.h> +#include <random> + +namespace tensorflow { +/** TransMat Functions **/ +void AugmentationLayerBase::TransMat::fromCoeff(AugmentationCoeff *coeff, + int out_width, + int out_height, + int src_width, + int src_height) { + leftMultiply(1, 0, -0.5 * out_width, + 0, 1, -0.5 * out_height); + + if (coeff->angle) { + leftMultiply(cos(coeff->angle()), -sin(coeff->angle()), 0, + sin(coeff->angle()), cos(coeff->angle()), 0); + } + + if (coeff->dx || coeff->dy) { + leftMultiply(1, 0, coeff->dx() * out_width, + 0, 1, coeff->dy() * out_height); + } + + if (coeff->zoom_x || coeff->zoom_y) { + leftMultiply(1.0 / coeff->zoom_x(), 0, 0, + 0, 1.0 / coeff->zoom_y(), 0); + } + + leftMultiply(1, 0, 0.5 * src_width, + 0, 1, 0.5 * src_height); +} + +void AugmentationLayerBase::TransMat::fromTensor(const float *tensor_data) { + t0 = tensor_data[0]; + t1 = tensor_data[1]; + t2 = tensor_data[2]; + t3 = tensor_data[3]; + t4 = tensor_data[4]; + t5 = tensor_data[5]; +} + +AugmentationLayerBase::TransMat AugmentationLayerBase::TransMat::inverse() { + float a = this->t0, b = this->t1, c = this->t2; + float d = this->t3, e = this->t4, f = this->t5; + + float denom = a * e - b * d; + + TransMat result; + + result.t0 = e / denom; + result.t1 = b / -denom; + result.t2 = (c * e - b * f) / -denom; + result.t3 = d / -denom; + result.t4 = a / denom; + result.t5 = (c * d - a * f) / denom; + + return result; +} + +void AugmentationLayerBase::TransMat::leftMultiply(float u0, + float u1, + float u2, + float u3, + float u4, + float u5) { + float t0 = this->t0, t1 = this->t1, t2 = this->t2; + float t3 = this->t3, t4 = this->t4, t5 = this->t5; + + this->t0 = t0 * u0 + t3 * u1; + this->t1 = t1 * u0 + t4 * u1; + this->t2 = t2 * u0 + t5 * u1 + u2; + this->t3 = t0 * u3 + t3 * u4; + this->t4 = t1 * u3 + t4 * u4; + this->t5 = t2 * u3 + t5 * u4 + u5; +} + +void AugmentationLayerBase::TransMat::toIdentity() { + t0 = 1; t1 = 0; t2 = 0; + t3 = 0; t4 = 1; t5 = 0; +} + +/** AugmentationCoeff Functions **/ +void AugmentationCoeff::clear() { + // Spatial variables + dx.clear(); + dy.clear(); + angle.clear(); + zoom_x.clear(); + zoom_y.clear(); + + // Chromatic variables + gamma.clear(); + brightness.clear(); + contrast.clear(); + color1.clear(); + color2.clear(); + color3.clear(); +} + +void AugmentationCoeff::combine_with(const AugmentationCoeff& coeff) { + // Spatial types + if (coeff.dx) { + dx = dx() * coeff.dx(); + } + + if (coeff.dy) { + dy = dy() * coeff.dy(); + } + + if (coeff.angle) { + angle = angle() * coeff.angle(); + } + + if (coeff.zoom_x) { + zoom_x = zoom_x() * coeff.zoom_x(); + } + + if (coeff.zoom_y) { + zoom_y = zoom_y() * coeff.zoom_y(); + } + + // Chromatic types + if (coeff.gamma) { + gamma = gamma() * coeff.gamma(); + } + + if (coeff.brightness) { + brightness = brightness() * coeff.brightness(); + } + + if (coeff.contrast) { + contrast = contrast() * coeff.contrast(); + } + + if (coeff.color1) { + color1 = color1() * coeff.color1(); + } + + if (coeff.color2) { + color2 = color2() * coeff.color2(); + } + + if (coeff.color3) { + color3 = color3() * coeff.color3(); + } +} + +void AugmentationCoeff::replace_with(const AugmentationCoeff& coeff) { + // Spatial types + if (coeff.dx) { + dx = coeff.dx(); + } + + if (coeff.dy) { + dy = coeff.dy(); + } + + if (coeff.angle) { + angle = coeff.angle(); + } + + if (coeff.zoom_x) { + zoom_x = coeff.zoom_x(); + } + + if (coeff.zoom_y) { + zoom_y = coeff.zoom_y(); + } + + // Chromatic types + if (coeff.gamma) { + gamma = gamma() * coeff.gamma(); + } + + if (coeff.brightness) { + brightness = coeff.brightness(); + } + + if (coeff.contrast) { + contrast = coeff.contrast(); + } + + if (coeff.color1) { + color1 = coeff.color1(); + } + + if (coeff.color2) { + color2 = coeff.color2(); + } + + if (coeff.color3) { + color3 = coeff.color3(); + } +} + +/** AugmentationLayerBase Functions **/ +float AugmentationLayerBase::rng_generate(const AugmentationParam& param, + float discount_coeff, + const float default_value) { + std::random_device rd; // Will be used to obtain a seed for the random number + // engine + std::mt19937 gen(rd()); // Standard mersenne_twister_engine seeded with rd() + + float spread = param.spread * discount_coeff; + + if (param.rand_type == "uniform_bernoulli") { + float tmp1 = 0.0; + bool tmp2 = false; + + if (param.prob > 0.0) { + std::bernoulli_distribution bernoulli(param.prob); + tmp2 = bernoulli(gen); + } + + if (!tmp2) { + return default_value; + } + + if (param.spread > 0.0) { + std::uniform_real_distribution<> uniform(param.mean - spread, + param.mean + spread); + tmp1 = uniform(gen); + } else { + tmp1 = param.mean; + } + + if (param.should_exp) { + tmp1 = exp(tmp1); + } + + return tmp1; + } else if (param.rand_type == "gaussian_bernoulli") { + float tmp1 = 0.0; + bool tmp2 = false; + + if (param.prob > 0.0) { + std::bernoulli_distribution bernoulli(param.prob); + tmp2 = bernoulli(gen); + } + + if (!tmp2) { + return default_value; + } + + if (spread > 0.0) { + std::normal_distribution<> normal(param.mean, spread); + tmp1 = normal(gen); + } else { + tmp1 = param.mean; + } + + if (param.should_exp) { + tmp1 = exp(tmp1); + } + + return tmp1; + } else { + throw "Unknown random type: " + param.rand_type; + } +} + +void AugmentationLayerBase::generate_chromatic_coeffs(float discount_coeff, + const AugmentationParams& aug, + AugmentationCoeff & coeff) { + if (aug.gamma) { + coeff.gamma = rng_generate(aug.gamma(), discount_coeff, coeff.gamma.get_default()); + } + + if (aug.brightness) { + coeff.brightness = + rng_generate(aug.brightness(), discount_coeff, coeff.brightness.get_default()); + } + + if (aug.contrast) { + coeff.contrast = rng_generate(aug.contrast(), discount_coeff, coeff.contrast.get_default()); + } + + if (aug.color) { + coeff.color1 = rng_generate(aug.color(), discount_coeff, coeff.color1.get_default()); + coeff.color2 = rng_generate(aug.color(), discount_coeff, coeff.color2.get_default()); + coeff.color3 = rng_generate(aug.color(), discount_coeff, coeff.color3.get_default()); + } +} + +void AugmentationLayerBase::generate_spatial_coeffs(float discount_coeff, + const AugmentationParams& aug, + AugmentationCoeff & coeff) { + if (aug.translate) { + coeff.dx = rng_generate(aug.translate(), discount_coeff, coeff.dx.get_default()); + coeff.dy = rng_generate(aug.translate(), discount_coeff, coeff.dy.get_default()); + } + + if (aug.rotate) { + coeff.angle = rng_generate(aug.rotate(), discount_coeff, coeff.angle.get_default()); + } + + if (aug.zoom) { + coeff.zoom_x = rng_generate(aug.zoom(), discount_coeff, coeff.zoom_x.get_default()); + coeff.zoom_y = coeff.zoom_x(); + } + + if (aug.squeeze) { + float squeeze_coeff = rng_generate(aug.squeeze(), discount_coeff, 1.0); + coeff.zoom_x = coeff.zoom_x() * squeeze_coeff; + coeff.zoom_y = coeff.zoom_y() * squeeze_coeff; + } +} + +void AugmentationLayerBase::generate_valid_spatial_coeffs( + float discount_coeff, + const AugmentationParams& aug, + AugmentationCoeff & coeff, + int src_width, + int src_height, + int out_width, + int out_height) { + int x, y; + float x1, y1, x2, y2; + int counter = 0; + int good_params = 0; + AugmentationCoeff incoming_coeff(coeff); + + while (good_params < 4 && counter < 50) { + coeff.clear(); + AugmentationLayerBase::generate_spatial_coeffs(discount_coeff, aug, coeff); + coeff.combine_with(incoming_coeff); + + // Check if all 4 corners of the transformed image fit into the original + // image + good_params = 0; + + for (x = 0; x < out_width; x += out_width - 1) { + for (y = 0; y < out_height; y += out_height - 1) { + // move the origin + x1 = x - 0.5 * out_width; + y1 = y - 0.5 * out_height; + + // rotate + x2 = cos(coeff.angle()) * x1 - sin(coeff.angle()) * y1; + y2 = sin(coeff.angle()) * x1 + sin(coeff.angle()) * y1; + + // translate + x2 = x2 + coeff.dx() * out_width; + y2 = y2 + coeff.dy() * out_height; + + // zoom + x2 = x2 / coeff.zoom_x(); + y2 = y2 / coeff.zoom_y(); + + // move the origin back + x2 = x2 + 0.5 * src_width; + y2 = y2 + 0.5 * src_height; + + if (!((floor(x2) < 0) || (floor(x2) > src_width - 2.0) || + (floor(y2) < 0) || (floor(y2) > src_height - 2.0))) { + good_params++; + } + } + } + counter++; + } + + if (counter >= 50) { + printf("Warning: No suitable spatial transformation after %d attempts.\n", counter); + coeff.clear(); + coeff.replace_with(incoming_coeff); + } +} + +void AugmentationLayerBase::copy_chromatic_coeffs_to_tensor( + const std::vector<AugmentationCoeff>& coeff_arr, + typename TTypes<float, 2>::Tensor& out) +{ + float *out_ptr = out.data(); + int counter = 0; + + for (AugmentationCoeff coeff : coeff_arr) { + out_ptr[counter + 0] = coeff.gamma(); + out_ptr[counter + 1] = coeff.brightness(); + out_ptr[counter + 2] = coeff.contrast(); + out_ptr[counter + 3] = coeff.color1(); + out_ptr[counter + 4] = coeff.color2(); + out_ptr[counter + 5] = coeff.color3(); + counter += 6; + } +} + +void AugmentationLayerBase::copy_spatial_coeffs_to_tensor( + const std::vector<AugmentationCoeff>& coeff_arr, + const int out_width, + const int out_height, + const int src_width, + const int src_height, + typename TTypes<float, 2>::Tensor& out, + const bool invert) +{ + float *out_ptr = out.data(); + int counter = 0; + TransMat t; + + for (AugmentationCoeff coeff : coeff_arr) { + t.toIdentity(); + t.fromCoeff(&coeff, out_width, out_height, src_width, src_height); + + if (invert) { + t = t.inverse(); + } + + out_ptr[counter + 0] = t.t0; + out_ptr[counter + 1] = t.t1; + out_ptr[counter + 2] = t.t2; + out_ptr[counter + 3] = t.t3; + out_ptr[counter + 4] = t.t4; + out_ptr[counter + 5] = t.t5; + counter += 6; + } +} +} diff --git a/Codes/flownet2/src/ops/preprocessing/kernels/augmentation_base.h b/Codes/flownet2/src/ops/preprocessing/kernels/augmentation_base.h new file mode 100644 index 0000000..d2aba2c --- /dev/null +++ b/Codes/flownet2/src/ops/preprocessing/kernels/augmentation_base.h @@ -0,0 +1,228 @@ +#ifndef AUGMENTATION_LAYER_BASE_H_ +#define AUGMENTATION_LAYER_BASE_H_ + +#include "tensorflow/core/framework/tensor_types.h" + +#include <iostream> +#include <string> +#include <vector> + +namespace tensorflow { +template<typename T> +class OptionalType { + public: + OptionalType(const T default_value) : default_value(default_value), has_value(false) {} + + operator bool() const { + return has_value; + } + + OptionalType& operator=(T val) { + has_value = true; + value = val; + return *this; + } + + const T operator()() const { + return has_value ? value : default_value; + } + + void clear() { + has_value = false; + } + + const T get_default() { + return default_value; + } + + private: + T value; + bool has_value; + const T default_value; +}; + +class AugmentationCoeff { + public: + // Spatial Types + OptionalType<float>dx; + OptionalType<float>dy; + OptionalType<float>angle; + OptionalType<float>zoom_x; + OptionalType<float>zoom_y; + + // Chromatic Types + OptionalType<float>gamma; + OptionalType<float>brightness; + OptionalType<float>contrast; + OptionalType<float>color1; + OptionalType<float>color2; + OptionalType<float>color3; + + AugmentationCoeff() : dx(0.0), dy(0.0), angle(0.0), zoom_x(1.0), zoom_y(1.0), gamma(1.0), + brightness(0.0), contrast(1.0), color1(1.0), color2(1.0), color3(1.0) {} + + AugmentationCoeff(const AugmentationCoeff& coeff) : AugmentationCoeff() { + replace_with(coeff); + } + + void clear(); + + void combine_with(const AugmentationCoeff& coeff); + + void replace_with(const AugmentationCoeff& coeff); +}; + +typedef struct AugmentationParam { + std::string rand_type; + bool should_exp; + float mean; + float spread; + float prob; +} AugmentationParam; + +class AugmentationParams { + public: + int crop_height; + int crop_width; + + // Spatial options + OptionalType<struct AugmentationParam>translate; + OptionalType<struct AugmentationParam>rotate; + OptionalType<struct AugmentationParam>zoom; + OptionalType<struct AugmentationParam>squeeze; + + // Chromatic options + OptionalType<struct AugmentationParam>gamma; + OptionalType<struct AugmentationParam>brightness; + OptionalType<struct AugmentationParam>contrast; + OptionalType<struct AugmentationParam>color; + + inline AugmentationParams(int crop_height, + int crop_width, + std::vector<std::string>params_name, + std::vector<std::string>params_rand_type, + std::vector<bool> params_exp, + std::vector<float> params_mean, + std::vector<float> params_spread, + std::vector<float> params_prob) : + crop_height(crop_height), + crop_width(crop_width), + translate(AugmentationParam()), + rotate(AugmentationParam()), + zoom(AugmentationParam()), + squeeze(AugmentationParam()), + gamma(AugmentationParam()), + brightness(AugmentationParam()), + contrast(AugmentationParam()), + color(AugmentationParam()) { + for (int i = 0; i < params_name.size(); i++) { + const std::string name = params_name[i]; + const std::string rand_type = params_rand_type[i]; + const bool should_exp = params_exp[i]; + const float mean = params_mean[i]; + const float spread = params_spread[i]; + const float prob = params_prob[i]; + + struct AugmentationParam param = { rand_type, should_exp, mean, spread, prob }; + + if (name == "translate") { + this->translate = param; + } else if (name == "rotate") { + this->rotate = param; + } else if (name == "zoom") { + this->zoom = param; + } else if (name == "squeeze") { + this->squeeze = param; + } else if (name == "noise") { + // NoOp: We handle noise on the Python side + } else if (name == "gamma") { + this->gamma = param; + } else if (name == "brightness") { + this->brightness = param; + } else if (name == "contrast") { + this->contrast = param; + } else if (name == "color") { + this->color = param; + } else { + std::cout << "Ignoring unknown augmentation parameter: " << name << std::endl; + } + } + } + + bool should_do_spatial_transform() { + return this->translate || this->rotate || this->zoom || this->squeeze; + } + + bool should_do_chromatic_transform() { + return this->gamma || this->brightness || this->contrast || this->color; + } +}; + +class AugmentationLayerBase { + public: + class TransMat { + /** + * Translation matrix class for spatial augmentation + * | 0 1 2 | + * | 3 4 5 | + */ + + public: + float t0, t1, t2; + float t3, t4, t5; + + + void fromCoeff(AugmentationCoeff *coeff, + int out_width, + int out_height, + int src_width, + int src_height); + + void fromTensor(const float *tensor_data); + + TransMat inverse(); + + void leftMultiply(float u0, + float u1, + float u2, + float u3, + float u4, + float u5); + + void toIdentity(); + }; + + // TODO: Class ChromaticCoeffs + + static float rng_generate(const AugmentationParam& param, + float discount_coeff, + const float default_value); + + static void clear_spatial_coeffs(AugmentationCoeff& coeff); + static void generate_chromatic_coeffs(float discount_coeff, + const AugmentationParams& aug, + AugmentationCoeff & coeff); + static void generate_spatial_coeffs(float discount_coeff, + const AugmentationParams& aug, + AugmentationCoeff & coeff); + static void generate_valid_spatial_coeffs(float discount_coeff, + const AugmentationParams& aug, + AugmentationCoeff & coeff, + int src_width, + int src_height, + int out_width, + int out_height); + + static void copy_chromatic_coeffs_to_tensor(const std::vector<AugmentationCoeff>& coeff_arr, + typename TTypes<float, 2>::Tensor& out); + static void copy_spatial_coeffs_to_tensor(const std::vector<AugmentationCoeff>& coeff_arr, + const int out_width, + const int out_height, + const int src_width, + const int src_height, + typename TTypes<float, 2>::Tensor& out, + const bool invert = false); +}; +} // namespace tensorflow + +#endif // AUGMENTATION_LAYER_BASE_H_ diff --git a/Codes/flownet2/src/ops/preprocessing/kernels/data_augmentation.cc b/Codes/flownet2/src/ops/preprocessing/kernels/data_augmentation.cc new file mode 100644 index 0000000..77b8c83 --- /dev/null +++ b/Codes/flownet2/src/ops/preprocessing/kernels/data_augmentation.cc @@ -0,0 +1,461 @@ +#define EIGEN_USE_THREADS + +#include <algorithm> +#include <iostream> +#include <random> +#include <vector> + +#include "augmentation_base.h" +#include "data_augmentation.h" +#include "tensorflow/core/framework/op_kernel.h" + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/logging.h" + +#include "tensorflow/core/util/work_sharder.h" + +namespace tensorflow { +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +inline float clamp(float f, float a, float b) { + return fmaxf(a, fminf(f, b)); +} + +template<> +void Augment(OpKernelContext *context, + const CPUDevice& d, + const int batch_size, + const int channels, + const int src_width, + const int src_height, + const int src_count, + const int out_width, + const int out_height, + const float *src_data, + float *out_data, + const float *transMats, + float *chromatic_coeffs) { + const int64 channel_count = batch_size * out_height * out_width; + const int kCostPerChannel = 10; + const DeviceBase::CpuWorkerThreads& worker_threads = + *context->device()->tensorflow_cpu_worker_threads(); + + Shard(worker_threads.num_threads, + worker_threads.workers, + channel_count, + kCostPerChannel, + [batch_size, channels, src_width, + src_height, src_count, out_width, out_height, src_data, + out_data, transMats, chromatic_coeffs]( + int64 start_channel, int64 end_channel) { + // TF, NHWK: ((n * H + h) * W + w) * K + k at point (n, h, w, k) + for (int index = start_channel; index < end_channel; index++) { + int x = index % out_width; + int y = (index / out_width) % out_height; + int n = index / out_width / out_height; + + const float *transMat = transMats + n * 6; + + float gamma, brightness, contrast; + + if (chromatic_coeffs) { + gamma = chromatic_coeffs[n * 6 + 0]; + brightness = chromatic_coeffs[n * 6 + 1]; + contrast = chromatic_coeffs[n * 6 + 2]; + } + + float xpos = x * transMat[0] + y * transMat[1] + transMat[2]; + float ypos = x * transMat[3] + y * transMat[4] + transMat[5]; + + xpos = clamp(xpos, 0.0f, (float)(src_width) - 1.05f); + ypos = clamp(ypos, 0.0f, (float)(src_height) - 1.05f); + + float tlx = floor(xpos); + float tly = floor(ypos); + + float xdist = xpos - tlx; + float ydist = ypos - tly; + + int srcTLIdxOffset = ((n * src_height + (int)tly) * src_width + (int)tlx) * channels; + + // ((n * src_height + tly) * src_width + (tlx + 1)) * channels + int srcTRIdxOffset = srcTLIdxOffset + channels; + + // ((n * src_height + (tly + 1)) * src_width + tlx) * channels + int srcBLIdxOffset = srcTLIdxOffset + channels * src_width; + + // ((n * src_height + (tly + 1)) * src_width + (tlx + 1)) * channels + int srcBRIdxOffset = srcTLIdxOffset + channels + channels * src_width; + + // Variables for chromatic transform + int data_index[3]; + float rgb[3]; + float mean_in = 0; + float mean_out = 0; + + for (int c = 0; c < channels; c++) { + // Bilinear interpolation + int srcTLIdx = srcTLIdxOffset + c; + int srcTRIdx = std::min(srcTRIdxOffset + c, src_count); + int srcBLIdx = std::min(srcBLIdxOffset + c, src_count); + int srcBRIdx = std::min(srcBRIdxOffset + c, src_count); + + float dest = (1 - xdist) * (1 - ydist) * src_data[srcTLIdx] + + (xdist) * (ydist) * src_data[srcBRIdx] + + (1 - xdist) * (ydist) * src_data[srcBLIdx] + + (xdist) * (1 - ydist) * src_data[srcTRIdx]; + + if (chromatic_coeffs) { + // Gather data for chromatic transform + data_index[c] = index * channels + c; + rgb[c] = dest; + mean_in += rgb[c]; + + // Note: coeff[3] == color1, coeff[4] == color2, ... + rgb[c] *= chromatic_coeffs[n * 6 + (3 + c)]; + + mean_out += rgb[c]; + } else { + out_data[index * channels + c] = dest; + } + } + + float brightness_coeff = mean_in / (mean_out + 0.01f); + + if (chromatic_coeffs) { + // Chromatic transformation + for (int c = 0; c < channels; c++) { + // compensate brightness + rgb[c] = clamp(rgb[c] * brightness_coeff, 0.0f, 1.0f); + + // gamma change + rgb[c] = pow(rgb[c], gamma); + + // brightness change + rgb[c] = rgb[c] + brightness; + + // contrast change + rgb[c] = 0.5f + (rgb[c] - 0.5f) * contrast; + + out_data[data_index[c]] = clamp(rgb[c], 0.0f, 1.0f); + } + } + } + }); +} + +template<typename Device> +class DataAugmentation : public OpKernel { + public: + explicit DataAugmentation(OpKernelConstruction *ctx) : OpKernel(ctx) { + // Get the crop [height, width] tensor and verify its dimensions + OP_REQUIRES_OK(ctx, ctx->GetAttr("crop", &crop_)); + OP_REQUIRES(ctx, crop_.size() == 2, + errors::InvalidArgument("crop must be 2 dimensions")); + + // TODO: Verify params are all the same length + + // Get the tensors for params_a and verify their dimensions + OP_REQUIRES_OK(ctx, ctx->GetAttr("params_a_name", ¶ms_a_name_)); + OP_REQUIRES_OK(ctx, + ctx->GetAttr("params_a_rand_type", ¶ms_a_rand_type_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("params_a_exp", ¶ms_a_exp_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("params_a_mean", ¶ms_a_mean_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("params_a_spread", ¶ms_a_spread_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("params_a_prob", ¶ms_a_prob_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("params_a_coeff_schedule", ¶ms_a_coeff_schedule_)); + + // Get the tensors for params_b and verify their dimensions + OP_REQUIRES_OK(ctx, ctx->GetAttr("params_b_name", ¶ms_b_name_)); + OP_REQUIRES_OK(ctx, + ctx->GetAttr("params_b_rand_type", ¶ms_b_rand_type_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("params_b_exp", ¶ms_b_exp_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("params_b_mean", ¶ms_b_mean_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("params_b_spread", ¶ms_b_spread_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("params_b_prob", ¶ms_b_prob_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("params_b_coeff_schedule", ¶ms_b_coeff_schedule_)); + } + + void Compute(OpKernelContext *ctx) override { + // Get the input images + const Tensor& input_a_t = ctx->input(0); + const Tensor& input_b_t = ctx->input(1); + + // Get the global step value + const Tensor& global_step_t = ctx->input(2); + auto global_step_eigen = global_step_t.tensor<int64, 0>(); + const int64 global_step = global_step_eigen.data()[0]; + + // Dimension constants + const int batch_size = input_a_t.dim_size(0); + const int src_height = input_a_t.dim_size(1); + const int src_width = input_a_t.dim_size(2); + const int channels = input_a_t.dim_size(3); + const int src_count = batch_size * src_height * src_width * channels; + const int out_height = crop_[0]; + const int out_width = crop_[1]; + const int out_count = batch_size * out_height * out_width * channels; + + // All tensors for this op + Tensor chromatic_coeffs_a_t; + Tensor chromatic_coeffs_b_t; + + // Allocate the memory for the output images + Tensor *output_a_t; + Tensor *output_b_t; + + OP_REQUIRES_OK(ctx, + ctx->allocate_output(0, TensorShape({ batch_size, crop_[0], crop_[1], + channels }), &output_a_t)); + OP_REQUIRES_OK(ctx, + ctx->allocate_output(1, TensorShape({ batch_size, crop_[0], crop_[1], + channels }), &output_b_t)); + + // Allocate the memory for the output spatial transforms + Tensor *spat_transform_a_t; + Tensor *spat_transform_b_t; + + OP_REQUIRES_OK(ctx, + ctx->allocate_output(2, TensorShape({ batch_size, 6 }), + &spat_transform_a_t)); + OP_REQUIRES_OK(ctx, + ctx->allocate_output(3, TensorShape({ batch_size, 6 }), + &spat_transform_b_t)); + + // Compute discount for coefficients if using a schedule + float discount_coeff_a = 1.0; + float discount_coeff_b = 1.0; + + if (params_a_coeff_schedule_.size() == 3) { + float half_life = params_a_coeff_schedule_[0]; + float initial_coeff = params_a_coeff_schedule_[1]; + float final_coeff = params_a_coeff_schedule_[2]; + discount_coeff_a = initial_coeff + (final_coeff - initial_coeff) * + (2.0 / (1.0 + exp(-1.0986 * global_step / half_life)) - 1.0); + } + + if (params_b_coeff_schedule_.size() == 3) { + if (params_a_coeff_schedule_.size() == 3) { + discount_coeff_b = discount_coeff_a; + } else { + float half_life = params_b_coeff_schedule_[0]; + float initial_coeff = params_b_coeff_schedule_[1]; + float final_coeff = params_b_coeff_schedule_[2]; + discount_coeff_b = initial_coeff + (final_coeff - initial_coeff) * + (2.0 / (1.0 + exp(-1.0986 * global_step / half_life)) - 1.0); + } + } + + /*** BEGIN AUGMENTATION TO IMAGE A ***/ + auto input_a = input_a_t.tensor<float, 4>(); + auto output_a = output_a_t->tensor<float, 4>(); + + // Load augmentation parameters for image A + AugmentationParams aug_a = AugmentationParams(out_height, out_width, + params_a_name_, + params_a_rand_type_, + params_a_exp_, + params_a_mean_, + params_a_spread_, + params_a_prob_); + + std::vector<AugmentationCoeff> coeffs_a; + + + bool gen_spatial_transform = aug_a.should_do_spatial_transform(); + bool gen_chromatic_transform = aug_a.should_do_chromatic_transform(); + + for (int n = 0; n < batch_size; n++) { + AugmentationCoeff coeff; + + if (gen_spatial_transform) { + AugmentationLayerBase::generate_valid_spatial_coeffs(discount_coeff_a, aug_a, coeff, + src_width, src_height, + out_width, out_height); + } + + if (gen_chromatic_transform) { + AugmentationLayerBase::generate_chromatic_coeffs(discount_coeff_a, aug_a, coeff); + } + + coeffs_a.push_back(coeff); + } + + // Copy spatial coefficients A to the output Tensor on the CPU + // (output for FlowAugmentation) + auto spat_transform_a = spat_transform_a_t->tensor<float, 2>(); + AugmentationLayerBase::copy_spatial_coeffs_to_tensor(coeffs_a, + out_width, out_height, + src_width, src_height, + spat_transform_a); + + float *chromatic_coeffs_a_data = NULL; + + if (gen_chromatic_transform) { + // Allocate a temporary tensor to hold the chromatic coefficients + OP_REQUIRES_OK(ctx, + ctx->allocate_temp(DataTypeToEnum<float>::value, + TensorShape({ batch_size, 6 }), + &chromatic_coeffs_a_t)); + + // Copy the chromatic coefficients A to a temporary Tensor on the CPU + auto chromatic_coeffs_a = chromatic_coeffs_a_t.tensor<float, 2>(); + AugmentationLayerBase::copy_chromatic_coeffs_to_tensor(coeffs_a, chromatic_coeffs_a); + chromatic_coeffs_a_data = chromatic_coeffs_a.data(); + } + + // Perform augmentation either on CPU or GPU + Augment<Device>( + ctx, + ctx->eigen_device<Device>(), + batch_size, + channels, + src_width, + src_height, + src_count, + out_width, + out_height, + input_a.data(), + output_a.data(), + spat_transform_a.data(), + chromatic_coeffs_a_data); + + /*** END AUGMENTATION TO IMAGE A ***/ + + /*** BEGIN GENERATE NEW COEFFICIENTS FOR IMAGE B ***/ + AugmentationParams aug_b = AugmentationParams(out_height, out_width, + params_b_name_, + params_b_rand_type_, + params_b_exp_, + params_b_mean_, + params_b_spread_, + params_b_prob_); + + std::vector<AugmentationCoeff> coeffs_b; + + bool gen_spatial_transform_b = aug_b.should_do_spatial_transform(); + bool gen_chromatic_transform_b = aug_b.should_do_chromatic_transform(); + + for (int n = 0; n < batch_size; n++) { + AugmentationCoeff coeff(coeffs_a[n]); + + // If we did a spatial transform on image A, we need to do the same one + // (+ possibly more) on image B + if (gen_spatial_transform_b) { + AugmentationLayerBase::generate_valid_spatial_coeffs(discount_coeff_b, aug_b, coeff, + src_width, src_height, + out_width, out_height); + } + + if (gen_chromatic_transform_b) { + AugmentationLayerBase::generate_chromatic_coeffs(discount_coeff_b, aug_b, coeff); + } + + coeffs_b.push_back(coeff); + } + + /*** END GENERATE NEW COEFFICIENTS FOR IMAGE B ***/ + + /*** BEGIN AUGMENTATION TO IMAGE B ***/ + auto input_b = input_b_t.tensor<float, 4>(); + auto output_b = output_b_t->tensor<float, 4>(); + + // Copy spatial coefficients B to the output Tensor on the CPU + auto spat_transform_b = spat_transform_b_t->tensor<float, 2>(); + AugmentationLayerBase::copy_spatial_coeffs_to_tensor(coeffs_b, + out_width, out_height, + src_width, src_height, + spat_transform_b); + + float *chromatic_coeffs_b_data = NULL; + + if (gen_chromatic_transform || gen_chromatic_transform_b) { + // Allocate a temporary tensor to hold the chromatic coefficients + tensorflow::AllocatorAttributes pinned_allocator; + pinned_allocator.set_on_host(true); + pinned_allocator.set_gpu_compatible(true); + OP_REQUIRES_OK(ctx, + ctx->allocate_temp(DataTypeToEnum<float>::value, + TensorShape({ batch_size, 6 }), + &chromatic_coeffs_b_t, pinned_allocator)); + + // Copy the chromatic coefficients A to a temporary Tensor on the CPU + auto chromatic_coeffs_b = chromatic_coeffs_b_t.tensor<float, 2>(); + AugmentationLayerBase::copy_chromatic_coeffs_to_tensor(coeffs_b, chromatic_coeffs_b); + chromatic_coeffs_b_data = chromatic_coeffs_b.data(); + } + + // Perform augmentation either on CPU or GPU + Augment<Device>( + ctx, + ctx->eigen_device<Device>(), + batch_size, + channels, + src_width, + src_height, + src_count, + out_width, + out_height, + input_b.data(), + output_b.data(), + spat_transform_b.data(), + chromatic_coeffs_b_data); + + // FlowAugmentation needs the inverse + // TODO: To avoid rewriting, can we invert when we read on the + // FlowAugmentation side? + AugmentationLayerBase::copy_spatial_coeffs_to_tensor(coeffs_b, + out_width, out_height, + src_width, src_height, + spat_transform_b, + true); + + /*** END AUGMENTATION TO IMAGE B ***/ + } + + private: + std::vector<int32>crop_; + + // Params A + std::vector<string>params_a_name_; + std::vector<string>params_a_rand_type_; + std::vector<bool>params_a_exp_; + std::vector<float>params_a_mean_; + std::vector<float>params_a_spread_; + std::vector<float>params_a_prob_; + std::vector<float>params_a_coeff_schedule_; + + // Params B + std::vector<string>params_b_name_; + std::vector<string>params_b_rand_type_; + std::vector<bool>params_b_exp_; + std::vector<float>params_b_mean_; + std::vector<float>params_b_spread_; + std::vector<float>params_b_prob_; + std::vector<float>params_b_coeff_schedule_; +}; + + +REGISTER_KERNEL_BUILDER(Name("DataAugmentation") + .Device(DEVICE_CPU) + .HostMemory("global_step") + .HostMemory("transforms_from_a") + .HostMemory("transforms_from_b"), + DataAugmentation<CPUDevice>) + +#if GOOGLE_CUDA + +REGISTER_KERNEL_BUILDER(Name("DataAugmentation") + .Device(DEVICE_GPU) + .HostMemory("global_step") + .HostMemory("transforms_from_a") + .HostMemory("transforms_from_b"), + DataAugmentation<GPUDevice>) +#endif // GOOGLE_CUDA +} // namespace tensorflow diff --git a/Codes/flownet2/src/ops/preprocessing/kernels/data_augmentation.cu.cc b/Codes/flownet2/src/ops/preprocessing/kernels/data_augmentation.cu.cc new file mode 100644 index 0000000..7a2101d --- /dev/null +++ b/Codes/flownet2/src/ops/preprocessing/kernels/data_augmentation.cu.cc @@ -0,0 +1,348 @@ +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "augmentation_base.h" +#include "data_augmentation.h" +#include "tensorflow/core/util/cuda_kernel_helper.h" + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { +inline __device__ __host__ float clamp(float f, float a, float b) { + return fmaxf(a, fminf(f, b)); +} + +__global__ void SpatialAugmentation( + const int32 nthreads, + const int src_width, + const int src_height, + const int channels, + const int src_count, + const int out_width, + const int out_height, + const float *src_data, + float *out_data, + const float *transMats) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + // Caffe, NKHW: ((n * K + k) * H + h) * W + w at point (n, k, h, w) + // TF, NHWK: ((n * H + h) * W + w) * K + k at point (n, h, w, k) + int c = index % channels; + int x = (index / channels) % out_width; + int y = (index / channels / out_width) % out_height; + int n = index / channels / out_width / out_height; + + const float *transMat = transMats + n * 6; + float xpos = x * transMat[0] + y * transMat[1] + transMat[2]; + float ypos = x * transMat[3] + y * transMat[4] + transMat[5]; + + xpos = clamp(xpos, 0.0f, (float)(src_width) - 1.05f); + ypos = clamp(ypos, 0.0f, (float)(src_height) - 1.05f); + + float tlx = floor(xpos); + float tly = floor(ypos); + + // Bilinear interpolation + int srcTLIdx = ((n * src_height + tly) * src_width + tlx) * channels + c; + int srcTRIdx = min((int)(((n * src_height + tly) * src_width + (tlx + 1)) * channels + c), + src_count); + int srcBLIdx = min((int)(((n * src_height + (tly + 1)) * src_width + tlx) * channels + c), + src_count); + int srcBRIdx = min((int)(((n * src_height + (tly + 1)) * src_width + (tlx + 1)) * channels + c), + src_count); + + float xdist = xpos - tlx; + float ydist = ypos - tly; + + float dest = (1 - xdist) * (1 - ydist) * src_data[srcTLIdx] + + (xdist) * (ydist) * src_data[srcBRIdx] + + (1 - xdist) * (ydist) * src_data[srcBLIdx] + + (xdist) * (1 - ydist) * src_data[srcTRIdx]; + + out_data[index] = dest; + } +} + +typedef Eigen::GpuDevice GPUDevice; + +template<> +void Augment(OpKernelContext *context, + const GPUDevice& d, + const int batch_size, + const int channels, + const int src_width, + const int src_height, + const int src_count, + const int out_width, + const int out_height, + const float *src_data, + float *out_data, + const float *transMats, + float *chromatic_coeffs) { + const int out_count = batch_size * out_height * out_width * channels; + CudaLaunchConfig config = GetCudaLaunchConfig(out_count, d); + + printf("Chromatic transform not yet implemented on GPU, ignoring."); + + SpatialAugmentation << < config.block_count, config.thread_per_block, 0, d.stream() >> > ( + config.virtual_thread_count, src_width, src_height, channels, src_count, + out_width, out_height, + src_data, out_data, transMats); +} + +// +// template<typename Device> +// class DataAugmentation : public OpKernel { +// public: +// explicit DataAugmentation(OpKernelConstruction *ctx) : OpKernel(ctx) { +// // Get the crop [height, width] tensor and verify its dimensions +// OP_REQUIRES_OK(ctx, ctx->GetAttr("crop", &crop_)); +// OP_REQUIRES(ctx, crop_.size() == 2, +// errors::InvalidArgument("crop must be 2 dimensions")); +// +// // TODO: Verify params are all the same length +// +// // Get the tensors for params_a and verify their dimensions +// OP_REQUIRES_OK(ctx, ctx->GetAttr("params_a_name", ¶ms_a_name_)); +// OP_REQUIRES_OK(ctx, +// ctx->GetAttr("params_a_rand_type", +// ¶ms_a_rand_type_)); +// OP_REQUIRES_OK(ctx, ctx->GetAttr("params_a_exp", ¶ms_a_exp_)); +// OP_REQUIRES_OK(ctx, ctx->GetAttr("params_a_mean", ¶ms_a_mean_)); +// OP_REQUIRES_OK(ctx, ctx->GetAttr("params_a_spread", +// ¶ms_a_spread_)); +// OP_REQUIRES_OK(ctx, ctx->GetAttr("params_a_prob", ¶ms_a_prob_)); +// +// // Get the tensors for params_b and verify their dimensions +// OP_REQUIRES_OK(ctx, ctx->GetAttr("params_b_name", ¶ms_b_name_)); +// OP_REQUIRES_OK(ctx, +// ctx->GetAttr("params_b_rand_type", +// ¶ms_b_rand_type_)); +// OP_REQUIRES_OK(ctx, ctx->GetAttr("params_b_exp", ¶ms_b_exp_)); +// OP_REQUIRES_OK(ctx, ctx->GetAttr("params_b_mean", ¶ms_b_mean_)); +// OP_REQUIRES_OK(ctx, ctx->GetAttr("params_b_spread", +// ¶ms_b_spread_)); +// OP_REQUIRES_OK(ctx, ctx->GetAttr("params_b_prob", ¶ms_b_prob_)); +// } +// +// void Compute(OpKernelContext *ctx) override { +// const GPUDevice& device = ctx->eigen_gpu_device(); +// +// // Get the input images +// const Tensor& input_a_t = ctx->input(0); +// const Tensor& input_b_t = ctx->input(1); +// +// // Dimension constants +// const int batch_size = input_a_t.dim_size(0); +// const int src_height = input_a_t.dim_size(1); +// const int src_width = input_a_t.dim_size(2); +// const int channels = input_a_t.dim_size(3); +// const int src_count = batch_size * src_height * src_width * channels; +// const int out_height = crop_[0]; +// const int out_width = crop_[1]; +// const int out_count = batch_size * out_height * out_width * channels; +// +// // Allocate the memory for the output images +// Tensor *output_a_t; +// Tensor *output_b_t; +// +// OP_REQUIRES_OK(ctx, +// ctx->allocate_output(0, TensorShape({ batch_size, +// crop_[0], crop_[1], +// channels }), +// &output_a_t)); +// OP_REQUIRES_OK(ctx, +// ctx->allocate_output(1, TensorShape({ batch_size, +// crop_[0], crop_[1], +// channels }), +// &output_b_t)); +// +// // Allocate the memory for the output spatial transforms +// Tensor *spat_transform_a_t; +// Tensor *spat_transform_b_t; +// +// OP_REQUIRES_OK(ctx, +// ctx->allocate_output(2, TensorShape({ batch_size, 6 }), +// &spat_transform_a_t)); +// OP_REQUIRES_OK(ctx, +// ctx->allocate_output(3, TensorShape({ batch_size, 6 }), +// &spat_transform_b_t)); +// +// // Allocate temporary pinned memory for the spatial transforms to be +// used +// // on the GPU +// tensorflow::AllocatorAttributes pinned_allocator; +// pinned_allocator.set_on_host(true); +// pinned_allocator.set_gpu_compatible(true); +// +// Tensor spat_transform_a_pinned_t; +// Tensor spat_transform_b_pinned_t; +// OP_REQUIRES_OK(ctx, +// ctx->allocate_temp(DataTypeToEnum<float>::value, +// TensorShape({ batch_size, 6 }), +// &spat_transform_a_pinned_t, +// pinned_allocator)); +// OP_REQUIRES_OK(ctx, +// ctx->allocate_temp(DataTypeToEnum<float>::value, +// TensorShape({ batch_size, 6 }), +// &spat_transform_b_pinned_t, +// pinned_allocator)); +// auto spat_transform_a_pinned = spat_transform_a_pinned_t.tensor<float, +// 2>(); +// auto spat_transform_b_pinned = spat_transform_b_pinned_t.tensor<float, +// 2>(); +// +// /*** BEGIN AUGMENTATION TO IMAGE A ***/ +// auto input_a = input_a_t.tensor<float, 4>(); +// auto output_a = output_a_t->tensor<float, 4>(); +// +// // Load augmentation parameters for image A +// AugmentationParams aug_a = AugmentationParams(out_height, out_width, +// params_a_name_, +// params_a_rand_type_, +// params_a_exp_, +// params_a_mean_, +// params_a_spread_, +// params_a_prob_); +// +// std::vector<AugmentationCoeff> coeffs_a; +// +// bool gen_spatial_transform = aug_a.should_do_spatial_transform(); +// +// for (int n = 0; n < batch_size; n++) { +// AugmentationCoeff coeff; +// +// if (gen_spatial_transform) { +// AugmentationLayerBase::generate_valid_spatial_coeffs(aug_a, coeff, +// src_width, +// src_height, +// out_width, +// out_height); +// } +// +// coeffs_a.push_back(coeff); +// } +// +// // Copy spatial coefficients A to the output Tensor on the CPU (output +// for +// // FlowAugmentation) +// auto spat_transform_a = spat_transform_a_t->tensor<float, 2>(); +// AugmentationLayerBase::copy_spatial_coeffs_to_tensor(coeffs_a, +// out_width, +// out_height, +// src_width, +// src_height, +// spat_transform_a); +// +// // ...as well as a Tensor going to the GPU +// AugmentationLayerBase::copy_spatial_coeffs_to_tensor(coeffs_a, +// out_width, +// out_height, +// src_width, +// src_height, +// +// +// +// spat_transform_a_pinned); +// +// CudaLaunchConfig config = GetCudaLaunchConfig(out_count, device); +// SpatialAugmentation << < config.block_count, config.thread_per_block, +// 0, +// device.stream() >> > ( +// config.virtual_thread_count, src_width, src_height, channels, +// src_count, +// out_width, out_height, +// input_a.data(), output_a.data(), spat_transform_a_pinned.data()); +// +// /*** END AUGMENTATION TO IMAGE A ***/ +// +// /*** BEGIN GENERATE NEW COEFFICIENTS FOR IMAGE B ***/ +// AugmentationParams aug_b = AugmentationParams(out_height, out_width, +// params_b_name_, +// params_b_rand_type_, +// params_b_exp_, +// params_b_mean_, +// params_b_spread_, +// params_b_prob_); +// +// std::vector<AugmentationCoeff> coeffs_b; +// +// gen_spatial_transform = aug_b.should_do_spatial_transform(); +// +// for (int n = 0; n < batch_size; n++) { +// AugmentationCoeff coeff; +// +// if (gen_spatial_transform) { +// AugmentationLayerBase::generate_valid_spatial_coeffs(aug_b, coeff, +// src_width, +// src_height, +// out_width, +// out_height); +// } +// +// coeffs_b.push_back(coeff); +// } +// +// /*** END GENERATE NEW COEFFICIENTS FOR IMAGE B ***/ +// +// /*** BEGIN AUGMENTATION TO IMAGE B ***/ +// auto input_b = input_b_t.tensor<float, 4>(); +// auto output_b = output_b_t->tensor<float, 4>(); +// +// // Copy spatial coefficients B to the output Tensor on the CPU +// auto spat_transform_b = spat_transform_b_t->tensor<float, 2>(); +// AugmentationLayerBase::copy_spatial_coeffs_to_tensor(coeffs_b, +// out_width, +// out_height, +// src_width, +// src_height, +// spat_transform_b, +// true); +// AugmentationLayerBase::copy_spatial_coeffs_to_tensor(coeffs_b, +// out_width, +// out_height, +// src_width, +// src_height, +// +// +// +// spat_transform_b_pinned); +// +// SpatialAugmentation << < config.block_count, config.thread_per_block, +// 0, +// device.stream() >> > ( +// config.virtual_thread_count, src_width, src_height, channels, +// src_count, +// out_width, out_height, +// input_b.data(), output_b.data(), spat_transform_b_pinned.data()); +// +// /*** END AUGMENTATION TO IMAGE B ***/ +// } +// +// private: +// std::vector<int32>crop_; +// +// // Params A +// std::vector<string>params_a_name_; +// std::vector<string>params_a_rand_type_; +// std::vector<bool>params_a_exp_; +// std::vector<float>params_a_mean_; +// std::vector<float>params_a_spread_; +// std::vector<float>params_a_prob_; +// +// // Params B +// std::vector<string>params_b_name_; +// std::vector<string>params_b_rand_type_; +// std::vector<bool>params_b_exp_; +// std::vector<float>params_b_mean_; +// std::vector<float>params_b_spread_; +// std::vector<float>params_b_prob_; +// }; +} // namespace tensorflow +#endif // GOOGLE_CUDA diff --git a/Codes/flownet2/src/ops/preprocessing/kernels/data_augmentation.h b/Codes/flownet2/src/ops/preprocessing/kernels/data_augmentation.h new file mode 100644 index 0000000..545b8a0 --- /dev/null +++ b/Codes/flownet2/src/ops/preprocessing/kernels/data_augmentation.h @@ -0,0 +1,22 @@ +#ifndef FLOWNET_DATA_AUGMENTATION_H_ +#define FLOWNET_DATA_AUGMENTATION_H_ + +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +template<class Device> +void Augment(OpKernelContext *context, + const Device & d, + const int batch_size, + const int channels, + const int src_width, + const int src_height, + const int src_count, + const int out_width, + const int out_height, + const float *src_data, + float *out_data, + const float *transMats, + float *chromatic_coeffs); +} // namespace tensorflow +#endif // FLOWNET_DATA_AUGMENTATION_H_ diff --git a/Codes/flownet2/src/ops/preprocessing/kernels/flow_augmentation.cc b/Codes/flownet2/src/ops/preprocessing/kernels/flow_augmentation.cc new file mode 100644 index 0000000..b5cc11f --- /dev/null +++ b/Codes/flownet2/src/ops/preprocessing/kernels/flow_augmentation.cc @@ -0,0 +1,129 @@ +#define EIGEN_USE_THREADS + +#include "flow_augmentation.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +inline int clamp(int f, int a, int b) { + return std::max(a, std::min(f, b)); +} + +template<> +void FillFlowAugmentation(const CPUDevice& device, + typename TTypes<float, 4>::Tensor output, + typename TTypes<float, 4>::ConstTensor flows, + typename TTypes<float, 2>::ConstTensor transforms_from_a, + typename TTypes<float, 2>::ConstTensor transforms_from_b) { + const int batch_size = output.dimension(0); + const int out_height = output.dimension(1); + const int out_width = output.dimension(2); + const int src_height = flows.dimension(1); + const int src_width = flows.dimension(2); + const int src_total_count = flows.dimension(0) * flows.dimension(1) * + flows.dimension(2) * flows.dimension(3); + float *output_ptr = output.data(); + const float *flow_ptr = flows.data(); + + for (int n = 0; n < batch_size; n++) { + const float *transMatA = transforms_from_a.data() + n * 6; + const float *transMatB = transforms_from_b.data() + n * 6; + + for (int y = 0; y < out_height; y++) { + int outputIdxOffset = (n * out_height + y) * out_width; + + for (int x = 0; x < out_width; x++) { + // Apply transformation matrix applied to first image + const float xpos1 = x * transMatA[0] + y * transMatA[1] + transMatA[2]; + const float ypos1 = x * transMatA[3] + y * transMatA[4] + transMatA[5]; + + const int srcXIdx = + ((n * src_height + (int)(ypos1 + 0.5)) * src_width + (int)(xpos1 + 0.5)) * 2 + 0; + const int srcYIdx = srcXIdx + 1; + + const float xpos2 = xpos1 + flow_ptr[clamp(srcXIdx, 0, src_total_count - 1)]; + const float ypos2 = ypos1 + flow_ptr[clamp(srcYIdx, 0, src_total_count - 1)]; + + // Apply inverse of the transformation matrix applied to second image + const float xpos3 = xpos2 * transMatB[0] + ypos2 * transMatB[1] + transMatB[2]; + const float ypos3 = xpos2 * transMatB[3] + ypos2 * transMatB[4] + transMatB[5]; + + output_ptr[(outputIdxOffset + x) * 2 + 0] = xpos3 - (float)x; + output_ptr[(outputIdxOffset + x) * 2 + 1] = ypos3 - (float)y; + } + } + } +} + +template<typename Device> +class FlowAugmentation : public OpKernel { + public: + explicit FlowAugmentation(OpKernelConstruction *ctx) : OpKernel(ctx) { + // Get the crop [height, width] tensor and verify its dimensions + OP_REQUIRES_OK(ctx, ctx->GetAttr("crop", &crop_)); + OP_REQUIRES(ctx, crop_.size() == 2, + errors::InvalidArgument("crop must be 2 dimensions")); + } + + void Compute(OpKernelContext *ctx) override { + // Get the input images and transforms and verify their dimensions + const Tensor& flows_t = ctx->input(0); + const Tensor& transforms_from_a_t = ctx->input(1); + const Tensor& transforms_from_b_t = ctx->input(2); + + OP_REQUIRES(ctx, flows_t.dims() == 4, + errors::InvalidArgument("Input images must have rank 4")); + OP_REQUIRES(ctx, + (TensorShapeUtils::IsMatrix(transforms_from_a_t.shape()) && + transforms_from_a_t.dim_size(0) == + flows_t.dim_size(0) && + transforms_from_a_t.dim_size(1) == 6), + errors::InvalidArgument( + "Input transforms_from_a should be num_images x 6")); + OP_REQUIRES(ctx, + (TensorShapeUtils::IsMatrix(transforms_from_b_t.shape()) && + transforms_from_b_t.dim_size(0) == + flows_t.dim_size(0) && + transforms_from_b_t.dim_size(1) == 6), + errors::InvalidArgument( + "Input transforms_from_b should be num_images x 6")); + + // Allocate the memory for the output + Tensor *output_t; + OP_REQUIRES_OK(ctx, ctx->allocate_output( + 0, + TensorShape({ flows_t.dim_size(0), crop_[0], crop_[1], + flows_t.dim_size(3) }), &output_t)); + + // Perform flow augmentation + auto flows = flows_t.tensor<float, 4>(); + auto transforms_from_a = transforms_from_a_t.tensor<float, 2>(); + auto transforms_from_b = transforms_from_b_t.tensor<float, 2>(); + auto output = output_t->tensor<float, 4>(); + + FillFlowAugmentation(ctx->eigen_device<Device>(), + output, + flows, + transforms_from_a, + transforms_from_b); + } + + private: + std::vector<int32>crop_; +}; + +REGISTER_KERNEL_BUILDER(Name("FlowAugmentation") + .Device(DEVICE_CPU), + FlowAugmentation<CPUDevice>) + +#if GOOGLE_CUDA +REGISTER_KERNEL_BUILDER(Name("FlowAugmentation") + .Device(DEVICE_GPU), + FlowAugmentation<GPUDevice>) +#endif // GOOGLE_CUDA +} // end namespace tensorflow diff --git a/Codes/flownet2/src/ops/preprocessing/kernels/flow_augmentation.h b/Codes/flownet2/src/ops/preprocessing/kernels/flow_augmentation.h new file mode 100644 index 0000000..7795991 --- /dev/null +++ b/Codes/flownet2/src/ops/preprocessing/kernels/flow_augmentation.h @@ -0,0 +1,19 @@ +#ifndef FLOWNET_FLOW_AUG_H_ +#define FLOWNET_FLOW_AUG_H_ + +// See docs in ../ops/image_ops.cc. + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +template<class Device> +void FillFlowAugmentation(const Device& device, + typename TTypes<float, 4>::Tensor output, + typename TTypes<float, 4>::ConstTensor flows, + typename TTypes<float, 2>::ConstTensor transforms_from_a, + typename TTypes<float, 2>::ConstTensor transforms_from_b); +} // end namespace tensorflow + +#endif // FLOWNET_FLOW_AUG_H_ diff --git a/Codes/flownet2/src/ops/preprocessing/kernels/flow_augmentation_gpu.cu.cc b/Codes/flownet2/src/ops/preprocessing/kernels/flow_augmentation_gpu.cu.cc new file mode 100644 index 0000000..7e10864 --- /dev/null +++ b/Codes/flownet2/src/ops/preprocessing/kernels/flow_augmentation_gpu.cu.cc @@ -0,0 +1,95 @@ +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include <stdio.h> +#include <iostream> + +#include "flow_augmentation.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/cuda_kernel_helper.h" + +namespace tensorflow { +typedef Eigen::GpuDevice GPUDevice; + +inline __device__ __host__ int clamp(int f, int a, int b) { + return max(a, min(f, b)); +} + +__global__ void FillFlowAugmentationKernel( + const int32 nthreads, + const float *flow_ptr, + const float *transforms_from_a, + const float *inv_transforms_from_b, + const int src_total_count, const int src_height, const int src_width, + const int batch_size, const int out_height, + const int out_width, float *output_ptr) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + const float x = (float)(index % out_width); + const float y = (float)((index / out_width) % out_height); + const int n = (index / out_width / out_height); + + const int transformIdx = n * 6; + + // Apply transformation matrix applied to second image + const float xpos1 = x * transforms_from_a[transformIdx + 0] + + y * transforms_from_a[transformIdx + 1] + + transforms_from_a[transformIdx + 2]; + const float ypos1 = x * transforms_from_a[transformIdx + 3] + + y * transforms_from_a[transformIdx + 4] + + transforms_from_a[transformIdx + 5]; + + // Caffe, NKHW: ((n * K + k) * H + h) * W + w at point (n, k, h, w) + // TF, NHWK: ((n * H + h) * W + w) * K + k at point (n, h, w, k) + const int srcXIdx = + ((n * src_height + (int)(ypos1 + 0.5)) * src_width + (int)(xpos1 + 0.5)) * + 2 + 0; + const int srcYIdx = srcXIdx + 1; + + const float xpos2 = xpos1 + flow_ptr[clamp(srcXIdx, 0, src_total_count - 1)]; + const float ypos2 = ypos1 + flow_ptr[clamp(srcYIdx, 0, src_total_count - 1)]; + + // Apply inverse of the transformation matrix applied to first image + const float xpos3 = xpos2 * inv_transforms_from_b[transformIdx + 0] + + ypos2 * inv_transforms_from_b[transformIdx + 1] + + inv_transforms_from_b[transformIdx + 2]; + const float ypos3 = xpos2 * inv_transforms_from_b[transformIdx + 3] + + ypos2 * inv_transforms_from_b[transformIdx + 4] + + inv_transforms_from_b[transformIdx + 5]; + + output_ptr[((n * out_height + (int)y) * out_width + (int)x) * 2 + 0] = xpos3 - + x; + output_ptr[((n * out_height + (int)y) * out_width + (int)x) * 2 + 1] = ypos3 - + y; + } +} + +template<> +void FillFlowAugmentation(const GPUDevice& device, + typename TTypes<float, 4>::Tensor output, + typename TTypes<float, 4>::ConstTensor flows, + typename TTypes<const float, 2>::ConstTensor transforms_from_a, + typename TTypes<const float, 2>::ConstTensor transforms_from_b) { + const int batch_size = output.dimension(0); + const int out_height = output.dimension(1); + const int out_width = output.dimension(2); + const int depth = 2; + const int total_count = batch_size * out_height * out_width * depth; + const int src_total_count = flows.dimension(0) * flows.dimension(1) * + flows.dimension(2) * flows.dimension(3); + + CudaLaunchConfig config = GetCudaLaunchConfig(total_count / 2, device); + + FillFlowAugmentationKernel << < config.block_count, config.thread_per_block, 0, + device.stream() >> > ( + total_count / 2, flows.data(), transforms_from_a.data(), + transforms_from_b.data(), + src_total_count, flows.dimension(1), flows.dimension(2), batch_size, + out_height, out_width, output.data()); +} +} // end namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/Codes/flownet2/src/ops/preprocessing/preprocessing.cc b/Codes/flownet2/src/ops/preprocessing/preprocessing.cc new file mode 100644 index 0000000..086a0d0 --- /dev/null +++ b/Codes/flownet2/src/ops/preprocessing/preprocessing.cc @@ -0,0 +1,96 @@ +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { +using shape_inference::InferenceContext; +using shape_inference::ShapeHandle; +using shape_inference::DimensionHandle; + +Status SetOutputToSizedImage(InferenceContext *c) { + ShapeHandle input; + + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input)); + DimensionHandle batch = c->Dim(input, 0); + DimensionHandle depth = c->Dim(input, 3); + std::vector<int32> crop_; + c->GetAttr("crop", &crop_); + DimensionHandle height = c->MakeDim(crop_[0]); + DimensionHandle width = c->MakeDim(crop_[1]); + c->set_output(0, c->MakeShape({ batch, height, width, depth })); + return Status::OK(); +} + +REGISTER_OP("DataAugmentation") +.Input("image_a: float32") +.Input("image_b: float32") +.Input("global_step: int64") +.Attr("crop: list(int) >= 2") +.Attr("params_a_name: list(string)") +.Attr("params_a_rand_type: list(string)") +.Attr("params_a_exp: list(bool)") +.Attr("params_a_mean: list(float)") +.Attr("params_a_spread: list(float)") +.Attr("params_a_prob: list(float)") +.Attr("params_a_coeff_schedule: list(float)") +.Attr("params_b_name: list(string)") +.Attr("params_b_rand_type: list(string)") +.Attr("params_b_exp: list(bool)") +.Attr("params_b_mean: list(float)") +.Attr("params_b_spread: list(float)") +.Attr("params_b_prob: list(float)") +.Attr("params_b_coeff_schedule: list(float)") +.Output("aug_image_a: float32") +.Output("aug_image_b: float32") +.Output("transforms_from_a: float32") +.Output("transforms_from_b: float32") +.SetShapeFn([](InferenceContext *c) { + // Verify input A and input B both have 4 dimensions + ShapeHandle input_shape_a, input_shape_b; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape_a)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &input_shape_b)); + + // TODO: Verify params vectors all have the same length + + // TODO: Move this out of here and into Compute + // Verify input A and input B are the same shape + DimensionHandle batch_size, unused; + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(input_shape_a, 0), + c->Value(c->Dim(input_shape_b, 0)), + &batch_size)); + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(input_shape_a, 1), + c->Value(c->Dim(input_shape_b, 1)), &unused)); + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(input_shape_a, 2), + c->Value(c->Dim(input_shape_b, 2)), &unused)); + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(input_shape_a, 3), + c->Value(c->Dim(input_shape_b, 3)), &unused)); + + // Get cropping dimensions + std::vector<int32>crop_; + TF_RETURN_IF_ERROR(c->GetAttr("crop", &crop_)); + + // Reshape input shape to cropped shape + TF_RETURN_IF_ERROR(c->ReplaceDim(input_shape_a, 1, c->MakeDim(crop_[0]), + &input_shape_a)); + TF_RETURN_IF_ERROR(c->ReplaceDim(input_shape_a, 2, c->MakeDim(crop_[1]), + &input_shape_a)); + + // Set output images shapes + c->set_output(0, input_shape_a); + c->set_output(1, input_shape_a); + + // Set output spatial transforms shapes + c->set_output(2, c->MakeShape({ batch_size, 6 })); + c->set_output(3, c->MakeShape({ batch_size, 6 })); + + return Status::OK(); + }); + +REGISTER_OP("FlowAugmentation") +.Input("flows: float32") +.Input("transforms_from_a: float32") +.Input("transforms_from_b: float32") +.Attr("crop: list(int) >= 2") +.Output("transformed_flows: float32") +.SetShapeFn(SetOutputToSizedImage); +} // namespace tensorflow |
