piiswrong closed pull request #10633: [MXNET-346] Hard Sigmoid Operator
URL: https://github.com/apache/incubator-mxnet/pull/10633
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/operator/tensor/elemwise_unary_op.h 
b/src/operator/tensor/elemwise_unary_op.h
index 37710843544..2422e5ba6a4 100644
--- a/src/operator/tensor/elemwise_unary_op.h
+++ b/src/operator/tensor/elemwise_unary_op.h
@@ -388,6 +388,94 @@ void CastCompute(const nnvm::NodeAttrs& attrs,
   });
 }
 
+struct HardSigmoidParam : public dmlc::Parameter<HardSigmoidParam> {
+  real_t alpha;
+  real_t beta;
+  DMLC_DECLARE_PARAMETER(HardSigmoidParam) {
+    DMLC_DECLARE_FIELD(alpha)
+    .set_default(0.2)
+    .describe("Slope of hard sigmoid");
+    DMLC_DECLARE_FIELD(beta)
+    .set_default(0.5)
+    .describe("Bias of hard sigmoid.");
+  }
+};
+
+template<int req>
+struct hard_sigmoid_forward {
+  template<typename DType>
+    MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* 
in_data,
+                                    const real_t alpha, const real_t beta) {
+      DType result = DType(alpha * in_data[i] + beta);
+      result = (DType(1) < result) ? DType(1) : result;
+      result = (DType(0) > result) ? DType(0) : result;
+      KERNEL_ASSIGN(out_data[i], req, result);
+    }
+};
+
+template<int req>
+struct hard_sigmoid_backward {
+  template<typename DType>
+    MSHADOW_XINLINE static void Map(int i, DType* in_grad, const DType* 
in_data,
+                                    const DType* out_grad, const real_t alpha, 
const real_t beta) {
+      DType out_val = DType(alpha) * in_data[i] + DType(beta);
+      DType grad = (out_val > DType(0) && out_val < DType(1)) ?
+                   (out_grad[i] * DType(alpha)) : DType(0);
+      KERNEL_ASSIGN(in_grad[i], req, grad);
+    }
+};
+
+
+template<typename xpu>
+void HardSigmoidForward(const nnvm::NodeAttrs& attrs,
+                        const OpContext& ctx,
+                        const std::vector<TBlob>& inputs,
+                        const std::vector<OpReqType>& req,
+                        const std::vector<TBlob>& outputs) {
+  CHECK_EQ(inputs.size(), 1U);
+  CHECK_EQ(outputs.size(), 1U);
+  CHECK_EQ(req.size(), 1U);
+  CHECK(req[0] != kNullOp);
+  using namespace mshadow;
+  Stream<xpu> *s = ctx.get_stream<xpu>();
+  const TBlob& in_data = inputs[0];
+  const TBlob& out_data = outputs[0];
+  const HardSigmoidParam& param = nnvm::get<HardSigmoidParam>(attrs.parsed);
+  using namespace mxnet_op;
+  MSHADOW_REAL_TYPE_SWITCH(out_data.type_flag_, DType, {
+    MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
+      Kernel<hard_sigmoid_forward<req_type>, xpu>::Launch(
+        s, out_data.Size(), out_data.dptr<DType>(), in_data.dptr<DType>(),
+        param.alpha, param.beta);
+    });
+  });
+}
+
+template<typename xpu>
+void HardSigmoidBackward(const nnvm::NodeAttrs& attrs,
+                         const OpContext& ctx,
+                         const std::vector<TBlob>& inputs,
+                         const std::vector<OpReqType>& req,
+                         const std::vector<TBlob>& outputs) {
+  CHECK_EQ(inputs.size(), 2U);
+  CHECK_EQ(outputs.size(), 1U);
+  CHECK_EQ(req.size(), 1U);
+  using namespace mshadow;
+  Stream<xpu> *s = ctx.get_stream<xpu>();
+  const TBlob& out_grad = inputs[0];
+  const TBlob& in_data = inputs[1];
+  const TBlob& in_grad = outputs[0];
+  const HardSigmoidParam& param = nnvm::get<HardSigmoidParam>(attrs.parsed);
+  using namespace mxnet_op;
+  MSHADOW_REAL_TYPE_SWITCH(in_data.type_flag_, DType, {
+    MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
+      Kernel<hard_sigmoid_backward<req_type>, xpu>::Launch(
+        s, in_grad.Size(), in_grad.dptr<DType>(), in_data.dptr<DType>(),
+        out_grad.dptr<DType>(), param.alpha, param.beta);
+    });
+  });
+}
+
 /*! \brief Unary compute */
 #define MXNET_OPERATOR_REGISTER_UNARY(__name$)                      \
   NNVM_REGISTER_OP(__name$)                                         \
diff --git a/src/operator/tensor/elemwise_unary_op_basic.cc 
b/src/operator/tensor/elemwise_unary_op_basic.cc
index e711148898f..11886eb7d87 100644
--- a/src/operator/tensor/elemwise_unary_op_basic.cc
+++ b/src/operator/tensor/elemwise_unary_op_basic.cc
@@ -107,6 +107,35 @@ The storage type of ``sigmoid`` output is always dense
 
 MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_sigmoid,
                                                
