This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 1bf0f1b61e [Unity][TVMScript] Use explicit `R.shape` in TVMScript
(#13979)
1bf0f1b61e is described below
commit 1bf0f1b61e79215b100575fbc583db11e9c86710
Author: Siyuan Feng <[email protected]>
AuthorDate: Tue Feb 14 10:42:01 2023 +0800
[Unity][TVMScript] Use explicit `R.shape` in TVMScript (#13979)
As we've introduced `arg_sinfo` in CallNode, implicit shape constructor
is not widely used in TVMScript. This PR removes the implicit shape since
it may cause confusion between shape and tuple.
---
python/tvm/relax/utils.py | 16 ++--------
python/tvm/script/ir_builder/relax/ir.py | 18 +++++++++++
python/tvm/script/parser/relax/entry.py | 22 ++++++++++---
src/script/printer/relax/expr.cc | 2 +-
src/script/printer/relax/struct_info.cc | 14 ++++++++-
.../relax/test_backend_transform_shape_lower.py | 2 +-
tests/python/relax/test_transform.py | 2 +-
tests/python/relax/test_tvmscript_parser.py | 36 ++++++++++++++++------
tests/python/relax/test_tvmscript_printer_relax.py | 4 +--
tests/python/relax/test_vm_build.py | 6 ++--
tests/python/relax/test_vm_codegen_only.py | 14 +++++----
11 files changed, 93 insertions(+), 43 deletions(-)
diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py
index 5bfb0d87bf..0bb82c79f4 100644
--- a/python/tvm/relax/utils.py
+++ b/python/tvm/relax/utils.py
@@ -23,7 +23,7 @@ from .. import tir
from ..runtime import String, convert_to_object
from ..tir import PrimExpr
from . import _ffi_api
-from .expr import Expr, Function, PrimValue, ShapeExpr, StringImm
+from .expr import Expr, Function, PrimValue, StringImm
from .expr import Tuple as rx_Tuple
@@ -74,14 +74,12 @@ def convert_to_expr(value: Any) -> Expr:
1. Return the input itself if it's already a `relax.Expr`;
2. Return `relax.PrimValue` if the input is a `PrimExpr`;
3. Return `relax.StringImm` if the input is `tvm.String` or `str`;
- 4. Return `relax.ShapeExpr` if the input is a tuple/list of `PrimExpr` w/
int dtype;
- 5. Return `relax.Tuple` if the input is a tuple/list of `Expr`.
+ 4. Return `relax.Tuple` if the input is a tuple/list of `Expr`.
Notes
-----
1. `tvm.tir.StringImm` is not allowed because of ambiguity,
which can be either `relax.StringImm` or `relax.PrimValue`.
- 2. We regard empty tuple/list as `relax.Tuple` instead of `relax.ShapeExpr`
"""
if isinstance(value, int):
return PrimValue(tir.IntImm("int64", value))
@@ -102,16 +100,8 @@ def convert_to_expr(value: Any) -> Expr:
# Case 3
if isinstance(tvm_value, String):
return StringImm(value)
- # Case 4 & 5
+ # Case 4
if isinstance(value, (tuple, list)):
- # Note 2
- if len(value) == 0:
- return rx_Tuple([])
- # Case 4
- opt_prim_value = [convert_to_object(v) for v in value]
- if all([isinstance(v, PrimExpr) and v.dtype.startswith("int") for v in
opt_prim_value]):
- return ShapeExpr(value)
- # Case 5
# `convert_to_expr` ensures that all elements are `Expr` if no
exception raises
return rx_Tuple([convert_to_expr(v) for v in value])
raise TypeError(f"Cannot convert {value} with type {type(value)} to
`relax.Expr`")
diff --git a/python/tvm/script/ir_builder/relax/ir.py
b/python/tvm/script/ir_builder/relax/ir.py
index 0692ec5683..0e6595cb45 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -329,6 +329,23 @@ def tuple(*fields: Expr) -> Expr:
return relax.Tuple(fields) # type: ignore[attr-defined] # pylint:
disable=no-member
+############################### R.shape ################################
+
+
+def shape(value: List[PrimExpr]) -> Expr:
+ """Create a ShapeExpr.
+ Parameters
+ ----------
+ value : List[PrimExpr]
+ The fields of the tuple.
+ Returns
+ -------
+ res : Expr
+ The result tuple.
+ """
+ return relax.ShapeExpr(value) # pylint: disable=no-member # type: ignore
+
+
############################### PrimValue ##############################
@@ -407,6 +424,7 @@ __all__ = [
"prim_value",
"print",
"reshape",
+ "shape",
"shape_of",
"str",
"tuple",
diff --git a/python/tvm/script/parser/relax/entry.py
b/python/tvm/script/parser/relax/entry.py
index d93f9a2826..7e51264cb3 100644
--- a/python/tvm/script/parser/relax/entry.py
+++ b/python/tvm/script/parser/relax/entry.py
@@ -22,6 +22,7 @@ from typing import Dict, List, Optional, Set, TypeVar, Union
from tvm.relax import (
Expr,
+ ShapeExpr,
FuncStructInfo,
Function,
ObjectStructInfo,
@@ -84,17 +85,22 @@ class TensorProxy(StructInfoProxy):
def __init__(
self,
- shape: Optional[List[Union[PrimExpr, str]]] = None,
+ shape: Optional[Union[List[Union[PrimExpr, str]], Expr]] = None,
dtype: Optional[str] = None,
ndim: int = -1,
) -> None:
self.shape = shape
+ if isinstance(shape, Expr) and not isinstance(shape, ShapeExpr):
+ raise ValueError(
+ "Only ShapeExpr is allowed as shape expr, but got: "
+ f"{shape} with type: {type(shape)}"
+ )
self.dtype = dtype
self.ndim = ndim
super().__init__()
def get_symbolic_vars(self) -> Set[str]:
- if self.shape is None:
+ if self.shape is None or isinstance(self.shape, Expr):
return {}
else:
return {s for s in self.shape if isinstance(s, str) and
s.isidentifier()}
@@ -102,6 +108,8 @@ class TensorProxy(StructInfoProxy):
def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) ->
TensorStructInfo:
if self.shape is None:
return TensorStructInfo(None, self.dtype, self.ndim)
+ elif isinstance(self.shape, ShapeExpr):
+ return TensorStructInfo(self.shape, self.dtype, self.ndim)
else:
if dict_globals is None and any([isinstance(s, str) for s in
self.shape]):
raise ValueError(
@@ -113,7 +121,7 @@ class TensorProxy(StructInfoProxy):
def Tensor(
- shape: Optional[List[Union[PrimExpr, str]]] = None,
+ shape: Optional[Union[List[Union[PrimExpr, str]], ShapeExpr]] = None,
dtype: Optional[str] = None,
ndim: int = -1,
) -> TensorProxy:
@@ -124,8 +132,12 @@ def Tensor(
dtype = shape
shape = None
- if shape is not None and not isinstance(shape, (tuple, list)):
- raise ValueError(f"shape must be a list or tuple, but got: {shape}")
+ if (
+ shape is not None
+ and not isinstance(shape, (tuple, list))
+ and not isinstance(shape, ShapeExpr)
+ ):
+ raise ValueError(f"shape must be a list/tuple or a ShapeExpr, but got:
{shape}")
return TensorProxy(shape, dtype, ndim)
diff --git a/src/script/printer/relax/expr.cc b/src/script/printer/relax/expr.cc
index a786932fc3..66d7d187d0 100644
--- a/src/script/printer/relax/expr.cc
+++ b/src/script/printer/relax/expr.cc
@@ -71,7 +71,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
for (int i = 0, l = n->values.size(); i < l; ++i) {
values_doc.push_back(PrintShapeVar(n->values[i],
values_p->ArrayIndex(i), d));
}
- return TupleDoc(values_doc);
+ return Relax(d, "shape")->Call({ListDoc(values_doc)});
});
Optional<ExprDoc> SpecialScalar(const runtime::NDArray& n, const ObjectPath&
p) {
diff --git a/src/script/printer/relax/struct_info.cc
b/src/script/printer/relax/struct_info.cc
index 6f4a66c991..c541619ec8 100644
--- a/src/script/printer/relax/struct_info.cc
+++ b/src/script/printer/relax/struct_info.cc
@@ -89,7 +89,19 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
Array<String> kwargs_keys;
Array<ExprDoc> kwargs_values;
if (n->shape.defined()) {
- args.push_back(d->AsDoc<ExprDoc>(n->shape.value(),
n_p->Attr("shape")));
+ // Need to dig into ShapeExpr to preserve the `R.shape` prefix
+ if (const auto* shape =
n->shape.value().as<relax::ShapeExprNode>()) {
+ auto shape_expr = GetRef<relax::ShapeExpr>(shape);
+ ObjectPath shape_p = n_p->Attr("shape")->Attr("values");
+ Array<ExprDoc> shape_docs;
+ for (int i = 0, ndim = shape_expr->values.size(); i < ndim; ++i)
{
+ shape_docs.push_back(
+ PrintShapeVar(shape_expr->values[i],
shape_p->ArrayIndex(i), d));
+ }
+ args.push_back(TupleDoc(shape_docs));
+ } else {
+ args.push_back(d->AsDoc<ExprDoc>(n->shape.value(),
n_p->Attr("shape")));
+ }
}
if (!n->IsUnknownDtype()) {
kwargs_keys.push_back("dtype");
diff --git a/tests/python/relax/test_backend_transform_shape_lower.py
b/tests/python/relax/test_backend_transform_shape_lower.py
index 0bf0f175dd..5cd104dd01 100644
--- a/tests/python/relax/test_backend_transform_shape_lower.py
+++ b/tests/python/relax/test_backend_transform_shape_lower.py
@@ -167,7 +167,7 @@ def test_symbolic_compute():
n = T.Var("n", "int64")
k = T.Var("k", "int64")
z = R.match_cast(y, R.Tensor([k, m, k + 1], dtype=None))
- return (k + 1, m, 2)
+ return R.shape([k + 1, m, 2])
# slot assignment:
# 0: n, 1: m, 2:k, 3: k+1
diff --git a/tests/python/relax/test_transform.py
b/tests/python/relax/test_transform.py
index 624b7877cd..12dd095c6b 100644
--- a/tests/python/relax/test_transform.py
+++ b/tests/python/relax/test_transform.py
@@ -109,7 +109,7 @@ def test_vm_builtin_lower():
@R.function
def foo(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor:
m, n = T.var("int64"), T.var("int64")
- alloc = R.builtin.alloc_tensor((m, n), runtime_device_index=0,
dtype="float32")
+ alloc = R.builtin.alloc_tensor(R.shape([m, n]),
runtime_device_index=0, dtype="float32")
_ = R.call_packed(
"test.op.identity", x, alloc, sinfo_args=(R.Tensor(ndim=2,
dtype="float32"))
)
diff --git a/tests/python/relax/test_tvmscript_parser.py
b/tests/python/relax/test_tvmscript_parser.py
index 34b02fdbb8..c9a16fbcac 100644
--- a/tests/python/relax/test_tvmscript_parser.py
+++ b/tests/python/relax/test_tvmscript_parser.py
@@ -22,10 +22,9 @@ import tvm
import tvm.script
import tvm.testing
from tvm import IRModule, relax, tir, topi
-from tvm.relax import DynTensorType
-from tvm.script import ir as I
-from tvm.script import relax as R
-from tvm.script import tir as T
+from tvm.script.parser import ir as I
+from tvm.script.parser import relax as R
+from tvm.script.parser import tir as T
def _check(
@@ -202,6 +201,23 @@ def test_relax_tensor_op():
_check(foo, bb.get()["foo"])
+def test_relax_base_op():
+ @R.function
+ def foo(x: R.Tensor((4, 4), "float32")):
+ alloc = R.builtin.alloc_tensor(R.shape([4, 4]),
runtime_device_index=0, dtype="float32")
+ shape = R.shape_of(alloc)
+ return shape
+
+ x = relax.Var("x", R.Tensor((4, 4), "float32"))
+ bb = relax.BlockBuilder()
+ with bb.function("foo", (x,)):
+ alloc = bb.emit(relax.op.builtin.alloc_tensor(relax.ShapeExpr((4, 4)),
"float32", 0))
+ shape = bb.emit(relax.op.shape_of(alloc))
+ bb.emit_func_output(shape)
+ # todo(yongwww): comment this check because 0 was changed to
R.prim_value(0) in the printed IR
+ # _check(foo, bb.get()["foo"])
+
+
def test_symbolic_shape():
@R.function
def foo(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"),
"float32"):
@@ -274,7 +290,7 @@ def test_match_cast():
y0 = R.match_cast(y, R.Tensor([n], "float32"))
gv = y0
R.output(gv)
- return (x0, (m, n * 2))
+ return (x0, R.shape([m, n * 2]))
x = relax.Var("x", R.Tensor("float32"))
y = relax.Var("y", R.Tensor("float32"))
@@ -314,7 +330,7 @@ def test_tuple_return_2():
def foo(x: R.Tensor("float32", ndim=2)):
n, m = T.var("int64"), T.var("int64")
x0 = R.match_cast(x, R.Tensor((n, m), "float32"))
- return (x0, (n + 1, m, 1))
+ return (x0, R.shape([n + 1, m, 1]))
x = relax.Var("x", R.Tensor("float32", ndim=2))
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
@@ -332,7 +348,7 @@ def test_tuple_binding():
n, m = T.var("int64"), T.var("int64")
x0 = R.match_cast(x, R.Tensor((n, m), "float32"))
t0 = (x, x0)
- t1 = (x, (n, m), t0)
+ t1 = (x, R.shape([n, m]), t0)
return t1
x = relax.Var("x", R.Tensor("float32", ndim=2))
@@ -965,9 +981,9 @@ def test_vm_ops():
def foo(x: R.Tensor(("m", "n"), dtype="float32")):
m = T.var("int64")
n = T.var("int64")
- storage = R.vm.alloc_storage((4 * m * n,), dtype="float32",
runtime_device_index=0)
- alloc = R.vm.alloc_tensor(storage, (m, n), offset=0, dtype="float32")
- tensor = R.builtin.alloc_tensor((m, n), dtype="float32",
runtime_device_index=0)
+ storage = R.vm.alloc_storage(R.shape([4 * m * n]), dtype="float32",
runtime_device_index=0)
+ alloc = R.vm.alloc_tensor(storage, shape=R.shape([m, n]), offset=0,
dtype="float32")
+ tensor = R.builtin.alloc_tensor(R.shape([m, n]), dtype="float32",
runtime_device_index=0)
_ = R.vm.call_tir_dyn("te_func", (x, tensor, (m, n)))
gv = tensor
return alloc, gv
diff --git a/tests/python/relax/test_tvmscript_printer_relax.py
b/tests/python/relax/test_tvmscript_printer_relax.py
index 58596f968f..db90c66422 100644
--- a/tests/python/relax/test_tvmscript_printer_relax.py
+++ b/tests/python/relax/test_tvmscript_printer_relax.py
@@ -292,7 +292,7 @@ c: R.Tensor((1, z, 3), dtype="float32")
def test_shape_expr():
obj = relax.ShapeExpr([1, 2, 3])
- _assert_print(obj, "(1, 2, 3)")
+ _assert_print(obj, "R.shape([1, 2, 3])")
def test_call():
@@ -304,7 +304,7 @@ def test_call():
"""
x = T.Var("x", "int64")
a: R.Tensor((1, x, 3), dtype="float32")
-R.call_tir("my_func", (a,), out_sinfo=R.Tensor((1, x, 3), dtype="float32"),
tir_vars=(x,))
+R.call_tir("my_func", (a,), out_sinfo=R.Tensor((1, x, 3), dtype="float32"),
tir_vars=R.shape([x]))
""",
)
diff --git a/tests/python/relax/test_vm_build.py
b/tests/python/relax/test_vm_build.py
index 534d2308da..0a881691ac 100644
--- a/tests/python/relax/test_vm_build.py
+++ b/tests/python/relax/test_vm_build.py
@@ -88,7 +88,7 @@ def test_vm_compile_stage2(exec_mode):
def foo(x: R.Tensor(dtype="float32")) -> R.Shape:
n, m = T.var("int64"), T.var("int64")
_ = R.match_cast(x, R.Tensor((n, m), "float32"))
- return (n * 2, m * 3)
+ return R.shape([n * 2, m * 3])
mod = TestVMCompileStage2
target = tvm.target.Target("llvm", host="llvm")
@@ -511,9 +511,9 @@ def test_lower_memory_alloc_storage_tensor(exec_mode):
@R.function
def main(x: R.Tensor((2, 3), dtype="float32")):
storage = R.memory.alloc_storage(
- (24,), virtual_device_index=0, storage_scope="global",
dtype="float32"
+ R.shape([24]), virtual_device_index=0, storage_scope="global",
dtype="float32"
)
- y = R.memory.alloc_tensor(storage, 0, (2, 3), dtype="float32")
+ y = R.memory.alloc_tensor(storage, 0, R.shape([2, 3]),
dtype="float32")
_ = copy(x, y)
return y
diff --git a/tests/python/relax/test_vm_codegen_only.py
b/tests/python/relax/test_vm_codegen_only.py
index b5e7709177..4b79ecf70f 100644
--- a/tests/python/relax/test_vm_codegen_only.py
+++ b/tests/python/relax/test_vm_codegen_only.py
@@ -18,13 +18,15 @@
Restrictions: all shape lowered, explicit allocation.
"""
-import tvm
-import pytest
import numpy as np
-from tvm import relax, TVMError
-from tvm.script import relax as R, tir as T
+import pytest
+import tvm
+import tvm.testing
+from tvm import relax
+from tvm.relax.testing.runtime_builtin import MakeShapeCode, MatchShapeCode
from tvm.relax.testing.vm import check_saved_func
-from tvm.relax.testing.runtime_builtin import MatchShapeCode, MakeShapeCode
+from tvm.script import relax as R
+from tvm.script import tir as T
EXEC_MODE = ["bytecode"]
@@ -312,7 +314,7 @@ def test_vm_builtin_reshape(exec_mode):
def main(x: R.Tensor((3, 4), "float32")):
R.func_attr({"global_symbol": "main"})
y = R.call_packed(
- "vm.builtin.reshape", x, (6, 2), sinfo_args=R.Tensor((6, 2),
"float32")
+ "vm.builtin.reshape", x, R.shape([6, 2]),
sinfo_args=R.Tensor((6, 2), "float32")
)
return y