This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new ebbb3e7c81 [Unity] Lower `shape_of` to a builtin (#14093)
ebbb3e7c81 is described below
commit ebbb3e7c818d720a565817da49644b90432cb3b2
Author: Yuchen Jin <[email protected]>
AuthorDate: Wed Feb 22 15:57:49 2023 -0800
[Unity] Lower `shape_of` to a builtin (#14093)
This PR lowers shape_of op to a Relax VM builtin, and changes a utility
function to take StructInfo as input.
Co-authored-by: Steven S. Lyubomirsky <[email protected]>
---
include/tvm/relax/utils.h | 8 ++--
src/relax/backend/vm/vm_builtin_lower.cc | 10 +++++
src/relax/utils.cc | 5 ++-
tests/python/relax/test_relax_operators.py | 62 ++++++++++++++++++++++++++++++
4 files changed, 79 insertions(+), 6 deletions(-)
diff --git a/include/tvm/relax/utils.h b/include/tvm/relax/utils.h
index b3cc76768d..dd0200623a 100644
--- a/include/tvm/relax/utils.h
+++ b/include/tvm/relax/utils.h
@@ -25,7 +25,6 @@
#define TVM_RELAX_UTILS_H_
#include <tvm/ir/module.h>
-#include <tvm/relax/expr.h>
#include <tvm/runtime/logging.h>
#include <algorithm>
@@ -110,9 +109,10 @@ class NameTable {
TVM_DLL Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds);
/*!
- * \brief Check if the given type is a boolean scalar type (tensor of rank 0
with a boolean dtype).
+ * \brief Check if the given StructInfo is for a boolean scalar (tensor of
rank 0 with a boolean
+ * dtype).
*
- * \param ty The input type.
+ * \param sinfo The input StructInfo.
* \param permit_unknown_rank If true, it will permit the input type to have
unknown rank
* (ndim of -1), which will require a dynamic check.
* \param permit_unknown_dtype If true, it will permit the input type to have
an unknown dtype
@@ -121,7 +121,7 @@ TVM_DLL Expr Bind(const Expr& expr, const tvm::Map<Var,
Expr>& binds);
* \return True iff the input type is a boolean scalar type (or, depending on
options, has unknown
* rank or dtype)
*/
-TVM_DLL bool IsBoolScalarType(const Type& ty, bool permit_unknown_rank = true,
+TVM_DLL bool IsBoolStructInfo(const StructInfo& sinfo, bool
permit_unknown_rank = true,
bool permit_unknown_dtype = true);
/*!
diff --git a/src/relax/backend/vm/vm_builtin_lower.cc
b/src/relax/backend/vm/vm_builtin_lower.cc
index 6613b39626..00d8512dc6 100644
--- a/src/relax/backend/vm/vm_builtin_lower.cc
+++ b/src/relax/backend/vm/vm_builtin_lower.cc
@@ -53,6 +53,8 @@ class VMBuiltinLowerMutator : public ExprMutator {
return CallTIRDyn(call);
} else if (call->op == reshape_op_) {
return Reshape(call);
+ } else if (call->op == shape_of_op_) {
+ return ShapeOf(call);
} else if (call->op == make_closure_op_) {
return MakeClosure(call);
} else if (call->op == invoke_closure_op_) {
@@ -132,6 +134,12 @@ class VMBuiltinLowerMutator : public ExprMutator {
return Call(builtin_reshape_, call_node->args, Attrs(),
{GetStructInfo(call_node)});
}
+ Expr ShapeOf(const Call& call_node) {
+ ICHECK(call_node->args.size() == 1);
+ ICHECK(call_node->struct_info_.defined());
+ return Call(builtin_shape_of_, call_node->args, Attrs(),
{GetStructInfo(call_node)});
+ }
+
Expr MakeClosure(const Call& call_node) {
ICHECK(call_node->args.size() == 2);
ICHECK(call_node->args[0]->IsInstance<GlobalVarNode>());
@@ -173,6 +181,7 @@ class VMBuiltinLowerMutator : public ExprMutator {
// object to pattern match.
const Op& call_tir_dyn_op_ = Op::Get("relax.vm.call_tir_dyn");
const Op& reshape_op_ = Op::Get("relax.reshape");
+ const Op& shape_of_op_ = Op::Get("relax.shape_of");
const Op& make_closure_op_ = Op::Get("relax.make_closure");
const Op& invoke_closure_op_ = Op::Get("relax.invoke_closure");
const Op& alloc_tensor_op_ = Op::Get("relax.builtin.alloc_tensor");
@@ -187,6 +196,7 @@ class VMBuiltinLowerMutator : public ExprMutator {
const ExternFunc
builtin_compute_alloc_shape_{"vm.builtin.compute_alloc_shape"};
const ExternFunc builtin_call_tir_dyn_{"vm.builtin.call_tir_dyn"};
const ExternFunc builtin_reshape_{"vm.builtin.reshape"};
+ const ExternFunc builtin_shape_of_{"vm.builtin.shape_of"};
const ExternFunc builtin_make_closure_{"vm.builtin.make_closure"};
const ExternFunc builtin_invoke_closure_{"vm.builtin.invoke_closure"};
};
diff --git a/src/relax/utils.cc b/src/relax/utils.cc
index 110bdb5c8c..1cf64cbf64 100644
--- a/src/relax/utils.cc
+++ b/src/relax/utils.cc
@@ -67,8 +67,9 @@ Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>&
args_map) {
}
}
-bool IsBoolScalarType(const Type& ty, bool permit_unknown_rank, bool
permit_unknown_dtype) {
- const DynTensorTypeNode* tt = ty.as<DynTensorTypeNode>();
+bool IsBoolStructInfo(const StructInfo& sinfo, bool permit_unknown_rank,
+ bool permit_unknown_dtype) {
+ const TensorStructInfoNode* tt = sinfo.as<TensorStructInfoNode>();
if (!tt) {
return false;
}
diff --git a/tests/python/relax/test_relax_operators.py
b/tests/python/relax/test_relax_operators.py
new file mode 100644
index 0000000000..7b0b98fea9
--- /dev/null
+++ b/tests/python/relax/test_relax_operators.py
@@ -0,0 +1,62 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import sys
+import tempfile
+
+import numpy as np
+import tvm
+import tvm.testing
+from tvm import relax
+from tvm._ffi.base import TVMError
+from tvm.script import relax as R
+
+
+def run_cpu(mod, func_name, *input):
+ target = tvm.target.Target("llvm")
+ ex = relax.vm.build(mod, target)
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+ return vm[func_name](*input)
+
+
[email protected]_module
+class ShapeOfTest:
+ @R.function
+ def get_shape(t: R.Tensor(ndim=-1, dtype="int32")) -> R.Shape(ndim=-1):
+ return R.shape_of(t)
+
+ @R.function
+ def get_shape_const() -> R.Shape(ndim=-1):
+ x: R.Tensor((), "int32") = R.const(1, dtype="int32")
+ return R.shape_of(x)
+
+
+def test_op_shape_of():
+ const_shape = run_cpu(ShapeOfTest, "get_shape_const")
+ assert const_shape == tvm.runtime.ShapeTuple([])
+
+ scalar_shape = run_cpu(ShapeOfTest, "get_shape", tvm.nd.array(np.array(1,
dtype="int32")))
+ assert scalar_shape == tvm.runtime.ShapeTuple([])
+
+ tensor_shape = run_cpu(
+ ShapeOfTest, "get_shape", tvm.nd.array(np.zeros((1, 2,
3)).astype("int32"))
+ )
+ assert tensor_shape == tvm.runtime.ShapeTuple([1, 2, 3])
+
+
+if __name__ == "__main__":
+ tvm.testing.main()