unary_bwd<mshadow_op::sigmoid_grad>);
+
+DMLC_REGISTER_PARAMETER(HardSigmoidParam);
+MXNET_OPERATOR_REGISTER_UNARY(hard_sigmoid)
+.describe(R"code(Computes hard sigmoid of x element-wise.
+
+.. math::
+   y = max(0, min(1, alpha * x + beta))
+
+)code" ADD_FILELINE)
+.set_attr_parser(ParamParser<HardSigmoidParam>)
+.set_attr<FCompute>("FCompute<cpu>", HardSigmoidForward<cpu>)
+.set_attr<nnvm::FGradient>("FGradient", 
ElemwiseGradUseIn{"_backward_hard_sigmoid"})
+.add_arguments(HardSigmoidParam::__FIELDS__());
+
+NNVM_REGISTER_OP(_backward_hard_sigmoid)
+.set_attr_parser(ParamParser<HardSigmoidParam>)
+.set_num_inputs(2)
+.set_num_outputs(1)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr<nnvm::FInplaceOption>("FInplaceOption",
+  [](const NodeAttrs& attrs){
+    return std::vector<std::pair<int, int> >{{0, 0}};
+  })
+.set_attr<nnvm::FInplaceIdentity>("FInplaceIdentity",
+  [](const NodeAttrs& attrs){
+    return std::vector<bool>{true};
+  })
+.set_attr<FCompute>("FCompute<cpu>", HardSigmoidBackward<cpu>);
+
 // softsign
 MXNET_OPERATOR_REGISTER_UNARY(softsign)
 MXNET_ADD_SPARSE_OP_ALIAS(softsign)
diff --git a/src/operator/tensor/elemwise_unary_op_basic.cu 
b/src/operator/tensor/elemwise_unary_op_basic.cu
index 8dfa9af74ce..4843c88deed 100644
--- a/src/operator/tensor/elemwise_unary_op_basic.cu
+++ b/src/operator/tensor/elemwise_unary_op_basic.cu
@@ -40,6 +40,12 @@ NNVM_REGISTER_OP(_backward_sigmoid)
 .set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::Compute<
   gpu, unary_bwd<mshadow_op::sigmoid_grad>>);
 
+NNVM_REGISTER_OP(hard_sigmoid)
+.set_attr<FCompute>("FCompute<gpu>", HardSigmoidForward<gpu>);
+
+NNVM_REGISTER_OP(_backward_hard_sigmoid)
+.set_attr<FCompute>("FCompute<gpu>", HardSigmoidBackward<gpu>);
+
 // softsign
 NNVM_REGISTER_OP(softsign)
 .set_attr<FCompute>("FCompute<gpu>", UnaryOp::Compute<gpu, 
mshadow_op::softsign>);
diff --git a/tests/python/unittest/test_operator.py 
b/tests/python/unittest/test_operator.py
index d0089e6f3dd..66c29051040 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -584,6 +584,37 @@ def fsigmoid(a):
     check_symbolic_forward(y, [xa], [ya])
     check_symbolic_backward(y, [xa], [np.ones(shape)], [ya * (1 - ya)])
 
+@with_seed()
+def test_hard_sigmoid():
+    def fhardsigmoid(a, alpha=0.2, beta=0.5):
+        return np.maximum(np.zeros(a.shape, dtype=a.dtype),
+                          np.minimum(np.ones(a.shape, dtype=a.dtype), 
alpha*a+beta))
+    def fhardsigmoid_grad(a, out_grad, alpha=0.2, beta=0.5):
+        orig_out = fhardsigmoid(a, alpha, beta)
+        res = out_grad * alpha
+        res[orig_out <= 0.0] = 0.0
+        res[orig_out >= 1.0] = 0.0
+        return res
+    shape = (3, 4)
+    x = mx.symbol.Variable("x")
+    y = mx.sym.hard_sigmoid(x)
+    for dtype in [np.float16, np.float32, np.float64]:
+        if dtype is np.float16:
+            rtol = 1e-2
+            atol = 1e-4
+        else:
+            rtol = 1e-3
+            atol = 1e-5
+        xa = np.random.uniform(low=-3.0,high=3.0,size=shape).astype(dtype)
+        # function not differentiable at x=2.5 and -2.5
+        xa[xa == 2.5] = xa[xa == 2.5] - 1e-2
+        xa[xa == -2.5] = xa[xa == -2.5] - 1e-2
+        ya = fhardsigmoid(xa)
+        grad_xa = fhardsigmoid_grad(xa, np.ones(shape))
+        check_numeric_gradient(y, [xa], numeric_eps=1e-3, rtol=rtol, atol=atol)
+        check_symbolic_forward(y, [xa], [ya], rtol=rtol, atol=atol)
+        check_symbolic_backward(y, [xa], [np.ones(shape)], [grad_xa], 
rtol=rtol, atol=atol)
+
 @with_seed()
 def test_softsign():
     def fsoftsign(a):


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to