This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new b89d2fed6a [Unity][Ops] Support for erf in relax (#15445)
b89d2fed6a is described below

commit b89d2fed6a70678e61a76bc7238c4808939fe566
Author: Josh Fromm <jwfr...@octoml.ai>
AuthorDate: Tue Aug 1 01:26:55 2023 -0700

    [Unity][Ops] Support for erf in relax (#15445)
    
    * Add support for relax erf
    
    * Formatting
    
    * Allow standard block builder var name
---
 python/tvm/relax/frontend/onnx/onnx_frontend.py       | 13 +------------
 python/tvm/relax/op/unary.py                          | 16 ++++++++++++++++
 python/tvm/relax/transform/legalize_ops/unary.py      | 19 ++++++++++++++++++-
 python/tvm/script/ir_builder/relax/ir.py              |  2 ++
 src/relax/distributed/transform/propagate_sharding.cc | 12 ++++++------
 src/relax/op/distributed/unary.cc                     |  1 +
 src/relax/op/tensor/unary.cc                          |  1 +
 src/relax/op/tensor/unary.h                           |  3 +++
 tests/python/relax/test_frontend_onnx.py              |  9 +++++----
 tests/python/relax/test_op_unary.py                   |  1 +
 10 files changed, 54 insertions(+), 23 deletions(-)

diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py 
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 9711ee2578..74eb904c4f 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -587,18 +587,7 @@ class Erf(OnnxOpConverter):
 
     @classmethod
     def _impl_v13(cls, bb, inputs, attr, params):
-        x = inputs[0]
-        sqrt2 = relax.const(_np.sqrt(2), x.struct_info.dtype)
-        # TODO: replace with erf operator once it is implemented
-        mul = relax.op.multiply(x, sqrt2)
-        gelu = relax.op.nn.gelu(mul)
-        mul_2 = relax.op.multiply(gelu, sqrt2)
-        return bb.normalize(
-            relax.op.add(
-                relax.op.divide(mul_2, x),
-                relax.const(-1, x.struct_info.dtype),
-            )
-        )
+        return relax.op.erf(inputs[0])
 
 
 class CumSum(OnnxOpConverter):
diff --git a/python/tvm/relax/op/unary.py b/python/tvm/relax/op/unary.py
index 78051452e2..11b78dbcc7 100644
--- a/python/tvm/relax/op/unary.py
+++ b/python/tvm/relax/op/unary.py
@@ -534,6 +534,22 @@ def clip(x: Expr, min: Expr, max: Expr) -> Expr:
     return _ffi_api.clip(x, min, max)  # type: ignore
 
 
+def erf(x: Expr) -> Expr:
+    """Computes the error function of the input.
+
+    Parameters
+    ----------
+    x : relax.Expr
+        The input data
+
+    Returns
+    -------
+    result : relax.Expr
+        Computed error function for each element.
+    """
+    return _ffi_api.erf(x)  # type: ignore
+
+
 ###################### Check operators ######################
 
 
diff --git a/python/tvm/relax/transform/legalize_ops/unary.py 
b/python/tvm/relax/transform/legalize_ops/unary.py
index f948f18dd3..33752b9bd3 100644
--- a/python/tvm/relax/transform/legalize_ops/unary.py
+++ b/python/tvm/relax/transform/legalize_ops/unary.py
@@ -14,8 +14,12 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+# pylint: disable=invalid-name,unused-argument
 """Default legalization function for unary operators."""
-from tvm import topi
+from tvm import topi, te
+
+from ...block_builder import BlockBuilder
+from ...expr import Call, Expr
 from .common import _call_topi_without_attr, register_legalize
 
 # To avoid conflict of IRModule function name and libc function name, we add
@@ -47,3 +51,16 @@ register_legalize("relax.sqrt", 
_call_topi_without_attr(topi.sqrt, "tir_sqrt"))
 register_legalize("relax.tan", _call_topi_without_attr(topi.tan, "tir_tan"))
 register_legalize("relax.tanh", _call_topi_without_attr(topi.tanh, "tir_tanh"))
 register_legalize("relax.clip", _call_topi_without_attr(topi.clip, "tir_clip"))
+
+
+@register_legalize("relax.erf")
+def _erf(bb: BlockBuilder, call: Call) -> Expr:
+    def te_erf(x: te.Tensor):
+        dtype = x.dtype
+        if dtype == "float16":
+            erf = topi.math.cast(topi.erf(topi.math.cast(x, "float32")), 
"float16")
+        else:
+            erf = topi.erf(x)
+        return erf
+
+    return bb.call_te(te_erf, call.args[0], primfunc_name_hint="erf")
diff --git a/python/tvm/script/ir_builder/relax/ir.py 
b/python/tvm/script/ir_builder/relax/ir.py
index 49d41765fd..5bb0374d35 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -138,6 +138,7 @@ from tvm.relax.op import (
     subtract,
     tan,
     tanh,
+    erf,
     tile,
     tril,
     triu,
@@ -703,4 +704,5 @@ __all__ = [
     "zeros",
     "zeros_like",
     "nn",
+    "erf",
 ]
diff --git a/src/relax/distributed/transform/propagate_sharding.cc 
b/src/relax/distributed/transform/propagate_sharding.cc
index 676274a6fc..beed7bd989 100644
--- a/src/relax/distributed/transform/propagate_sharding.cc
+++ b/src/relax/distributed/transform/propagate_sharding.cc
@@ -56,12 +56,12 @@ void CollectAxisGraphBinary(const VarBindingNode* binding, 
const CallNode* call,
 void CollectAxisGraphUnary(const VarBindingNode* binding, const CallNode* call,
                            AxisGroupGraph* axis_group_graph) {
   const std::vector<std::string> unary_op_names = {
-      "abs",    "acos",     "acosh",    "asin",   "asinh", "atan",
-      "atanh",  "ceil",     "cos",      "cosh",   "exp",   "floor",
-      "log",    "negative", "nn.relu",  "round",  "rsqrt", "sigmoid",
-      "sign",   "sin",      "sinh",     "square", "sqrt",  "tan",
-      "tanh",   "clip",     "isfinite", "isinf",  "isnan", 
"dist.annotate_sharding",
-      "nn.gelu"};
+      "abs",   "acos",     "acosh",    "asin",   "asinh", "atan",
+      "atanh", "ceil",     "cos",      "cosh",   "exp",   "floor",
+      "log",   "negative", "nn.relu",  "round",  "rsqrt", "sigmoid",
+      "sign",  "sin",      "sinh",     "square", "sqrt",  "tan",
+      "tanh",  "clip",     "isfinite", "isinf",  "isnan", 
"dist.annotate_sharding",
+      "erf",   "nn.gelu"};
   for (const auto& op_name : unary_op_names) {
     const Op& unary_op = Op::Get("relax." + op_name);
     if (call->op.same_as(unary_op)) {
diff --git a/src/relax/op/distributed/unary.cc 
b/src/relax/op/distributed/unary.cc
index 0ef0d9ffa1..4e62d93eec 100644
--- a/src/relax/op/distributed/unary.cc
+++ b/src/relax/op/distributed/unary.cc
@@ -54,6 +54,7 @@ RELAX_REGISTER_UNARY_ARITH_DIST_INFER_STRUCT_INFO(square, 
/*require_float_dtype=
 RELAX_REGISTER_UNARY_ARITH_DIST_INFER_STRUCT_INFO(sqrt, 
/*require_float_dtype=*/true);
 RELAX_REGISTER_UNARY_ARITH_DIST_INFER_STRUCT_INFO(tan, 
/*require_float_dtype=*/true);
 RELAX_REGISTER_UNARY_ARITH_DIST_INFER_STRUCT_INFO(tanh, 
/*require_float_dtype=*/true);
+RELAX_REGISTER_UNARY_ARITH_DIST_INFER_STRUCT_INFO(erf, 
/*require_float_dtype=*/true);
 
 RELAX_REGISTER_UNARY_CHECK_DIST_INFER_STRUCT_INFO(isfinite);
 RELAX_REGISTER_UNARY_CHECK_DIST_INFER_STRUCT_INFO(isinf);
diff --git a/src/relax/op/tensor/unary.cc b/src/relax/op/tensor/unary.cc
index 6eef44821d..35ad0de2fc 100644
--- a/src/relax/op/tensor/unary.cc
+++ b/src/relax/op/tensor/unary.cc
@@ -62,6 +62,7 @@ RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(square, 
/*require_float_dtype=*/false);
 RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(sqrt, /*require_float_dtype=*/true);
 RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(tan, /*require_float_dtype=*/true);
 RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(tanh, /*require_float_dtype=*/true);
+RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(erf, /*require_float_dtype=*/true);
 
 // relax.clip
 TVM_REGISTER_OP("relax.clip")
diff --git a/src/relax/op/tensor/unary.h b/src/relax/op/tensor/unary.h
index 5f92ed0b0c..dfbad37897 100644
--- a/src/relax/op/tensor/unary.h
+++ b/src/relax/op/tensor/unary.h
@@ -147,6 +147,9 @@ Expr isinf(Expr x);
 /*! \brief Check if input value is Nan. */
 Expr isnan(Expr x);
 
+/*! \brief Apply error function to input value. */
+Expr erf(Expr x);
+
 }  // namespace relax
 }  // namespace tvm
 
diff --git a/tests/python/relax/test_frontend_onnx.py 
b/tests/python/relax/test_frontend_onnx.py
index c5c094e115..3467e5bba2 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -174,15 +174,15 @@ def test_sanitize(input_names, expected_names):
         assert param.name_hint == expected_names[i]
 
 
-def verify_unary(op_name, shape, attrs={}, domain=None):
+def verify_unary(op_name, shape, attrs={}, domain=None, 
dtype=TensorProto.FLOAT):
     test_node = helper.make_node(op_name, ["x"], ["y"], **attrs, domain=domain)
     graph = helper.make_graph(
         [test_node],
         "elemwise_test",
         inputs=[
-            helper.make_tensor_value_info("x", TensorProto.FLOAT, shape),
+            helper.make_tensor_value_info("x", dtype, shape),
         ],
-        outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, shape)],
+        outputs=[helper.make_tensor_value_info("y", dtype, shape)],
     )
 
     model = helper.make_model(graph, producer_name="elemwise_test")
@@ -562,7 +562,8 @@ def test_pow():
 
 
 def test_erf():
-    verify_unary("Erf", [32, 32])
+    verify_unary("Erf", [32, 32], dtype=TensorProto.FLOAT)
+    verify_unary("Erf", [32, 32], dtype=TensorProto.FLOAT16)
 
 
 @pytest.mark.parametrize("reverse", [False])
diff --git a/tests/python/relax/test_op_unary.py 
b/tests/python/relax/test_op_unary.py
index 9bfb8612ef..3a6c14fd66 100644
--- a/tests/python/relax/test_op_unary.py
+++ b/tests/python/relax/test_op_unary.py
@@ -53,6 +53,7 @@ def test_op_correctness():
     assert relax.op.tan(x).op == Op.get("relax.tan")
     assert relax.op.tanh(x).op == Op.get("relax.tanh")
     assert relax.op.clip(x, 0, 6).op == Op.get("relax.clip")
+    assert relax.op.erf(x).op == Op.get("relax.erf")
 
     x = relax.Var("x", R.Tensor((2, 3), "int32"))
     assert relax.op.bitwise_not(x).op == Op.get("relax.bitwise_not")

Reply via email to