This is an automated email from the ASF dual-hosted git repository.

sanirudh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 77391714ab Replacing unary ops with LookUpTable and Take op to improve 
performance (#17214)
77391714ab is described below

commit 77391714ab714afcc849fde1378a5a0c62d99c2e
Author: sdalvi-quic <[email protected]>
AuthorDate: Fri Aug 9 00:27:35 2024 -0500

    Replacing unary ops with LookUpTable and Take op to improve performance 
(#17214)
    
    * Created Look Up Table for unary ops such that the values are computed 
during compile time and take op is used to access the values at runtime
    
    * Black formatting for hexagon_unary_ops.py
    
    * minor edit
    
    * Accessed variables with op attributes and op name in the prim fucn 
definition. Added check if the call node is of call tir type
---
 python/tvm/contrib/hexagon/generate_take_op.py  |  98 ++++++
 python/tvm/contrib/hexagon/hexagon_unary_ops.py |  97 ++++++
 tests/python/contrib/test_hexagon/test_take.py  | 393 ++++++++++++++++++++++++
 3 files changed, 588 insertions(+)

diff --git a/python/tvm/contrib/hexagon/generate_take_op.py 
b/python/tvm/contrib/hexagon/generate_take_op.py
new file mode 100644
index 0000000000..b70eb451a1
--- /dev/null
+++ b/python/tvm/contrib/hexagon/generate_take_op.py
@@ -0,0 +1,98 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring, invalid-name, unnecessary-comprehension, 
unused-argument
+
+import tvm
+import tvm.testing
+from tvm import relax
+from tvm.contrib.hexagon import hexagon_unary_ops
+
+
+def op_replace(call_node, func) -> bool:
+    if not isinstance(call_node, relax.Call):
+        return False
+    call_tir_op = tvm.ir.Op.get("relax.call_tir")
+    if call_node.op != call_tir_op:
+        return False
+    ops = [
+        "qnn.tanh",
+        "qnn.sqrt",
+        "qnn.rsqrt",
+        "qnn.exp",
+        "qnn.erf",
+        "qnn.sigmoid",
+        "qnn.hardswish",
+        "qnn.log",
+        "qnn.abs",
+    ]
+    if func.attrs["op_attrs"]["op_name"] in ops:
+        return True
+    return False
+
+
[email protected]_functor.mutator
+class Tanh2TakeReplace(tvm.relax.PyExprMutator):
+    def __init__(self, mod: tvm.IRModule) -> None:
+        super().__init__(mod)
+        self.mod_ = mod
+
+    def transform(self) -> tvm.IRModule:
+        # Iterate over all the nodes to check for the node replaceable
+        for global_var, func in self.mod_.functions.items():
+            # Skip non-relax functions
+            if not isinstance(func, relax.Function):
+                continue
+            updated_func = self.visit_expr(func)
+            self.builder_.normalize(updated_func)
+            self.builder_.update_func(global_var, updated_func)
+        # At the end of the transformation we return the updated IRModule from 
the BlockBuilder.
+        return self.builder_.get()
+
+    def visit_call_(self, call_node: relax.Call) -> relax.Call:
+        call_tir_op = tvm.ir.Op.get("relax.call_tir")
+        if call_node.op != call_tir_op:
+            return call_node
+
+        var = call_node.args[0]
+        func = self.mod_[var]
+
+        if call_node.args[1][0].struct_info.dtype == "uint8":
+            if op_replace(call_node, func):
+                inp, inp_scale, inp_zp, out_scale, out_zp = [x for x in 
call_node.args[1]]
+                # LUT node creation
+                LUT = hexagon_unary_ops.LUT_generation(
+                    inp_scale, inp_zp, out_scale, out_zp, 
call_node.args[0].name_hint
+                )
+                # Take operation node creation
+                take_func = hexagon_unary_ops.generate_take_primfunc(inp, 
call_node.struct_info)
+                take_func = take_func.without_attr("global_symbol")
+                take_func_gv = self.builder_.add_func(take_func, "take")
+                take_node = relax.call_tir(
+                    take_func_gv,
+                    relax.expr.Tuple(
+                        [call_node.args[1][0], 
relax.expr.Constant(tvm.nd.array(LUT))]
+                    ),
+                    call_node.struct_info,
+                )
+                return take_node
+        return call_node
+
+
[email protected]_pass(opt_level=2, name="replace_tanh_take")
+class PassReplaceWithTakeOpPrimFuncs:
+    def transform_module(self, mod, ctx):
+        return Tanh2TakeReplace(mod).transform()
diff --git a/python/tvm/contrib/hexagon/hexagon_unary_ops.py 
b/python/tvm/contrib/hexagon/hexagon_unary_ops.py
new file mode 100644
index 0000000000..1bb4d4ba4f
--- /dev/null
+++ b/python/tvm/contrib/hexagon/hexagon_unary_ops.py
@@ -0,0 +1,97 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring, invalid-name
+import logging
+import numpy as np
+from scipy import special
+from tvm import te
+
+logger = logging.getLogger(__name__)
+
+######################################################################
+#################### PRIMFUNC FOR LUT and Take Op ####################
+######################################################################
+
+
+def saturate(x: te.Tensor, dtype: str):
+    """Saturate value for the specified data type"""
+    return te.max(te.min_value(dtype), te.min(x, te.max_value(dtype)))
+
+
+def hardswish_func(x):
+    x_2 = np.add(x, 3.0)
+    x_2 = np.clip(x_2, 0.0, 6.0)
+    return x * x_2 / 6.0
+
+
+def LUT_generation(inp_scale, inp_zp, out_scale, out_zp, op_name) -> None:
+    LUT = []
+    for i in range(256):
+        i = np.int32(i)
+        # converting the constants to the numpy value
+        if inp_zp.data.shape == ():
+            i_zp = inp_zp.data.numpy()[()]
+        if inp_scale.data.shape == ():
+            i_scale = inp_scale.data.numpy()[()]
+        if out_zp.data.shape == ():
+            o_zp = out_zp.data.numpy()[()]
+        if out_scale.data.shape == ():
+            o_scale = out_scale.data.numpy()[()]
+        # Dequantization followed by computing the op value
+        dequant = (i - i_zp) * i_scale
+        if "tanh" in op_name:
+            op_val = np.tanh(dequant)
+        elif "rsqrt" in op_name:
+            op_val = 1 / np.sqrt(dequant)
+        elif "sqrt" in op_name:
+            op_val = np.sqrt(dequant)
+        elif "exp" in op_name:
+            op_val = np.exp(dequant)
+        elif "erf" in op_name:
+            op_val = special.erf(dequant)
+        elif "sigmoid" in op_name:
+            op_val = 1 / (1 + np.exp(np.negative(dequant)))
+        elif "hardswish" in op_name:
+            op_val = hardswish_func(dequant)
+        elif "log" in op_name:
+            op_val = np.log(dequant)
+        elif "abs" in op_name:
+            op_val = np.abs(dequant)
+        else:
+            logger.error("Error op is other than unary op")
+
+        # Quantizing the value generated and appending in the Look Up Table
+        quant = np.round((op_val) / o_scale) + o_zp
+        val = np.maximum(0, np.minimum(quant, 255)).astype(np.uint8)
+        LUT.append(val)
+    return LUT
+
+
+def generate_take_primfunc(inp, struct_info):
+    # Generating the take op
+    N, H, W, C = inp.struct_info.shape
+    data = te.placeholder((N, H, W, C), dtype=struct_info.dtype, name="data")
+    LUT_func = te.placeholder((256,), dtype="uint8", name="LUT")
+    take = te.compute(
+        struct_info.shape,
+        lambda *indices: saturate(
+            (LUT_func[data[indices].astype("uint8")]), struct_info.dtype
+        ).astype(struct_info.dtype),
+        name="take_op",
+    )
+    mod = te.create_prim_func([data, LUT_func, take])
+    return mod
diff --git a/tests/python/contrib/test_hexagon/test_take.py 
b/tests/python/contrib/test_hexagon/test_take.py
new file mode 100644
index 0000000000..80c2b05339
--- /dev/null
+++ b/tests/python/contrib/test_hexagon/test_take.py
@@ -0,0 +1,393 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring, invalid-name, unused-argument, 
not-callable
+import numpy as np
+from scipy import special
+
+import tvm
+import tvm.testing
+from tvm import relax
+from tvm.script import tir as T, relax as R
+from tvm.contrib.hexagon import generate_take_op
+from tvm.contrib.hexagon import hexagon_unary_ops
+
+from .infrastructure import quantize_np
+
+
+# Testing the structural and value correctness on replacing unary op with take 
op.
+
+
[email protected]_module
+class Module_tanh:
+    @R.function
+    def main(
+        input_tanh: R.Tensor((1, 2, 2, 2), dtype="uint8"),
+    ) -> R.Tensor((1, 2, 2, 2), dtype="uint8"):
+        out = R.call_tir(
+            Module_tanh.tanh,
+            (
+                input_tanh,
+                R.const(0.003186821002586215, "float32"),
+                R.const(0, "int32"),
+                R.const(0.002631544131858676, "float32"),
+                R.const(0, "int32"),
+            ),
+            out_sinfo=R.Tensor((1, 2, 2, 2), dtype="uint8"),
+        )
+        return out
+
+    @T.prim_func
+    def tanh(
+        rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(2), 
T.int64(2)), "uint8"),
+        rxplaceholder_1: T.Buffer((), "float32"),
+        rxplaceholder_2: T.Buffer((), "int32"),
+        rxplaceholder_3: T.Buffer((), "float32"),
+        rxplaceholder_4: T.Buffer((), "int32"),
+        compute: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), 
"uint8"),
+    ):
+        T.func_attr({"tir.noalias": True, "op_attrs": {"op_name": "qnn.tanh"}})
+
+
[email protected]_module
+class Module_sqrt:
+    @R.function
+    def main(
+        input_sqrt: R.Tensor((1, 2, 2, 2), dtype="uint8"),
+    ) -> R.Tensor((1, 2, 2, 2), dtype="uint8"):
+        out = R.call_tir(
+            Module_sqrt.sqrt,
+            (
+                input_sqrt,
+                R.const(0.003186821002586215, "float32"),
+                R.const(0, "int32"),
+                R.const(0.003535157327728918, "float32"),
+                R.const(0, "int32"),
+            ),
+            out_sinfo=R.Tensor((1, 2, 2, 2), dtype="uint8"),
+        )
+        return out
+
+    @T.prim_func
+    def sqrt(
+        rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(2), 
T.int64(2)), "uint8"),
+        rxplaceholder_1: T.Buffer((), "float32"),
+        rxplaceholder_2: T.Buffer((), "int32"),
+        rxplaceholder_3: T.Buffer((), "float32"),
+        rxplaceholder_4: T.Buffer((), "int32"),
+        compute: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), 
"uint8"),
+    ):
+        T.func_attr({"tir.noalias": True, "op_attrs": {"op_name": "qnn.sqrt"}})
+
+
[email protected]_module
+class Module_rsqrt:
+    @R.function
+    def main(
+        input_rsqrt: R.Tensor((1, 2, 2, 2), dtype="uint8"),
+    ) -> R.Tensor((1, 2, 2, 2), dtype="uint8"):
+        out = R.call_tir(
+            Module_rsqrt.rsqrt,
+            (
+                input_rsqrt,
+                R.const(0.003186821002586215, "float32"),
+                R.const(0, "int32"),
+                R.const(0.008154160766635542, "float32"),
+                R.const(0, "int32"),
+            ),
+            out_sinfo=R.Tensor((1, 2, 2, 2), dtype="uint8"),
+        )
+        return out
+
+    @T.prim_func
+    def rsqrt(
+        rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(2), 
T.int64(2)), "uint8"),
+        rxplaceholder_1: T.Buffer((), "float32"),
+        rxplaceholder_2: T.Buffer((), "int32"),
+        rxplaceholder_3: T.Buffer((), "float32"),
+        rxplaceholder_4: T.Buffer((), "int32"),
+        compute: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), 
"uint8"),
+    ):
+        T.func_attr({"tir.noalias": True, "op_attrs": {"op_name": 
"qnn.rsqrt"}})
+
+
[email protected]_module
+class Module_exp:
+    @R.function
+    def main(
+        input_exp: R.Tensor((1, 2, 2, 2), dtype="uint8"),
+    ) -> R.Tensor((1, 2, 2, 2), dtype="uint8"):
+        out = R.call_tir(
+            Module_exp.exp,
+            (
+                input_exp,
+                R.const(0.003186821002586215, "float32"),
+                R.const(0, "int32"),
+                R.const(0.008838622987079832, "float32"),
+                R.const(0, "int32"),
+            ),
+            out_sinfo=R.Tensor((1, 2, 2, 2), dtype="uint8"),
+        )
+        return out
+
+    @T.prim_func
+    def exp(
+        rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(2), 
T.int64(2)), "uint8"),
+        rxplaceholder_1: T.Buffer((), "float32"),
+        rxplaceholder_2: T.Buffer((), "int32"),
+        rxplaceholder_3: T.Buffer((), "float32"),
+        rxplaceholder_4: T.Buffer((), "int32"),
+        compute: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), 
"uint8"),
+    ):
+        T.func_attr({"tir.noalias": True, "op_attrs": {"op_name": "qnn.exp"}})
+
+
[email protected]_module
+class Module_erf:
+    @R.function
+    def main(
+        input_erf: R.Tensor((1, 2, 2, 2), dtype="uint8"),
+    ) -> R.Tensor((1, 2, 2, 2), dtype="uint8"):
+        out = R.call_tir(
+            Module_erf.erf,
+            (
+                input_erf,
+                R.const(0.003186821002586215, "float32"),
+                R.const(0, "int32"),
+                R.const(0.002939393251118067, "float32"),
+                R.const(0, "int32"),
+            ),
+            out_sinfo=R.Tensor((1, 2, 2, 2), dtype="uint8"),
+        )
+        return out
+
+    @T.prim_func
+    def erf(
+        rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(2), 
T.int64(2)), "uint8"),
+        rxplaceholder_1: T.Buffer((), "float32"),
+        rxplaceholder_2: T.Buffer((), "int32"),
+        rxplaceholder_3: T.Buffer((), "float32"),
+        rxplaceholder_4: T.Buffer((), "int32"),
+        compute: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), 
"uint8"),
+    ):
+        T.func_attr({"tir.noalias": True, "op_attrs": {"op_name": "qnn.erf"}})
+
+
[email protected]_module
+class Module_sigmoid:
+    @R.function
+    def main(
+        input_sigmoid: R.Tensor((1, 2, 2, 2), dtype="uint8"),
+    ) -> R.Tensor((1, 2, 2, 2), dtype="uint8"):
+        out = R.call_tir(
+            Module_sigmoid.sigmoid,
+            (
+                input_sigmoid,
+                R.const(0.003186821002586215, "float32"),
+                R.const(0, "int32"),
+                R.const(0.002631544131858676, "float32"),
+                R.const(0, "int32"),
+            ),
+            out_sinfo=R.Tensor((1, 2, 2, 2), dtype="uint8"),
+        )
+        return out
+
+    @T.prim_func
+    def sigmoid(
+        rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(2), 
T.int64(2)), "uint8"),
+        rxplaceholder_1: T.Buffer((), "float32"),
+        rxplaceholder_2: T.Buffer((), "int32"),
+        rxplaceholder_3: T.Buffer((), "float32"),
+        rxplaceholder_4: T.Buffer((), "int32"),
+        compute: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), 
"uint8"),
+    ):
+        T.func_attr({"tir.noalias": True, "op_attrs": {"op_name": 
"qnn.sigmoid"}})
+
+
[email protected]_module
+class Module_hardswish:
+    @R.function
+    def main(
+        input_hardswish: R.Tensor((1, 2, 2, 2), dtype="uint8"),
+    ) -> R.Tensor((1, 2, 2, 2), dtype="uint8"):
+        out = R.call_tir(
+            Module_hardswish.hardswish,
+            (
+                input_hardswish,
+                R.const(0.003186821002586215, "float32"),
+                R.const(0, "int32"),
+                R.const(0.0020250332087720325, "float32"),
+                R.const(0, "int32"),
+            ),
+            out_sinfo=R.Tensor((1, 2, 2, 2), dtype="uint8"),
+        )
+        return out
+
+    @T.prim_func
+    def hardswish(
+        rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(2), 
T.int64(2)), "uint8"),
+        rxplaceholder_1: T.Buffer((), "float32"),
+        rxplaceholder_2: T.Buffer((), "int32"),
+        rxplaceholder_3: T.Buffer((), "float32"),
+        rxplaceholder_4: T.Buffer((), "int32"),
+        compute: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), 
"uint8"),
+    ):
+        T.func_attr({"tir.noalias": True, "op_attrs": {"op_name": 
"qnn.hardswish"}})
+
+
[email protected]_module
+class Module_log:
+    @R.function
+    def main(
+        input_log: R.Tensor((1, 2, 2, 2), dtype="uint8"),
+    ) -> R.Tensor((1, 2, 2, 2), dtype="uint8"):
+        out = R.call_tir(
+            Module_log.log,
+            (
+                input_log,
+                R.const(0.003186821002586215, "float32"),
+                R.const(0, "int32"),
+                R.const(0.0057414634248614226, "float32"),
+                R.const(255, "int32"),
+            ),
+            out_sinfo=R.Tensor((1, 2, 2, 2), dtype="uint8"),
+        )
+        return out
+
+    @T.prim_func
+    def log(
+        rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(2), 
T.int64(2)), "uint8"),
+        rxplaceholder_1: T.Buffer((), "float32"),
+        rxplaceholder_2: T.Buffer((), "int32"),
+        rxplaceholder_3: T.Buffer((), "float32"),
+        rxplaceholder_4: T.Buffer((), "int32"),
+        compute: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), 
"uint8"),
+    ):
+        T.func_attr({"tir.noalias": True, "op_attrs": {"op_name": "qnn.log"}})
+
+
[email protected]_module
+class Module_abs:
+    @R.function
+    def main(
+        input_abs: R.Tensor((1, 2, 2, 2), dtype="uint8"),
+    ) -> R.Tensor((1, 2, 2, 2), dtype="uint8"):
+        out = R.call_tir(
+            Module_abs.abs,
+            (
+                input_abs,
+                R.const(0.003186821002586215, "float32"),
+                R.const(0, "int32"),
+                R.const(0.0031868210196078434, "float32"),
+                R.const(0, "int32"),
+            ),
+            out_sinfo=R.Tensor((1, 2, 2, 2), dtype="uint8"),
+        )
+        return out
+
+    @T.prim_func
+    def abs(
+        rxplaceholder: T.Buffer((T.int64(1), T.int64(2), T.int64(2), 
T.int64(2)), "uint8"),
+        rxplaceholder_1: T.Buffer((), "float32"),
+        rxplaceholder_2: T.Buffer((), "int32"),
+        rxplaceholder_3: T.Buffer((), "float32"),
+        rxplaceholder_4: T.Buffer((), "int32"),
+        compute: T.Buffer((T.int64(1), T.int64(2), T.int64(2), T.int64(2)), 
"uint8"),
+    ):
+        T.func_attr({"tir.noalias": True, "op_attrs": {"op_name": "qnn.abs"}})
+
+
+# data = np.random.random([1, 2, 2, 2]).astype("float32") : Need to hadcode 
the data
+# so that we can get the quantization parameters and use them as input to the 
main func
+data = [
+    [
+        [[0.3034368, 0.60848576], [0.29697746, 0.67340654]],
+        [[0.656068, 0.23129226], [0.42117321, 0.81263936]],
+    ]
+]
+dtype = "uint8"
+
+# Quantizing input : scale is returned as float64 and zp is returned as int32
+inp_quant, inp_scale, inp_zero_point = quantize_np(data, dtype)
+inp_quant = tvm.nd.array(inp_quant.astype(np.uint8))
+
+
+# Test the implementations value output with numpy data. First the IR is runn 
through pass
+# to replace unary op with take op. Followed by value testing.
+def test_value():
+    ops = ["tanh", "sqrt", "rsqrt", "exp", "erf", "sigmoid", "hardswish", 
"log", "abs"]
+
+    atol_val = 2
+    for op_name in ops:
+        if op_name == "tanh":
+            op_val = np.tanh(data)
+            before = Module_tanh
+        elif op_name == "sqrt":
+            op_val = np.sqrt(data)
+            before = Module_sqrt
+        elif op_name == "rsqrt":
+            op_val = 1 / np.sqrt(data)
+            before = Module_rsqrt
+        elif op_name == "exp":
+            op_val = np.exp(data)
+            before = Module_exp
+        elif op_name == "erf":
+            op_val = special.erf(data)
+            before = Module_erf
+        elif op_name == "sigmoid":
+            op_val = 1 / (1 + np.exp(np.negative(data)))
+            atol_val = 15
+            before = Module_sigmoid
+        elif op_name == "hardswish":
+            op_val = hexagon_unary_ops.hardswish_func(data)
+            before = Module_hardswish
+        elif op_name == "log":
+            op_val = np.log(data)
+            before = Module_log
+        elif op_name == "abs":
+            op_val = np.abs(data)
+            before = Module_abs
+
+        # Quantizing output : scale is returned as float64 and zp is returned 
as int32
+        out_quant, _, _ = quantize_np(op_val, dtype)
+
+        after = generate_take_op.PassReplaceWithTakeOpPrimFuncs()(before)
+        target = tvm.target.Target("llvm", host="llvm")
+        ex = relax.build(after, target, exec_mode="compiled")
+        vm = relax.VirtualMachine(ex, tvm.cpu())
+        res = vm["main"](inp_quant)
+
+        tvm.testing.assert_allclose(res.numpy(), out_quant, atol=atol_val)
+        print("Passed Value : ", op_name)
+
+
+# Testing the structural implementation, if the unary op is replaced with take 
op.
+def test_structural():
+    Modules = [
+        Module_tanh,
+        Module_sqrt,
+        Module_rsqrt,
+        Module_exp,
+        Module_erf,
+        Module_sigmoid,
+        Module_hardswish,
+        Module_log,
+        Module_abs,
+    ]
+    for mod in Modules:
+        after = generate_take_op.PassReplaceWithTakeOpPrimFuncs()(mod)
+        assert not tvm.ir.structural_equal(after["main"], mod["main"])
+    print("Passed Structural")

Reply via email to