This is an automated email from the ASF dual-hosted git repository. haibin pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new f783a66 [MXNET-379] L1 Normalization (#11229) f783a66 is described below commit f783a66b1c9f141738ab4f4c0b6f525f61a95d6c Author: Anirudh <anirudhk...@gmail.com> AuthorDate: Fri Jun 29 17:06:41 2018 -0700 [MXNET-379] L1 Normalization (#11229) * l1 norm --- src/operator/tensor/broadcast_reduce_op.h | 79 +++++++++++++++++------- src/operator/tensor/broadcast_reduce_op_value.cc | 32 ++++++---- src/operator/tensor/broadcast_reduce_op_value.cu | 4 +- tests/python/unittest/test_ndarray.py | 47 +++++++------- tests/python/unittest/test_operator.py | 44 +++++++++++++ 5 files changed, 143 insertions(+), 63 deletions(-) diff --git a/src/operator/tensor/broadcast_reduce_op.h b/src/operator/tensor/broadcast_reduce_op.h index e50071b..ac7199a 100644 --- a/src/operator/tensor/broadcast_reduce_op.h +++ b/src/operator/tensor/broadcast_reduce_op.h @@ -70,7 +70,7 @@ struct NormParam : public dmlc::Parameter<NormParam> { bool keepdims; DMLC_DECLARE_PARAMETER(NormParam) { DMLC_DECLARE_FIELD(ord).set_default(2) - .describe("Order of the norm. Currently ord=2 is supported."); + .describe("Order of the norm. Currently ord=1 and ord=2 is supported."); DMLC_DECLARE_FIELD(axis).set_default(dmlc::optional<TShape>()) .describe(R"code(The axis or axes along which to perform the reduction. The default, `axis=()`, will compute over all elements into a @@ -869,7 +869,7 @@ struct ReduceGrad { } }; -inline bool L2NormStorageType(const nnvm::NodeAttrs& attrs, +inline bool LpNormStorageType(const nnvm::NodeAttrs& attrs, const int dev_mask, DispatchMode* dispatch_mode, std::vector<int>* in_attrs, @@ -889,18 +889,20 @@ inline bool L2NormStorageType(const nnvm::NodeAttrs& attrs, dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, DispatchMode::kFCompute); } - const TShape axis = param.axis.has_value() ? param.axis.value() : TShape(); - if (!dispatched && (in_stype == kRowSparseStorage || in_stype == kCSRStorage) && - axis.ndim() == 0 && param.ord == 2) { - // l2 norm: rsp/csr, axis = () -> dns - dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, - DispatchMode::kFComputeEx); - } - if (!dispatched && in_stype == kCSRStorage && axis.ndim() == 1 && !param.keepdims && - (axis[0] == 0 || axis[0] == 1) && param.ord == 2) { - // l2 norm: csr, axis = 0/1 -> dns - dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, - dispatch_ex); + if (param.ord == 2) { + const TShape axis = param.axis.has_value() ? param.axis.value() : TShape(); + if (!dispatched && (in_stype == kRowSparseStorage || in_stype == kCSRStorage) && + axis.ndim() == 0 && param.ord == 2) { + // l2 norm: rsp/csr, axis = () -> dns + dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, + DispatchMode::kFComputeEx); + } + if (!dispatched && in_stype == kCSRStorage && axis.ndim() == 1 && !param.keepdims && + (axis[0] == 0 || axis[0] == 1) && param.ord == 2) { + // l2 norm: csr, axis = 0/1 -> dns + dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, + dispatch_ex); + } } if (!dispatched) { dispatched = dispatch_fallback(out_attrs, dispatch_mode); @@ -984,13 +986,13 @@ void SqRootForL2(const OpContext& ctx, OpReqType req, const TBlob &output) { } template<typename xpu> -void L2NormCompute(const nnvm::NodeAttrs& attrs, +void LpNormCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector<TBlob>& inputs, const std::vector<OpReqType>& req, const std::vector<TBlob>& outputs) { const NormParam& param = nnvm::get<NormParam>(attrs.parsed); - CHECK_EQ(param.ord, 2) << "norm only support ord=2"; + CHECK(param.ord == 1 || param.ord == 2) << "norm only supports ord=1 and ord=2"; if (req[0] == kNullOp) return; TShape small; @@ -999,13 +1001,18 @@ void L2NormCompute(const nnvm::NodeAttrs& attrs, } else { small = ReduceAxesShapeImpl(inputs[0].shape_, param.axis, true, false); } - ReduceAxesComputeImpl<xpu, mshadow::red::sum, false, mshadow_op::square>( - ctx, inputs, req, outputs, small); - SqRootForL2<xpu>(ctx, req[0], outputs[0]); + if (param.ord == 1) { + ReduceAxesComputeImpl<xpu, mshadow::red::sum, false, mshadow_op::abs>( + ctx, inputs, req, outputs, small); + } else if (param.ord == 2) { + ReduceAxesComputeImpl<xpu, mshadow::red::sum, false, mshadow_op::square>( + ctx, inputs, req, outputs, small); + SqRootForL2<xpu>(ctx, req[0], outputs[0]); + } } template<typename xpu> -void L2NormGradCompute(const nnvm::NodeAttrs& attrs, +void LpNormGradCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector<TBlob>& inputs, const std::vector<OpReqType>& req, @@ -1021,8 +1028,36 @@ void L2NormGradCompute(const nnvm::NodeAttrs& attrs, } else { small = ReduceAxesShapeImpl(outputs[0].shape_, param.axis, true, false); } - ReduceAxesBackwardUseInOutImpl<xpu, mshadow_op::div, false>(ctx, small, inputs, - req, outputs); + if (param.ord == 1) { + TShape src_shape, dst_shape; + BroadcastReduceShapeCompact(outputs[0].shape_, small, &src_shape, &dst_shape); + Stream<xpu> *s = ctx.get_stream<xpu>(); + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { + if (dst_shape.ndim() == 2) { + Tensor<xpu, 2, DType> ograd = + inputs[0].get_with_shape<xpu, 2, DType>(dst_shape.get<2>(), s); + Tensor<xpu, 2, DType> igrad = + outputs[0].get_with_shape<xpu, 2, DType>(src_shape.get<2>(), s); + Tensor<xpu, 2, DType> data = + inputs[1].get_with_shape<xpu, 2, DType>(src_shape.get<2>(), s); + ASSIGN_DISPATCH(igrad, req[0], + broadcast_to(ograd, src_shape)*F<mshadow_op::sign>(data)); + } else { + const int ndim = MXNET_SPECIAL_MAX_NDIM; + Tensor<xpu, ndim, DType> igrad = + outputs[0].get_with_shape<xpu, ndim, DType>(src_shape.get<ndim>(), s); + Tensor<xpu, ndim, DType> ograd = + inputs[0].get_with_shape<xpu, ndim, DType>(dst_shape.get<ndim>(), s); + Tensor<xpu, ndim, DType> data = + inputs[1].get_with_shape<xpu, ndim, DType>(src_shape.get<ndim>(), s); + ASSIGN_DISPATCH(igrad, req[0], + broadcast_to(ograd, src_shape)*F<mshadow_op::sign>(data)); + } + }); + } else if (param.ord == 2) { + ReduceAxesBackwardUseInOutImpl<xpu, mshadow_op::div, false>(ctx, small, inputs, + req, outputs); + } } template<typename xpu> diff --git a/src/operator/tensor/broadcast_reduce_op_value.cc b/src/operator/tensor/broadcast_reduce_op_value.cc index 7bcc3e9..cde31bf 100644 --- a/src/operator/tensor/broadcast_reduce_op_value.cc +++ b/src/operator/tensor/broadcast_reduce_op_value.cc @@ -88,9 +88,9 @@ MXNET_ADD_SPARSE_OP_ALIAS(sum) Example:: - data = [[[1,2],[2,3],[1,3]], - [[1,4],[4,3],[5,2]], - [[7,1],[7,2],[7,3]]] + data = [[[1, 2], [2, 3], [1, 3]], + [[1, 4], [4, 3], [5, 2]], + [[7, 1], [7, 2], [7, 3]]] sum(data, axis=1) [[ 4. 8.] @@ -100,9 +100,9 @@ Example:: sum(data, axis=[1,2]) [ 12. 19. 27.] - data = [[1,2,0], - [3,0,1], - [4,1,0]] + data = [[1, 2, 0], + [3, 0, 1], + [4, 1, 0]] csr = cast_storage(data, 'csr') @@ -280,14 +280,20 @@ MXNET_ADD_SPARSE_OP_ALIAS(norm) This operator computes the norm on an NDArray with the specified axis, depending on the value of the ord parameter. By default, it computes the L2 norm on the entire -array. +array. Currently only ord=2 supports sparse ndarrays. Examples:: - x = [[1, 2], - [3, 4]] + x = [[[1, 2], + [3, 4]], + [[2, 2], + [5, 6]]] - norm(x) = [5.47722578] + norm(x, ord=2, axis=1) = [[3.1622777 4.472136 ] + [5.3851647 6.3245554]] + + norm(x, ord=1, axis=1) = [[4., 6.], + [7., 8.]] rsp = x.cast_storage('row_sparse') @@ -303,13 +309,13 @@ Examples:: .set_attr_parser(ParamParser<NormParam>) .set_attr<nnvm::FInferShape>("FInferShape", NormShape) .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>) -.set_attr<FInferStorageType>("FInferStorageType", L2NormStorageType) +.set_attr<FInferStorageType>("FInferStorageType", LpNormStorageType) .set_attr<nnvm::FGradient>("FGradient", ReduceGrad{ "_backward_norm" }) .set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& attrs) { return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; }) -.set_attr<FCompute>("FCompute<cpu>", L2NormCompute<cpu>) +.set_attr<FCompute>("FCompute<cpu>", LpNormCompute<cpu>) .set_attr<FComputeEx>("FComputeEx<cpu>", L2NormComputeEx<cpu>) .add_argument("data", "NDArray-or-Symbol", "The input") .add_arguments(NormParam::__FIELDS__()); @@ -322,7 +328,7 @@ NNVM_REGISTER_OP(_backward_norm) [](const NodeAttrs& attrs) { return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; }) -.set_attr<FCompute>("FCompute<cpu>", L2NormGradCompute<cpu>); +.set_attr<FCompute>("FCompute<cpu>", LpNormGradCompute<cpu>); } // namespace op diff --git a/src/operator/tensor/broadcast_reduce_op_value.cu b/src/operator/tensor/broadcast_reduce_op_value.cu index f7fba68..e2a3840 100644 --- a/src/operator/tensor/broadcast_reduce_op_value.cu +++ b/src/operator/tensor/broadcast_reduce_op_value.cu @@ -101,11 +101,11 @@ NNVM_REGISTER_OP(_broadcast_backward) .set_attr<FCompute>("FCompute<gpu>", ReduceAxesCompute<gpu, mshadow::red::sum>); NNVM_REGISTER_OP(norm) -.set_attr<FCompute>("FCompute<gpu>", L2NormCompute<gpu>) +.set_attr<FCompute>("FCompute<gpu>", LpNormCompute<gpu>) .set_attr<FComputeEx>("FComputeEx<gpu>", L2NormComputeEx<gpu>); NNVM_REGISTER_OP(_backward_norm) -.set_attr<FCompute>("FCompute<gpu>", L2NormGradCompute<gpu>); +.set_attr<FCompute>("FCompute<gpu>", LpNormGradCompute<gpu>); } // namespace op } // namespace mxnet diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index aeaa0b7..a01514d 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -27,6 +27,7 @@ from mxnet.test_utils import assert_almost_equal, assert_exception from mxnet.test_utils import default_context from mxnet.test_utils import np_reduce from mxnet.test_utils import same +from mxnet.test_utils import random_sample, rand_shape_nd from numpy.testing import assert_allclose import mxnet.autograd @@ -1275,33 +1276,27 @@ def test_ndarray_astype(): @with_seed() def test_norm(ctx=default_context()): - np_arr = np.random.uniform(size=(3, 3, 3, 3)) + def l1norm(input_data, axis=0, keepdims=False): + return np.sum(abs(input_data), axis=axis, keepdims=keepdims) + def l2norm(input_data, axis=0, keepdims=False): + return np.linalg.norm(input_data, axis=axis, keepdims=keepdims) + + in_data_dim = random_sample([4,5,6], 1)[0] + in_data_shape = rand_shape_nd(in_data_dim) + np_arr = np.random.uniform(-1, 1, in_data_shape).astype(np.float32) mx_arr = mx.nd.array(np_arr, ctx=ctx) - arr1 = np.linalg.norm(np_arr, keepdims=False) - arr2 = mx.nd.norm(mx_arr, keepdims=False) - print(arr1) - print(arr2.asnumpy()) - mx.test_utils.assert_almost_equal(arr1, arr2.asnumpy()[0]) - - for i in range(4): - arr1 = np.linalg.norm(np_arr, axis=i, keepdims=False) - arr2 = mx.nd.norm(mx_arr, axis=i, keepdims=False) - assert arr1.shape == arr2.shape - mx.test_utils.assert_almost_equal(arr1, arr2.asnumpy()) - - arr1 = np.linalg.norm(np_arr, axis=i, keepdims=True) - arr2 = mx.nd.norm(mx_arr, axis=i, keepdims=True) - assert arr1.shape == arr2.shape - mx.test_utils.assert_almost_equal(arr1, arr2.asnumpy()) - if (i < 3): - arr1 = np.linalg.norm(np_arr, axis=(i, i+1), keepdims=False) - arr2 = mx.nd.norm(mx_arr, axis=(i, i+1), keepdims=False) - assert arr1.shape == arr2.shape - mx.test_utils.assert_almost_equal(arr1, arr2.asnumpy()) - arr1 = np.linalg.norm(np_arr, axis=(i, i+1), keepdims=True) - arr2 = mx.nd.norm(mx_arr, axis=(i, i+1), keepdims=True) - assert arr1.shape == arr2.shape - mx.test_utils.assert_almost_equal(arr1, arr2.asnumpy()) + for ord in [1,2]: + for keep_dims in [True, False]: + for i in range(4): + npy_out = l1norm(np_arr, i, keep_dims) if ord==1 else l2norm(np_arr, i, keep_dims) + mx_out = mx.nd.norm(mx_arr, ord=ord, axis=i, keepdims=keep_dims) + assert npy_out.shape == mx_out.shape + mx.test_utils.assert_almost_equal(npy_out, mx_out.asnumpy()) + if (i < 3): + npy_out = l1norm(np_arr, (i, i+1), keep_dims) if ord==1 else l2norm(np_arr, (i, i+1), keep_dims) + mx_out = mx.nd.norm(mx_arr, ord=ord, axis=(i, i+1), keepdims=keep_dims) + assert npy_out.shape == mx_out.shape + mx.test_utils.assert_almost_equal(npy_out, mx_out.asnumpy()) @with_seed() def test_ndarray_cpu_shared_ctx(): diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 0fa31de..2707d8f 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -3012,6 +3012,50 @@ def check_layer_normalization(in_shape, axis, eps, dtype=np.float32, forward_che grad_nodes={'data': req, 'gamma': req, 'beta': req}, numeric_eps=1e-2, rtol=1e-2, atol=1e-2) +@with_seed() +def test_norm(): + def l1norm(input_data, axis=0, keepdims=True): + return np.sum(abs(input_data), axis=axis, keepdims=keepdims) + def l2norm(input_data, axis=0, keepdims=True): + return np.linalg.norm(input_data, axis=axis, keepdims=keepdims) + + ctx = default_context() + data = mx.symbol.Variable('data') + in_data_dim = random_sample([4,5,6], 1)[0] + in_shape = rand_shape_nd(in_data_dim) + epsilon = 1e-3 + for order in [1, 2]: + for dtype in [np.float16, np.float32, np.float64]: + in_data = np.random.uniform(-1, 1, in_shape).astype(dtype) + in_data[abs(in_data) < epsilon] = epsilon + for i in range(in_data_dim): + norm_sym = mx.symbol.norm(data=data, ord=order, axis=i, keepdims=True) + npy_out = l1norm(in_data, i) if order is 1 else l2norm(in_data, i) + npy_out_backward = np.sign(in_data) if order is 1 else in_data/npy_out + check_symbolic_forward(norm_sym, [in_data], [npy_out], + rtol=1e-2 if dtype is np.float16 else 1e-5, + atol=1e-2 if dtype is np.float16 else 1e-5, ctx=ctx) + check_symbolic_backward(norm_sym, [in_data], [np.ones(npy_out.shape)], + [npy_out_backward], + rtol=1e-2 if dtype is np.float16 else 1e-5, + atol=1e-2 if dtype is np.float16 else 1e-5, ctx=ctx) + # check gradient + check_numeric_gradient(norm_sym, [in_data], numeric_eps=epsilon, rtol=1e-2, atol=1e-3) + if i < in_data_dim-1: + norm_sym = mx.symbol.norm(data=data, ord=order, axis=(i, i+1), keepdims=True) + npy_out = l1norm(in_data, (i, i+1)) if order is 1 else l2norm(in_data, (i, i+1)) + npy_out_backward = np.sign(in_data) if order is 1 else in_data/npy_out + check_symbolic_forward(norm_sym, [in_data], [npy_out], + rtol=1e-2 if dtype is np.float16 else 1e-5, + atol=1e-2 if dtype is np.float16 else 1e-5, ctx=ctx) + check_symbolic_backward(norm_sym, [in_data], [np.ones(npy_out.shape)], + [npy_out_backward], + rtol=1e-2 if dtype is np.float16 else 1e-5, + atol=1e-2 if dtype is np.float16 else 1e-5, ctx=ctx) + # check gradient + check_numeric_gradient(norm_sym, [in_data], numeric_eps=epsilon, rtol=1e-2, atol=1e-3) + + def test_layer_norm(): for dtype, forward_check_eps in zip([np.float16, np.float32, np.float64], [1E-2, 1E-3, 1E-4]):