This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 42f37ff780 [TIR] Expose UndefinedVars to Python (#15165)
42f37ff780 is described below

commit 42f37ff78010c94144aaa4bb3f8180286bacb904
Author: Siyuan Feng <[email protected]>
AuthorDate: Tue Jun 27 23:43:23 2023 +0800

    [TIR] Expose UndefinedVars to Python (#15165)
    
    This PR exposes the UndefinedVars analysis to Python. This is useful for
    iterator and index access analysis.
---
 python/tvm/tir/analysis/analysis.py      | 26 +++++++++++++++++++++++---
 src/tir/analysis/var_use_def_analysis.cc |  9 +++++++++
 2 files changed, 32 insertions(+), 3 deletions(-)

diff --git a/python/tvm/tir/analysis/analysis.py 
b/python/tvm/tir/analysis/analysis.py
index 1a5f8b9781..493c3d957b 100644
--- a/python/tvm/tir/analysis/analysis.py
+++ b/python/tvm/tir/analysis/analysis.py
@@ -16,7 +16,7 @@
 # under the License.
 """Wrapping existing analysis utils."""
 # pylint: disable=invalid-name
-from typing import Dict, List, Union
+from typing import Dict, List, Optional, Union
 
 import tvm
 from tvm import Object
@@ -211,7 +211,7 @@ def calculate_allocated_bytes(
     ----------
     func_or_mod: Union[PrimFunc, IRModule]
         The function or module to be detected. If a module is passed, allocated
-        memory is calcualted for all PrimFuncs inside the module
+        memory is calculated for all PrimFuncs inside the module
 
     Returns
     -------
@@ -266,6 +266,26 @@ def estimate_tir_flops(stmt_or_mod: Union[Stmt, IRModule]) 
-> float:
 # introduce a cycling dependency. We make do with Object.
 
 
+def undefined_vars(node: Union[Stmt, PrimExpr], defs: Optional[List[Var]] = 
None) -> List[Var]:
+    """Find undefined vars in a TIR statement or expression.
+
+    Parameters
+    ----------
+    node: Union[Stmt, PrimExpr]
+        The TIR statement or expression to be checked.
+
+    defs: Optional[List[Var]]
+        The vars that is defined
+
+    Returns
+    -------
+    result : List[Var]
+        The undefined vars.
+    """
+    defs = defs or []
+    return _ffi_api.UndefinedVars(node, defs)  # type: ignore # pylint: 
disable=no-member
+
+
 def get_prim_func_arg_and_result_memory_constraints(
     func: PrimFunc, relay_func_type: Object
 ) -> List[str]:
@@ -388,7 +408,7 @@ def find_anchor_block(mod: IRModule) -> Block:
 
 
 def get_vtcm_compaction_passes() -> List[tvm.transform.Pass]:
-    """Utility function to get the list of lowering passes to be applied to 
calculate thecompacted
+    """Utility function to get the list of lowering passes to be applied to 
calculate the compacted
     VTCM allocation size
 
     Returns
diff --git a/src/tir/analysis/var_use_def_analysis.cc 
b/src/tir/analysis/var_use_def_analysis.cc
index 0d5a4be8ed..456cdf8963 100644
--- a/src/tir/analysis/var_use_def_analysis.cc
+++ b/src/tir/analysis/var_use_def_analysis.cc
@@ -199,5 +199,14 @@ Array<Var> UndefinedVars(const PrimExpr& expr, const 
Array<Var>& args) {
   return m.undefined_;
 }
 
+TVM_REGISTER_GLOBAL("tir.analysis.UndefinedVars").set_body([](TVMArgs args, 
TVMRetValue* rv) {
+  if (args.size() == 2 && args[0].IsObjectRef<Stmt>()) {
+    *rv = UndefinedVars(args[0].AsObjectRef<Stmt>(), args[1]);
+  } else if (args.size() == 2 && args[0].IsObjectRef<PrimExpr>()) {
+    *rv = UndefinedVars(args[0].AsObjectRef<PrimExpr>(), args[1]);
+  } else {
+    LOG(FATAL) << "either UndefinedVars(stmt, args) or UndefinedVars(expr, 
args) is expected";
+  }
+});
 }  // namespace tir
 }  // namespace tvm

Reply via email to