This is an automated email from the ASF dual-hosted git repository.

kevinthesun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new d992468  [topi][relay] add operation tan to TVM (#4938)
d992468 is described below

commit d992468d80af816f0413fc43c2ee1c02f7fe19c3
Author: Yao Wang <[email protected]>
AuthorDate: Thu Mar 5 17:17:35 2020 -0800

    [topi][relay] add operation tan to TVM (#4938)
    
    * Add relay operation relay.op.tan.
    
    * Update tan implementation in TVM.
    
    * Update tests.
    
    * Add shape function for tan.
    
    * Add missing main test to python/frontend/tensorflow/test_forward.
    
    * Revert, back to sin/cos.
    
    * Revert "Revert, back to sin/cos."
    
    This reverts commit 4da5b503b921585ba9d80944b29136142b575c40.
    
    * Fix implementation of tan in cuda. Do not support tan for float16.
    
    Simplify topi/tests/python/test_topi_math. Add testing for tan with float32 
and float64.
    
    Try again to implement tan as sin/cos in llvm.
---
 docs/frontend/tensorflow.rst                     |  1 +
 include/tvm/tir/op.h                             |  1 +
 python/tvm/relay/frontend/mxnet.py               |  1 +
 python/tvm/relay/frontend/tensorflow.py          |  1 +
 python/tvm/relay/frontend/tflite.py              |  8 ++++++++
 python/tvm/relay/op/_tensor.py                   |  2 ++
 python/tvm/relay/op/_tensor_grad.py              |  7 +++++++
 python/tvm/relay/op/tensor.py                    | 15 +++++++++++++++
 python/tvm/te/__init__.py                        |  2 +-
 python/tvm/tir/__init__.py                       |  2 +-
 python/tvm/tir/op.py                             | 16 ++++++++++++++++
 src/relay/op/tensor/unary.cc                     | 11 +++++++++++
 src/target/intrin_rule.cc                        |  3 +++
 src/target/llvm/intrin_rule_llvm.cc              | 14 ++++++++++++++
 src/target/llvm/intrin_rule_nvptx.cc             |  3 +++
 src/target/llvm/intrin_rule_rocm.cc              |  3 +++
 src/target/source/intrin_rule_cuda.cc            | 19 +++++++++++++++++++
 src/tir/ir/expr.cc                               |  2 +-
 tests/python/frontend/tensorflow/test_forward.py | 10 ++++++++++
 tests/python/frontend/tflite/test_forward.py     |  8 ++++++++
 tests/python/relay/test_op_grad_level1.py        |  1 +
 tests/python/relay/test_op_level1.py             |  1 +
 tests/python/unittest/test_testing.py            |  1 +
 topi/include/topi/elemwise.h                     |  1 +
 topi/python/topi/math.py                         | 17 +++++++++++++++++
 topi/src/topi.cc                                 |  5 +++++
 topi/tests/python/test_topi_basic.py             |  1 +
 topi/tests/python/test_topi_math.py              |  2 ++
 28 files changed, 155 insertions(+), 3 deletions(-)

diff --git a/docs/frontend/tensorflow.rst b/docs/frontend/tensorflow.rst
index 87341ab..8a54033 100644
--- a/docs/frontend/tensorflow.rst
+++ b/docs/frontend/tensorflow.rst
@@ -135,6 +135,7 @@ Supported Ops
 - ConcatV2
 - Conv2D
 - Cos
+- Tan
 - CropAndResize
 - DecodeJpeg
 - DepthwiseConv2dNative
diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h
index 5172b14..0a714d8 100644
--- a/include/tvm/tir/op.h
+++ b/include/tvm/tir/op.h
@@ -515,6 +515,7 @@ TVM_DECLARE_INTRIN_UNARY(sqrt);
 TVM_DECLARE_INTRIN_UNARY(rsqrt);
 TVM_DECLARE_INTRIN_UNARY(log);
 TVM_DECLARE_INTRIN_UNARY(popcount);
+TVM_DECLARE_INTRIN_UNARY(tan);
 TVM_DECLARE_INTRIN_UNARY(cos);
 TVM_DECLARE_INTRIN_UNARY(sin);
 TVM_DECLARE_INTRIN_UNARY(atan);
diff --git a/python/tvm/relay/frontend/mxnet.py 
b/python/tvm/relay/frontend/mxnet.py
index 0020a63..c2bfd75 100644
--- a/python/tvm/relay/frontend/mxnet.py
+++ b/python/tvm/relay/frontend/mxnet.py
@@ -1696,6 +1696,7 @@ _identity_list = [
     "ones_like",
     "where",
     "gather_nd",
+    "tan",
     "cos",
     "sin"
 ]
diff --git a/python/tvm/relay/frontend/tensorflow.py 
b/python/tvm/relay/frontend/tensorflow.py
index 14d2418..24164a3 100644
--- a/python/tvm/relay/frontend/tensorflow.py
+++ b/python/tvm/relay/frontend/tensorflow.py
@@ -1564,6 +1564,7 @@ _convert_map = {
     'LessEqual'                         : _broadcast('less_equal'),
     'Log'                               : AttrCvt('log'),
     'Log1p'                             : _log1p(),
+    'Tan'                               : AttrCvt('tan'),
     'Cos'                               : AttrCvt('cos'),
     'Sin'                               : AttrCvt('sin'),
     'LogicalAnd'                        : _logical('logical_and'),
diff --git a/python/tvm/relay/frontend/tflite.py 
b/python/tvm/relay/frontend/tflite.py
index bc51c91..c2ec4d4 100644
--- a/python/tvm/relay/frontend/tflite.py
+++ b/python/tvm/relay/frontend/tflite.py
@@ -68,6 +68,7 @@ class OperatorConverter(object):
             'LOG': self.convert_log,
             'SIN': self.convert_sin,
             'COS': self.convert_cos,
+            'TAN': self.convert_tan,
             'SQRT': self.convert_sqrt,
             'RSQRT': self.convert_rsqrt,
             'NEG': self.convert_neg,
@@ -657,6 +658,13 @@ class OperatorConverter(object):
                 'TFlite quantized SIN operator is not supported yet.')
         return self._convert_unary_elemwise(_op.sin, op)
 
+    def convert_tan(self, op):
+        """Convert TFLite TAN"""
+        if self.is_quantized(op):
+            raise tvm.error.OpNotImplemented(
+                'TFlite quantized TAN operator is not supported yet.')
+        return self._convert_unary_elemwise(_op.tan, op)
+
     def convert_cos(self, op):
         """Convert TFLite COS"""
         if self.is_quantized(op):
diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py
index 9f0906b..4480849 100644
--- a/python/tvm/relay/op/_tensor.py
+++ b/python/tvm/relay/op/_tensor.py
@@ -27,6 +27,7 @@ from ...hybrid import script
 
 
 register_broadcast_schedule("log")
+register_broadcast_schedule("tan")
 register_broadcast_schedule("cos")
 register_broadcast_schedule("sin")
 register_broadcast_schedule("atan")
@@ -214,3 +215,4 @@ register_shape_func("minimum", False, broadcast_shape_func)
 register_shape_func("sqrt", False, elemwise_shape_func)
 register_shape_func("negative", False, elemwise_shape_func)
 register_shape_func("exp", False, elemwise_shape_func)
+register_shape_func("tan", False, elemwise_shape_func)
diff --git a/python/tvm/relay/op/_tensor_grad.py 
b/python/tvm/relay/op/_tensor_grad.py
index 944e51e..33a1937 100644
--- a/python/tvm/relay/op/_tensor_grad.py
+++ b/python/tvm/relay/op/_tensor_grad.py
@@ -61,6 +61,13 @@ def log_grad(orig, grad):
     return [grad * ones_like(x) / x]
 
 
+@register_gradient("tan")
+def tan_grad(orig, grad):
+    """Returns [grad / (cos^2(x))]"""
+    x = orig.args[0]
+    return [grad / (cos(x) * cos(x))]
+
+
 @register_gradient("cos")
 def cos_grad(orig, grad):
     """Returns [grad * (-sin(x))]"""
diff --git a/python/tvm/relay/op/tensor.py b/python/tvm/relay/op/tensor.py
index ada1f5e..7796918 100644
--- a/python/tvm/relay/op/tensor.py
+++ b/python/tvm/relay/op/tensor.py
@@ -47,6 +47,21 @@ def log(data):
     """
     return _make.log(data)
 
+def tan(data):
+    """Compute elementwise tan of data.
+
+    Parameters
+    ----------
+    data : relay.Expr
+        The input data
+
+    Returns
+    -------
+    result : relay.Expr
+        The computed result.
+    """
+    return _make.tan(data)
+
 def cos(data):
     """Compute elementwise cos of data.
 
diff --git a/python/tvm/te/__init__.py b/python/tvm/te/__init__.py
index 065cf4e..3a1a124 100644
--- a/python/tvm/te/__init__.py
+++ b/python/tvm/te/__init__.py
@@ -19,7 +19,7 @@
 """
 # expose all operators in tvm tir.op
 from tvm.tir import any, all, min_value, max_value, trace
-from tvm.tir import exp, erf, tanh, sigmoid, log, cos, sin, atan, sqrt, rsqrt, 
floor, ceil
+from tvm.tir import exp, erf, tanh, sigmoid, log, tan, cos, sin, atan, sqrt, 
rsqrt, floor, ceil
 from tvm.tir import trunc, abs, round, nearbyint, isnan, power, popcount, 
fmod, if_then_else
 from tvm.tir import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, 
floormod
 from tvm.tir import comm_reducer, min, max, sum
diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py
index a5c81ac..aa5871a 100644
--- a/python/tvm/tir/__init__.py
+++ b/python/tvm/tir/__init__.py
@@ -33,7 +33,7 @@ from .stmt import IfThenElse, Evaluate, Prefetch, 
LoweredFunc, stmt_seq, stmt_li
 
 from .op import call_packed, call_pure_intrin, call_intrin, call_pure_extern, 
call_extern
 from .op import call_llvm_intrin, all, any, min_value, max_value, trace
-from .op import exp, erf, tanh, sigmoid, log, cos, sin, atan, sqrt, rsqrt, 
floor, ceil
+from .op import exp, erf, tanh, sigmoid, log, tan, cos, sin, atan, sqrt, 
rsqrt, floor, ceil
 from .op import trunc, abs, round, nearbyint, isnan, power, popcount, fmod, 
if_then_else
 from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod
 from .op import comm_reducer, min, max, sum
diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py
index a8aef8f..c5b1a0a 100644
--- a/python/tvm/tir/op.py
+++ b/python/tvm/tir/op.py
@@ -393,6 +393,22 @@ def log(x):
     """
     return call_pure_intrin(x.dtype, "log", x)
 
+def tan(x):
+    """Take tan of input x.
+
+    Parameters
+    ----------
+    x : PrimExpr
+        Input argument.
+
+    Returns
+    -------
+    y : PrimExpr
+        The result.
+    """
+    return call_pure_intrin(x.dtype, "tan", x)
+
+
 def cos(x):
     """Take cos of input x.
 
diff --git a/src/relay/op/tensor/unary.cc b/src/relay/op/tensor/unary.cc
index 1169fa8..b9d2473 100644
--- a/src/relay/op/tensor/unary.cc
+++ b/src/relay/op/tensor/unary.cc
@@ -51,6 +51,17 @@ RELAY_REGISTER_UNARY_OP("log")
 .set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log));
 
 
+RELAY_REGISTER_UNARY_OP("tan")
+.describe(R"code(Returns the tan of input array, computed element-wise.
+
+.. math::
+   Y = tan(X)
+
+)code" TVM_ADD_FILELINE)
+.set_support_level(1)
+.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::tan));
+
+
 RELAY_REGISTER_UNARY_OP("cos")
 .describe(R"code(Returns the cos of input array, computed element-wise.
 
diff --git a/src/target/intrin_rule.cc b/src/target/intrin_rule.cc
index 7e9ac71..ade3bbd 100644
--- a/src/target/intrin_rule.cc
+++ b/src/target/intrin_rule.cc
@@ -40,6 +40,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log")
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tanh")
 .set_body(DispatchExtern<FloatSuffix>);
 
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tan")
+.set_body(DispatchExtern<FloatSuffix>);
+
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.cos")
 .set_body(DispatchExtern<FloatSuffix>);
 
diff --git a/src/target/llvm/intrin_rule_llvm.cc 
b/src/target/llvm/intrin_rule_llvm.cc
index 758b0af..6c5a9cd 100644
--- a/src/target/llvm/intrin_rule_llvm.cc
+++ b/src/target/llvm/intrin_rule_llvm.cc
@@ -91,6 +91,20 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.pow")
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.popcount")
 .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::ctpop, 1>);
 
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tan")
+.set_body([](const TVMArgs& targs, TVMRetValue* rv) {
+  PrimExpr e = targs[0];
+  const tir::CallNode* call = e.as<tir::CallNode>();
+  CHECK(call != nullptr);
+  const PrimExpr& x = call->args[0];
+  PrimExpr sin_x = tir::CallNode::make(
+      x.dtype(), "sin", {x}, tir::CallNode::PureIntrinsic);
+  PrimExpr cos_x = tir::CallNode::make(
+      x.dtype(), "cos", {x}, tir::CallNode::PureIntrinsic);
+  PrimExpr tan_x = sin_x / cos_x;
+  *rv = tan_x;
+});
+
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.cos")
 .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::cos, 1>);
 
diff --git a/src/target/llvm/intrin_rule_nvptx.cc 
b/src/target/llvm/intrin_rule_nvptx.cc
index 6f7a89c..6d41d13 100644
--- a/src/target/llvm/intrin_rule_nvptx.cc
+++ b/src/target/llvm/intrin_rule_nvptx.cc
@@ -81,6 +81,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.pow")
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.tanh")
 .set_body(DispatchExternLibDevice);
 
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.tan")
+.set_body(DispatchExternLibDevice);
+
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.cos")
 .set_body(DispatchExternLibDevice);
 
diff --git a/src/target/llvm/intrin_rule_rocm.cc 
b/src/target/llvm/intrin_rule_rocm.cc
index 31b7bf1..4e6a661 100644
--- a/src/target/llvm/intrin_rule_rocm.cc
+++ b/src/target/llvm/intrin_rule_rocm.cc
@@ -80,6 +80,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.pow")
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tanh")
 .set_body(DispatchExternOCML);
 
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tan")
+.set_body(DispatchExternOCML);
+
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.cos")
 .set_body(DispatchExternOCML);
 
diff --git a/src/target/source/intrin_rule_cuda.cc 
b/src/target/source/intrin_rule_cuda.cc
index aed6c86..849e098 100644
--- a/src/target/source/intrin_rule_cuda.cc
+++ b/src/target/source/intrin_rule_cuda.cc
@@ -54,6 +54,22 @@ struct CUDAFastMath : public CUDAMath {
   }
 };
 
+struct CUDAFastMathTan : public CUDAMath {
+  std::string operator()(DataType t, std::string name) const {
+    if (t.lanes() == 1 && t.is_float()) {
+        switch (t.bits()) {
+          case 64: return name;
+          // `__tanf` seems to produce some values too deviant from numpy tan 
version.
+          // So, let's use just `tanf` instead.
+          case 32: return name + 'f';
+          case 16: LOG(FATAL) << "cuda tan unsupported for float16";
+          default: return "";
+        }
+    }
+    return "";
+  }
+};
+
 struct CUDAPopcount {
   std::string operator()(DataType t, std::string name) const {
     if (t.lanes() == 1 && t.is_uint()) {
@@ -97,6 +113,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.erf")
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log")
 .set_body(DispatchExtern<CUDAFastMath>);
 
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tan")
+.set_body(DispatchExtern<CUDAFastMathTan>);
+
 TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.cos")
 .set_body(DispatchExtern<CUDAFastMath>);
 
diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc
index 07cae5e..7572f8d 100644
--- a/src/tir/ir/expr.cc
+++ b/src/tir/ir/expr.cc
@@ -229,7 +229,7 @@ PrimExpr LetNode::make(Var var, PrimExpr value, PrimExpr 
body) {
 
 const char* CallNode::vectorizable_intrinsics[] = {
     "floor", "ceil", "sign", "trunc", "fabs", "round", "exp", "tanh", "sqrt",
-    "log", "sin", "cos", "pow", tir::CallNode::shift_left, 
tir::CallNode::shift_right,
+    "log", "sin", "cos", "pow", "tan", tir::CallNode::shift_left, 
tir::CallNode::shift_right,
     tir::CallNode::likely, tir::CallNode::popcount
 };
 
diff --git a/tests/python/frontend/tensorflow/test_forward.py 
b/tests/python/frontend/tensorflow/test_forward.py
index 42408b7..31c5480 100644
--- a/tests/python/frontend/tensorflow/test_forward.py
+++ b/tests/python/frontend/tensorflow/test_forward.py
@@ -2628,6 +2628,15 @@ def test_forward_cos():
     compare_tf_with_tvm([np_data], ['in_data:0'], 'cos:0')
 
 
+def test_forward_tan():
+    """test operator tan """
+    np_data = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32)
+    tf.reset_default_graph()
+    in_data = tf.placeholder(tf.float32, (2, 3, 5), name="in_data")
+    tf.tan(in_data, name="tan")
+    compare_tf_with_tvm([np_data], ['in_data:0'], 'tan:0')
+
+
 def test_forward_sin():
     """test operator sin """
     np_data = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32)
@@ -3031,6 +3040,7 @@ if __name__ == '__main__':
     test_forward_sign()
     test_forward_log()
     test_forward_log1p()
+    test_forward_tan()
     test_forward_cos()
     test_forward_sin()
     test_forward_negative()
diff --git a/tests/python/frontend/tflite/test_forward.py 
b/tests/python/frontend/tflite/test_forward.py
index 1478393..28216fc 100644
--- a/tests/python/frontend/tflite/test_forward.py
+++ b/tests/python/frontend/tflite/test_forward.py
@@ -723,6 +723,13 @@ def _test_cos(data):
     """ One iteration of cos """
     return _test_unary_elemwise(math_ops.cos, data)
 #######################################################################
+# Tan
+# ---
+
+def _test_tan(data):
+    """ One iteration of tan """
+    return _test_unary_elemwise(math_ops.tan, data)
+#######################################################################
 # Sqrt
 # ----
 
@@ -772,6 +779,7 @@ def test_all_unary_elemwise():
     if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'):
         _test_forward_unary_elemwise(_test_ceil)
         _test_forward_unary_elemwise(_test_cos)
+        _test_forward_unary_elemwise(_test_tan)
 
 #######################################################################
 # Element-wise
diff --git a/tests/python/relay/test_op_grad_level1.py 
b/tests/python/relay/test_op_grad_level1.py
index 0eb1cec..0579441 100644
--- a/tests/python/relay/test_op_grad_level1.py
+++ b/tests/python/relay/test_op_grad_level1.py
@@ -64,6 +64,7 @@ def test_unary_op():
                         (relay.nn.relu, lambda x: np.where(x < 0, 
np.zeros_like(x), np.ones_like(x))),
                         (tvm.relay.cos, lambda x: -1.0 * np.sin(x)),
                         (tvm.relay.sin, lambda x: np.cos(x)),
+                        (tvm.relay.tan, lambda x: 1.0 / (np.cos(x) ** 2)),
                         (tvm.relay.atan, lambda x: 1 / (1 + np.power(x, 
2.0)))]:
         check_single_op(opfunc, ref)
 
diff --git a/tests/python/relay/test_op_level1.py 
b/tests/python/relay/test_op_level1.py
index 0fa0749..c58ff0f 100644
--- a/tests/python/relay/test_op_level1.py
+++ b/tests/python/relay/test_op_level1.py
@@ -76,6 +76,7 @@ def test_unary_op():
                         (relay.nn.relu, relu),
                         (tvm.relay.cos, np.cos),
                         (tvm.relay.sin, np.sin),
+                        (tvm.relay.tan, np.tan),
                         (tvm.relay.atan, np.arctan)]:
         for dtype in ['float16', 'float32']:
             check_single_op(opfunc, ref, dtype)
diff --git a/tests/python/unittest/test_testing.py 
b/tests/python/unittest/test_testing.py
index ecf520d..cfa1384 100644
--- a/tests/python/unittest/test_testing.py
+++ b/tests/python/unittest/test_testing.py
@@ -31,6 +31,7 @@ def test_check_numerical_grads():
         lambda x: (np.sign(np.sin(1/x)), np.zeros_like(x)),
         lambda x: (x*np.sin(1/x), np.sin(1/x) - np.cos(1/x)/x),
         lambda x: (np.sin(1/x), - np.cos(1/x)/(x*x)),
+        lambda x: (np.tan(x), 1.0 / (np.cos(x) * np.cos(x))),
     ]
 
     # Avoid values too close to 0 since singularities of our functions are 
there
diff --git a/topi/include/topi/elemwise.h b/topi/include/topi/elemwise.h
index 3c0822f..26107ea 100644
--- a/topi/include/topi/elemwise.h
+++ b/topi/include/topi/elemwise.h
@@ -55,6 +55,7 @@ TOPI_DECLARE_UNARY_OP(round);
 TOPI_DECLARE_UNARY_OP(trunc);
 TOPI_DECLARE_UNARY_OP(abs);
 TOPI_DECLARE_UNARY_OP(cos);
+TOPI_DECLARE_UNARY_OP(tan);
 TOPI_DECLARE_UNARY_OP(sin);
 TOPI_DECLARE_UNARY_OP(atan);
 TOPI_DECLARE_UNARY_OP(isnan);
diff --git a/topi/python/topi/math.py b/topi/python/topi/math.py
index 4a63c45..3eda88a 100644
--- a/topi/python/topi/math.py
+++ b/topi/python/topi/math.py
@@ -110,6 +110,23 @@ def tanh(x):
 
 
 @tvm.te.tag_scope(tag=tag.ELEMWISE)
+def tan(x):
+    """Take tan of input x.
+
+    Parameters
+    ----------
+    x : tvm.te.Tensor
+        Input argument.
+
+    Returns
+    -------
+    y : tvm.te.Tensor
+        The result.
+    """
+    return te.compute(x.shape, lambda *i: te.tan(x(*i)))
+
+
[email protected]_scope(tag=tag.ELEMWISE)
 def cos(x):
     """Take cos of input x.
 
diff --git a/topi/src/topi.cc b/topi/src/topi.cc
index 75517b8..add01c2 100644
--- a/topi/src/topi.cc
+++ b/topi/src/topi.cc
@@ -175,6 +175,11 @@ TVM_REGISTER_GLOBAL("topi.erf")
   *rv = erf(args[0]);
   });
 
