This is an automated email from the ASF dual-hosted git repository.
zhic pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 51af454 [Relay][FastMath] Relay pass to use fast exp/tanh (#4873)
51af454 is described below
commit 51af454ad7f97a49b19bd02830edcdff9379c58f
Author: Animesh Jain <[email protected]>
AuthorDate: Sun Mar 1 13:57:24 2020 -0800
[Relay][FastMath] Relay pass to use fast exp/tanh (#4873)
* [Relay][FastMath] Relay pass to use fast exp/tanh
* Adding required_pass to the tests.
* FastMath test changes.
---
include/tvm/relay/transform.h | 7 +++
python/tvm/relay/transform.py | 16 ++++++-
src/relay/backend/build_module.cc | 3 ++
src/relay/op/tensor/unary.cc | 22 +++++++++
src/relay/pass/fast_math.cc | 79 +++++++++++++++++++++++++++++++
src/relay/pass/pattern_util.h | 10 ++++
tests/python/relay/test_pass_fast_math.py | 52 ++++++++++++++++++++
topi/include/topi/elemwise.h | 7 +--
topi/python/topi/math.py | 16 +++++++
topi/src/topi.cc | 5 +-
10 files changed, 211 insertions(+), 6 deletions(-)
diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h
index 8d886aa..2862800 100644
--- a/include/tvm/relay/transform.h
+++ b/include/tvm/relay/transform.h
@@ -164,6 +164,13 @@ TVM_DLL Pass PartialEval();
TVM_DLL Pass SimplifyInference();
/*!
+ * \brief Replaces non linear activation functions with their fast but
approximate counterparts.
+ *
+ * \return The Pass.
+ */
+TVM_DLL Pass FastMath();
+
+/*!
* \brief Infer the type of an expression.
*
* The result of type checking is a new expression with unambigous
diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py
index 45535af..f773835 100644
--- a/python/tvm/relay/transform.py
+++ b/python/tvm/relay/transform.py
@@ -57,7 +57,8 @@ def build_config(opt_level=2,
"CanonicalizeCast": 3,
"EliminateCommonSubexpr": 3,
"CombineParallelConv2D": 4,
- "CombineParallelDense": 4
+ "CombineParallelDense": 4,
+ "FastMath": 4
}
fallback_device : int, str, or tvmContext, optional
@@ -175,11 +176,22 @@ def SimplifyInference():
Returns
-------
ret: tvm.relay.Pass
- The registered to perform operator simplification.
+ The registered pass to perform operator simplification.
"""
return _transform.SimplifyInference()
+def FastMath():
+ """ Converts the expensive non linear functions to their fast but
approximate counterparts.
+
+ Returns
+ -------
+ ret: tvm.relay.Pass
+ The registered pass to perform fast math operations.
+ """
+ return _transform.FastMath()
+
+
def CanonicalizeOps():
"""Canonicalize special operators to basic operators.
This can simplify followed analysis, e.g. expanding bias_add to
diff --git a/src/relay/backend/build_module.cc
b/src/relay/backend/build_module.cc
index ff64d4a..0c0a8b8 100644
--- a/src/relay/backend/build_module.cc
+++ b/src/relay/backend/build_module.cc
@@ -305,6 +305,9 @@ class RelayBuildModule : public runtime::ModuleNode {
if (targets.size() == 1) {
pass_seqs.push_back(transform::AlterOpLayout());
}
+
+ // Fast math optimizations.
+ pass_seqs.push_back(transform::FastMath());
pass_seqs.push_back(transform::FoldConstant());
// Create a sequential pass and perform optimizations.
diff --git a/src/relay/op/tensor/unary.cc b/src/relay/op/tensor/unary.cc
index 2c73458..1169fa8 100644
--- a/src/relay/op/tensor/unary.cc
+++ b/src/relay/op/tensor/unary.cc
@@ -95,6 +95,17 @@ RELAY_REGISTER_UNARY_OP("exp")
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::exp));
+RELAY_REGISTER_UNARY_OP("fast_exp")
+.describe(R"code(Returns the fast_exp input array, computed element-wise.
+
+.. math::
+ \fast_exp(x)
+
+)code" TVM_ADD_FILELINE)
+.set_support_level(1)
+.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::fast_exp));
+
+
RELAY_REGISTER_UNARY_OP("erf")
.describe(R"code(Returns the error function value for input array, computed
element-wise.
@@ -250,6 +261,17 @@ RELAY_REGISTER_UNARY_OP("tanh")
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::tanh));
+RELAY_REGISTER_UNARY_OP("fast_tanh")
+.describe(R"code(Returns the fast_tanh of input array, computed element-wise.
+
+.. math::
+ Y = sinh(X) / cosh(X)
+
+)code" TVM_ADD_FILELINE)
+.set_support_level(1)
+.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::fast_tanh));
+
+
RELAY_REGISTER_UNARY_OP("negative")
.describe(R"code(Returns the numeric negative of input array, computed
element-wise.
diff --git a/src/relay/pass/fast_math.cc b/src/relay/pass/fast_math.cc
new file mode 100644
index 0000000..898f760
--- /dev/null
+++ b/src/relay/pass/fast_math.cc
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file fast_math.cc
+ * \brief Replaces non linear activation functions with their fast but
approximate counterparts.
+ */
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/attrs/nn.h>
+#include <tvm/relay/transform.h>
+#include <tvm/relay/op.h>
+#include "pattern_util.h"
+
+namespace tvm {
+namespace relay {
+
+class FastMathMutator : public ExprMutator {
+ public:
+ FastMathMutator()
+ : exp_op_(Op::Get("exp")),
+ tanh_op_(Op::Get("tanh")) {}
+
+ Expr VisitExpr_(const CallNode* n) {
+ auto new_n = ExprMutator::VisitExpr_(n);
+ if (n->op == exp_op_) {
+ return FastExp(new_n.as<CallNode>()->args[0]);
+ } else if (n->op == tanh_op_) {
+ return FastTanh(new_n.as<CallNode>()->args[0]);
+ }
+ return new_n;
+ }
+
+ private:
+ // Cache the following ops. They will be used in the passes repeatedly for
+ // operator equivalence checking so that the registry lookup overhead can be
+ // reduced.
+ const Op& exp_op_;
+ const Op& tanh_op_;
+};
+
+Expr FastMath(const Expr& e) {
+ return FastMathMutator().Mutate(e);
+}
+
+namespace transform {
+
+Pass FastMath() {
+ runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>
pass_func =
+ [=](Function f, IRModule m, PassContext pc) {
+ return Downcast<Function>(FastMath(f));
+ };
+ return CreateFunctionPass(pass_func, 4, "FastMath",
+ {tir::StringImmNode::make("InferType")});
+}
+
+TVM_REGISTER_GLOBAL("relay._transform.FastMath")
+.set_body_typed(FastMath);
+
+} // namespace transform
+
+} // namespace relay
+} // namespace tvm
diff --git a/src/relay/pass/pattern_util.h b/src/relay/pass/pattern_util.h
index f7d8f9c..85750f5 100644
--- a/src/relay/pass/pattern_util.h
+++ b/src/relay/pass/pattern_util.h
@@ -316,6 +316,16 @@ inline Expr Exp(Expr e) {
return CallNode::make(op, {e});
}
+inline Expr FastExp(Expr e) {
+ static const Op& op = Op::Get("fast_exp");
+ return CallNode::make(op, {e});
+}
+
+inline Expr FastTanh(Expr e) {
+ static const Op& op = Op::Get("fast_tanh");
+ return CallNode::make(op, {e});
+}
+
inline Expr Log(Expr e) {
static const Op& op = Op::Get("log");
return CallNode::make(op, {e});
diff --git a/tests/python/relay/test_pass_fast_math.py
b/tests/python/relay/test_pass_fast_math.py
new file mode 100644
index 0000000..e75316f
--- /dev/null
+++ b/tests/python/relay/test_pass_fast_math.py
@@ -0,0 +1,52 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+from tvm.ir import IRModule
+from tvm import relay
+from tvm.relay.transform import FastMath
+
+def test_exp():
+ x = relay.var("x", shape=(1, 16, 16, 16), dtype="float32")
+ y = relay.exp(x)
+ func = relay.Function([x], y)
+ mod = tvm.IRModule.from_expr(func)
+
+ fast_mod = FastMath()(mod)
+ assert "fast_exp" in fast_mod.astext()
+
+ # Check that FastMath option works for relay.build.
+ with relay.build_config(opt_level=3, required_pass=['FastMath']):
+ fast_mod = relay.optimize(mod, target='llvm', params=None)
+ assert "fast_exp" in fast_mod[0].astext()
+
+def test_tanh():
+ x = relay.var("x", shape=(1, 16, 16, 16), dtype="float32")
+ y = relay.tanh(x)
+ func = relay.Function([x], y)
+ mod = tvm.IRModule.from_expr(func)
+
+ fast_mod = FastMath()(mod)
+ assert "fast_tanh" in fast_mod.astext()
+
+ # Check that FastMath option works for relay.build.
+ with relay.build_config(opt_level=3, required_pass=['FastMath']):
+ fast_mod = relay.optimize(mod, target='llvm', params=None)
+ assert "fast_tanh" in fast_mod[0].astext()
+
+if __name__ == "__main__":
+ test_exp()
+ test_tanh()
diff --git a/topi/include/topi/elemwise.h b/topi/include/topi/elemwise.h
index e35e3e4..3c0822f 100644
--- a/topi/include/topi/elemwise.h
+++ b/topi/include/topi/elemwise.h
@@ -58,6 +58,7 @@ TOPI_DECLARE_UNARY_OP(cos);
TOPI_DECLARE_UNARY_OP(sin);
TOPI_DECLARE_UNARY_OP(atan);
TOPI_DECLARE_UNARY_OP(isnan);
+TOPI_DECLARE_UNARY_OP(tanh);
/*
* \brief Fast_tanh_float implementation from Eigen
@@ -113,9 +114,9 @@ inline Tensor fast_tanh_float(const Tensor& in,
*
* \return A Tensor whose op member is tanh
*/
-inline Tensor tanh(const Tensor& x,
- std::string name = "T_tanh",
- std::string tag = kElementWise) {
+inline Tensor fast_tanh(const Tensor& x,
+ std::string name = "T_fast_tanh",
+ std::string tag = kElementWise) {
if (x->dtype == DataType::Float(32)) {
// invoke fast_tanh_float implementation
return fast_tanh_float(x, name, tag);
diff --git a/topi/python/topi/math.py b/topi/python/topi/math.py
index 5b6b9ab..4a63c45 100644
--- a/topi/python/topi/math.py
+++ b/topi/python/topi/math.py
@@ -467,3 +467,19 @@ def fast_exp(x):
The result.
"""
return cpp.fast_exp(x, x.dtype, tag.ELEMWISE)
+
+
+def fast_tanh(x):
+ """Take tanhonential of input x using fast_tanh implementation
+
+ Parameters
+ ----------
+ x : tvm.Tensor
+ Input argument.
+
+ Returns
+ -------
+ y : tvm.Tensor
+ The result.
+ """
+ return cpp.fast_tanh(x, x.dtype, tag.ELEMWISE)
diff --git a/topi/src/topi.cc b/topi/src/topi.cc
index 79e223c..75517b8 100644
--- a/topi/src/topi.cc
+++ b/topi/src/topi.cc
@@ -188,7 +188,10 @@ TVM_REGISTER_GLOBAL("topi.tanh")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = tanh(args[0]);
});
-
+TVM_REGISTER_GLOBAL("topi.fast_tanh")
+.set_body([](TVMArgs args, TVMRetValue *rv) {
+ *rv = fast_tanh(args[0]);
+ });
TVM_REGISTER_GLOBAL("topi.atan")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = atan(args[0]);