This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new f783a66  [MXNET-379] L1 Normalization (#11229)
f783a66 is described below

commit f783a66b1c9f141738ab4f4c0b6f525f61a95d6c
Author: Anirudh <anirudhk...@gmail.com>
AuthorDate: Fri Jun 29 17:06:41 2018 -0700

    [MXNET-379] L1 Normalization (#11229)
    
    * l1 norm
---
 src/operator/tensor/broadcast_reduce_op.h        | 79 +++++++++++++++++-------
 src/operator/tensor/broadcast_reduce_op_value.cc | 32 ++++++----
 src/operator/tensor/broadcast_reduce_op_value.cu |  4 +-
 tests/python/unittest/test_ndarray.py            | 47 +++++++-------
 tests/python/unittest/test_operator.py           | 44 +++++++++++++
 5 files changed, 143 insertions(+), 63 deletions(-)

diff --git a/src/operator/tensor/broadcast_reduce_op.h 
b/src/operator/tensor/broadcast_reduce_op.h
index e50071b..ac7199a 100644
--- a/src/operator/tensor/broadcast_reduce_op.h
+++ b/src/operator/tensor/broadcast_reduce_op.h
@@ -70,7 +70,7 @@ struct NormParam : public dmlc::Parameter<NormParam> {
   bool keepdims;
   DMLC_DECLARE_PARAMETER(NormParam) {
     DMLC_DECLARE_FIELD(ord).set_default(2)
-      .describe("Order of the norm. Currently ord=2 is supported.");
+      .describe("Order of the norm. Currently ord=1 and ord=2 is supported.");
     DMLC_DECLARE_FIELD(axis).set_default(dmlc::optional<TShape>())
       .describe(R"code(The axis or axes along which to perform the reduction.
       The default, `axis=()`, will compute over all elements into a
@@ -869,7 +869,7 @@ struct ReduceGrad {
   }
 };
 
-inline bool L2NormStorageType(const nnvm::NodeAttrs& attrs,
+inline bool LpNormStorageType(const nnvm::NodeAttrs& attrs,
                               const int dev_mask,
                               DispatchMode* dispatch_mode,
                               std::vector<int>* in_attrs,
@@ -889,18 +889,20 @@ inline bool L2NormStorageType(const nnvm::NodeAttrs& 
attrs,
     dispatched = storage_type_assign(&out_stype, kDefaultStorage, 
dispatch_mode,
                                      DispatchMode::kFCompute);
   }
-  const TShape axis = param.axis.has_value() ? param.axis.value() : TShape();
-  if (!dispatched && (in_stype == kRowSparseStorage || in_stype == 
kCSRStorage) &&
-      axis.ndim() == 0 && param.ord == 2) {
-    // l2 norm: rsp/csr, axis = () -> dns
-    dispatched = storage_type_assign(&out_stype, kDefaultStorage, 
dispatch_mode,
-                                     DispatchMode::kFComputeEx);
-  }
-  if (!dispatched && in_stype == kCSRStorage && axis.ndim() == 1 && 
!param.keepdims &&
-      (axis[0] == 0 || axis[0] == 1) && param.ord == 2) {
-    // l2 norm: csr, axis = 0/1 -> dns
-    dispatched = storage_type_assign(&out_stype, kDefaultStorage, 
dispatch_mode,
-                                     dispatch_ex);
+  if (param.ord == 2) {
+    const TShape axis = param.axis.has_value() ? param.axis.value() : TShape();
+    if (!dispatched && (in_stype == kRowSparseStorage || in_stype == 
kCSRStorage) &&
+        axis.ndim() == 0 && param.ord == 2) {
+      // l2 norm: rsp/csr, axis = () -> dns
+      dispatched = storage_type_assign(&out_stype, kDefaultStorage, 
dispatch_mode,
+                                       DispatchMode::kFComputeEx);
+    }
+    if (!dispatched && in_stype == kCSRStorage && axis.ndim() == 1 && 
!param.keepdims &&
+        (axis[0] == 0 || axis[0] == 1) && param.ord == 2) {
+      // l2 norm: csr, axis = 0/1 -> dns
+      dispatched = storage_type_assign(&out_stype, kDefaultStorage, 
dispatch_mode,
+                                       dispatch_ex);
+    }
   }
   if (!dispatched) {
     dispatched = dispatch_fallback(out_attrs, dispatch_mode);
@@ -984,13 +986,13 @@ void SqRootForL2(const OpContext& ctx, OpReqType req, 
const TBlob &output) {
 }
 
 template<typename xpu>
-void L2NormCompute(const nnvm::NodeAttrs& attrs,
+void LpNormCompute(const nnvm::NodeAttrs& attrs,
                    const OpContext& ctx,
                    const std::vector<TBlob>& inputs,
                    const std::vector<OpReqType>& req,
                    const std::vector<TBlob>& outputs) {
   const NormParam& param = nnvm::get<NormParam>(attrs.parsed);
-  CHECK_EQ(param.ord, 2) << "norm only support ord=2";
+  CHECK(param.ord == 1 || param.ord == 2) << "norm only supports ord=1 and 
ord=2";
   if (req[0] == kNullOp) return;
 
   TShape small;
@@ -999,13 +1001,18 @@ void L2NormCompute(const nnvm::NodeAttrs& attrs,
   } else {
     small = ReduceAxesShapeImpl(inputs[0].shape_, param.axis, true, false);
   }
-  ReduceAxesComputeImpl<xpu, mshadow::red::sum, false, mshadow_op::square>(
-      ctx, inputs, req, outputs, small);
-  SqRootForL2<xpu>(ctx, req[0], outputs[0]);
+  if (param.ord == 1) {
+    ReduceAxesComputeImpl<xpu, mshadow::red::sum, false, mshadow_op::abs>(
+          ctx, inputs, req, outputs, small);
+  } else if (param.ord == 2) {
+    ReduceAxesComputeImpl<xpu, mshadow::red::sum, false, mshadow_op::square>(
+        ctx, inputs, req, outputs, small);
+    SqRootForL2<xpu>(ctx, req[0], outputs[0]);
+  }
 }
 
 template<typename xpu>
-void L2NormGradCompute(const nnvm::NodeAttrs& attrs,
+void LpNormGradCompute(const nnvm::NodeAttrs& attrs,
                        const OpContext& ctx,
                        const std::vector<TBlob>& inputs,
                        const std::vector<OpReqType>& req,
@@ -1021,8 +1028,36 @@ void L2NormGradCompute(const nnvm::NodeAttrs& attrs,
   } else {
     small = ReduceAxesShapeImpl(outputs[0].shape_, param.axis, true, false);
   }
-  ReduceAxesBackwardUseInOutImpl<xpu, mshadow_op::div, false>(ctx, small, 
inputs,
-                                                              req, outputs);
+  if (param.ord == 1) {
+    TShape src_shape, dst_shape;
+    BroadcastReduceShapeCompact(outputs[0].shape_, small, &src_shape, 
&dst_shape);
+    Stream<xpu> *s = ctx.get_stream<xpu>();
+    MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+      if (dst_shape.ndim() == 2) {
+        Tensor<xpu, 2, DType> ograd =
+          inputs[0].get_with_shape<xpu, 2, DType>(dst_shape.get<2>(), s);
+        Tensor<xpu, 2, DType> igrad =
+          outputs[0].get_with_shape<xpu, 2, DType>(src_shape.get<2>(), s);
+        Tensor<xpu, 2, DType> data =
+          inputs[1].get_with_shape<xpu, 2, DType>(src_shape.get<2>(), s);
+        ASSIGN_DISPATCH(igrad, req[0],
+          broadcast_to(ograd, src_shape)*F<mshadow_op::sign>(data));
+      } else {
+        const int ndim = MXNET_SPECIAL_MAX_NDIM;
+        Tensor<xpu, ndim, DType> igrad =
+          outputs[0].get_with_shape<xpu, ndim, DType>(src_shape.get<ndim>(), 
s);
+        Tensor<xpu, ndim, DType> ograd =
+          inputs[0].get_with_shape<xpu, ndim, DType>(dst_shape.get<ndim>(), s);
+        Tensor<xpu, ndim, DType> data =
+          inputs[1].get_with_shape<xpu, ndim, DType>(src_shape.get<ndim>(), s);
+        ASSIGN_DISPATCH(igrad, req[0],
+          broadcast_to(ograd, src_shape)*F<mshadow_op::sign>(data));
+      }
+    });
+  } else if (param.ord == 2) {
+    ReduceAxesBackwardUseInOutImpl<xpu, mshadow_op::div, false>(ctx, small, 
inputs,
+                                                                req, outputs);
+  }
 }
 
 template<typename xpu>
diff --git a/src/operator/tensor/broadcast_reduce_op_value.cc 
b/src/operator/tensor/broadcast_reduce_op_value.cc
index 7bcc3e9..cde31bf 100644
--- a/src/operator/tensor/broadcast_reduce_op_value.cc
+++ b/src/operator/tensor/broadcast_reduce_op_value.cc
@@ -88,9 +88,9 @@ MXNET_ADD_SPARSE_OP_ALIAS(sum)
 
 Example::
 
-  data = [[[1,2],[2,3],[1,3]],
-          [[1,4],[4,3],[5,2]],
-          [[7,1],[7,2],[7,3]]]
+  data = [[[1, 2], [2, 3], [1, 3]],
+          [[1, 4], [4, 3], [5, 2]],
+          [[7, 1], [7, 2], [7, 3]]]
 
   sum(data, axis=1)
   [[  4.   8.]
@@ -100,9 +100,9 @@ Example::
   sum(data, axis=[1,2])
   [ 12.  19.  27.]
 
-  data = [[1,2,0],
-          [3,0,1],
-          [4,1,0]]
+  data = [[1, 2, 0],
+          [3, 0, 1],
+          [4, 1, 0]]
 
   csr = cast_storage(data, 'csr')
 
@@ -280,14 +280,20 @@ MXNET_ADD_SPARSE_OP_ALIAS(norm)
 
 This operator computes the norm on an NDArray with the specified axis, 
depending
 on the value of the ord parameter. By default, it computes the L2 norm on the 
entire
-array.
+array. Currently only ord=2 supports sparse ndarrays.
 
 Examples::
 
-  x = [[1, 2],
-       [3, 4]]
+  x = [[[1, 2],
+        [3, 4]],
+       [[2, 2],
+        [5, 6]]]
 
-  norm(x) = [5.47722578]
+  norm(x, ord=2, axis=1) = [[3.1622777 4.472136 ]
+                            [5.3851647 6.3245554]]
+
+  norm(x, ord=1, axis=1) = [[4., 6.],
+                            [7., 8.]]
 
   rsp = x.cast_storage('row_sparse')
 
@@ -303,13 +309,13 @@ Examples::
 .set_attr_parser(ParamParser<NormParam>)
 .set_attr<nnvm::FInferShape>("FInferShape", NormShape)
 .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
-.set_attr<FInferStorageType>("FInferStorageType", L2NormStorageType)
+.set_attr<FInferStorageType>("FInferStorageType", LpNormStorageType)
 .set_attr<nnvm::FGradient>("FGradient", ReduceGrad{ "_backward_norm" })
 .set_attr<FResourceRequest>("FResourceRequest",
   [](const NodeAttrs& attrs) {
     return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
   })
-.set_attr<FCompute>("FCompute<cpu>", L2NormCompute<cpu>)
+.set_attr<FCompute>("FCompute<cpu>", LpNormCompute<cpu>)
 .set_attr<FComputeEx>("FComputeEx<cpu>", L2NormComputeEx<cpu>)
 .add_argument("data", "NDArray-or-Symbol", "The input")
 .add_arguments(NormParam::__FIELDS__());
@@ -322,7 +328,7 @@ NNVM_REGISTER_OP(_backward_norm)
   [](const NodeAttrs& attrs) {
     return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
   })
-.set_attr<FCompute>("FCompute<cpu>", L2NormGradCompute<cpu>);
+.set_attr<FCompute>("FCompute<cpu>", LpNormGradCompute<cpu>);
 
 
 }  // namespace op
diff --git a/src/operator/tensor/broadcast_reduce_op_value.cu 
b/src/operator/tensor/broadcast_reduce_op_value.cu
index f7fba68..e2a3840 100644
--- a/src/operator/tensor/broadcast_reduce_op_value.cu
+++ b/src/operator/tensor/broadcast_reduce_op_value.cu
@@ -101,11 +101,11 @@ NNVM_REGISTER_OP(_broadcast_backward)
 .set_attr<FCompute>("FCompute<gpu>", ReduceAxesCompute<gpu, 
mshadow::red::sum>);
 
 NNVM_REGISTER_OP(norm)
-.set_attr<FCompute>("FCompute<gpu>", L2NormCompute<gpu>)
+.set_attr<FCompute>("FCompute<gpu>", LpNormCompute<gpu>)
 .set_attr<FComputeEx>("FComputeEx<gpu>", L2NormComputeEx<gpu>);
 
 NNVM_REGISTER_OP(_backward_norm)
-.set_attr<FCompute>("FCompute<gpu>", L2NormGradCompute<gpu>);
+.set_attr<FCompute>("FCompute<gpu>", LpNormGradCompute<gpu>);
 
 }  // namespace op
 }  // namespace mxnet
diff --git a/tests/python/unittest/test_ndarray.py 
b/tests/python/unittest/test_ndarray.py
index aeaa0b7..a01514d 100644
--- a/tests/python/unittest/test_ndarray.py
+++ b/tests/python/unittest/test_ndarray.py
@@ -27,6 +27,7 @@ from mxnet.test_utils import assert_almost_equal, 
assert_exception
 from mxnet.test_utils import default_context
 from mxnet.test_utils import np_reduce
 from mxnet.test_utils import same
+from mxnet.test_utils import random_sample, rand_shape_nd
 from numpy.testing import assert_allclose
 import mxnet.autograd
 
@@ -1275,33 +1276,27 @@ def test_ndarray_astype():
 
 @with_seed()
 def test_norm(ctx=default_context()):
-    np_arr = np.random.uniform(size=(3, 3, 3, 3))
+    def l1norm(input_data, axis=0, keepdims=False):
+        return np.sum(abs(input_data), axis=axis, keepdims=keepdims)
+    def l2norm(input_data, axis=0, keepdims=False): 
+        return np.linalg.norm(input_data, axis=axis, keepdims=keepdims)
+
+    in_data_dim = random_sample([4,5,6], 1)[0]
+    in_data_shape = rand_shape_nd(in_data_dim)
+    np_arr = np.random.uniform(-1, 1, in_data_shape).astype(np.float32)
     mx_arr = mx.nd.array(np_arr, ctx=ctx)
-    arr1 = np.linalg.norm(np_arr, keepdims=False)
-    arr2 = mx.nd.norm(mx_arr, keepdims=False)
-    print(arr1)
-    print(arr2.asnumpy())
-    mx.test_utils.assert_almost_equal(arr1, arr2.asnumpy()[0])
-
-    for i in range(4):
-        arr1 = np.linalg.norm(np_arr, axis=i, keepdims=False)
-        arr2 = mx.nd.norm(mx_arr, axis=i, keepdims=False)
-        assert arr1.shape == arr2.shape
-        mx.test_utils.assert_almost_equal(arr1, arr2.asnumpy())
-
-        arr1 = np.linalg.norm(np_arr, axis=i, keepdims=True)
-        arr2 = mx.nd.norm(mx_arr, axis=i, keepdims=True)
-        assert arr1.shape == arr2.shape
-        mx.test_utils.assert_almost_equal(arr1, arr2.asnumpy())
-        if (i < 3):
-            arr1 = np.linalg.norm(np_arr, axis=(i, i+1), keepdims=False)
-            arr2 = mx.nd.norm(mx_arr, axis=(i, i+1), keepdims=False)
-            assert arr1.shape == arr2.shape
-            mx.test_utils.assert_almost_equal(arr1, arr2.asnumpy())
-            arr1 = np.linalg.norm(np_arr, axis=(i, i+1), keepdims=True)
-            arr2 = mx.nd.norm(mx_arr, axis=(i, i+1), keepdims=True)
-            assert arr1.shape == arr2.shape
-            mx.test_utils.assert_almost_equal(arr1, arr2.asnumpy())
+    for ord in [1,2]:
+        for keep_dims in [True, False]:
+            for i in range(4):
+                npy_out = l1norm(np_arr, i, keep_dims) if ord==1 else 
l2norm(np_arr, i, keep_dims)
+                mx_out = mx.nd.norm(mx_arr, ord=ord, axis=i, 
keepdims=keep_dims)
+                assert npy_out.shape == mx_out.shape
+                mx.test_utils.assert_almost_equal(npy_out, mx_out.asnumpy())
+                if (i < 3):
+                    npy_out = l1norm(np_arr, (i, i+1), keep_dims) if ord==1 
else l2norm(np_arr, (i, i+1), keep_dims)
+                    mx_out = mx.nd.norm(mx_arr, ord=ord, axis=(i, i+1), 
keepdims=keep_dims)
+                    assert npy_out.shape == mx_out.shape
+                    mx.test_utils.assert_almost_equal(npy_out, 
mx_out.asnumpy())
 
 @with_seed()
 def test_ndarray_cpu_shared_ctx():
diff --git a/tests/python/unittest/test_operator.py 
b/tests/python/unittest/test_operator.py
index 0fa31de..2707d8f 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -3012,6 +3012,50 @@ def check_layer_normalization(in_shape, axis, eps, 
dtype=np.float32, forward_che
                                grad_nodes={'data': req, 'gamma': req, 'beta': 
req},
                                numeric_eps=1e-2, rtol=1e-2, atol=1e-2)
 
+@with_seed()
+def test_norm():
+    def l1norm(input_data, axis=0, keepdims=True):
+        return np.sum(abs(input_data), axis=axis, keepdims=keepdims)
+    def l2norm(input_data, axis=0, keepdims=True): 
+        return np.linalg.norm(input_data, axis=axis, keepdims=keepdims)
+
+    ctx = default_context()
+    data = mx.symbol.Variable('data')
+    in_data_dim = random_sample([4,5,6], 1)[0]
+    in_shape = rand_shape_nd(in_data_dim)
+    epsilon = 1e-3
+    for order in [1, 2]:
+        for dtype in [np.float16, np.float32, np.float64]:
+            in_data = np.random.uniform(-1, 1, in_shape).astype(dtype)
+            in_data[abs(in_data) < epsilon] = epsilon
+            for i in range(in_data_dim):
+                norm_sym = mx.symbol.norm(data=data, ord=order, axis=i, 
keepdims=True)
+                npy_out = l1norm(in_data, i) if order is 1 else 
l2norm(in_data, i)
+                npy_out_backward = np.sign(in_data) if order is 1 else 
in_data/npy_out 
+                check_symbolic_forward(norm_sym, [in_data], [npy_out],
+                                        rtol=1e-2 if dtype is np.float16 else 
1e-5,
+                                        atol=1e-2 if dtype is np.float16 else 
1e-5, ctx=ctx)
+                check_symbolic_backward(norm_sym, [in_data], 
[np.ones(npy_out.shape)],
+                                        [npy_out_backward],
+                                        rtol=1e-2 if dtype is np.float16 else 
1e-5,
+                                        atol=1e-2 if dtype is np.float16 else 
1e-5, ctx=ctx)
+                # check gradient
+                check_numeric_gradient(norm_sym, [in_data], 
numeric_eps=epsilon, rtol=1e-2, atol=1e-3)
+                if i < in_data_dim-1:
+                    norm_sym = mx.symbol.norm(data=data, ord=order, axis=(i, 
i+1), keepdims=True)
+                    npy_out = l1norm(in_data, (i, i+1)) if order is 1 else 
l2norm(in_data, (i, i+1))
+                    npy_out_backward = np.sign(in_data) if order is 1 else 
in_data/npy_out 
+                    check_symbolic_forward(norm_sym, [in_data], [npy_out],
+                                           rtol=1e-2 if dtype is np.float16 
else 1e-5,
+                                           atol=1e-2 if dtype is np.float16 
else 1e-5, ctx=ctx)
+                    check_symbolic_backward(norm_sym, [in_data], 
[np.ones(npy_out.shape)],
+                                            [npy_out_backward],
+                                            rtol=1e-2 if dtype is np.float16 
else 1e-5,
+                                            atol=1e-2 if dtype is np.float16 
else 1e-5, ctx=ctx)
+                    # check gradient
+                    check_numeric_gradient(norm_sym, [in_data], 
numeric_eps=epsilon, rtol=1e-2, atol=1e-3)
+                        
+
 def test_layer_norm():
     for dtype, forward_check_eps in zip([np.float16, np.float32, np.float64],
                                         [1E-2, 1E-3, 1E-4]):

Reply via email to