This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 58e00f3267 [Unity]Add FastMathTransform pass to Relax (#15814)
58e00f3267 is described below
commit 58e00f3267bdbb53a41ef2250c70812efbb15437
Author: Honglin Zhu <[email protected]>
AuthorDate: Wed Oct 11 17:16:23 2023 +0800
[Unity]Add FastMathTransform pass to Relax (#15814)
* delete unused import and add class docstring
* add test for fast math transform
* Update test_fast_math_transform.py
---
python/tvm/relax/transform/__init__.py | 1 +
python/tvm/relax/transform/fast_math.py | 67 ++++++++++++++++++++++++++
tests/python/relax/test_fast_math_transform.py | 59 +++++++++++++++++++++++
3 files changed, 127 insertions(+)
diff --git a/python/tvm/relax/transform/__init__.py
b/python/tvm/relax/transform/__init__.py
index 5dfdec8ea0..6b00a40105 100644
--- a/python/tvm/relax/transform/__init__.py
+++ b/python/tvm/relax/transform/__init__.py
@@ -72,6 +72,7 @@ from .transform import (
from .lazy_transform_params import LazyTransformParams
from .optimize_layout_transform import OptimizeLayoutTransform
from .remove_redundant_reshape import RemoveRedundantReshape
+from .fast_math import FastMathTransform
# Import to register the legalization functions.
from . import legalize_ops, tuning_api
diff --git a/python/tvm/relax/transform/fast_math.py
b/python/tvm/relax/transform/fast_math.py
new file mode 100644
index 0000000000..2aebd96db3
--- /dev/null
+++ b/python/tvm/relax/transform/fast_math.py
@@ -0,0 +1,67 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-argument, redefined-argument-from-local
+"""Relax Use Fast Math pass."""
+import tvm
+from tvm import topi
+from tvm.ir.module import IRModule
+from tvm.relax import Expr, Call, expr_functor, PyExprMutator
+
+
+@expr_functor.mutator
+class FastMathCodeGenerator(PyExprMutator):
+ """
+ Converts the expensive non linear functions to their fast but approximate
counterparts.
+
+ Parameters
+ ----------
+ mod: IRModule
+ The module to be transformed
+ """
+
+ def __init__(self, mod):
+ super().__init__(mod)
+
+ def visit_call_(self, call: Call) -> Expr:
+ if call.op.name == "relax.nn.softmax":
+ return self.builder_.call_te(topi.nn.fast_softmax, call.args[0],
call.attrs.axis)
+ if call.op.name == "relax.exp":
+ return self.builder_.call_te(topi.fast_exp, call.args[0])
+ if call.op.name == "relax.erf":
+ return self.builder_.call_te(topi.fast_erf, call.args[0])
+ if call.op.name == "relax.tanh":
+ return self.builder_.call_te(topi.fast_tanh, call.args[0])
+
+ return super().visit_call_(call)
+
+
[email protected]_pass(opt_level=0, name="FastMathTransform")
+class FastMathTransform:
+ """
+ Pass to convert the expensive non linear functions to their fast but
approximate counterparts.
+ """
+
+ def transform_module(self, mod: IRModule, ctx: tvm.transform.PassContext)
-> IRModule:
+ fast_math_codegen = FastMathCodeGenerator(mod)
+ for gv in mod.functions:
+ func = mod[gv]
+ if not isinstance(func, tvm.relax.Function):
+ continue
+ func = fast_math_codegen.visit_expr(func)
+ fast_math_codegen.builder_.update_func(gv, func)
+
+ return fast_math_codegen.builder_.get()
diff --git a/tests/python/relax/test_fast_math_transform.py
b/tests/python/relax/test_fast_math_transform.py
new file mode 100644
index 0000000000..f5b88f312c
--- /dev/null
+++ b/tests/python/relax/test_fast_math_transform.py
@@ -0,0 +1,59 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Tests to validate relax fast math tranform pass."""
+
+import pytest
+import tvm.testing
+from tvm import relax, topi
+from tvm.ir.base import assert_structural_equal
+from tvm.relax.transform import FastMathTransform
+from tvm.script import ir as I, relax as R
+
+
+def _run_pass_compare_output(Before, Expected):
+ fast_mod = FastMathTransform()(Before)
+ if not relax.analysis.well_formed(fast_mod):
+ print("IRModule is not well-formed")
+ assert_structural_equal(Expected, fast_mod)
+
+
+def test_optimize_transform_layout_pass_one_arg():
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(x: R.Tensor((16,), dtype="float32")) -> R.Tensor((16,),
dtype="float32"):
+ lv1: R.Tensor((16,), dtype="float32") = R.nn.softmax(x)
+ lv2: R.Tensor((16,), dtype="float32") = R.exp(lv1)
+ lv3: R.Tensor((16,), dtype="float32") = R.erf(lv2)
+ lv4: R.Tensor((16,), dtype="float32") = R.tanh(lv3)
+ return lv4
+
+ bb = relax.BlockBuilder()
+ x = relax.Var("x", R.Tensor((16,), "float32"))
+ with bb.function("main", [x]):
+ lv1 = bb.emit_te(topi.nn.fast_softmax, x)
+ lv2 = bb.emit_te(topi.fast_exp, lv1)
+ lv3 = bb.emit_te(topi.fast_erf, lv2)
+ lv4 = bb.emit_te(topi.fast_tanh, lv3)
+ bb.emit_func_output(lv4)
+ Expected = bb.get()
+
+ _run_pass_compare_output(Before, Expected)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()