This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new b86ccf1  Add erfinv operator for calculating inverse error function 
(#13811)
b86ccf1 is described below

commit b86ccf1be704f5f97e085f3fb21e781bceac884d
Author: Ziyi Mu <dd...@sina.com>
AuthorDate: Tue Jan 22 11:36:31 2019 -0500

    Add erfinv operator for calculating inverse error function (#13811)
    
    * add default behaviour for argmax
    
    * prototype of erfvin
    
    * add test
    
    * gpu support
    
    * Revert "add default behaviour for argmax"
    
    This reverts commit 64e9f1a9e3c9cabf312b8d80b3520b22da31c0b6.
    
    * move erfinv to contrib
    
    * edit copyright
    
    * remove atof
    
    * use std and update license
    
    * add license exclude file
    
    * fix per eric's comments
    
    * change license header
---
 docs/api/python/ndarray/ndarray.md                 |   1 +
 docs/api/python/symbol/symbol.md                   |   1 +
 src/operator/contrib/erfinv-inl.h                  | 105 +++++++++++++++++++++
 src/operator/mshadow_op.h                          |   3 +
 src/operator/operator_tune.cc                      |   4 +-
 src/operator/tensor/elemwise_unary_op_basic.cc     |  16 ++++
 src/operator/tensor/elemwise_unary_op_basic.cu     |   8 ++
 .../nightly/apache_rat_license_check/rat-excludes  |   3 +-
 tests/python/unittest/test_operator.py             |   6 +-
 tools/license_header.py                            |   1 +
 10 files changed, 145 insertions(+), 3 deletions(-)

diff --git a/docs/api/python/ndarray/ndarray.md 
b/docs/api/python/ndarray/ndarray.md
index 6419c4e..2df18c2 100644
--- a/docs/api/python/ndarray/ndarray.md
+++ b/docs/api/python/ndarray/ndarray.md
@@ -659,6 +659,7 @@ The `ndarray` package provides several classes:
     relu
     sigmoid
     erf
+    erfinv
 ```
 
 ### More
diff --git a/docs/api/python/symbol/symbol.md b/docs/api/python/symbol/symbol.md
index 9eba261..0fc2aa7 100644
--- a/docs/api/python/symbol/symbol.md
+++ b/docs/api/python/symbol/symbol.md
@@ -659,6 +659,7 @@ Composite multiple symbols into a new one by an operator.
     relu
     sigmoid
     erf
+    erfinv
 ```
 
 ### More
diff --git a/src/operator/contrib/erfinv-inl.h 
b/src/operator/contrib/erfinv-inl.h
new file mode 100644
index 0000000..8d718ad
--- /dev/null
+++ b/src/operator/contrib/erfinv-inl.h
@@ -0,0 +1,105 @@
+/*
+ * Copyright (c) 2014 Indiana University
+ * All rights reserved.
+ * Written by Prof. Gary L. Pavlis, Dept. of Geol. Sci.,
+ *           Indiana University, Bloomington, IN
+ * This software is licensed under the New BSD license:
+ * Redistribution and use in source and binary forms,
+ * with or without modification, are permitted provided
+ * that the following conditions are met:
+ * Redistributions of source code must retain the above
+ * copyright notice, this list of conditions and the
+ * following disclaimer.
+ * Redistributions in binary form must reproduce the
+ * above copyright notice, this list of conditions and
+ * the following disclaimer in the documentation and/or
+ * other materials provided with the distribution.
+ * Neither the name of Indiana University nor
+ * the names of its contributors may be used to endorse
+ * or promote products derived from this software without
+ * specific prior written permission.
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
+ * CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED
+ * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+ * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
+ * PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL
+ * THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY
+ * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+ * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF
+ * USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
+ * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER
+ * IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
+ * USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+ * POSSIBILITY OF SUCH DAMAGE.
+ */
+/*
+ * The next function is taken from
+ * 
https://github.com/antelopeusersgroup/antelope_contrib/blob/master/lib/location/libgenloc/erfinv.c.
+ * Output was modified to be inf or -inf when input is 1 or -1.
+ */
+#ifndef MXNET_OPERATOR_CONTRIB_ERFINV_INL_H_
+#define MXNET_OPERATOR_CONTRIB_ERFINV_INL_H_
+
+#define _USE_MATH_DEFINES
+
+#include <mxnet/base.h>
+#include <limits>
+#include "math.h"
+
+namespace mxnet {
+namespace op {
+namespace mshadow_op {
+
+/*! \brief inverse gauss error function */
+struct erfinv : public mxnet_op::tunable {
+  template<typename DType>
+  MSHADOW_XINLINE static DType Map(DType v) {
+    /* Function to calculate inverse error function.  Rational approximation
+    is used to generate an initial approximation, which is then improved to
+    full accuracy by two steps of Newton's method.  Code is a direct
+    translation of the erfinv m file in matlab version 2.0.
+    Author:  Gary L. Pavlis, Indiana University
+    Date:  February 1996
+    */
+    const double central_range = 0.7;
+    double y = static_cast<double>(v);
+    double y_fab = std::fabs(y);
+    /*working variables */
+    double x = 0.0;
+    double z, num, dem;
+    /* coefficients in rational expansion */
+    double a[4]={ 0.886226899, -1.645349621,  0.914624893, -0.140543331};
+    double b[4]={-2.118377725,  1.442710462, -0.329097515,  0.012229801};
+    double c[4]={-1.970840454, -1.624906493,  3.429567803,  1.641345311};
+    double d[2]={ 3.543889200,  1.637067800};
+    if (y_fab > 1.0) {
+      /* This needs IEEE constant*/
+      return DType(std::numeric_limits<double>::quiet_NaN());
+    } else if (y_fab == 1.0) {
+      return DType((std::copysign(1.0, 
y))*std::numeric_limits<double>::infinity());
+    } else if (y_fab <= central_range) {
+            z = y*y;
+            num = (((a[3]*z + a[2])*z + a[1])*z + a[0]);
+            dem = ((((b[3]*z + b[2])*z + b[1])*z +b[0])*z + 1.0);
+            x = y*num/dem;
+    } else {
+            z = std::sqrt(-std::log((1.0-y_fab)/2.0));
+            num = ((c[3]*z + c[2])*z + c[1])*z + c[0];
+            dem = (d[1]*z + d[0])*z + 1.0;
+            x = (std::copysign(1.0, y))*num/dem;
+    }
+    /* Two steps of Newton-Raphson correction */
+    x = x - (std::erf(x) - y)/((2.0/std::sqrt(M_PI))*std::exp(-x*x));
+    x = x - (std::erf(x) - y)/((2.0/std::sqrt(M_PI))*std::exp(-x*x));
+
+    return DType(x);
+  }
+};
+
+}  // namespace mshadow_op
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_OPERATOR_CONTRIB_ERFINV_INL_H_
diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h
index 0b20a02..f56436b 100644
--- a/src/operator/mshadow_op.h
+++ b/src/operator/mshadow_op.h
@@ -31,6 +31,7 @@
 #include "math_functions-inl.h"
 #include "special_functions-inl.h"
 #include "./operator_tune.h"
+#include "./contrib/erfinv-inl.h"
 
 #ifdef __CUDACC__
 #include <cuda_fp16.h>
@@ -169,6 +170,8 @@ struct softrelu : public mxnet_op::tunable {
 
 MXNET_UNARY_MATH_OP(softrelu_grad, -math::expm1(-a));
 
+MXNET_UNARY_MATH_OP(erfinv_grad, 0.5 * math::sqrt(PI) * 
math::exp(math::sqr(erfinv::Map(a))));
+
 MXNET_UNARY_MATH_OP(erf_grad, 2.0 / math::sqrt(PI) * math::exp(-(a * a)));
 
 MXNET_SIMPLE_UNARY_MATH_OP(erf);
diff --git a/src/operator/operator_tune.cc b/src/operator/operator_tune.cc
index 2018e80..56d35b2 100644
--- a/src/operator/operator_tune.cc
+++ b/src/operator/operator_tune.cc
@@ -234,9 +234,11 @@ IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::log2); 
 // NOLINT()
 IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::log2_grad);  // NOLINT()
 IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::log10);  // NOLINT()
 IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::log10_grad);  // NOLINT()
-IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::sin);  // NOLINT()
 IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::erf);  // NOLINT()
 IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::erf_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::erfinv);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::erfinv_grad);  // NOLINT()
+IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::sin);  // NOLINT()
 IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sin_grad);  // NOLINT()
 IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::sinh);  // NOLINT()
 IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sinh_grad);  // NOLINT()
diff --git a/src/operator/tensor/elemwise_unary_op_basic.cc 
b/src/operator/tensor/elemwise_unary_op_basic.cc
index c0d420f..d0079b5 100644
--- a/src/operator/tensor/elemwise_unary_op_basic.cc
+++ b/src/operator/tensor/elemwise_unary_op_basic.cc
@@ -916,6 +916,22 @@ MXNET_OPERATOR_REGISTER_BINARY(_backward_erf)
 .set_attr<FCompute>("FCompute<cpu>",
                     ElemwiseBinaryOp::Compute<cpu, 
unary_bwd<mshadow_op::erf_grad>>);
 
+// erfinv
+MXNET_OPERATOR_REGISTER_UNARY(erfinv)
+.describe(R"code(Returns element-wise inverse gauss error function of the 
input.
+
+Example::
+
+   erfinv([0, 0.5., -1.]) = [0., 0.4769, -inf]
+
+)code" ADD_FILELINE)
+.set_attr<FCompute>("FCompute<cpu>", UnaryOp::Compute<cpu, mshadow_op::erfinv>)
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_erfinv"});
+
+MXNET_OPERATOR_REGISTER_BINARY(_backward_erfinv)
+.set_attr<FCompute>("FCompute<cpu>",
+                    ElemwiseBinaryOp::Compute<cpu, 
unary_bwd<mshadow_op::erfinv_grad>>);
+
 // rcbrt
 MXNET_OPERATOR_REGISTER_UNARY(rcbrt)
 .describe(R"code(Returns element-wise inverse cube-root value of the input.
diff --git a/src/operator/tensor/elemwise_unary_op_basic.cu 
b/src/operator/tensor/elemwise_unary_op_basic.cu
index 14f2be0..642cb0e 100644
--- a/src/operator/tensor/elemwise_unary_op_basic.cu
+++ b/src/operator/tensor/elemwise_unary_op_basic.cu
@@ -62,6 +62,14 @@ NNVM_REGISTER_OP(_backward_erf)
 .set_attr<FCompute>("FCompute<gpu>",
                     ElemwiseBinaryOp::Compute<gpu, 
unary_bwd<mshadow_op::erf_grad>>);
 
+// erfinv
+NNVM_REGISTER_OP(erfinv)
+.set_attr<FCompute>("FCompute<gpu>", UnaryOp::Compute<gpu, 
mshadow_op::erfinv>);
+
+NNVM_REGISTER_OP(_backward_erfinv)
+.set_attr<FCompute>("FCompute<gpu>",
+                    ElemwiseBinaryOp::Compute<gpu, 
unary_bwd<mshadow_op::erfinv_grad>>);
+
 // copy
 NNVM_REGISTER_OP(_copy)
 .set_attr<FCompute>("FCompute<gpu>", UnaryOp::IdentityCompute<gpu>)
diff --git a/tests/nightly/apache_rat_license_check/rat-excludes 
b/tests/nightly/apache_rat_license_check/rat-excludes
index 5969f01..782ef40 100755
--- a/tests/nightly/apache_rat_license_check/rat-excludes
+++ b/tests/nightly/apache_rat_license_check/rat-excludes
@@ -35,6 +35,7 @@ _mask.pyx
 coco.py
 base.pyi
 special_functions-inl.h
+erfinv-inl.h
 im2col.cuh
 im2col.h
 pool.h
@@ -49,4 +50,4 @@ deformable_im2col.h
 REQUIRE
 include/*
 .*.iml
-.*.json.ref
\ No newline at end of file
+.*.json.ref
diff --git a/tests/python/unittest/test_operator.py 
b/tests/python/unittest/test_operator.py
index cb19fd8..cda801c 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -3500,7 +3500,11 @@ def test_special_functions_using_scipy():
 
     # erf
     mathematical_core("erf", lambda x: mx.sym.erf(x), lambda x: 
scipy_special.erf(x),
-                     lambda x: 2.0 / math.sqrt(math.pi) * math.exp(-(x ** 2)), 
0.5, 0.5)
+                     lambda x: 2.0 / math.sqrt(math.pi) * np.exp(-(x ** 2)), 
0.5, 0.5)
+
+    # erfinv
+    mathematical_core("erfinv", lambda x: mx.sym.erfinv(x), lambda x: 
scipy_special.erfinv(x),
+                     lambda x: 0.5 * math.sqrt(math.pi) * 
np.exp(scipy_special.erfinv(x) ** 2), 0.5, 0.5)
 
 
 def rounding(name, forward_mxnet_call, forward_numpy_call, data_init=5., 
grad_init=2.):
diff --git a/tools/license_header.py b/tools/license_header.py
index 199d56c..11cc928 100755
--- a/tools/license_header.py
+++ b/tools/license_header.py
@@ -84,6 +84,7 @@ _WHITE_LIST = [
                'src/operator/nn/im2col.cuh',
 
                # Licenses in headers
+               'src/operator/contrib/erfinv-inl.h',
                'docs/_static/searchtools_custom.js',
                'docs/_static/js/clipboard.js',
                'docs/_static/js/clipboard.min.js',

Reply via email to