This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 3e98901965 [Unity][Legalize] Fix Scalar Constant Legalization (#14127)
3e98901965 is described below

commit 3e989019654ca45c62acb4bdbe4b18ba5a36d0dc
Author: Xiyou Zhou <[email protected]>
AuthorDate: Tue Feb 28 11:24:05 2023 -0800

    [Unity][Legalize] Fix Scalar Constant Legalization (#14127)
    
    This PR fixes the issue of loss of data type during Legalization. 
Previously, if we use a constant scalar in operators like `multiply`, it will 
automatically be converted to a python data type variable, which may lose its 
original data type. For example, `float16` may become python `float` and be 
interpreted as `float32` later.
    
    This is now fixed by avoiding scalar value conversion. The conversion could 
be added back once we have better support for scalar prim value.
    
    Co-authored-by: Sunghyun Park <[email protected]>
    Co-authored-by: Wuwei Lin <[email protected]>
---
 python/tvm/relax/transform/legalize_ops/common.py  | 34 ++++++--
 .../tvm/relax/transform/legalize_ops/creation.py   |  4 +-
 .../tvm/relax/transform/legalize_ops/datatype.py   |  2 +-
 tests/python/relax/test_transform_legalize_ops.py  | 97 ++++++++++++++++++++++
 4 files changed, 126 insertions(+), 11 deletions(-)

diff --git a/python/tvm/relax/transform/legalize_ops/common.py 
b/python/tvm/relax/transform/legalize_ops/common.py
index 85d7fba85c..4407b3fdf3 100644
--- a/python/tvm/relax/transform/legalize_ops/common.py
+++ b/python/tvm/relax/transform/legalize_ops/common.py
@@ -19,6 +19,7 @@ from typing import Callable, Optional, Union
 
 import tvm
 from tvm import te
+from tvm.tir import FloatImm, IntImm
 from ...block_builder import BlockBuilder
 from ...expr import Call, Expr, Constant
 
@@ -39,11 +40,17 @@ LegalizeFunc = Callable[[BlockBuilder, Call], Expr]
 ##################### Utilities #####################
 
 
-def _try_convert_to_scalar_const(expr: Expr) -> Union[Expr, bool, float, int]:
+def _try_convert_to_scalar_const(
+    expr: Expr, python_native: bool = False
+) -> Union[Expr, FloatImm, IntImm, bool, float, int]:
     """Check if the input Expr is a scalar constant.
-    If it is, return its plain value.
+    If it is, return its plain value with the same data type or in native 
python type.
     If it is not, return the input expr.
 
+    Note that if the python_native flag is True, the returned value will be in 
native python type,
+    this might cause loss of data type for example, a float16 constant will be 
converted to float32
+    and a int64 constant will be converted to int32.
+
     Parameters
     ----------
     expr : Expr
@@ -51,15 +58,24 @@ def _try_convert_to_scalar_const(expr: Expr) -> Union[Expr, 
bool, float, int]:
 
     Returns
     --–----
-    ret : Union[Expr, bool, float, int]
-        Return a Python native value (int/float/bool) if the given
-        expr is a scalar constant. Or return the input itself
-        if it is not.
+    ret : Union[Expr, FloatImm, IntImm, bool, float, int]
+        Return a FloatImm or IntImm if the given expr is a scalar integer or 
float constant, and the
+        python native flag is False. Or return the plain value of the constant 
in native python type
+        if the python native flag is True.
+        Or return the input itself if it is not a scalar constant.
     """
     if isinstance(expr, Constant) and expr.struct_info.ndim == 0:
-        return expr.data.numpy()[()].item()
-    else:
-        return expr
+        # get the value of the scalar constant
+        value = expr.data.numpy()[()].item()
+        dtype = expr.struct_info.dtype
+        if python_native:
+            return value
+        # preserve the data type of the constant
+        if dtype.startswith("float"):
+            return tvm.tir.FloatImm(dtype, value)
+        elif dtype.startswith("int") or dtype.startswith("uint") or 
dtype.startswith("bool"):
+            return tvm.tir.IntImm(dtype, value)
+    return expr
 
 
 def _call_topi_without_attr(te_func: TEFunc, primfunc_name: Optional[str] = 
None) -> LegalizeFunc:
diff --git a/python/tvm/relax/transform/legalize_ops/creation.py 
b/python/tvm/relax/transform/legalize_ops/creation.py
index 38ce8427b7..76548fcfb4 100644
--- a/python/tvm/relax/transform/legalize_ops/creation.py
+++ b/python/tvm/relax/transform/legalize_ops/creation.py
@@ -27,7 +27,9 @@ from .common import LegalizeFunc, register_legalize, 
_try_convert_to_scalar_cons
 def _full(is_like: bool, fill_value: Optional[float], primfunc_name: str) -> 
LegalizeFunc:
     def full_call_te(bb: BlockBuilder, call: Call) -> Expr:
         _fill_value = (
-            _try_convert_to_scalar_const(call.args[1]) if fill_value is None 
else fill_value
+            _try_convert_to_scalar_const(call.args[1], python_native=True)
+            if fill_value is None
+            else fill_value
         )
 
         return bb.call_te(
diff --git a/python/tvm/relax/transform/legalize_ops/datatype.py 
b/python/tvm/relax/transform/legalize_ops/datatype.py
index a71e8ca15e..8e1d885775 100644
--- a/python/tvm/relax/transform/legalize_ops/datatype.py
+++ b/python/tvm/relax/transform/legalize_ops/datatype.py
@@ -24,7 +24,7 @@ from .common import _try_convert_to_scalar_const, 
register_legalize
 
 @register_legalize("relax.astype")
 def _astype(bb: BlockBuilder, call: Call) -> Expr:
-    arg = _try_convert_to_scalar_const(call.args[0])
+    arg = _try_convert_to_scalar_const(call.args[0], python_native=True)
     if isinstance(arg, Expr):  # type: ignore
         return bb.call_te(topi.cast, arg, call.attrs.dtype)
     else:
diff --git a/tests/python/relax/test_transform_legalize_ops.py 
b/tests/python/relax/test_transform_legalize_ops.py
index 91f8cb4259..6a658f2a4f 100644
--- a/tests/python/relax/test_transform_legalize_ops.py
+++ b/tests/python/relax/test_transform_legalize_ops.py
@@ -156,5 +156,102 @@ def test_can_not_legalize():
     tvm.ir.assert_structural_equal(After1, Before1)
 
 
+def test_legalize_scalar_data_type_preserve():
+    # fmt: off
+    @tvm.script.ir_module
+    class Before0:
+        @R.function
+        def main(x: R.Tensor((3, 3), "float16")):
+            gv: R.Tensor((3, 3), "float16") = R.multiply(x, R.const(1.14514, 
"float16"))
+            return gv
+
+    @tvm.script.ir_module
+    class Before1:
+        @R.function
+        def main(x: R.Tensor((3, 3), "uint8")):
+            gv: R.Tensor((3, 3), "uint8") = R.multiply(x, R.const(2, "uint8"))
+            return gv
+
+    @tvm.script.ir_module
+    class Before2:
+        @R.function
+        def main(x: R.Tensor((3, 3), "bool")):
+            gv: R.Tensor((3, 3), "bool") = R.equal(x, R.const(True, "bool"))
+            return gv
+
+    @tvm.script.ir_module
+    class Expected0:
+        @T.prim_func
+        def multiply(
+            rxplaceholder: T.Buffer((T.int64(3), T.int64(3)), "float16"),
+            T_multiply: T.Buffer((T.int64(3), T.int64(3)), "float16"),
+        ):
+            T.func_attr({"tir.noalias": True})
+            # with T.block("root"):
+            for ax0, ax1 in T.grid(T.int64(3), T.int64(3)):
+                with T.block("T_multiply"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(rxplaceholder[v_ax0, v_ax1])
+                    T.writes(T_multiply[v_ax0, v_ax1])
+                    T_multiply[v_ax0, v_ax1] = rxplaceholder[v_ax0, v_ax1] * 
T.float16(
+                        1.1455078125
+                    )
+
+        @R.function
+        def main(x: R.Tensor((3, 3), dtype="float16")) -> R.Tensor((3, 3), 
dtype="float16"):
+            gv = R.call_tir(multiply, (x,), out_sinfo=R.Tensor((3, 3), 
dtype="float16"))
+            return gv
+
+    @tvm.script.ir_module
+    class Expected1:
+        @T.prim_func
+        def multiply(
+            rxplaceholder: T.Buffer((T.int64(3), T.int64(3)), "uint8"),
+            T_multiply: T.Buffer((T.int64(3), T.int64(3)), "uint8"),
+        ):
+            T.func_attr({"tir.noalias": True})
+            # with T.block("root"):
+            for ax0, ax1 in T.grid(T.int64(3), T.int64(3)):
+                with T.block("T_multiply"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(rxplaceholder[v_ax0, v_ax1])
+                    T.writes(T_multiply[v_ax0, v_ax1])
+                    T_multiply[v_ax0, v_ax1] = rxplaceholder[v_ax0, v_ax1] * 
T.uint8(2)
+
+        @R.function
+        def main(x: R.Tensor((3, 3), dtype="uint8")) -> R.Tensor((3, 3), 
dtype="uint8"):
+            gv = R.call_tir(multiply, (x,), out_sinfo=R.Tensor((3, 3), 
dtype="uint8"))
+            return gv
+
+    @tvm.script.ir_module
+    class Expected2:
+        @T.prim_func
+        def equal(
+            rxplaceholder: T.Buffer((T.int64(3), T.int64(3)), "bool"),
+            T_equal: T.Buffer((T.int64(3), T.int64(3)), "bool"),
+        ):
+            T.func_attr({"tir.noalias": True})
+            # with T.block("root"):
+            for ax0, ax1 in T.grid(T.int64(3), T.int64(3)):
+                with T.block("T_equal"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(rxplaceholder[v_ax0, v_ax1])
+                    T.writes(T_equal[v_ax0, v_ax1])
+                    T_equal[v_ax0, v_ax1] = rxplaceholder[v_ax0, v_ax1] == 
tvm.tir.const(True, "bool")
+
+        @R.function
+        def main(x: R.Tensor((3, 3), dtype="bool")) -> R.Tensor((3, 3), 
dtype="bool"):
+            gv = R.call_tir(equal, (x,), out_sinfo=R.Tensor((3, 3), 
dtype="bool"))
+            return gv
+    # fmt: on
+
+    After0 = LegalizeOps()(Before0)
+    tvm.ir.assert_structural_equal(After0, Expected0)
+    After1 = LegalizeOps()(Before1)
+    tvm.ir.assert_structural_equal(After1, Expected1)
+    After2 = LegalizeOps()(Before2)
+    tvm.ir.assert_structural_equal(After2, Expected2)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to