anirudhacharya closed pull request #11178: [MXNET-379] L1 Norm operator URL: https://github.com/apache/incubator-mxnet/pull/11178
This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/src/operator/l1_normalization-inl.h b/src/operator/l1_normalization-inl.h new file mode 100644 index 00000000000..8481c05c737 --- /dev/null +++ b/src/operator/l1_normalization-inl.h @@ -0,0 +1,325 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2016 by Contributors + * \file l1_normalization_op-inl.h + * \brief instance l1 Normalization op +*/ +#ifndef MXNET_OPERATOR_L1_NORMALIZATION_INL_H_ +#define MXNET_OPERATOR_L1_NORMALIZATION_INL_H_ + +#include <dmlc/logging.h> +#include <dmlc/parameter.h> +#include <mxnet/operator.h> +#include <map> +#include <vector> +#include <string> +#include <utility> +#include "./operator_common.h" +#include "./mshadow_op.h" + +namespace mxnet { +namespace op { + +namespace l1_normalization { +enum L1NormalizationOpInputs {kData}; +enum L1NormalizationOpOutputs {kOut, kNorm}; +enum L1NormalizationOpType {kInstance, kChannel, kSpatial}; +enum L1NormalizationBackResource {kTempSpace}; +} // l1_normalization + +struct L1NormalizationParam : public dmlc::Parameter<L1NormalizationParam> { + float eps; + int mode; + DMLC_DECLARE_PARAMETER(L1NormalizationParam) { + DMLC_DECLARE_FIELD(eps).set_default(1e-10f) + .describe("A small constant for numerical stability."); + DMLC_DECLARE_FIELD(mode) + .add_enum("instance", l1_normalization::kInstance) + .add_enum("spatial", l1_normalization::kSpatial) + .add_enum("channel", l1_normalization::kChannel) + .set_default(l1_normalization::kInstance) + .describe("Specify the dimension along which to compute L1 norm."); + } +}; + +/** + * \brief This is the implementation of l1 normalization operator. + * \tparam xpu The device that the operator will be executed on. + */ +template<typename xpu, typename DType> +class L1NormalizationOp : public Operator { +public: + explicit L1NormalizationOp(L1NormalizationParam p) { + this->param_ = p; + } + + virtual void Forward(const OpContext &ctx, + const std::vector<TBlob> &in_data, + const std::vector<OpReqType> &req, + const std::vector<TBlob> &out_data, + const std::vector<TBlob> &aux_args) { + using namespace mshadow; + using namespace mshadow::expr; + if (req[l1_normalization::kOut] == kNullOp) return; + CHECK_EQ(req[l1_normalization::kOut], kWriteTo); + CHECK_EQ(in_data.size(), 1U); + CHECK_EQ(out_data.size(), 2U); + Stream<xpu> *s = ctx.get_stream<xpu>(); + TShape orig_shape = in_data[l1_normalization::kData].shape_; + if (param_.mode == l1_normalization::kInstance) { + Shape<2> dshape = Shape2(orig_shape[0], + orig_shape.ProdShape(1, orig_shape.ndim())); + Tensor<xpu, 2, DType> data = in_data[l1_normalization::kData] + .get_with_shape<xpu, 2, DType>(dshape, s); + Tensor<xpu, 2, DType> out = out_data[l1_normalization::kOut] + .get_with_shape<xpu, 2, DType>(dshape, s); + Tensor<xpu, 1, DType> norm = out_data[l1_normalization::kNorm].get<xpu, 1, DType>(s); + norm = sumall_except_dim<0>(F<mxnet::op::mshadow_op::abs>(data)); + MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { + mxnet_op::Kernel<mxnet_op::op_with_req<mxnet::op::mshadow_op::plus, Req>, xpu>::Launch( + s, norm.size(0), norm.dptr_, norm.dptr_, DType(param_.eps)); + }); + out = data / broadcast<0>(norm, out.shape_); + } else if (param_.mode == l1_normalization::kChannel) { + CHECK_GE(orig_shape.ndim(), 3U); + Shape<3> dshape = Shape3(orig_shape[0], orig_shape[1], + orig_shape.ProdShape(2, orig_shape.ndim())); + Tensor<xpu, 3, DType> data = in_data[l1_normalization::kData] + .get_with_shape<xpu, 3, DType>(dshape, s); + Tensor<xpu, 3, DType> out = out_data[l1_normalization::kOut] + .get_with_shape<xpu, 3, DType>(dshape, s); + Shape<2> norm_shape = Shape2(dshape[0], dshape[2]); + Tensor<xpu, 2, DType> norm = out_data[l1_normalization::kNorm] + .get_with_shape<xpu, 2, DType>(norm_shape, s); + norm = reduce_with_axis<red::sum, false>(F<mxnet::op::mshadow_op::abs>(data), 1); + MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { + mxnet_op::Kernel<mxnet_op::op_with_req<mxnet::op::mshadow_op::plus, Req>, xpu>::Launch( + s, norm.size(0) * norm.size(1), norm.dptr_, norm.dptr_, DType(param_.eps)); + }); + out = data / broadcast_with_axis(norm, 0, orig_shape[1]); + } else if (param_.mode == l1_normalization::kSpatial) { + CHECK_GE(orig_shape.ndim(), 3U); + Shape<3> dshape = Shape3(orig_shape[0], orig_shape[1], + orig_shape.ProdShape(2, orig_shape.ndim())); + Tensor<xpu, 3, DType> data = in_data[l1_normalization::kData] + .get_with_shape<xpu, 3, DType>(dshape, s); + Tensor<xpu, 3, DType> out = out_data[l1_normalization::kOut] + .get_with_shape<xpu, 3, DType>(dshape, s); + Shape<2> norm_shape = Shape2(dshape[0], dshape[1]); + Tensor<xpu, 2, DType> norm = out_data[l1_normalization::kNorm] + .get_with_shape<xpu, 2, DType>(norm_shape, s); + norm = reduce_with_axis<red::sum, false>(F<mxnet::op::mshadow_op::abs>(data), 2); + MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { + mxnet_op::Kernel<mxnet_op::op_with_req<mxnet::op::mshadow_op::plus, Req>, xpu>::Launch( + s, norm.size(0) * norm.size(1), norm.dptr_, norm.dptr_, DType(param_.eps)); + }); + out = data / broadcast_with_axis(norm, 1, dshape[2]); + } else { + LOG(FATAL) << "Unexpected mode in l1 normalization"; + } + } + + virtual void Backward(const OpContext &ctx, + const std::vector<TBlob> &out_grad, + const std::vector<TBlob> &in_data, + const std::vector<TBlob> &out_data, + const std::vector<OpReqType> &req, + const std::vector<TBlob> &in_grad, + const std::vector<TBlob> &aux_args) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(out_grad.size(), 1U); + CHECK(in_data.size() == 1U && in_grad.size() == 1U); + CHECK_EQ(req.size(), 1U); + + Stream<xpu> *s = ctx.get_stream<xpu>(); + TShape orig_shape = out_data[l1_normalization::kOut].shape_; + if (param_.mode == l1_normalization::kInstance) { + Shape<2> dshape = Shape2(orig_shape[0], + orig_shape.ProdShape(1, orig_shape.ndim())); + Tensor<xpu, 2, DType> data = out_data[l1_normalization::kOut] + .get_with_shape<xpu, 2, DType>(dshape, s); + Tensor<xpu, 2, DType> grad_in = in_grad[l1_normalization::kData] + .get_with_shape<xpu, 2, DType>(dshape, s); + Tensor<xpu, 2, DType> grad_out = out_grad[l1_normalization::kOut] + .get_with_shape<xpu, 2, DType>(dshape, s); + Tensor<xpu, 1, DType> norm = out_data[l1_normalization::kNorm].get<xpu, 1, DType>(s); + Tensor<xpu, 1, DType> temp = ctx.requested[l1_normalization::kTempSpace] + .get_space_typed<xpu, 1, DType>(mshadow::Shape1(data.shape_[0]), s); + Assign(grad_in, req[l1_normalization::kData], + (grad_out) / broadcast<0>(norm, data.shape_)); + } else if (param_.mode == l1_normalization::kChannel) { + CHECK_GE(orig_shape.ndim(), 3U); + Shape<3> dshape = Shape3(orig_shape[0], orig_shape[1], + orig_shape.ProdShape(2, orig_shape.ndim())); + Tensor<xpu, 3, DType> data = out_data[l1_normalization::kOut] + .get_with_shape<xpu, 3, DType>(dshape, s); + Tensor<xpu, 3, DType> grad_in = in_grad[l1_normalization::kData] + .get_with_shape<xpu, 3, DType>(dshape, s); + Tensor<xpu, 3, DType> grad_out = out_grad[l1_normalization::kOut] + .get_with_shape<xpu, 3, DType>(dshape, s); + Shape<2> norm_shape = Shape2(dshape[0], dshape[2]); + Tensor<xpu, 2, DType> norm = out_data[l1_normalization::kNorm] + .get_with_shape<xpu, 2, DType>(norm_shape, s); + Tensor<xpu, 2, DType> temp = ctx.requested[l1_normalization::kTempSpace] + .get_space_typed<xpu, 2, DType>(mshadow::Shape2(data.shape_[0], data.shape_[2]), s); + Assign(grad_in, req[l1_normalization::kData], + (grad_out) / broadcast_with_axis(norm, 0, orig_shape[1])); + } else if (param_.mode == l1_normalization::kSpatial) { + CHECK_GE(orig_shape.ndim(), 3U); + Shape<3> dshape = Shape3(orig_shape[0], orig_shape[1], + orig_shape.ProdShape(2, orig_shape.ndim())); + Tensor<xpu, 3, DType> data = out_data[l1_normalization::kOut] + .get_with_shape<xpu, 3, DType>(dshape, s); + Tensor<xpu, 3, DType> grad_in = in_grad[l1_normalization::kData] + .get_with_shape<xpu, 3, DType>(dshape, s); + Tensor<xpu, 3, DType> grad_out = out_grad[l1_normalization::kOut] + .get_with_shape<xpu, 3, DType>(dshape, s); + Shape<2> norm_shape = Shape2(dshape[0], dshape[1]); + Tensor<xpu, 2, DType> norm = out_data[l1_normalization::kNorm] + .get_with_shape<xpu, 2, DType>(norm_shape, s); + Tensor<xpu, 2, DType> temp = ctx.requested[l1_normalization::kTempSpace] + .get_space_typed<xpu, 2, DType>(mshadow::Shape2(data.shape_[0], data.shape_[1]), s); + Assign(grad_in, req[l1_normalization::kData], + (grad_out) / broadcast_with_axis(norm, 1, dshape[2])); + } else { + LOG(FATAL) << "Unexpected mode in l1 normalization"; + } + } + + private: + L1NormalizationParam param_; +}; // class L1NormalizationOp + +// Decalre Factory function, used for dispatch specialization +template<typename xpu> +Operator* CreateOp(L1NormalizationParam param, int dtype); + +#if DMLC_USE_CXX11 +class L1NormalizationProp : public OperatorProperty { + public: + std::vector<std::string> ListArguments() const override { + return {"data"}; + } + + std::vector<std::string> ListOutputs() const override { + return {"output", "norm"}; + } + + int NumVisibleOutputs() const override { + return 1; + } + + void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override { + param_.Init(kwargs); + } + + std::map<std::string, std::string> GetParams() const override { + return param_.__DICT__(); + } + + bool InferType(std::vector<int> *in_type, + std::vector<int> *out_type, + std::vector<int> *aux_type) const override { + int dtype = (*in_type)[0]; + type_assign(&dtype, (*out_type)[0]); + type_assign(&dtype, (*out_type)[1]); + + TYPE_ASSIGN_CHECK(*in_type, 0, dtype); + TYPE_ASSIGN_CHECK(*out_type, 0, dtype); + TYPE_ASSIGN_CHECK(*out_type, 1, dtype); + return dtype != -1; + } + + bool InferShape(std::vector<TShape> *in_shape, + std::vector<TShape> *out_shape, + std::vector<TShape> *aux_shape) const override { + using namespace mshadow; + CHECK_EQ(in_shape->size(), 1U) << "L1Normalization layer only accepts data as input"; + const TShape &dshape = (*in_shape)[l1_normalization::kData]; + // require data to be known + if ((*in_shape)[l1_normalization::kData].ndim() == 0) return false; + out_shape->clear(); + out_shape->push_back(dshape); + if (param_.mode == l1_normalization::kInstance) { + out_shape->push_back(Shape1(dshape[0])); + } else if (param_.mode == l1_normalization::kChannel) { + CHECK_GE(dshape.ndim(), 3U) << "At lease 3 dimensions required in channel mode"; + TShape norm_shape = dshape; + norm_shape[1] = 1; + out_shape->push_back(norm_shape); + } else if (param_.mode == l1_normalization::kSpatial) { + CHECK_GE(dshape.ndim(), 3U) << "At lease 3 dimensions required in spatial mode"; + out_shape->push_back(Shape2(dshape[0], dshape[1])); + } else { + return false; + } + return true; + } + + OperatorProperty* Copy() const override { + L1NormalizationProp* norm_sym = new L1NormalizationProp(); + norm_sym->param_ = this->param_; + return norm_sym; + } + + std::string TypeString() const override { + return "L1Normalization"; + } + + // declare dependency and inplace optimization options + std::vector<int> DeclareBackwardDependency( + const std::vector<int> &out_grad, + const std::vector<int> &in_data, + const std::vector<int> &out_data) const override { + return {out_grad[l1_normalization::kOut], + out_data[l1_normalization::kOut], + out_data[l1_normalization::kNorm]}; + } + + std::vector<std::pair<int, void*> > BackwardInplaceOption( + const std::vector<int> &out_grad, + const std::vector<int> &in_data, + const std::vector<int> &out_data, + const std::vector<void*> &in_grad) const override { + return {{out_grad[l1_normalization::kOut], in_grad[l1_normalization::kData]}}; + } + + std::vector<ResourceRequest> BackwardResource( + const std::vector<TShape> &in_shape) const override { + return {ResourceRequest::kTempSpace}; + } + + Operator* CreateOperator(Context ctx) const override { + LOG(FATAL) << "Not Implemented."; + return NULL; + } + + Operator* CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape, + std::vector<int> *in_type) const override; + + private: + L1NormalizationParam param_; +}; // class L1NormalizationSymbol +#endif +} // namespace op +} // namespace mxnet +#endif // MXNET_OPERATOR_L1_NORMALIZATION_INL_H_ diff --git a/src/operator/l1_normalization.cc b/src/operator/l1_normalization.cc new file mode 100644 index 00000000000..bd72b76e10e --- /dev/null +++ b/src/operator/l1_normalization.cc @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2015 by Contributors + * \file l1_normalization.cc + * \brief l1 normalization operator +*/ +#include "./l1_normalization-inl.h" +namespace mxnet { +namespace op { +template<> +Operator* CreateOp<cpu>(L1NormalizationParam param, int dtype) { + Operator* op = NULL; + MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { + op = new L1NormalizationOp<cpu, DType>(param); + }); + return op; +} + +// DO_BIND_DISPATCH comes from static_operator_common.h +Operator* L1NormalizationProp::CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape, + std::vector<int> *in_type) const { + DO_BIND_DISPATCH(CreateOp, param_, in_type->at(0)); +} + +DMLC_REGISTER_PARAMETER(L1NormalizationParam); + +MXNET_REGISTER_OP_PROPERTY(L1Normalization, L1NormalizationProp) +.describe(R"code(Normalize the input array using the L1 norm. + +For 1-D NDArray, it computes:: + + out = data / sum(abs(data) + eps) + +For N-D NDArray, if the input array has shape (N, N, ..., N), + +with ``mode`` = ``instance``, it normalizes each instance in the multidimensional +array by its L1 norm.:: + + for i in 0...N + out[i,:,:,...,:] = data[i,:,:,...,:] / sum(abs(data[i,:,:,...,:]) + eps) + +with ``mode`` = ``channel``, it normalizes each channel in the array by its L1 norm.:: + + for i in 0...N + out[:,i,:,...,:] = data[:,i,:,...,:] / sum(abs(data[:,i,:,...,:])) + eps) + +with ``mode`` = ``spatial``, it normalizes the cross channel norm for each position +in the array by its L1 norm.:: + + for dim in 2...N + for i in 0...N + out[.....,i,...] = take(out, indices=i, axis=dim) / sum(abs(out), indices=i, axis=dim) + eps) + -dim- + +Example:: + + x = [[[1,2], + [3,4]], + [[2,2], + [5,6]]] + + L1Normalization(x, mode='instance') + =[[[0.1 , 0.2 ], + [0.3 , 0.4 ]], + [[0.13333333, 0.13333333], + [0.33333333, 0.4 ]]] + + L1Normalization(x, mode='channel') + =[[[0.25 , 0.33333333], + [0.75 , 0.66666667]], + [[0.28571429, 0.25 ], + [0.71428571, 0.75 ]]] + + L1Normalization(x, mode='spatial') + =[[[0.33333333, 0.66666667], + [0.42857143, 0.57142857]], + [[0.5 , 0.5 ], + [0.45454545, 0.54545455]]] + +)code" ADD_FILELINE) +.add_argument("data", "NDArray-or-Symbol", "Input array to normalize.") +.add_arguments(L1NormalizationParam::__FIELDS__()); +} // namespace op +} // namespace mxnet diff --git a/src/operator/l1_normalization.cu b/src/operator/l1_normalization.cu new file mode 100644 index 00000000000..1ca1b118083 --- /dev/null +++ b/src/operator/l1_normalization.cu @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2015 by Contributors + * \file l1_normalization.cu + * \brief l1 normalization operator +*/ +#include "./l1_normalization-inl.h" +namespace mxnet { +namespace op { +template<> +Operator* CreateOp<gpu>(L1NormalizationParam param, int dtype) { + Operator* op = NULL; + MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { + op = new L1NormalizationOp<gpu, DType>(param); + }); + return op; +} +} // namespace op +} // namespace mxnet diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 1eb23cc9228..2af17b6eb60 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -2787,6 +2787,38 @@ def check_l2_normalization(in_shape, mode, dtype, norm_eps=1e-10): # check gradient check_numeric_gradient(out, [in_data], numeric_eps=1e-3, rtol=1e-2, atol=1e-3) +def check_l1_normalization(in_shape, mode, dtype, norm_eps=1e-10): + ctx = default_context() + data = mx.symbol.Variable('data') + out = mx.symbol.L1Normalization(data=data, mode=mode, eps=norm_eps) + in_data = np.random.uniform(-1, 1, in_shape).astype(dtype) + # calculate numpy results + if mode == 'channel': + assert in_data.ndim > 2 + np_norm1 = np.linalg.norm(in_data, ord=1, axis=1) + norm_eps + np_norm1 = np.repeat(1. / np.expand_dims(np_norm1, axis=1), in_data.shape[1], axis=1) + np_out = np.multiply(in_data, np_norm1) + elif mode == 'spatial': + assert in_data.ndim > 2 + s = in_data.shape + np_norm1 = np.linalg.norm(in_data.reshape((s[0], s[1], -1)), ord=1, axis=2) + norm_eps + np_norm1 = np.repeat(1. / np_norm1[:, np.newaxis], in_data.size / s[0] / s[1], axis=2) + np_out = np.multiply(in_data, np_norm1.reshape(s)) + elif mode == 'instance': + assert in_data.ndim > 1 + s = in_data.shape + np_norm1 = np.linalg.norm(in_data.reshape((s[0], -1)), ord=1, axis=1) + norm_eps + np_norm1 = np.repeat(1. / np_norm1[:, np.newaxis], in_data.size / s[0], axis=1) + np_out = np.multiply(in_data, np_norm1.reshape(s)) + else: + raise RuntimeError('Unknown l2 normalization mode') + exe = out.simple_bind(ctx=ctx, data=in_data.shape) + output = exe.forward(is_train=True, data=in_data) + # compare numpy + mxnet + assert_almost_equal(exe.outputs[0].asnumpy(), np_out, rtol=1e-2 if dtype is 'float16' else 1e-5, atol=1e-5) + # check gradient + check_numeric_gradient(out, [in_data], numeric_eps=1e-3, rtol=1e-2, atol=1e-3) + # TODO(szha): Seeding this masks failures. We need to do a deep dive for failures without this seed. @with_seed(1234) @@ -2800,6 +2832,16 @@ def test_l2_normalization(): for width in [5, 7]: check_l2_normalization((nbatch, nchannel, height, width), mode, dtype) +@with_seed(1234) +def test_l1_normalization(): + for dtype in ['float16', 'float32', 'float64']: + for mode in ['channel', 'spatial', 'instance']: + for nbatch in [1, 4]: + for nchannel in [3, 5]: + for height in [4, 6]: + check_l1_normalization((nbatch, nchannel, height), mode, dtype) + for width in [5, 7]: + check_l1_normalization((nbatch, nchannel, height, width), mode, dtype) def check_layer_normalization(in_shape, axis, eps, dtype=np.float32, forward_check_eps=1E-3): def npy_layer_norm(data, gamma, beta, axis=1, eps=1E-5): ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected] With regards, Apache Git Services
