This is an automated email from the ASF dual-hosted git repository.

lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 2f7c097594 [TIR] Allow VerifyWellFormed to accept IRModule (#15247)
2f7c097594 is described below

commit 2f7c0975940ddc6f3c526424b904ed5bb28fc41a
Author: Eric Lunderberg <[email protected]>
AuthorDate: Thu Jul 6 08:04:00 2023 -0500

    [TIR] Allow VerifyWellFormed to accept IRModule (#15247)
    
    Previously, the calling code needed to iterate over all functions in a
    module.  This commit adds an overload that accepts `const IRModule&`,
    allowing it to be called more easily.  This also provides an API that
    can be extended to validate behavior across an entire
    IRModule (e.g. requiring that internal function calls have the correct
    argument types).
---
 python/tvm/tir/analysis/analysis.py                |  8 +++----
 src/tir/analysis/verify_well_formed.cc             | 25 +++++++++++++++++++++-
 .../test_tir_analysis_verify_well_formed.py        |  1 +
 3 files changed, 29 insertions(+), 5 deletions(-)

diff --git a/python/tvm/tir/analysis/analysis.py 
b/python/tvm/tir/analysis/analysis.py
index 493c3d957b..8d7e81d7d0 100644
--- a/python/tvm/tir/analysis/analysis.py
+++ b/python/tvm/tir/analysis/analysis.py
@@ -349,14 +349,14 @@ def apply_prim_func_arg_and_result_memory_constraints(
     )
 
 
-def verify_well_formed(func: PrimFunc, assert_mode: bool = True) -> bool:
+def verify_well_formed(obj: Union[PrimFunc, IRModule], assert_mode: bool = 
True) -> bool:
     """Verify if the given TIR is well-formed. The verification includes:
         - Check if expressions not contain vars that is defined outside the 
block.
 
     Parameters
     ----------
-    func: tvm.tir.PrimFunc
-        The function to be verified.
+    obj: Union[tvm.tir.PrimFunc, tvm.ir.IRModule]
+        The function or module to be verified.
 
     assert_mode: bool
         The indicator if it raises an error when the function is not 
well-formed.
@@ -366,7 +366,7 @@ def verify_well_formed(func: PrimFunc, assert_mode: bool = 
True) -> bool:
     result: bool
         Whether it is a well-formed TIR function.
     """
-    return _ffi_api.VerifyWellFormed(func, assert_mode)  # type: ignore # 
pylint: disable=no-member
+    return _ffi_api.VerifyWellFormed(obj, assert_mode)  # type: ignore # 
pylint: disable=no-member
 
 
 def OOBChecker():
diff --git a/src/tir/analysis/verify_well_formed.cc 
b/src/tir/analysis/verify_well_formed.cc
index e0318e1408..898183533c 100644
--- a/src/tir/analysis/verify_well_formed.cc
+++ b/src/tir/analysis/verify_well_formed.cc
@@ -27,6 +27,7 @@
 #include <tvm/tir/stmt_functor.h>
 
 #include "../ir/functor_common.h"
+#include "tvm/ir/module.h"
 
 namespace tvm {
 namespace tir {
@@ -142,7 +143,29 @@ bool VerifyWellFormed(const PrimFunc& func, bool 
assert_mode) {
   return true;
 }
 
-TVM_REGISTER_GLOBAL("tir.analysis.VerifyWellFormed").set_body_typed(VerifyWellFormed);
+bool VerifyWellFormed(const IRModule& mod, bool assert_mode) {
+  for (const auto& [gvar, base_func] : mod->functions) {
+    if (auto prim_func = base_func.as<PrimFunc>()) {
+      bool res = VerifyWellFormed(prim_func.value(), assert_mode);
+      if (!res) {
+        return false;
+      }
+    }
+  }
+  return true;
+}
+
+TVM_REGISTER_GLOBAL("tir.analysis.VerifyWellFormed")
+    .set_body_typed([](const ObjectRef& obj, bool assert_mode) {
+      if (auto opt = obj.as<PrimFunc>()) {
+        return VerifyWellFormed(opt.value(), assert_mode);
+      } else if (auto opt = obj.as<IRModule>()) {
+        return VerifyWellFormed(opt.value(), assert_mode);
+      } else {
+        LOG(FATAL) << "Expected VerifyWellFormed argument to be a PrimFunc or 
IRModule, but found "
+                   << obj->GetTypeKey();
+      }
+    });
 
 }  // namespace tir
 }  // namespace tvm
diff --git a/tests/python/unittest/test_tir_analysis_verify_well_formed.py 
b/tests/python/unittest/test_tir_analysis_verify_well_formed.py
index 023d5f5f31..4f88cc8be1 100644
--- a/tests/python/unittest/test_tir_analysis_verify_well_formed.py
+++ b/tests/python/unittest/test_tir_analysis_verify_well_formed.py
@@ -36,6 +36,7 @@ def test_pass_simple():
                 C[i, j] = B[i, j] * 2.0
 
     assert tvm.tir.analysis.verify_well_formed(element_wise)
+    assert 
tvm.tir.analysis.verify_well_formed(tvm.IRModule.from_expr(element_wise))
 
 
 def test_fail_use_out_loop_var():

Reply via email to