This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 58e00f3267 [Unity]Add FastMathTransform pass to Relax (#15814)
58e00f3267 is described below

commit 58e00f3267bdbb53a41ef2250c70812efbb15437
Author: Honglin Zhu <[email protected]>
AuthorDate: Wed Oct 11 17:16:23 2023 +0800

    [Unity]Add FastMathTransform pass to Relax (#15814)
    
    * delete unused import and add class docstring
    
    * add test for fast math transform
    
    * Update test_fast_math_transform.py
---
 python/tvm/relax/transform/__init__.py         |  1 +
 python/tvm/relax/transform/fast_math.py        | 67 ++++++++++++++++++++++++++
 tests/python/relax/test_fast_math_transform.py | 59 +++++++++++++++++++++++
 3 files changed, 127 insertions(+)

diff --git a/python/tvm/relax/transform/__init__.py 
b/python/tvm/relax/transform/__init__.py
index 5dfdec8ea0..6b00a40105 100644
--- a/python/tvm/relax/transform/__init__.py
+++ b/python/tvm/relax/transform/__init__.py
@@ -72,6 +72,7 @@ from .transform import (
 from .lazy_transform_params import LazyTransformParams
 from .optimize_layout_transform import OptimizeLayoutTransform
 from .remove_redundant_reshape import RemoveRedundantReshape
+from .fast_math import FastMathTransform
 
 # Import to register the legalization functions.
 from . import legalize_ops, tuning_api
diff --git a/python/tvm/relax/transform/fast_math.py 
b/python/tvm/relax/transform/fast_math.py
new file mode 100644
index 0000000000..2aebd96db3
--- /dev/null
+++ b/python/tvm/relax/transform/fast_math.py
@@ -0,0 +1,67 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-argument, redefined-argument-from-local
+"""Relax Use Fast Math pass."""
+import tvm
+from tvm import topi
+from tvm.ir.module import IRModule
+from tvm.relax import Expr, Call, expr_functor, PyExprMutator
+
+
+@expr_functor.mutator
+class FastMathCodeGenerator(PyExprMutator):
+    """
+    Converts the expensive non linear functions to their fast but approximate 
counterparts.
+
+    Parameters
+    ----------
+    mod: IRModule
+        The module to be transformed
+    """
+
+    def __init__(self, mod):
+        super().__init__(mod)
+
+    def visit_call_(self, call: Call) -> Expr:
+        if call.op.name == "relax.nn.softmax":
+            return self.builder_.call_te(topi.nn.fast_softmax, call.args[0], 
call.attrs.axis)
+        if call.op.name == "relax.exp":
+            return self.builder_.call_te(topi.fast_exp, call.args[0])
+        if call.op.name == "relax.erf":
+            return self.builder_.call_te(topi.fast_erf, call.args[0])
+        if call.op.name == "relax.tanh":
+            return self.builder_.call_te(topi.fast_tanh, call.args[0])
+
+        return super().visit_call_(call)
+
+
[email protected]_pass(opt_level=0, name="FastMathTransform")
+class FastMathTransform:
+    """
+    Pass to convert the expensive non linear functions to their fast but 
approximate counterparts.
+    """
+
+    def transform_module(self, mod: IRModule, ctx: tvm.transform.PassContext) 
-> IRModule:
+        fast_math_codegen = FastMathCodeGenerator(mod)
+        for gv in mod.functions:
+            func = mod[gv]
+            if not isinstance(func, tvm.relax.Function):
+                continue
+            func = fast_math_codegen.visit_expr(func)
+            fast_math_codegen.builder_.update_func(gv, func)
+
+        return fast_math_codegen.builder_.get()
diff --git a/tests/python/relax/test_fast_math_transform.py 
b/tests/python/relax/test_fast_math_transform.py
new file mode 100644
index 0000000000..f5b88f312c
--- /dev/null
+++ b/tests/python/relax/test_fast_math_transform.py
@@ -0,0 +1,59 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Tests to validate relax fast math tranform pass."""
+
+import pytest
+import tvm.testing
+from tvm import relax, topi
+from tvm.ir.base import assert_structural_equal
+from tvm.relax.transform import FastMathTransform
+from tvm.script import ir as I, relax as R
+
+
+def _run_pass_compare_output(Before, Expected):
+    fast_mod = FastMathTransform()(Before)
+    if not relax.analysis.well_formed(fast_mod):
+        print("IRModule is not well-formed")
+    assert_structural_equal(Expected, fast_mod)
+
+
+def test_optimize_transform_layout_pass_one_arg():
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(x: R.Tensor((16,), dtype="float32")) -> R.Tensor((16,), 
dtype="float32"):
+            lv1: R.Tensor((16,), dtype="float32") = R.nn.softmax(x)
+            lv2: R.Tensor((16,), dtype="float32") = R.exp(lv1)
+            lv3: R.Tensor((16,), dtype="float32") = R.erf(lv2)
+            lv4: R.Tensor((16,), dtype="float32") = R.tanh(lv3)
+            return lv4
+
+    bb = relax.BlockBuilder()
+    x = relax.Var("x", R.Tensor((16,), "float32"))
+    with bb.function("main", [x]):
+        lv1 = bb.emit_te(topi.nn.fast_softmax, x)
+        lv2 = bb.emit_te(topi.fast_exp, lv1)
+        lv3 = bb.emit_te(topi.fast_erf, lv2)
+        lv4 = bb.emit_te(topi.fast_tanh, lv3)
+        bb.emit_func_output(lv4)
+    Expected = bb.get()
+
+    _run_pass_compare_output(Before, Expected)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to