This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new db01567b09 [Unity][Op] introduce `shape_to_tensor` op (#14447)
db01567b09 is described below
commit db01567b09070c41ab7868c7037b81ae828a26a1
Author: Sunghyun Park <[email protected]>
AuthorDate: Mon Apr 3 10:44:52 2023 -0700
[Unity][Op] introduce `shape_to_tensor` op (#14447)
* feat: introduce `shape_to_tensor` op
* lint
* lint
* reflect feedback
---
python/tvm/relax/op/base.py | 28 +++++++++++++++
python/tvm/script/ir_builder/relax/ir.py | 2 ++
src/relax/op/op.cc | 23 +++++++++++++
src/relax/op/op_common.cc | 2 +-
src/relax/transform/fold_constant.cc | 19 ++++++++++
tests/python/relax/test_relax_operators.py | 40 +++++++++++++++++++++-
tests/python/relax/test_transform_fold_constant.py | 34 ++++++++++++++++--
tests/python/relax/test_tvmscript_parser.py | 15 ++++++++
8 files changed, 159 insertions(+), 4 deletions(-)
diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py
index becd3f2a0f..d6e8b29b6d 100644
--- a/python/tvm/relax/op/base.py
+++ b/python/tvm/relax/op/base.py
@@ -19,6 +19,7 @@ from typing import Union, List, Tuple, Optional
import tvm
+import tvm.runtime
from tvm.runtime.object import Object
from . import _ffi_api
@@ -253,6 +254,19 @@ def render_object(val: tvm.Object) -> str:
return str(val)
[email protected]_func("relax.run.shape_to_tensor")
+def relax_shape_to_tensor(shape_tuple: tvm.runtime.ShapeTuple) ->
tvm.nd.NDArray:
+ """
+ Takes a ShapeTuple and convert it to NDArray.
+
+ Parameters
+ ----------
+ shape_tuple: tvm.runtime.ShapeTuple
+ Shape tuple that we want to convert to NDArray at runtime
+ """
+ return tvm.nd.array([int(v) for v in shape_tuple])
+
+
@tvm.register_func("relax.run.print")
def relax_print(format_str: str, *format_args: tvm.Object) -> None:
"""
@@ -416,3 +430,17 @@ def tensor_to_shape(expr: Expr) -> Expr:
A relax Call, which transforms the tensor values to the shape
"""
return _ffi_api.tensor_to_shape(expr) # type: ignore # pylint:
disable=no-member
+
+
+def shape_to_tensor(expr: Expr) -> Expr:
+ """Convert shape to tensor expr.
+ Parameters
+ ----------
+ expr : Expr
+ The input Expr
+ Returns
+ -------
+ result : Expr
+ A relax Call, which transforms the shape values to the tensor
+ """
+ return _ffi_api.shape_to_tensor(expr) # type: ignore # pylint:
disable=no-member
diff --git a/python/tvm/script/ir_builder/relax/ir.py
b/python/tvm/script/ir_builder/relax/ir.py
index f9104c9430..2f8a37a4e1 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -100,6 +100,7 @@ from tvm.relax.op import (
repeat,
reshape,
tensor_to_shape,
+ shape_to_tensor,
round,
shape_of,
std,
@@ -619,6 +620,7 @@ __all__ = [
"repeat",
"reshape",
"tensor_to_shape",
+ "shape_to_tensor",
"round",
"shape",
"shape_of",
diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc
index b353cce27a..c641c45922 100644
--- a/src/relax/op/op.cc
+++ b/src/relax/op/op.cc
@@ -341,6 +341,29 @@ Expr MakeTensorToShape(Expr expr) {
TVM_REGISTER_GLOBAL("relax.op.tensor_to_shape").set_body_typed(MakeTensorToShape);
+// shape_to_tensor
+StructInfo ReturnShapeToTensorStructInfo(const Call& call, const BlockBuilder&
ctx) {
+ ICHECK(call->args.size() == 1);
+ ICHECK(call->args[0]->struct_info_.defined());
+ const auto* sinfo = GetStructInfoAs<ShapeStructInfoNode>(call->args[0]);
+ ICHECK(sinfo);
+ int32_t ndim = sinfo->ndim;
+ return TensorStructInfo(ShapeExpr({PrimExpr(ndim)}), DataType::Int(64));
+}
+
+RELAY_REGISTER_OP("relax.shape_to_tensor")
+ .set_num_inputs(1)
+ .add_argument("input", "Expr", "The input expression")
+ .set_attr<FInferStructInfo>("FInferStructInfo",
ReturnShapeToTensorStructInfo)
+ .set_attr<FCallPacked>("FCallPacked", "relax.run.shape_to_tensor");
+
+Expr MakeShapeToTensor(Expr expr) {
+ static const Op& op = Op::Get("relax.shape_to_tensor");
+ return Call(op, {expr}, {}, {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.shape_to_tensor").set_body_typed(MakeShapeToTensor);
+
// alloc_tensor
StructInfo InferStructInfoAllocateTensor(const Call& call, const BlockBuilder&
ctx) {
diff --git a/src/relax/op/op_common.cc b/src/relax/op/op_common.cc
index 0421c957ca..0997a3623e 100644
--- a/src/relax/op/op_common.cc
+++ b/src/relax/op/op_common.cc
@@ -38,7 +38,7 @@ Array<TensorStructInfo> GetInputTensorStructInfo(const Call&
call, const BlockBu
if (sinfo == nullptr) {
ctx->ReportFatal(Diagnostic::Error(call)
<< op << " requires the input " <<
op->arguments[i]->name
- << " to be Tensor. However, the given one is "
+ << " to be Tensor. However, the given one has a "
<< call->args[i]->struct_info_->GetTypeKey());
}
input_tensor_sinfo.push_back(GetRef<TensorStructInfo>(sinfo));
diff --git a/src/relax/transform/fold_constant.cc
b/src/relax/transform/fold_constant.cc
index 315b3dc1f2..db30900cd2 100644
--- a/src/relax/transform/fold_constant.cc
+++ b/src/relax/transform/fold_constant.cc
@@ -279,6 +279,25 @@ class ConstantFolder : public ExprMutator {
}
return ShapeExpr(shape_values);
}
+ } else if (op->name == "relax.shape_to_tensor") {
+ // Special handling for "relax.shape_to_tensor" since it is
implemented in PackedFunc.
+ // TODO(sunggg): revisit this when we extend ConstantFolding to fold
PackedFunc.
+ Expr arg = post_call->args[0];
+ ShapeExpr shape = Downcast<ShapeExpr>(arg);
+ Array<PrimExpr> values = shape->values;
+ Array<Integer> arr;
+ bool is_known = true;
+ for (size_t i = 0; i < values.size(); i++) {
+ PrimExpr val = values[i];
+ arr.push_back(GetRef<IntImm>(val.as<IntImmNode>()));
+ is_known &= (val.dtype() == DataType::Int(64));
+ }
+ if (is_known) {
+ const auto* func =
tvm::runtime::Registry::Get("relax.run.shape_to_tensor");
+ ICHECK(func != nullptr);
+ runtime::NDArray vals = (*func)(arr);
+ return Constant(vals);
+ }
}
}
diff --git a/tests/python/relax/test_relax_operators.py
b/tests/python/relax/test_relax_operators.py
index f197eaa9ab..776abbce76 100644
--- a/tests/python/relax/test_relax_operators.py
+++ b/tests/python/relax/test_relax_operators.py
@@ -23,7 +23,7 @@ import tvm
import tvm.testing
from tvm import relax
from tvm._ffi.base import TVMError
-from tvm.script import relax as R
+from tvm.script import relax as R, tir as T
@tvm.script.ir_module
@@ -193,5 +193,43 @@ def test_op_shape_of():
assert constrained_shape == tvm.runtime.ShapeTuple([1])
[email protected]_module
+class ShapeToTensorTest:
+ @R.function
+ def const_shape(shape: R.Shape(ndim=-1)) -> R.Tensor(ndim=-1):
+ return R.shape_to_tensor(shape)
+
+ @R.function
+ def symbolic_shape(shape: R.Shape(("m", "n"))) -> R.Tensor(ndim=-1):
+ m = T.int64()
+ n = T.int64()
+ return R.shape_to_tensor(shape)
+
+
+def test_op_shape_to_tensor():
+ # Check struct info
+ isinstance(ShapeToTensorTest["const_shape"].body.struct_info,
tvm.relax.TensorStructInfo)
+ assert ShapeToTensorTest["const_shape"].body.struct_info.ndim == 1
+ isinstance(ShapeToTensorTest["symbolic_shape"].body.struct_info,
tvm.relax.TensorStructInfo)
+ assert ShapeToTensorTest["symbolic_shape"].body.struct_info.ndim == 1
+
+ # Check its functionality
+ out2d = run_cpu(ShapeToTensorTest, "const_shape",
tvm.runtime.ShapeTuple([3, 2]))
+ assert isinstance(out2d, tvm.runtime.ndarray.NDArray)
+ assert np.array_equal(out2d.numpy(), np.array([3, 2]))
+
+ out3d = run_cpu(ShapeToTensorTest, "const_shape",
tvm.runtime.ShapeTuple([3, 3, 2]))
+ assert isinstance(out3d, tvm.runtime.ndarray.NDArray)
+ assert np.array_equal(out3d.numpy(), np.array([3, 3, 2]))
+
+ out4d = run_cpu(ShapeToTensorTest, "const_shape",
tvm.runtime.ShapeTuple([3, 3, 2, 2]))
+ assert isinstance(out4d, tvm.runtime.ndarray.NDArray)
+ assert np.array_equal(out4d.numpy(), np.array([3, 3, 2, 2]))
+
+ outs = run_cpu(ShapeToTensorTest, "symbolic_shape",
tvm.runtime.ShapeTuple([3, 2]))
+ assert isinstance(outs, tvm.runtime.ndarray.NDArray)
+ assert np.array_equal(outs.numpy(), np.array([3, 2]))
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/relax/test_transform_fold_constant.py
b/tests/python/relax/test_transform_fold_constant.py
index ebd4348b64..b8ad5c4487 100644
--- a/tests/python/relax/test_transform_fold_constant.py
+++ b/tests/python/relax/test_transform_fold_constant.py
@@ -20,7 +20,7 @@ from tvm import relax
import numpy as np
import tvm.script
-from tvm.script import tir as T, relax as R
+from tvm.script import ir as I, tir as T, relax as R
def gen_mod(mod, name, binding):
@@ -368,7 +368,6 @@ def
test_fold_multiple_relax_ops_with_data_dependent_reshape():
@R.function
def expected(data: R.Tensor((256,), "float32")) -> R.Tensor((16, 16),
dtype="float32"):
- R.func_attr({"global_symbol": "main"})
with R.dataflow():
gv: R.Tensor((16, 16), dtype="float32") = R.reshape(data,
R.shape([16, 16]))
R.output(gv)
@@ -420,5 +419,36 @@ def
test_unsupported_fold_ops_legalized_to_multiple_calls():
register_legalize("relax.nn.relu", relu_legalize)
+def test_fold_shape_computation():
+ @I.ir_module
+ class Module:
+ @R.function
+ def before(
+ data: R.Tensor((5, 4, 3, 2), dtype="float32"),
+ indices: R.Tensor((1,), dtype="int64"),
+ ) -> R.Tensor((1, 1), dtype="int64"):
+ with R.dataflow():
+ lv: R.Tensor((4,), dtype="int64") =
R.shape_to_tensor(R.shape([5, 4, 3, 2]))
+ lv1: R.Tensor((1,), dtype="int64") = R.take(lv, indices,
axis=0)
+ lv2: R.Tensor((1, 1), dtype="int64") = R.expand_dims(lv1,
axis=[0])
+ gv: R.Tensor((1, 1), dtype="int64") = R.concat((lv2,), axis=0)
+ R.output(gv)
+ return gv
+
+ @R.function
+ def expected(
+ data: R.Tensor((5, 4, 3, 2), dtype="float32"), new_shape:
R.Tensor((1, 1), "int64")
+ ) -> R.Tensor((1, 1), dtype="int64"):
+ return new_shape
+
+ before = gen_mod(Module, "before", {"indices":
tvm.nd.array(np.array([0]).astype("int64"))})
+ after = relax.transform.FoldConstant()(before)
+ np_take = np.take([5, 4, 3, 2], [0], axis=0)
+ np_expand = np.expand_dims(np_take, axis=[0])
+ np_concat = np.concatenate([np_expand], axis=0)
+ expected = gen_mod(Module, "expected", {"new_shape":
tvm.nd.array(np_concat)})
+ tvm.ir.assert_structural_equal(after, expected)
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/relax/test_tvmscript_parser.py
b/tests/python/relax/test_tvmscript_parser.py
index 0e0905ffbc..9b8865b943 100644
--- a/tests/python/relax/test_tvmscript_parser.py
+++ b/tests/python/relax/test_tvmscript_parser.py
@@ -337,6 +337,21 @@ def test_relax_base_op():
_check(foo, bb.get()["foo"])
+def test_relax_shape_to_tensor():
+ @R.function
+ def foo(x: R.Shape((4, 4))):
+ tensor = R.shape_to_tensor(x)
+ return tensor
+
+ x = relax.Var("x", R.Shape((4, 4)))
+ bb = relax.BlockBuilder()
+ with bb.function("foo", (x,)):
+ tensor = bb.emit(relax.op.shape_to_tensor(x))
+ bb.emit_func_output(tensor)
+
+ _check(foo, bb.get()["foo"])
+
+
def test_symbolic_shape():
@R.function
def foo(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"),
"float32"):