+TVM_REGISTER_GLOBAL("topi.tan")
+.set_body([](TVMArgs args, TVMRetValue *rv) {
+  *rv = tan(args[0]);
+  });
+
 TVM_REGISTER_GLOBAL("topi.cos")
 .set_body([](TVMArgs args, TVMRetValue *rv) {
   *rv = cos(args[0]);
diff --git a/topi/tests/python/test_topi_basic.py 
b/topi/tests/python/test_topi_basic.py
index 83f0469..13f1463 100644
--- a/topi/tests/python/test_topi_basic.py
+++ b/topi/tests/python/test_topi_basic.py
@@ -45,6 +45,7 @@ def test_ewise():
     test_apply(topi.rsqrt, "rsqrt")
     test_apply(topi.sin, "sin")
     test_apply(topi.cos, "cos")
+    test_apply(topi.tan, "tan")
     test_apply(topi.atan, "atan")
 
 
diff --git a/topi/tests/python/test_topi_math.py 
b/topi/tests/python/test_topi_math.py
index 30a0f44..3e58518 100644
--- a/topi/tests/python/test_topi_math.py
+++ b/topi/tests/python/test_topi_math.py
@@ -127,6 +127,8 @@ def test_ewise():
     test_apply(topi.sqrt, "sqrt", np.sqrt, 0, 100)
     test_apply(topi.rsqrt, "rsqrt", lambda x: np.ones_like(x) / np.sqrt(x), 0, 
100, skip_name_check=True)
     test_apply(topi.cos, "cos", np.cos, -2.0*np.pi, 2.0*np.pi)
+    test_apply(topi.tan, "tan", np.tan, -2.0*np.pi, 2.0*np.pi, dtype='float32')
+    test_apply(topi.tan, "tan", np.tan, -2.0*np.pi, 2.0*np.pi, dtype='float64')
     test_apply(topi.sin, "sin", np.sin, -2.0*np.pi, 2.0*np.pi)
     test_apply(topi.erf, "erf", scipy.special.erf, -.1, .1, dtype="float32")
     test_isnan(-100, 100)

Reply via email to