This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 42f37ff780 [TIR] Expose UndefinedVars to Python (#15165)
42f37ff780 is described below
commit 42f37ff78010c94144aaa4bb3f8180286bacb904
Author: Siyuan Feng <[email protected]>
AuthorDate: Tue Jun 27 23:43:23 2023 +0800
[TIR] Expose UndefinedVars to Python (#15165)
This PR exposes the UndefinedVars analysis to Python. This is useful for
iterator and index access analysis.
---
python/tvm/tir/analysis/analysis.py | 26 +++++++++++++++++++++++---
src/tir/analysis/var_use_def_analysis.cc | 9 +++++++++
2 files changed, 32 insertions(+), 3 deletions(-)
diff --git a/python/tvm/tir/analysis/analysis.py
b/python/tvm/tir/analysis/analysis.py
index 1a5f8b9781..493c3d957b 100644
--- a/python/tvm/tir/analysis/analysis.py
+++ b/python/tvm/tir/analysis/analysis.py
@@ -16,7 +16,7 @@
# under the License.
"""Wrapping existing analysis utils."""
# pylint: disable=invalid-name
-from typing import Dict, List, Union
+from typing import Dict, List, Optional, Union
import tvm
from tvm import Object
@@ -211,7 +211,7 @@ def calculate_allocated_bytes(
----------
func_or_mod: Union[PrimFunc, IRModule]
The function or module to be detected. If a module is passed, allocated
- memory is calcualted for all PrimFuncs inside the module
+ memory is calculated for all PrimFuncs inside the module
Returns
-------
@@ -266,6 +266,26 @@ def estimate_tir_flops(stmt_or_mod: Union[Stmt, IRModule])
-> float:
# introduce a cycling dependency. We make do with Object.
+def undefined_vars(node: Union[Stmt, PrimExpr], defs: Optional[List[Var]] =
None) -> List[Var]:
+ """Find undefined vars in a TIR statement or expression.
+
+ Parameters
+ ----------
+ node: Union[Stmt, PrimExpr]
+ The TIR statement or expression to be checked.
+
+ defs: Optional[List[Var]]
+ The vars that is defined
+
+ Returns
+ -------
+ result : List[Var]
+ The undefined vars.
+ """
+ defs = defs or []
+ return _ffi_api.UndefinedVars(node, defs) # type: ignore # pylint:
disable=no-member
+
+
def get_prim_func_arg_and_result_memory_constraints(
func: PrimFunc, relay_func_type: Object
) -> List[str]:
@@ -388,7 +408,7 @@ def find_anchor_block(mod: IRModule) -> Block:
def get_vtcm_compaction_passes() -> List[tvm.transform.Pass]:
- """Utility function to get the list of lowering passes to be applied to
calculate thecompacted
+ """Utility function to get the list of lowering passes to be applied to
calculate the compacted
VTCM allocation size
Returns
diff --git a/src/tir/analysis/var_use_def_analysis.cc
b/src/tir/analysis/var_use_def_analysis.cc
index 0d5a4be8ed..456cdf8963 100644
--- a/src/tir/analysis/var_use_def_analysis.cc
+++ b/src/tir/analysis/var_use_def_analysis.cc
@@ -199,5 +199,14 @@ Array<Var> UndefinedVars(const PrimExpr& expr, const
Array<Var>& args) {
return m.undefined_;
}
+TVM_REGISTER_GLOBAL("tir.analysis.UndefinedVars").set_body([](TVMArgs args,
TVMRetValue* rv) {
+ if (args.size() == 2 && args[0].IsObjectRef<Stmt>()) {
+ *rv = UndefinedVars(args[0].AsObjectRef<Stmt>(), args[1]);
+ } else if (args.size() == 2 && args[0].IsObjectRef<PrimExpr>()) {
+ *rv = UndefinedVars(args[0].AsObjectRef<PrimExpr>(), args[1]);
+ } else {
+ LOG(FATAL) << "either UndefinedVars(stmt, args) or UndefinedVars(expr,
args) is expected";
+ }
+});
} // namespace tir
} // namespace tvm