This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new db01567b09 [Unity][Op] introduce `shape_to_tensor` op (#14447)
db01567b09 is described below

commit db01567b09070c41ab7868c7037b81ae828a26a1
Author: Sunghyun Park <[email protected]>
AuthorDate: Mon Apr 3 10:44:52 2023 -0700

    [Unity][Op] introduce `shape_to_tensor` op (#14447)
    
    * feat: introduce `shape_to_tensor` op
    
    * lint
    
    * lint
    
    * reflect feedback
---
 python/tvm/relax/op/base.py                        | 28 +++++++++++++++
 python/tvm/script/ir_builder/relax/ir.py           |  2 ++
 src/relax/op/op.cc                                 | 23 +++++++++++++
 src/relax/op/op_common.cc                          |  2 +-
 src/relax/transform/fold_constant.cc               | 19 ++++++++++
 tests/python/relax/test_relax_operators.py         | 40 +++++++++++++++++++++-
 tests/python/relax/test_transform_fold_constant.py | 34 ++++++++++++++++--
 tests/python/relax/test_tvmscript_parser.py        | 15 ++++++++
 8 files changed, 159 insertions(+), 4 deletions(-)

diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py
index becd3f2a0f..d6e8b29b6d 100644
--- a/python/tvm/relax/op/base.py
+++ b/python/tvm/relax/op/base.py
@@ -19,6 +19,7 @@ from typing import Union, List, Tuple, Optional
 
 
 import tvm
+import tvm.runtime
 from tvm.runtime.object import Object
 
 from . import _ffi_api
@@ -253,6 +254,19 @@ def render_object(val: tvm.Object) -> str:
     return str(val)
 
 
[email protected]_func("relax.run.shape_to_tensor")
+def relax_shape_to_tensor(shape_tuple: tvm.runtime.ShapeTuple) -> 
tvm.nd.NDArray:
+    """
+    Takes a ShapeTuple and convert it to NDArray.
+
+    Parameters
+    ----------
+    shape_tuple: tvm.runtime.ShapeTuple
+        Shape tuple that we want to convert to NDArray at runtime
+    """
+    return tvm.nd.array([int(v) for v in shape_tuple])
+
+
 @tvm.register_func("relax.run.print")
 def relax_print(format_str: str, *format_args: tvm.Object) -> None:
     """
@@ -416,3 +430,17 @@ def tensor_to_shape(expr: Expr) -> Expr:
         A relax Call, which transforms the tensor values to the shape
     """
     return _ffi_api.tensor_to_shape(expr)  # type: ignore # pylint: 
disable=no-member
+
+
+def shape_to_tensor(expr: Expr) -> Expr:
+    """Convert shape to tensor expr.
+    Parameters
+    ----------
+    expr : Expr
+        The input Expr
+    Returns
+    -------
+    result : Expr
+        A relax Call, which transforms the shape values to the tensor
+    """
+    return _ffi_api.shape_to_tensor(expr)  # type: ignore # pylint: 
disable=no-member
diff --git a/python/tvm/script/ir_builder/relax/ir.py 
b/python/tvm/script/ir_builder/relax/ir.py
index f9104c9430..2f8a37a4e1 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -100,6 +100,7 @@ from tvm.relax.op import (
     repeat,
     reshape,
     tensor_to_shape,
+    shape_to_tensor,
     round,
     shape_of,
     std,
@@ -619,6 +620,7 @@ __all__ = [
     "repeat",
     "reshape",
     "tensor_to_shape",
+    "shape_to_tensor",
     "round",
     "shape",
     "shape_of",
diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc
index b353cce27a..c641c45922 100644
--- a/src/relax/op/op.cc
+++ b/src/relax/op/op.cc
@@ -341,6 +341,29 @@ Expr MakeTensorToShape(Expr expr) {
 
 
TVM_REGISTER_GLOBAL("relax.op.tensor_to_shape").set_body_typed(MakeTensorToShape);
 
+// shape_to_tensor
+StructInfo ReturnShapeToTensorStructInfo(const Call& call, const BlockBuilder& 
ctx) {
+  ICHECK(call->args.size() == 1);
+  ICHECK(call->args[0]->struct_info_.defined());
+  const auto* sinfo = GetStructInfoAs<ShapeStructInfoNode>(call->args[0]);
+  ICHECK(sinfo);
+  int32_t ndim = sinfo->ndim;
+  return TensorStructInfo(ShapeExpr({PrimExpr(ndim)}), DataType::Int(64));
+}
+
+RELAY_REGISTER_OP("relax.shape_to_tensor")
+    .set_num_inputs(1)
+    .add_argument("input", "Expr", "The input expression")
+    .set_attr<FInferStructInfo>("FInferStructInfo", 
ReturnShapeToTensorStructInfo)
+    .set_attr<FCallPacked>("FCallPacked", "relax.run.shape_to_tensor");
+
+Expr MakeShapeToTensor(Expr expr) {
+  static const Op& op = Op::Get("relax.shape_to_tensor");
+  return Call(op, {expr}, {}, {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.shape_to_tensor").set_body_typed(MakeShapeToTensor);
+
 // alloc_tensor
 
 StructInfo InferStructInfoAllocateTensor(const Call& call, const BlockBuilder& 
ctx) {
diff --git a/src/relax/op/op_common.cc b/src/relax/op/op_common.cc
index 0421c957ca..0997a3623e 100644
--- a/src/relax/op/op_common.cc
+++ b/src/relax/op/op_common.cc
@@ -38,7 +38,7 @@ Array<TensorStructInfo> GetInputTensorStructInfo(const Call& 
call, const BlockBu
     if (sinfo == nullptr) {
       ctx->ReportFatal(Diagnostic::Error(call)
                        << op << " requires the input " << 
op->arguments[i]->name
-                       << " to be Tensor. However, the given one is "
+                       << " to be Tensor. However, the given one has a "
                        << call->args[i]->struct_info_->GetTypeKey());
     }
     input_tensor_sinfo.push_back(GetRef<TensorStructInfo>(sinfo));
diff --git a/src/relax/transform/fold_constant.cc 
b/src/relax/transform/fold_constant.cc
index 315b3dc1f2..db30900cd2 100644
--- a/src/relax/transform/fold_constant.cc
+++ b/src/relax/transform/fold_constant.cc
@@ -279,6 +279,25 @@ class ConstantFolder : public ExprMutator {
           }
           return ShapeExpr(shape_values);
         }
+      } else if (op->name == "relax.shape_to_tensor") {
+        // Special handling for "relax.shape_to_tensor" since it is 
implemented in PackedFunc.
+        // TODO(sunggg): revisit this when we extend ConstantFolding to fold 
PackedFunc.
+        Expr arg = post_call->args[0];
+        ShapeExpr shape = Downcast<ShapeExpr>(arg);
+        Array<PrimExpr> values = shape->values;
+        Array<Integer> arr;
+        bool is_known = true;
+        for (size_t i = 0; i < values.size(); i++) {
+          PrimExpr val = values[i];
+          arr.push_back(GetRef<IntImm>(val.as<IntImmNode>()));
+          is_known &= (val.dtype() == DataType::Int(64));
+        }
+        if (is_known) {
+          const auto* func = 
tvm::runtime::Registry::Get("relax.run.shape_to_tensor");
+          ICHECK(func != nullptr);
+          runtime::NDArray vals = (*func)(arr);
+          return Constant(vals);
+        }
       }
     }
 
diff --git a/tests/python/relax/test_relax_operators.py 
b/tests/python/relax/test_relax_operators.py
index f197eaa9ab..776abbce76 100644
--- a/tests/python/relax/test_relax_operators.py
+++ b/tests/python/relax/test_relax_operators.py
@@ -23,7 +23,7 @@ import tvm
 import tvm.testing
 from tvm import relax
 from tvm._ffi.base import TVMError
-from tvm.script import relax as R
+from tvm.script import relax as R, tir as T
 
 
 @tvm.script.ir_module
@@ -193,5 +193,43 @@ def test_op_shape_of():
     assert constrained_shape == tvm.runtime.ShapeTuple([1])
 
 
[email protected]_module
+class ShapeToTensorTest:
+    @R.function
+    def const_shape(shape: R.Shape(ndim=-1)) -> R.Tensor(ndim=-1):
+        return R.shape_to_tensor(shape)
+
+    @R.function
+    def symbolic_shape(shape: R.Shape(("m", "n"))) -> R.Tensor(ndim=-1):
+        m = T.int64()
+        n = T.int64()
+        return R.shape_to_tensor(shape)
+
+
+def test_op_shape_to_tensor():
+    # Check struct info
+    isinstance(ShapeToTensorTest["const_shape"].body.struct_info, 
tvm.relax.TensorStructInfo)
+    assert ShapeToTensorTest["const_shape"].body.struct_info.ndim == 1
+    isinstance(ShapeToTensorTest["symbolic_shape"].body.struct_info, 
tvm.relax.TensorStructInfo)
+    assert ShapeToTensorTest["symbolic_shape"].body.struct_info.ndim == 1
+
+    # Check its functionality
+    out2d = run_cpu(ShapeToTensorTest, "const_shape", 
tvm.runtime.ShapeTuple([3, 2]))
+    assert isinstance(out2d, tvm.runtime.ndarray.NDArray)
+    assert np.array_equal(out2d.numpy(), np.array([3, 2]))
+
+    out3d = run_cpu(ShapeToTensorTest, "const_shape", 
tvm.runtime.ShapeTuple([3, 3, 2]))
+    assert isinstance(out3d, tvm.runtime.ndarray.NDArray)
+    assert np.array_equal(out3d.numpy(), np.array([3, 3, 2]))
+
+    out4d = run_cpu(ShapeToTensorTest, "const_shape", 
tvm.runtime.ShapeTuple([3, 3, 2, 2]))
+    assert isinstance(out4d, tvm.runtime.ndarray.NDArray)
+    assert np.array_equal(out4d.numpy(), np.array([3, 3, 2, 2]))
+
+    outs = run_cpu(ShapeToTensorTest, "symbolic_shape", 
tvm.runtime.ShapeTuple([3, 2]))
+    assert isinstance(outs, tvm.runtime.ndarray.NDArray)
+    assert np.array_equal(outs.numpy(), np.array([3, 2]))
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/relax/test_transform_fold_constant.py 
b/tests/python/relax/test_transform_fold_constant.py
index ebd4348b64..b8ad5c4487 100644
--- a/tests/python/relax/test_transform_fold_constant.py
+++ b/tests/python/relax/test_transform_fold_constant.py
@@ -20,7 +20,7 @@ from tvm import relax
 import numpy as np
 
 import tvm.script
-from tvm.script import tir as T, relax as R
+from tvm.script import ir as I, tir as T, relax as R
 
 
 def gen_mod(mod, name, binding):
@@ -368,7 +368,6 @@ def 
test_fold_multiple_relax_ops_with_data_dependent_reshape():
 
         @R.function
         def expected(data: R.Tensor((256,), "float32")) -> R.Tensor((16, 16), 
dtype="float32"):
-            R.func_attr({"global_symbol": "main"})
             with R.dataflow():
                 gv: R.Tensor((16, 16), dtype="float32") = R.reshape(data, 
R.shape([16, 16]))
                 R.output(gv)
@@ -420,5 +419,36 @@ def 
test_unsupported_fold_ops_legalized_to_multiple_calls():
     register_legalize("relax.nn.relu", relu_legalize)
 
 
+def test_fold_shape_computation():
+    @I.ir_module
+    class Module:
+        @R.function
+        def before(
+            data: R.Tensor((5, 4, 3, 2), dtype="float32"),
+            indices: R.Tensor((1,), dtype="int64"),
+        ) -> R.Tensor((1, 1), dtype="int64"):
+            with R.dataflow():
+                lv: R.Tensor((4,), dtype="int64") = 
R.shape_to_tensor(R.shape([5, 4, 3, 2]))
+                lv1: R.Tensor((1,), dtype="int64") = R.take(lv, indices, 
axis=0)
+                lv2: R.Tensor((1, 1), dtype="int64") = R.expand_dims(lv1, 
axis=[0])
+                gv: R.Tensor((1, 1), dtype="int64") = R.concat((lv2,), axis=0)
+                R.output(gv)
+            return gv
+
+        @R.function
+        def expected(
+            data: R.Tensor((5, 4, 3, 2), dtype="float32"), new_shape: 
R.Tensor((1, 1), "int64")
+        ) -> R.Tensor((1, 1), dtype="int64"):
+            return new_shape
+
+    before = gen_mod(Module, "before", {"indices": 
tvm.nd.array(np.array([0]).astype("int64"))})
+    after = relax.transform.FoldConstant()(before)
+    np_take = np.take([5, 4, 3, 2], [0], axis=0)
+    np_expand = np.expand_dims(np_take, axis=[0])
+    np_concat = np.concatenate([np_expand], axis=0)
+    expected = gen_mod(Module, "expected", {"new_shape": 
tvm.nd.array(np_concat)})
+    tvm.ir.assert_structural_equal(after, expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/relax/test_tvmscript_parser.py 
b/tests/python/relax/test_tvmscript_parser.py
index 0e0905ffbc..9b8865b943 100644
--- a/tests/python/relax/test_tvmscript_parser.py
+++ b/tests/python/relax/test_tvmscript_parser.py
@@ -337,6 +337,21 @@ def test_relax_base_op():
     _check(foo, bb.get()["foo"])
 
 
+def test_relax_shape_to_tensor():
+    @R.function
+    def foo(x: R.Shape((4, 4))):
+        tensor = R.shape_to_tensor(x)
+        return tensor
+
+    x = relax.Var("x", R.Shape((4, 4)))
+    bb = relax.BlockBuilder()
+    with bb.function("foo", (x,)):
+        tensor = bb.emit(relax.op.shape_to_tensor(x))
+        bb.emit_func_output(tensor)
+
+    _check(foo, bb.get()["foo"])
+
+
 def test_symbolic_shape():
     @R.function
     def foo(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), 
"float32"):

Reply via email to