This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 3e98901965 [Unity][Legalize] Fix Scalar Constant Legalization (#14127)
3e98901965 is described below
commit 3e989019654ca45c62acb4bdbe4b18ba5a36d0dc
Author: Xiyou Zhou <[email protected]>
AuthorDate: Tue Feb 28 11:24:05 2023 -0800
[Unity][Legalize] Fix Scalar Constant Legalization (#14127)
This PR fixes the issue of loss of data type during Legalization.
Previously, if we use a constant scalar in operators like `multiply`, it will
automatically be converted to a python data type variable, which may lose its
original data type. For example, `float16` may become python `float` and be
interpreted as `float32` later.
This is now fixed by avoiding scalar value conversion. The conversion could
be added back once we have better support for scalar prim value.
Co-authored-by: Sunghyun Park <[email protected]>
Co-authored-by: Wuwei Lin <[email protected]>
---
python/tvm/relax/transform/legalize_ops/common.py | 34 ++++++--
.../tvm/relax/transform/legalize_ops/creation.py | 4 +-
.../tvm/relax/transform/legalize_ops/datatype.py | 2 +-
tests/python/relax/test_transform_legalize_ops.py | 97 ++++++++++++++++++++++
4 files changed, 126 insertions(+), 11 deletions(-)
diff --git a/python/tvm/relax/transform/legalize_ops/common.py
b/python/tvm/relax/transform/legalize_ops/common.py
index 85d7fba85c..4407b3fdf3 100644
--- a/python/tvm/relax/transform/legalize_ops/common.py
+++ b/python/tvm/relax/transform/legalize_ops/common.py
@@ -19,6 +19,7 @@ from typing import Callable, Optional, Union
import tvm
from tvm import te
+from tvm.tir import FloatImm, IntImm
from ...block_builder import BlockBuilder
from ...expr import Call, Expr, Constant
@@ -39,11 +40,17 @@ LegalizeFunc = Callable[[BlockBuilder, Call], Expr]
##################### Utilities #####################
-def _try_convert_to_scalar_const(expr: Expr) -> Union[Expr, bool, float, int]:
+def _try_convert_to_scalar_const(
+ expr: Expr, python_native: bool = False
+) -> Union[Expr, FloatImm, IntImm, bool, float, int]:
"""Check if the input Expr is a scalar constant.
- If it is, return its plain value.
+ If it is, return its plain value with the same data type or in native
python type.
If it is not, return the input expr.
+ Note that if the python_native flag is True, the returned value will be in
native python type,
+ this might cause loss of data type for example, a float16 constant will be
converted to float32
+ and a int64 constant will be converted to int32.
+
Parameters
----------
expr : Expr
@@ -51,15 +58,24 @@ def _try_convert_to_scalar_const(expr: Expr) -> Union[Expr,
bool, float, int]:
Returns
--–----
- ret : Union[Expr, bool, float, int]
- Return a Python native value (int/float/bool) if the given
- expr is a scalar constant. Or return the input itself
- if it is not.
+ ret : Union[Expr, FloatImm, IntImm, bool, float, int]
+ Return a FloatImm or IntImm if the given expr is a scalar integer or
float constant, and the
+ python native flag is False. Or return the plain value of the constant
in native python type
+ if the python native flag is True.
+ Or return the input itself if it is not a scalar constant.
"""
if isinstance(expr, Constant) and expr.struct_info.ndim == 0:
- return expr.data.numpy()[()].item()
- else:
- return expr
+ # get the value of the scalar constant
+ value = expr.data.numpy()[()].item()
+ dtype = expr.struct_info.dtype
+ if python_native:
+ return value
+ # preserve the data type of the constant
+ if dtype.startswith("float"):
+ return tvm.tir.FloatImm(dtype, value)
+ elif dtype.startswith("int") or dtype.startswith("uint") or
dtype.startswith("bool"):
+ return tvm.tir.IntImm(dtype, value)
+ return expr
def _call_topi_without_attr(te_func: TEFunc, primfunc_name: Optional[str] =
None) -> LegalizeFunc:
diff --git a/python/tvm/relax/transform/legalize_ops/creation.py
b/python/tvm/relax/transform/legalize_ops/creation.py
index 38ce8427b7..76548fcfb4 100644
--- a/python/tvm/relax/transform/legalize_ops/creation.py
+++ b/python/tvm/relax/transform/legalize_ops/creation.py
@@ -27,7 +27,9 @@ from .common import LegalizeFunc, register_legalize,
_try_convert_to_scalar_cons
def _full(is_like: bool, fill_value: Optional[float], primfunc_name: str) ->
LegalizeFunc:
def full_call_te(bb: BlockBuilder, call: Call) -> Expr:
_fill_value = (
- _try_convert_to_scalar_const(call.args[1]) if fill_value is None
else fill_value
+ _try_convert_to_scalar_const(call.args[1], python_native=True)
+ if fill_value is None
+ else fill_value
)
return bb.call_te(
diff --git a/python/tvm/relax/transform/legalize_ops/datatype.py
b/python/tvm/relax/transform/legalize_ops/datatype.py
index a71e8ca15e..8e1d885775 100644
--- a/python/tvm/relax/transform/legalize_ops/datatype.py
+++ b/python/tvm/relax/transform/legalize_ops/datatype.py
@@ -24,7 +24,7 @@ from .common import _try_convert_to_scalar_const,
register_legalize
@register_legalize("relax.astype")
def _astype(bb: BlockBuilder, call: Call) -> Expr:
- arg = _try_convert_to_scalar_const(call.args[0])
+ arg = _try_convert_to_scalar_const(call.args[0], python_native=True)
if isinstance(arg, Expr): # type: ignore
return bb.call_te(topi.cast, arg, call.attrs.dtype)
else:
diff --git a/tests/python/relax/test_transform_legalize_ops.py
b/tests/python/relax/test_transform_legalize_ops.py
index 91f8cb4259..6a658f2a4f 100644
--- a/tests/python/relax/test_transform_legalize_ops.py
+++ b/tests/python/relax/test_transform_legalize_ops.py
@@ -156,5 +156,102 @@ def test_can_not_legalize():
tvm.ir.assert_structural_equal(After1, Before1)
+def test_legalize_scalar_data_type_preserve():
+ # fmt: off
+ @tvm.script.ir_module
+ class Before0:
+ @R.function
+ def main(x: R.Tensor((3, 3), "float16")):
+ gv: R.Tensor((3, 3), "float16") = R.multiply(x, R.const(1.14514,
"float16"))
+ return gv
+
+ @tvm.script.ir_module
+ class Before1:
+ @R.function
+ def main(x: R.Tensor((3, 3), "uint8")):
+ gv: R.Tensor((3, 3), "uint8") = R.multiply(x, R.const(2, "uint8"))
+ return gv
+
+ @tvm.script.ir_module
+ class Before2:
+ @R.function
+ def main(x: R.Tensor((3, 3), "bool")):
+ gv: R.Tensor((3, 3), "bool") = R.equal(x, R.const(True, "bool"))
+ return gv
+
+ @tvm.script.ir_module
+ class Expected0:
+ @T.prim_func
+ def multiply(
+ rxplaceholder: T.Buffer((T.int64(3), T.int64(3)), "float16"),
+ T_multiply: T.Buffer((T.int64(3), T.int64(3)), "float16"),
+ ):
+ T.func_attr({"tir.noalias": True})
+ # with T.block("root"):
+ for ax0, ax1 in T.grid(T.int64(3), T.int64(3)):
+ with T.block("T_multiply"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(rxplaceholder[v_ax0, v_ax1])
+ T.writes(T_multiply[v_ax0, v_ax1])
+ T_multiply[v_ax0, v_ax1] = rxplaceholder[v_ax0, v_ax1] *
T.float16(
+ 1.1455078125
+ )
+
+ @R.function
+ def main(x: R.Tensor((3, 3), dtype="float16")) -> R.Tensor((3, 3),
dtype="float16"):
+ gv = R.call_tir(multiply, (x,), out_sinfo=R.Tensor((3, 3),
dtype="float16"))
+ return gv
+
+ @tvm.script.ir_module
+ class Expected1:
+ @T.prim_func
+ def multiply(
+ rxplaceholder: T.Buffer((T.int64(3), T.int64(3)), "uint8"),
+ T_multiply: T.Buffer((T.int64(3), T.int64(3)), "uint8"),
+ ):
+ T.func_attr({"tir.noalias": True})
+ # with T.block("root"):
+ for ax0, ax1 in T.grid(T.int64(3), T.int64(3)):
+ with T.block("T_multiply"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(rxplaceholder[v_ax0, v_ax1])
+ T.writes(T_multiply[v_ax0, v_ax1])
+ T_multiply[v_ax0, v_ax1] = rxplaceholder[v_ax0, v_ax1] *
T.uint8(2)
+
+ @R.function
+ def main(x: R.Tensor((3, 3), dtype="uint8")) -> R.Tensor((3, 3),
dtype="uint8"):
+ gv = R.call_tir(multiply, (x,), out_sinfo=R.Tensor((3, 3),
dtype="uint8"))
+ return gv
+
+ @tvm.script.ir_module
+ class Expected2:
+ @T.prim_func
+ def equal(
+ rxplaceholder: T.Buffer((T.int64(3), T.int64(3)), "bool"),
+ T_equal: T.Buffer((T.int64(3), T.int64(3)), "bool"),
+ ):
+ T.func_attr({"tir.noalias": True})
+ # with T.block("root"):
+ for ax0, ax1 in T.grid(T.int64(3), T.int64(3)):
+ with T.block("T_equal"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(rxplaceholder[v_ax0, v_ax1])
+ T.writes(T_equal[v_ax0, v_ax1])
+ T_equal[v_ax0, v_ax1] = rxplaceholder[v_ax0, v_ax1] ==
tvm.tir.const(True, "bool")
+
+ @R.function
+ def main(x: R.Tensor((3, 3), dtype="bool")) -> R.Tensor((3, 3),
dtype="bool"):
+ gv = R.call_tir(equal, (x,), out_sinfo=R.Tensor((3, 3),
dtype="bool"))
+ return gv
+ # fmt: on
+
+ After0 = LegalizeOps()(Before0)
+ tvm.ir.assert_structural_equal(After0, Expected0)
+ After1 = LegalizeOps()(Before1)
+ tvm.ir.assert_structural_equal(After1, Expected1)
+ After2 = LegalizeOps()(Before2)
+ tvm.ir.assert_structural_equal(After2, Expected2)
+
+
if __name__ == "__main__":
tvm.testing.main()