This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new ebbb3e7c81 [Unity] Lower `shape_of` to a builtin (#14093)
ebbb3e7c81 is described below

commit ebbb3e7c818d720a565817da49644b90432cb3b2
Author: Yuchen Jin <[email protected]>
AuthorDate: Wed Feb 22 15:57:49 2023 -0800

    [Unity] Lower `shape_of` to a builtin (#14093)
    
    This PR lowers shape_of op to a Relax VM builtin, and changes a utility 
function to take StructInfo as input.
    
    Co-authored-by: Steven S. Lyubomirsky <[email protected]>
---
 include/tvm/relax/utils.h                  |  8 ++--
 src/relax/backend/vm/vm_builtin_lower.cc   | 10 +++++
 src/relax/utils.cc                         |  5 ++-
 tests/python/relax/test_relax_operators.py | 62 ++++++++++++++++++++++++++++++
 4 files changed, 79 insertions(+), 6 deletions(-)

diff --git a/include/tvm/relax/utils.h b/include/tvm/relax/utils.h
index b3cc76768d..dd0200623a 100644
--- a/include/tvm/relax/utils.h
+++ b/include/tvm/relax/utils.h
@@ -25,7 +25,6 @@
 #define TVM_RELAX_UTILS_H_
 
 #include <tvm/ir/module.h>
-#include <tvm/relax/expr.h>
 #include <tvm/runtime/logging.h>
 
 #include <algorithm>
@@ -110,9 +109,10 @@ class NameTable {
 TVM_DLL Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds);
 
 /*!
- * \brief Check if the given type is a boolean scalar type (tensor of rank 0 
with a boolean dtype).
+ * \brief Check if the given StructInfo is for a boolean scalar (tensor of 
rank 0 with a boolean
+ * dtype).
  *
- * \param ty The input type.
+ * \param sinfo The input StructInfo.
  * \param permit_unknown_rank If true, it will permit the input type to have 
unknown rank
  *   (ndim of -1), which will require a dynamic check.
  * \param permit_unknown_dtype If true, it will permit the input type to have 
an unknown dtype
@@ -121,7 +121,7 @@ TVM_DLL Expr Bind(const Expr& expr, const tvm::Map<Var, 
Expr>& binds);
  * \return True iff the input type is a boolean scalar type (or, depending on 
options, has unknown
  *   rank or dtype)
  */
-TVM_DLL bool IsBoolScalarType(const Type& ty, bool permit_unknown_rank = true,
+TVM_DLL bool IsBoolStructInfo(const StructInfo& sinfo, bool 
permit_unknown_rank = true,
                               bool permit_unknown_dtype = true);
 
 /*!
diff --git a/src/relax/backend/vm/vm_builtin_lower.cc 
b/src/relax/backend/vm/vm_builtin_lower.cc
index 6613b39626..00d8512dc6 100644
--- a/src/relax/backend/vm/vm_builtin_lower.cc
+++ b/src/relax/backend/vm/vm_builtin_lower.cc
@@ -53,6 +53,8 @@ class VMBuiltinLowerMutator : public ExprMutator {
       return CallTIRDyn(call);
     } else if (call->op == reshape_op_) {
       return Reshape(call);
+    } else if (call->op == shape_of_op_) {
+      return ShapeOf(call);
     } else if (call->op == make_closure_op_) {
       return MakeClosure(call);
     } else if (call->op == invoke_closure_op_) {
@@ -132,6 +134,12 @@ class VMBuiltinLowerMutator : public ExprMutator {
     return Call(builtin_reshape_, call_node->args, Attrs(), 
{GetStructInfo(call_node)});
   }
 
+  Expr ShapeOf(const Call& call_node) {
+    ICHECK(call_node->args.size() == 1);
+    ICHECK(call_node->struct_info_.defined());
+    return Call(builtin_shape_of_, call_node->args, Attrs(), 
{GetStructInfo(call_node)});
+  }
+
   Expr MakeClosure(const Call& call_node) {
     ICHECK(call_node->args.size() == 2);
     ICHECK(call_node->args[0]->IsInstance<GlobalVarNode>());
@@ -173,6 +181,7 @@ class VMBuiltinLowerMutator : public ExprMutator {
   // object to pattern match.
   const Op& call_tir_dyn_op_ = Op::Get("relax.vm.call_tir_dyn");
   const Op& reshape_op_ = Op::Get("relax.reshape");
+  const Op& shape_of_op_ = Op::Get("relax.shape_of");
   const Op& make_closure_op_ = Op::Get("relax.make_closure");
   const Op& invoke_closure_op_ = Op::Get("relax.invoke_closure");
   const Op& alloc_tensor_op_ = Op::Get("relax.builtin.alloc_tensor");
@@ -187,6 +196,7 @@ class VMBuiltinLowerMutator : public ExprMutator {
   const ExternFunc 
builtin_compute_alloc_shape_{"vm.builtin.compute_alloc_shape"};
   const ExternFunc builtin_call_tir_dyn_{"vm.builtin.call_tir_dyn"};
   const ExternFunc builtin_reshape_{"vm.builtin.reshape"};
+  const ExternFunc builtin_shape_of_{"vm.builtin.shape_of"};
   const ExternFunc builtin_make_closure_{"vm.builtin.make_closure"};
   const ExternFunc builtin_invoke_closure_{"vm.builtin.invoke_closure"};
 };
diff --git a/src/relax/utils.cc b/src/relax/utils.cc
index 110bdb5c8c..1cf64cbf64 100644
--- a/src/relax/utils.cc
+++ b/src/relax/utils.cc
@@ -67,8 +67,9 @@ Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& 
args_map) {
   }
 }
 
-bool IsBoolScalarType(const Type& ty, bool permit_unknown_rank, bool 
permit_unknown_dtype) {
-  const DynTensorTypeNode* tt = ty.as<DynTensorTypeNode>();
+bool IsBoolStructInfo(const StructInfo& sinfo, bool permit_unknown_rank,
+                      bool permit_unknown_dtype) {
+  const TensorStructInfoNode* tt = sinfo.as<TensorStructInfoNode>();
   if (!tt) {
     return false;
   }
diff --git a/tests/python/relax/test_relax_operators.py 
b/tests/python/relax/test_relax_operators.py
new file mode 100644
index 0000000000..7b0b98fea9
--- /dev/null
+++ b/tests/python/relax/test_relax_operators.py
@@ -0,0 +1,62 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import sys
+import tempfile
+
+import numpy as np
+import tvm
+import tvm.testing
+from tvm import relax
+from tvm._ffi.base import TVMError
+from tvm.script import relax as R
+
+
+def run_cpu(mod, func_name, *input):
+    target = tvm.target.Target("llvm")
+    ex = relax.vm.build(mod, target)
+    vm = relax.VirtualMachine(ex, tvm.cpu())
+    return vm[func_name](*input)
+
+
[email protected]_module
+class ShapeOfTest:
+    @R.function
+    def get_shape(t: R.Tensor(ndim=-1, dtype="int32")) -> R.Shape(ndim=-1):
+        return R.shape_of(t)
+
+    @R.function
+    def get_shape_const() -> R.Shape(ndim=-1):
+        x: R.Tensor((), "int32") = R.const(1, dtype="int32")
+        return R.shape_of(x)
+
+
+def test_op_shape_of():
+    const_shape = run_cpu(ShapeOfTest, "get_shape_const")
+    assert const_shape == tvm.runtime.ShapeTuple([])
+
+    scalar_shape = run_cpu(ShapeOfTest, "get_shape", tvm.nd.array(np.array(1, 
dtype="int32")))
+    assert scalar_shape == tvm.runtime.ShapeTuple([])
+
+    tensor_shape = run_cpu(
+        ShapeOfTest, "get_shape", tvm.nd.array(np.zeros((1, 2, 
3)).astype("int32"))
+    )
+    assert tensor_shape == tvm.runtime.ShapeTuple([1, 2, 3])
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to