This is an automated email from the ASF dual-hosted git repository.
sanirudh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 77391714ab Replacing unary ops with LookUpTable and Take op to improve
performance (#17214)
77391714ab is described below
commit 77391714ab714afcc849fde1378a5a0c62d99c2e
Author: sdalvi-quic <[email protected]>
AuthorDate: Fri Aug 9 00:27:35 2024 -0500
Replacing unary ops with LookUpTable and Take op to improve performance
(#17214)
* Created Look Up Table for unary ops such that the values are computed
during compile time and take op is used to access the values at runtime
* Black formatting for hexagon_unary_ops.py
* minor edit
* Accessed variables with op attributes and op name in the prim fucn
definition. Added check if the call node is of call tir type
---
python/tvm/contrib/hexagon/generate_take_op.py | 98 ++++++
python/tvm/contrib/hexagon/hexagon_unary_ops.py | 97 ++++++
tests/python/contrib/test_hexagon/test_take.py | 393 ++++++++++++++++++++++++
3 files changed, 588 insertions(+)
diff --git a/python/tvm/contrib/hexagon/generate_take_op.py
b/python/tvm/contrib/hexagon/generate_take_op.py
new file mode 100644
index 0000000000..b70eb451a1
--- /dev/null
+++ b/python/tvm/contrib/hexagon/generate_take_op.py
@@ -0,0 +1,98 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring, invalid-name, unnecessary-comprehension,
unused-argument
+
+import tvm
+import tvm.testing
+from tvm import relax
+from tvm.contrib.hexagon import hexagon_unary_ops
+
+
+def op_replace(call_node, func) -> bool:
+ if not isinstance(call_node, relax.Call):
+ return False
+ call_tir_op = tvm.ir.Op.get("relax.call_tir")
+ if call_node.op != call_tir_op:
+ return False
+ ops = [
+ "qnn.tanh",
+ "qnn.sqrt",
+ "qnn.rsqrt",
+ "qnn.exp",
+ "qnn.erf",
+ "qnn.sigmoid",
+ "qnn.hardswish",
+ "qnn.log",
+ "qnn.abs",
+ ]
+ if func.attrs["op_attrs"]["op_name"] in ops:
+ return True
+ return False
+
+
[email protected]_functor.mutator
+class Tanh2TakeReplace(tvm.relax.PyExprMutator):
+ def __init__(self, mod: tvm.IRModule) -> None:
+ super().__init__(mod)
+ self.mod_ = mod
+
+ def transform(self) -> tvm.IRModule:
+ # Iterate over all the nodes to check for the node replaceable
+ for global_var, func in self.mod_.functions.items():
+ # Skip non-relax functions
+ if not isinstance(func, relax.Function):
+ continue
+ updated_func = self.visit_expr(func)
+ self.builder_.normalize(updated_func)
+ self.builder_.update_func(global_var, updated_func)
+ # At the end of the transformation we return the updated IRModule from
the BlockBuilder.
+ return self.builder_.get()
+
+ def visit_call_(self, call_node: relax.Call) -> relax.Call:
+ call_tir_op = tvm.ir.Op.get("relax.call_tir")
+ if call_node.op != call_tir_op:
+ return call_node
+
+ var = call_node.args[0]
+ func = self.mod_[var]
+
+ if call_node.args[1][0].struct_info.dtype == "uint8":
+ if op_replace(call_node, func):
+ inp, inp_scale, inp_zp, out_scale, out_zp = [x for x in
call_node.args[1]]
+ # LUT node creation
+ LUT = hexagon_unary_ops.LUT_generation(
+ inp_scale, inp_zp, out_scale, out_zp,
call_node.args[0].name_hint
+ )
+ # Take operation node creation
+ take_func = hexagon_unary_ops.generate_take_primfunc(inp,
call_node.struct_info)
+ take_func = take_func.without_attr("global_symbol")
+ take_func_gv = self.builder_.add_func(take_func, "take")
+ take_node = relax.call_tir(
+ take_func_gv,
+ relax.expr.Tuple(
+ [call_node.args[1][0],
relax.expr.Constant(tvm.nd.array(LUT))]
+ ),
+ call_node.struct_info,
+ )
+ return take_node
+ return call_node
+
+
[email protected]_pass(opt_level=2, name="replace_tanh_take")
+class PassReplaceWithTakeOpPrimFuncs:
+ def transform_module(self, mod, ctx):
+ return Tanh2TakeReplace(mod).transform()
diff --git a/python/tvm/contrib/hexagon/hexagon_unary_ops.py
b/python/tvm/contrib/hexagon/hexagon_unary_ops.py
new file mode 100644
index 0000000000..1bb4d4ba4f
--- /dev/null
+++ b/python/tvm/contrib/hexagon/hexagon_unary_ops.py
@@ -0,0 +1,97 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring, invalid-name
+import logging
+import numpy as np
+from scipy import special
+from tvm import te
+
+logger = logging.getLogger(__name__)
+
+######################################################################
+#################### PRIMFUNC FOR LUT and Take Op ####################
+######################################################################
+
+
+def saturate(x: te.Tensor, dtype: str):
+ """Saturate value for the specified data type"""
+ return te.max(te.min_value(dtype), te.min(x, te.max_value(dtype)))
+
+
+def hardswish_func(x):
+ x_2 = np.add(x, 3.0)
+ x_2 = np.clip(x_2, 0.0, 6.0)
+ return x * x_2 / 6.0
+
+
+def LUT_generation(inp_scale, inp_zp, out_scale, out_zp, op_name) -> None:
+ LUT = []
+ for i in range(256):
+ i = np.int32(i)
+ # converting the constants to the numpy value
+ if inp_zp.data.shape == ():
+ i_zp = inp_zp.data.numpy()[()]
+ if inp_scale.data.shape == ():
+ i_scale = inp_scale.data.numpy()[()]
+ if out_zp.data.shape == ():
+ o_zp = out_zp.data.numpy()[()]
+ if out_scale.data.shape == ():
+ o_scale = out_scale.data.numpy()[()]
+ # Dequantization followed by computing the op value
+ dequant = (i - i_zp) * i_scale
+ if "tanh" in op_name:
+ op_val = np.tanh(dequant)
+ elif "rsqrt" in op_name:
+ op_val = 1 / np.sqrt(dequant)
+ elif "sqrt" in op_name:
+ op_val = np.sqrt(dequant)
+ elif "exp" in op_name:
+ op_val = np.exp(dequant)
+ elif "erf" in op_name:
+ op_val = special.erf(dequant)
+ elif "sigmoid" in op_name:
+ op_val = 1 / (1 + np.exp(np.negative(dequant)))
+ elif "hardswish" in op_name:
+ op_val = hardswish_func(dequant)
+ elif "log" in op_name:
+ op_val = np.log(dequant)
+ elif "abs" in op_name:
+ op_val = np.abs(dequant)
+ else:
+ logger.error("Error op is other than unary op")
+
+ # Quantizing the value generated and appending in the Look Up Table
+ quant = np.round((op_val) / o_scale) + o_zp
+ val = np.maximum(0, np.minimum(quant, 255)).astype(np.uint8)
+ LUT.append(val)
+ return LUT
+
+
+def generate_take_primfunc(inp, struct_info):
+ # Generating the take op
+ N, H, W, C = inp.struct_info.shape
+ data = te.placeholder((N, H, W, C), dtype=struct_info.dtype, name="data")
+ LUT_func = te.placeholder((256,), dtype="uint8", name="LUT")
+ take = te.compute(
+ struct_info.shape,
+ lambda *indices: saturate(
+ (LUT_func[data[indices].astype("uint8")]), struct_info.dtype
+ ).astype(struct_info.dtype),
+ name="take_op",
+ )
+ mod = te.create_prim_func([data, LUT_func, take])
+ return mod
diff --git a/tests/python/contrib/test_hexagon/test_take.py
b/tests/python/contrib/test_hexagon/test_take.py
new file mode 100644
index 0000000000..80c2b05339
--- /dev/null
+++ b/tests/python/contrib/test_hexagon/test_take.py
@@ -0,0 +1,393 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring, invalid-name, unused-argument,
not-callable
+import numpy as np
+from scipy import special
+
+import tvm
+import tvm.testing
+from tvm import relax
+from tvm.script import tir as T, relax as R
+from tvm.contrib.hexagon import generate_take_op
+from tvm.contrib.hexagon import hexagon_unary_ops
+
+from .infrastructure import quantize_np
+
+
+# Testing the structural and value correctness on replacing unary op with take
op.
+
+
[email protected]_module
+class Module_tanh:
+ @R.function
+ def main(
+ input_tanh: R.Tensor((1, 2, 2, 2), dtype="uint8"),
+ ) -> R.Tensor((1, 2, 2, 2), dtype="uint8"):
+ out = R.call_tir(
+ Module_tanh.tanh,
+ (
+ input_tanh,
+ R.const(0.003186821002586215, "float32"),
+ R.const(0, "int32"),
+ R.const(0.002631544131858676, "float32"),
+ R.const(0, "int32"),
+ ),
+ out_sinfo=R.Tensor((1, 2, 2, 2), dtype="uint8"),
+ )
+ return out
+
+ @T.prim_func
+ def tanh(
+ rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(2),
T.int64(2)), "uint8"),
+ rxplaceholder_1: T.Buffer((), "float32"),
+ rxplaceholder_2: T.Buffer((), "int32"),
+ rxplaceholder_3: T.Buffer((), "float32"),
+ rxplaceholder_4: T.Buffer((), "int32"),
+ compute: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)),
"uint8"),
+ ):
+ T.func_attr({"tir.noalias": True, "op_attrs": {"op_name": "qnn.tanh"}})
+
+
[email protected]_module
+class Module_sqrt:
+ @R.function
+ def main(
+ input_sqrt: R.Tensor((1, 2, 2, 2), dtype="uint8"),
+ ) -> R.Tensor((1, 2, 2, 2), dtype="uint8"):
+ out = R.call_tir(
+ Module_sqrt.sqrt,
+ (
+ input_sqrt,
+ R.const(0.003186821002586215, "float32"),
+ R.const(0, "int32"),
+ R.const(0.003535157327728918, "float32"),
+ R.const(0, "int32"),
+ ),
+ out_sinfo=R.Tensor((1, 2, 2, 2), dtype="uint8"),
+ )
+ return out
+
+ @T.prim_func
+ def sqrt(
+ rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(2),
T.int64(2)), "uint8"),
+ rxplaceholder_1: T.Buffer((), "float32"),
+ rxplaceholder_2: T.Buffer((), "int32"),
+ rxplaceholder_3: T.Buffer((), "float32"),
+ rxplaceholder_4: T.Buffer((), "int32"),
+ compute: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)),
"uint8"),
+ ):
+ T.func_attr({"tir.noalias": True, "op_attrs": {"op_name": "qnn.sqrt"}})
+
+
[email protected]_module
+class Module_rsqrt:
+ @R.function
+ def main(
+ input_rsqrt: R.Tensor((1, 2, 2, 2), dtype="uint8"),
+ ) -> R.Tensor((1, 2, 2, 2), dtype="uint8"):
+ out = R.call_tir(
+ Module_rsqrt.rsqrt,
+ (
+ input_rsqrt,
+ R.const(0.003186821002586215, "float32"),
+ R.const(0, "int32"),
+ R.const(0.008154160766635542, "float32"),
+ R.const(0, "int32"),
+ ),
+ out_sinfo=R.Tensor((1, 2, 2, 2), dtype="uint8"),
+ )
+ return out
+
+ @T.prim_func
+ def rsqrt(
+ rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(2),
T.int64(2)), "uint8"),
+ rxplaceholder_1: T.Buffer((), "float32"),
+ rxplaceholder_2: T.Buffer((), "int32"),
+ rxplaceholder_3: T.Buffer((), "float32"),
+ rxplaceholder_4: T.Buffer((), "int32"),
+ compute: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)),
"uint8"),
+ ):
+ T.func_attr({"tir.noalias": True, "op_attrs": {"op_name":
"qnn.rsqrt"}})
+
+
[email protected]_module
+class Module_exp:
+ @R.function
+ def main(
+ input_exp: R.Tensor((1, 2, 2, 2), dtype="uint8"),
+ ) -> R.Tensor((1, 2, 2, 2), dtype="uint8"):
+ out = R.call_tir(
+ Module_exp.exp,
+ (
+ input_exp,
+ R.const(0.003186821002586215, "float32"),
+ R.const(0, "int32"),
+ R.const(0.008838622987079832, "float32"),
+ R.const(0, "int32"),
+ ),
+ out_sinfo=R.Tensor((1, 2, 2, 2), dtype="uint8"),
+ )
+ return out
+
+ @T.prim_func
+ def exp(
+ rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(2),
T.int64(2)), "uint8"),
+ rxplaceholder_1: T.Buffer((), "float32"),
+ rxplaceholder_2: T.Buffer((), "int32"),
+ rxplaceholder_3: T.Buffer((), "float32"),
+ rxplaceholder_4: T.Buffer((), "int32"),
+ compute: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)),
"uint8"),
+ ):
+ T.func_attr({"tir.noalias": True, "op_attrs": {"op_name": "qnn.exp"}})
+
+
[email protected]_module
+class Module_erf:
+ @R.function
+ def main(
+ input_erf: R.Tensor((1, 2, 2, 2), dtype="uint8"),
+ ) -> R.Tensor((1, 2, 2, 2), dtype="uint8"):
+ out = R.call_tir(
+ Module_erf.erf,
+ (
+ input_erf,
+ R.const(0.003186821002586215, "float32"),
+ R.const(0, "int32"),
+ R.const(0.002939393251118067, "float32"),
+ R.const(0, "int32"),
+ ),
+ out_sinfo=R.Tensor((1, 2, 2, 2), dtype="uint8"),
+ )
+ return out
+
+ @T.prim_func
+ def erf(
+ rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(2),
T.int64(2)), "uint8"),
+ rxplaceholder_1: T.Buffer((), "float32"),
+ rxplaceholder_2: T.Buffer((), "int32"),
+ rxplaceholder_3: T.Buffer((), "float32"),
+ rxplaceholder_4: T.Buffer((), "int32"),
+ compute: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)),
"uint8"),
+ ):
+ T.func_attr({"tir.noalias": True, "op_attrs": {"op_name": "qnn.erf"}})
+
+
[email protected]_module
+class Module_sigmoid:
+ @R.function
+ def main(
+ input_sigmoid: R.Tensor((1, 2, 2, 2), dtype="uint8"),
+ ) -> R.Tensor((1, 2, 2, 2), dtype="uint8"):
+ out = R.call_tir(
+ Module_sigmoid.sigmoid,
+ (
+ input_sigmoid,
+ R.const(0.003186821002586215, "float32"),
+ R.const(0, "int32"),
+ R.const(0.002631544131858676, "float32"),
+ R.const(0, "int32"),
+ ),
+ out_sinfo=R.Tensor((1, 2, 2, 2), dtype="uint8"),
+ )
+ return out
+
+ @T.prim_func
+ def sigmoid(
+ rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(2),
T.int64(2)), "uint8"),
+ rxplaceholder_1: T.Buffer((), "float32"),
+ rxplaceholder_2: T.Buffer((), "int32"),
+ rxplaceholder_3: T.Buffer((), "float32"),
+ rxplaceholder_4: T.Buffer((), "int32"),
+ compute: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)),
"uint8"),
+ ):
+ T.func_attr({"tir.noalias": True, "op_attrs": {"op_name":
"qnn.sigmoid"}})
+
+
[email protected]_module
+class Module_hardswish:
+ @R.function
+ def main(
+ input_hardswish: R.Tensor((1, 2, 2, 2), dtype="uint8"),
+ ) -> R.Tensor((1, 2, 2, 2), dtype="uint8"):
+ out = R.call_tir(
+ Module_hardswish.hardswish,
+ (
+ input_hardswish,
+ R.const(0.003186821002586215, "float32"),
+ R.const(0, "int32"),
+ R.const(0.0020250332087720325, "float32"),
+ R.const(0, "int32"),
+ ),
+ out_sinfo=R.Tensor((1, 2, 2, 2), dtype="uint8"),
+ )
+ return out
+
+ @T.prim_func
+ def hardswish(
+ rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(2),
T.int64(2)), "uint8"),
+ rxplaceholder_1: T.Buffer((), "float32"),
+ rxplaceholder_2: T.Buffer((), "int32"),
+ rxplaceholder_3: T.Buffer((), "float32"),
+ rxplaceholder_4: T.Buffer((), "int32"),
+ compute: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)),
"uint8"),
+ ):
+ T.func_attr({"tir.noalias": True, "op_attrs": {"op_name":
"qnn.hardswish"}})
+
+
[email protected]_module
+class Module_log:
+ @R.function
+ def main(
+ input_log: R.Tensor((1, 2, 2, 2), dtype="uint8"),
+ ) -> R.Tensor((1, 2, 2, 2), dtype="uint8"):
+ out = R.call_tir(
+ Module_log.log,
+ (
+ input_log,
+ R.const(0.003186821002586215, "float32"),
+ R.const(0, "int32"),
+ R.const(0.0057414634248614226, "float32"),
+ R.const(255, "int32"),
+ ),
+ out_sinfo=R.Tensor((1, 2, 2, 2), dtype="uint8"),
+ )
+ return out
+
+ @T.prim_func
+ def log(
+ rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(2),
T.int64(2)), "uint8"),
+ rxplaceholder_1: T.Buffer((), "float32"),
+ rxplaceholder_2: T.Buffer((), "int32"),
+ rxplaceholder_3: T.Buffer((), "float32"),
+ rxplaceholder_4: T.Buffer((), "int32"),
+ compute: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)),
"uint8"),
+ ):
+ T.func_attr({"tir.noalias": True, "op_attrs": {"op_name": "qnn.log"}})
+
+
[email protected]_module
+class Module_abs:
+ @R.function
+ def main(
+ input_abs: R.Tensor((1, 2, 2, 2), dtype="uint8"),
+ ) -> R.Tensor((1, 2, 2, 2), dtype="uint8"):
+ out = R.call_tir(
+ Module_abs.abs,
+ (
+ input_abs,
+ R.const(0.003186821002586215, "float32"),
+ R.const(0, "int32"),
+ R.const(0.0031868210196078434, "float32"),
+ R.const(0, "int32"),
+ ),
+ out_sinfo=R.Tensor((1, 2, 2, 2), dtype="uint8"),
+ )
+ return out
+
+ @T.prim_func
+ def abs(
+ rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(2),
T.int64(2)), "uint8"),
+ rxplaceholder_1: T.Buffer((), "float32"),
+ rxplaceholder_2: T.Buffer((), "int32"),
+ rxplaceholder_3: T.Buffer((), "float32"),
+ rxplaceholder_4: T.Buffer((), "int32"),
+ compute: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)),
"uint8"),
+ ):
+ T.func_attr({"tir.noalias": True, "op_attrs": {"op_name": "qnn.abs"}})
+
+
+# data = np.random.random([1, 2, 2, 2]).astype("float32") : Need to hadcode
the data
+# so that we can get the quantization parameters and use them as input to the
main func
+data = [
+ [
+ [[0.3034368, 0.60848576], [0.29697746, 0.67340654]],
+ [[0.656068, 0.23129226], [0.42117321, 0.81263936]],
+ ]
+]
+dtype = "uint8"
+
+# Quantizing input : scale is returned as float64 and zp is returned as int32
+inp_quant, inp_scale, inp_zero_point = quantize_np(data, dtype)
+inp_quant = tvm.nd.array(inp_quant.astype(np.uint8))
+
+
+# Test the implementations value output with numpy data. First the IR is runn
through pass
+# to replace unary op with take op. Followed by value testing.
+def test_value():
+ ops = ["tanh", "sqrt", "rsqrt", "exp", "erf", "sigmoid", "hardswish",
"log", "abs"]
+
+ atol_val = 2
+ for op_name in ops:
+ if op_name == "tanh":
+ op_val = np.tanh(data)
+ before = Module_tanh
+ elif op_name == "sqrt":
+ op_val = np.sqrt(data)
+ before = Module_sqrt
+ elif op_name == "rsqrt":
+ op_val = 1 / np.sqrt(data)
+ before = Module_rsqrt
+ elif op_name == "exp":
+ op_val = np.exp(data)
+ before = Module_exp
+ elif op_name == "erf":
+ op_val = special.erf(data)
+ before = Module_erf
+ elif op_name == "sigmoid":
+ op_val = 1 / (1 + np.exp(np.negative(data)))
+ atol_val = 15
+ before = Module_sigmoid
+ elif op_name == "hardswish":
+ op_val = hexagon_unary_ops.hardswish_func(data)
+ before = Module_hardswish
+ elif op_name == "log":
+ op_val = np.log(data)
+ before = Module_log
+ elif op_name == "abs":
+ op_val = np.abs(data)
+ before = Module_abs
+
+ # Quantizing output : scale is returned as float64 and zp is returned
as int32
+ out_quant, _, _ = quantize_np(op_val, dtype)
+
+ after = generate_take_op.PassReplaceWithTakeOpPrimFuncs()(before)
+ target = tvm.target.Target("llvm", host="llvm")
+ ex = relax.build(after, target, exec_mode="compiled")
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+ res = vm["main"](inp_quant)
+
+ tvm.testing.assert_allclose(res.numpy(), out_quant, atol=atol_val)
+ print("Passed Value : ", op_name)
+
+
+# Testing the structural implementation, if the unary op is replaced with take
op.
+def test_structural():
+ Modules = [
+ Module_tanh,
+ Module_sqrt,
+ Module_rsqrt,
+ Module_exp,
+ Module_erf,
+ Module_sigmoid,
+ Module_hardswish,
+ Module_log,
+ Module_abs,
+ ]
+ for mod in Modules:
+ after = generate_take_op.PassReplaceWithTakeOpPrimFuncs()(mod)
+ assert not tvm.ir.structural_equal(after["main"], mod["main"])
+ print("Passed Structural")