This is an automated email from the ASF dual-hosted git repository. masahi pushed a commit to branch unity in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push: new b89d2fed6a [Unity][Ops] Support for erf in relax (#15445) b89d2fed6a is described below commit b89d2fed6a70678e61a76bc7238c4808939fe566 Author: Josh Fromm <jwfr...@octoml.ai> AuthorDate: Tue Aug 1 01:26:55 2023 -0700 [Unity][Ops] Support for erf in relax (#15445) * Add support for relax erf * Formatting * Allow standard block builder var name --- python/tvm/relax/frontend/onnx/onnx_frontend.py | 13 +------------ python/tvm/relax/op/unary.py | 16 ++++++++++++++++ python/tvm/relax/transform/legalize_ops/unary.py | 19 ++++++++++++++++++- python/tvm/script/ir_builder/relax/ir.py | 2 ++ src/relax/distributed/transform/propagate_sharding.cc | 12 ++++++------ src/relax/op/distributed/unary.cc | 1 + src/relax/op/tensor/unary.cc | 1 + src/relax/op/tensor/unary.h | 3 +++ tests/python/relax/test_frontend_onnx.py | 9 +++++---- tests/python/relax/test_op_unary.py | 1 + 10 files changed, 54 insertions(+), 23 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 9711ee2578..74eb904c4f 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -587,18 +587,7 @@ class Erf(OnnxOpConverter): @classmethod def _impl_v13(cls, bb, inputs, attr, params): - x = inputs[0] - sqrt2 = relax.const(_np.sqrt(2), x.struct_info.dtype) - # TODO: replace with erf operator once it is implemented - mul = relax.op.multiply(x, sqrt2) - gelu = relax.op.nn.gelu(mul) - mul_2 = relax.op.multiply(gelu, sqrt2) - return bb.normalize( - relax.op.add( - relax.op.divide(mul_2, x), - relax.const(-1, x.struct_info.dtype), - ) - ) + return relax.op.erf(inputs[0]) class CumSum(OnnxOpConverter): diff --git a/python/tvm/relax/op/unary.py b/python/tvm/relax/op/unary.py index 78051452e2..11b78dbcc7 100644 --- a/python/tvm/relax/op/unary.py +++ b/python/tvm/relax/op/unary.py @@ -534,6 +534,22 @@ def clip(x: Expr, min: Expr, max: Expr) -> Expr: return _ffi_api.clip(x, min, max) # type: ignore +def erf(x: Expr) -> Expr: + """Computes the error function of the input. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + Computed error function for each element. + """ + return _ffi_api.erf(x) # type: ignore + + ###################### Check operators ###################### diff --git a/python/tvm/relax/transform/legalize_ops/unary.py b/python/tvm/relax/transform/legalize_ops/unary.py index f948f18dd3..33752b9bd3 100644 --- a/python/tvm/relax/transform/legalize_ops/unary.py +++ b/python/tvm/relax/transform/legalize_ops/unary.py @@ -14,8 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=invalid-name,unused-argument """Default legalization function for unary operators.""" -from tvm import topi +from tvm import topi, te + +from ...block_builder import BlockBuilder +from ...expr import Call, Expr from .common import _call_topi_without_attr, register_legalize # To avoid conflict of IRModule function name and libc function name, we add @@ -47,3 +51,16 @@ register_legalize("relax.sqrt", _call_topi_without_attr(topi.sqrt, "tir_sqrt")) register_legalize("relax.tan", _call_topi_without_attr(topi.tan, "tir_tan")) register_legalize("relax.tanh", _call_topi_without_attr(topi.tanh, "tir_tanh")) register_legalize("relax.clip", _call_topi_without_attr(topi.clip, "tir_clip")) + + +@register_legalize("relax.erf") +def _erf(bb: BlockBuilder, call: Call) -> Expr: + def te_erf(x: te.Tensor): + dtype = x.dtype + if dtype == "float16": + erf = topi.math.cast(topi.erf(topi.math.cast(x, "float32")), "float16") + else: + erf = topi.erf(x) + return erf + + return bb.call_te(te_erf, call.args[0], primfunc_name_hint="erf") diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index 49d41765fd..5bb0374d35 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -138,6 +138,7 @@ from tvm.relax.op import ( subtract, tan, tanh, + erf, tile, tril, triu, @@ -703,4 +704,5 @@ __all__ = [ "zeros", "zeros_like", "nn", + "erf", ] diff --git a/src/relax/distributed/transform/propagate_sharding.cc b/src/relax/distributed/transform/propagate_sharding.cc index 676274a6fc..beed7bd989 100644 --- a/src/relax/distributed/transform/propagate_sharding.cc +++ b/src/relax/distributed/transform/propagate_sharding.cc @@ -56,12 +56,12 @@ void CollectAxisGraphBinary(const VarBindingNode* binding, const CallNode* call, void CollectAxisGraphUnary(const VarBindingNode* binding, const CallNode* call, AxisGroupGraph* axis_group_graph) { const std::vector<std::string> unary_op_names = { - "abs", "acos", "acosh", "asin", "asinh", "atan", - "atanh", "ceil", "cos", "cosh", "exp", "floor", - "log", "negative", "nn.relu", "round", "rsqrt", "sigmoid", - "sign", "sin", "sinh", "square", "sqrt", "tan", - "tanh", "clip", "isfinite", "isinf", "isnan", "dist.annotate_sharding", - "nn.gelu"}; + "abs", "acos", "acosh", "asin", "asinh", "atan", + "atanh", "ceil", "cos", "cosh", "exp", "floor", + "log", "negative", "nn.relu", "round", "rsqrt", "sigmoid", + "sign", "sin", "sinh", "square", "sqrt", "tan", + "tanh", "clip", "isfinite", "isinf", "isnan", "dist.annotate_sharding", + "erf", "nn.gelu"}; for (const auto& op_name : unary_op_names) { const Op& unary_op = Op::Get("relax." + op_name); if (call->op.same_as(unary_op)) { diff --git a/src/relax/op/distributed/unary.cc b/src/relax/op/distributed/unary.cc index 0ef0d9ffa1..4e62d93eec 100644 --- a/src/relax/op/distributed/unary.cc +++ b/src/relax/op/distributed/unary.cc @@ -54,6 +54,7 @@ RELAX_REGISTER_UNARY_ARITH_DIST_INFER_STRUCT_INFO(square, /*require_float_dtype= RELAX_REGISTER_UNARY_ARITH_DIST_INFER_STRUCT_INFO(sqrt, /*require_float_dtype=*/true); RELAX_REGISTER_UNARY_ARITH_DIST_INFER_STRUCT_INFO(tan, /*require_float_dtype=*/true); RELAX_REGISTER_UNARY_ARITH_DIST_INFER_STRUCT_INFO(tanh, /*require_float_dtype=*/true); +RELAX_REGISTER_UNARY_ARITH_DIST_INFER_STRUCT_INFO(erf, /*require_float_dtype=*/true); RELAX_REGISTER_UNARY_CHECK_DIST_INFER_STRUCT_INFO(isfinite); RELAX_REGISTER_UNARY_CHECK_DIST_INFER_STRUCT_INFO(isinf); diff --git a/src/relax/op/tensor/unary.cc b/src/relax/op/tensor/unary.cc index 6eef44821d..35ad0de2fc 100644 --- a/src/relax/op/tensor/unary.cc +++ b/src/relax/op/tensor/unary.cc @@ -62,6 +62,7 @@ RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(square, /*require_float_dtype=*/false); RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(sqrt, /*require_float_dtype=*/true); RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(tan, /*require_float_dtype=*/true); RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(tanh, /*require_float_dtype=*/true); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(erf, /*require_float_dtype=*/true); // relax.clip TVM_REGISTER_OP("relax.clip") diff --git a/src/relax/op/tensor/unary.h b/src/relax/op/tensor/unary.h index 5f92ed0b0c..dfbad37897 100644 --- a/src/relax/op/tensor/unary.h +++ b/src/relax/op/tensor/unary.h @@ -147,6 +147,9 @@ Expr isinf(Expr x); /*! \brief Check if input value is Nan. */ Expr isnan(Expr x); +/*! \brief Apply error function to input value. */ +Expr erf(Expr x); + } // namespace relax } // namespace tvm diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index c5c094e115..3467e5bba2 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -174,15 +174,15 @@ def test_sanitize(input_names, expected_names): assert param.name_hint == expected_names[i] -def verify_unary(op_name, shape, attrs={}, domain=None): +def verify_unary(op_name, shape, attrs={}, domain=None, dtype=TensorProto.FLOAT): test_node = helper.make_node(op_name, ["x"], ["y"], **attrs, domain=domain) graph = helper.make_graph( [test_node], "elemwise_test", inputs=[ - helper.make_tensor_value_info("x", TensorProto.FLOAT, shape), + helper.make_tensor_value_info("x", dtype, shape), ], - outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, shape)], + outputs=[helper.make_tensor_value_info("y", dtype, shape)], ) model = helper.make_model(graph, producer_name="elemwise_test") @@ -562,7 +562,8 @@ def test_pow(): def test_erf(): - verify_unary("Erf", [32, 32]) + verify_unary("Erf", [32, 32], dtype=TensorProto.FLOAT) + verify_unary("Erf", [32, 32], dtype=TensorProto.FLOAT16) @pytest.mark.parametrize("reverse", [False]) diff --git a/tests/python/relax/test_op_unary.py b/tests/python/relax/test_op_unary.py index 9bfb8612ef..3a6c14fd66 100644 --- a/tests/python/relax/test_op_unary.py +++ b/tests/python/relax/test_op_unary.py @@ -53,6 +53,7 @@ def test_op_correctness(): assert relax.op.tan(x).op == Op.get("relax.tan") assert relax.op.tanh(x).op == Op.get("relax.tanh") assert relax.op.clip(x, 0, 6).op == Op.get("relax.clip") + assert relax.op.erf(x).op == Op.get("relax.erf") x = relax.Var("x", R.Tensor((2, 3), "int32")) assert relax.op.bitwise_not(x).op == Op.get("relax.bitwise_not")