szha closed pull request #13059: Refactor L2_normalization
URL: https://github.com/apache/incubator-mxnet/pull/13059
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/src/operator/l2_normalization-inl.h
b/src/operator/l2_normalization-inl.h
index d53e0c5caf9..c7e71424ada 100644
--- a/src/operator/l2_normalization-inl.h
+++ b/src/operator/l2_normalization-inl.h
@@ -216,7 +216,7 @@ class L2NormalizationOp : public Operator {
}
}
- private:
+ protected:
L2NormalizationParam param_;
}; // class L2NormalizationOp
diff --git a/src/operator/l2_normalization.cc b/src/operator/l2_normalization.cc
index f2f485ae6d1..6801a0a2057 100644
--- a/src/operator/l2_normalization.cc
+++ b/src/operator/l2_normalization.cc
@@ -23,13 +23,111 @@
* \brief l2 normalization operator
*/
#include "./l2_normalization-inl.h"
+
+/* VisualStudio only supports openmp 2.0 */
+#ifdef _MSC_VER
+#define collapse(x)
+#endif
+
namespace mxnet {
namespace op {
+
+template<typename DType>
+class L2NormalizationOpCPU : public L2NormalizationOp<cpu, DType> {
+ public:
+ explicit L2NormalizationOpCPU(L2NormalizationParam p)
+ : L2NormalizationOp<cpu, DType>(p) {}
+ void Forward(const OpContext &ctx, const std::vector<TBlob> &in_data,
+ const std::vector<OpReqType> &req,
+ const std::vector<TBlob> &out_data,
+ const std::vector<TBlob> &aux_args) override {
+ using namespace mshadow;
+ using namespace mshadow::expr;
+ if (req[l2_normalization::kOut] == kNullOp) return;
+ CHECK_EQ(req[l2_normalization::kOut], kWriteTo);
+ CHECK_EQ(in_data.size(), 1U);
+ CHECK_EQ(out_data.size(), 2U);
+ Stream<cpu> *s = ctx.get_stream<cpu>();
+ TShape orig_shape = in_data[l2_normalization::kData].shape_;
+ auto omp_threads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
+ if (this->param_.mode == l2_normalization::kInstance) {
+ Shape<2> dshape = Shape2(orig_shape[0],
+ orig_shape.ProdShape(1, orig_shape.ndim()));
+ Tensor<cpu, 2, DType> data = in_data[l2_normalization::kData]
+ .get_with_shape<cpu, 2, DType>(dshape, s);
+ Tensor<cpu, 2, DType> out = out_data[l2_normalization::kOut]
+ .get_with_shape<cpu, 2, DType>(dshape, s);
+ Tensor<cpu, 1, DType> norm = out_data[l2_normalization::kNorm].get<cpu,
1, DType>(s);
+#pragma omp parallel for num_threads(omp_threads)
+ for (int shape0 = 0; shape0 < static_cast<int>(dshape[0]); shape0++) {
+ norm[shape0] = DType(this->param_.eps);
+ for (int shape1 = 0; shape1 < static_cast<int>(dshape[1]); shape1++) {
+ norm[shape0] += data[shape0][shape1] * data[shape0][shape1];
+ }
+ norm[shape0] = std::sqrt(norm[shape0]);
+ for (int shape1 = 0; shape1 < static_cast<int>(dshape[1]); shape1++) {
+ out[shape0][shape1] = data[shape0][shape1] / norm[shape0];
+ }
+ }
+ } else if (this->param_.mode == l2_normalization::kChannel) {
+ CHECK_GE(orig_shape.ndim(), 3U);
+ Shape<3> dshape = Shape3(orig_shape[0], orig_shape[1],
+ orig_shape.ProdShape(2, orig_shape.ndim()));
+ Tensor<cpu, 3, DType> data = in_data[l2_normalization::kData]
+ .get_with_shape<cpu, 3, DType>(dshape, s);
+ Tensor<cpu, 3, DType> out = out_data[l2_normalization::kOut]
+ .get_with_shape<cpu, 3, DType>(dshape, s);
+ Shape<2> norm_shape = Shape2(dshape[0], dshape[2]);
+ Tensor<cpu, 2, DType> norm = out_data[l2_normalization::kNorm]
+ .get_with_shape<cpu, 2, DType>(norm_shape, s);
+#pragma omp parallel for num_threads(omp_threads) collapse(2)
+ for (int shape0 = 0; shape0 < static_cast<int>(dshape[0]); shape0++) {
+ for (int shape2 = 0; shape2 < static_cast<int>(dshape[2]); shape2++) {
+ norm[shape0][shape2] = DType(this->param_.eps);
+ for (int shape1 = 0; shape1 < static_cast<int>(dshape[1]); shape1++)
{
+ norm[shape0][shape2] += data[shape0][shape1][shape2] *
data[shape0][shape1][shape2];
+ }
+ norm[shape0][shape2] = std::sqrt(norm[shape0][shape2]);
+ for (int shape1 = 0; shape1 < static_cast<int>(dshape[1]); shape1++)
{
+ out[shape0][shape1][shape2] = data[shape0][shape1][shape2] /
norm[shape0][shape2];
+ }
+ }
+ }
+ } else if (this->param_.mode == l2_normalization::kSpatial) {
+ CHECK_GE(orig_shape.ndim(), 3U);
+ Shape<3> dshape = Shape3(orig_shape[0], orig_shape[1],
+ orig_shape.ProdShape(2, orig_shape.ndim()));
+ Tensor<cpu, 3, DType> data = in_data[l2_normalization::kData]
+ .get_with_shape<cpu, 3, DType>(dshape, s);
+ Tensor<cpu, 3, DType> out = out_data[l2_normalization::kOut]
+ .get_with_shape<cpu, 3, DType>(dshape, s);
+ Shape<2> norm_shape = Shape2(dshape[0], dshape[1]);
+ Tensor<cpu, 2, DType> norm = out_data[l2_normalization::kNorm]
+ .get_with_shape<cpu, 2, DType>(norm_shape, s);
+#pragma omp parallel for num_threads(omp_threads) collapse(2)
+ for (int shape0 = 0; shape0 < static_cast<int>(dshape[0]); shape0++) {
+ for (int shape1 = 0; shape1 < static_cast<int>(dshape[1]); shape1++) {
+ norm[shape0][shape1] = DType(this->param_.eps);
+ for (int shape2 = 0; shape2 < static_cast<int>(dshape[2]); shape2++)
{
+ norm[shape0][shape1] += data[shape0][shape1][shape2] *
data[shape0][shape1][shape2];
+ }
+ norm[shape0][shape1] = std::sqrt(norm[shape0][shape1]);
+ for (int shape2 = 0; shape2 < static_cast<int>(dshape[2]); shape2++)
{
+ out[shape0][shape1][shape2] = data[shape0][shape1][shape2] /
norm[shape0][shape1];
+ }
+ }
+ }
+ } else {
+ LOG(FATAL) << "Unexpected mode in l2 normalization";
+ }
+ }
+};
+
template<>
Operator* CreateOp<cpu>(L2NormalizationParam param, int dtype) {
Operator* op = nullptr;
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
- op = new L2NormalizationOp<cpu, DType>(param);
+ op = new L2NormalizationOpCPU<DType>(param);
});
return op;
}
@@ -37,7 +135,7 @@ Operator* CreateOp<cpu>(L2NormalizationParam param, int
dtype) {
// DO_BIND_DISPATCH comes from static_operator_common.h
Operator* L2NormalizationProp::CreateOperatorEx(Context ctx,
std::vector<TShape> *in_shape,
std::vector<int> *in_type)
const {
- DO_BIND_DISPATCH(CreateOp, param_, in_type->at(0));
+ DO_BIND_DISPATCH(CreateOp, this->param_, in_type->at(0));
}
DMLC_REGISTER_PARAMETER(L2NormalizationParam);
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services