This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e8cd33b601 [TIR] Update SplitHostDevice to post-process with 
ConvertSSA (#14496)
e8cd33b601 is described below

commit e8cd33b601f081d9c9b4dc86b5c3b36147b37834
Author: Eric Lunderberg <[email protected]>
AuthorDate: Fri Apr 7 20:35:35 2023 -0500

    [TIR] Update SplitHostDevice to post-process with ConvertSSA (#14496)
    
    * [TIR][Utils] Implemented ConvertSSA as IRModule transform
    
    When passes create new PrimFuncs, such as when `tir.SplitHostDevice`
    separates out a `tir::Stmt` into an independent function, the
    parameters of these new function may alias existing variable
    definitions.  While this is well-defined, because variable definitions
    are not shared across function boundaries, it can give false
    discrepancies from `tvm.ir.assert_structural_equal`.
    
    This commit implements `tvm::tir::transform::ConvertSSA`, which
    ensures unique variable declaration locations across an entire module.
    
    * [TIR] Update SplitHostDevice to post-process with ConvertSSA
    
    Avoid duplicate variable defitions between the host and device
    PrimFunc.
---
 include/tvm/tir/transform.h                        |  13 ++
 src/tir/transforms/ir_utils.cc                     | 186 ++++++++++++++++++---
 src/tir/transforms/split_host_device.cc            |   2 +-
 .../test_tir_transform_split_host_device.py        |  25 +++
 4 files changed, 205 insertions(+), 21 deletions(-)

diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h
index d4f537ff31..35aa392db2 100644
--- a/include/tvm/tir/transform.h
+++ b/include/tvm/tir/transform.h
@@ -176,6 +176,19 @@ TVM_DLL Pass RewriteUnsafeSelect();
  */
 TVM_DLL Pass Simplify();
 
+/*!
+ * \brief Convert an IRModule to be SSA form.
+ *
+ * This pass handles cases where the same tir::Var appears in
+ * multiple functions within the same module.  For example, after
+ * extracting a fragment from one function into another, where the
+ * same `tir::Var` may be defined both as within the body of the
+ * original function, and as a parameter within the hoisted function.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass ConvertSSA();
+
 /*!
  * \brief Instruments bound checkers.
  *
diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc
index f6e4ac45c6..e80772fda3 100644
--- a/src/tir/transforms/ir_utils.cc
+++ b/src/tir/transforms/ir_utils.cc
@@ -26,6 +26,7 @@
 #include <tvm/arith/analyzer.h>
 #include <tvm/arith/int_solver.h>
 #include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
 
 #include <unordered_map>
 #include <unordered_set>
@@ -90,13 +91,107 @@ Stmt MergeNest(const std::vector<std::vector<Stmt>>& nest, 
Stmt body) {
 
 class IRConvertSSA final : public StmtExprMutator {
  public:
-  PrimExpr VisitExpr_(const VarNode* op) final {
-    if (scope_.count(op) && !scope_[op].empty()) {
-      return scope_[op].back();
-    } else {
-      return GetRef<PrimExpr>(op);
+  PrimFunc VisitPrimFunc(PrimFunc func) {
+    std::vector<ScopedRedefine> redefines;
+
+    // Remap parameters, if they were used in another function
+    auto params = func->params.Map([&](const tir::Var& var) -> tir::Var {
+      if (defined_.count(var.get())) {
+        const ScopedRedefine& redefine = redefines.emplace_back(this, var);
+        return redefine.new_var;
+      } else {
+        defined_.insert(var.get());
+        return var;
+      }
+    });
+
+    // Remap implicitly defined buffer parameters
+    {
+      std::unordered_set<const VarNode*> defined_params;
+      for (const auto& var : func->params) {
+        defined_params.insert(var.get());
+      }
+      for (const auto& [var, buffer] : func->buffer_map) {
+        static_cast<void>(var);  // gcc 7.x bug, 
https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81767
+        auto check_expr = [&](const PrimExpr& expr) {
+          auto* var_ptr = expr.as<VarNode>();
+          if (!var_ptr) return;
+          if (defined_params.count(var_ptr)) return;
+
+          if (defined_.count(var_ptr)) {
+            auto var = GetRef<Var>(var_ptr);
+            redefines.emplace_back(this, var);
+          } else {
+            defined_.insert(var_ptr);
+          }
+        };
+        for (const auto& dim : buffer->shape) {
+          check_expr(dim);
+        }
+        for (const auto& stride : buffer->strides) {
+          check_expr(stride);
+        }
+        check_expr(buffer->elem_offset);
+      }
+    }
+
+    // Update the buffer map, based on the redefined parameters
+    auto buffer_map = [&]() {
+      Map<Var, Buffer> buffer_map;
+      bool made_change = false;
+      for (const auto& [var, buffer] : func->buffer_map) {
+        auto new_var = GetRemappedVar(var);
+        auto new_buf = GetRemappedBuffer(buffer);
+
+        made_change = made_change || !var.same_as(new_var) || 
!buffer.same_as(new_buf);
+        buffer_map.Set(new_var, new_buf);
+      }
+      if (made_change) {
+        return buffer_map;
+      } else {
+        return func->buffer_map;
+      }
+    }();
+
+    auto attrs = [&]() -> DictAttrs {
+      Map<String, ObjectRef> dict;
+      bool made_change = false;
+
+      for (const auto& [key, old_value] : func->attrs->dict) {
+        auto value = old_value;
+        if (auto* expr = value.as<PrimExprNode>()) {
+          value = VisitExpr(GetRef<PrimExpr>(expr));
+        } else if (auto* stmt = value.as<StmtNode>()) {
+          value = VisitStmt(GetRef<Stmt>(stmt));
+        }
+
+        made_change = made_change || !value.same_as(old_value);
+        dict.Set(key, value);
+      }
+
+      if (made_change) {
+        return DictAttrs(dict);
+      } else {
+        return func->attrs;
+      }
+    }();
+
+    auto body = VisitStmt(func->body);
+
+    // If anything changed, update the returned function
+    if (!params.same_as(func->params) || !buffer_map.same_as(func->buffer_map) 
||
+        !attrs.same_as(func->attrs) || !body.same_as(func->body)) {
+      func = PrimFunc(params, body, func->ret_type, buffer_map, attrs);
+    }
+
+    // Pop the redefines in reverse order of creation
+    while (redefines.size()) {
+      redefines.pop_back();
     }
+    return func;
   }
+
+  PrimExpr VisitExpr_(const VarNode* op) final { return 
GetRemappedVar(GetRef<Var>(op)); }
   PrimExpr VisitExpr_(const LetNode* op) final {
     const Var& v = op->var;
     if (defined_.count(v.get())) {
@@ -142,18 +237,27 @@ class IRConvertSSA final : public StmtExprMutator {
     return node;
   }
 
+  Var GetRemappedVar(Var var) {
+    if (auto it = scope_.find(var.get()); it != scope_.end() && 
it->second.size()) {
+      return it->second.back();
+    } else {
+      return var;
+    }
+  }
+
   Buffer GetRemappedBuffer(Buffer buf) {
     // Determine the buffer var that should be in the updated buffer,
     // given the current scope.  If no redefines are present, then the
     // buffer var is unchanged.
-    Var new_buffer_var = buf->data;
-    auto var_it = scope_.find(buf->data.get());
-    if (var_it != scope_.end() && !var_it->second.empty()) {
-      new_buffer_var = var_it->second.back();
-    }
+    Var new_buffer_var = GetRemappedVar(buf->data);
+    PrimExpr elem_offset = VisitExpr(buf->elem_offset);
+    auto visit_expr = [this](const PrimExpr& expr) { return VisitExpr(expr); };
+    Array<PrimExpr> shape = buf->shape.Map(visit_expr);
+    Array<PrimExpr> strides = buf->strides.Map(visit_expr);
 
     // If no mapping is required, return the original buffer.
-    if (new_buffer_var.same_as(buf->data)) {
+    if (new_buffer_var.same_as(buf->data) && 
elem_offset.same_as(buf->elem_offset) &&
+        shape.same_as(buf->shape) && strides.same_as(buf->strides)) {
       return buf;
     }
 
@@ -169,9 +273,9 @@ class IRConvertSSA final : public StmtExprMutator {
     // new buffer, pushing it onto the scoped stack of existing
     // buffers.  This will be popped when the new_buffer_var
     // redefinition is popped.
-    Buffer new_buf(new_buffer_var, buf->dtype, buf->shape, buf->strides, 
buf->elem_offset,
-                   buf->name, buf->data_alignment, buf->offset_factor, 
buf->buffer_type,
-                   buf->axis_separators, buf->span);
+    Buffer new_buf(new_buffer_var, buf->dtype, shape, strides, elem_offset, 
buf->name,
+                   buf->data_alignment, buf->offset_factor, buf->buffer_type, 
buf->axis_separators,
+                   buf->span);
     buffers.push_back(new_buf);
     return new_buf;
   }
@@ -239,16 +343,33 @@ class IRConvertSSA final : public StmtExprMutator {
     }
 
     ~ScopedRedefine() {
-      parent->scope_[old_var.get()].pop_back();
-      for (auto& kv : parent->buf_remap_) {
-        std::vector<Buffer>& buffers = kv.second;
-        if (buffers.size() && (buffers.back()->data.get() == new_var.get())) {
-          buffers.pop_back();
+      if (parent) {
+        parent->scope_[old_var.get()].pop_back();
+        for (auto& kv : parent->buf_remap_) {
+          std::vector<Buffer>& buffers = kv.second;
+          if (buffers.size() && (buffers.back()->data.get() == new_var.get())) 
{
+            buffers.pop_back();
+          }
         }
       }
     }
 
-    IRConvertSSA* parent;
+    ScopedRedefine& operator=(const ScopedRedefine&) = delete;
+    ScopedRedefine(const ScopedRedefine&) = delete;
+
+    ScopedRedefine& operator=(ScopedRedefine&& other) {
+      swap(other);
+      return *this;
+    }
+    ScopedRedefine(ScopedRedefine&& other) { swap(other); }
+
+    void swap(ScopedRedefine& other) {
+      std::swap(parent, other.parent);
+      std::swap(old_var, other.old_var);
+      std::swap(new_var, other.new_var);
+    }
+
+    IRConvertSSA* parent{nullptr};
     Var old_var;
     Var new_var;
   };
@@ -447,5 +568,30 @@ std::pair<PrimExpr, PrimExpr> GetAsyncWaitAttributes(const 
AttrStmtNode* op) {
   return std::make_pair(op->value, inner->value);
 }
 
+namespace transform {
+Pass ConvertSSA() {
+  auto pass_func = [](IRModule mod, PassContext ctx) {
+    tir::IRConvertSSA converter;
+    Map<GlobalVar, BaseFunc> functions;
+    bool made_change = false;
+    for (auto [gvar, base_func] : mod->functions) {
+      if (auto* ptr = base_func.as<tir::PrimFuncNode>()) {
+        auto updated = converter.VisitPrimFunc(GetRef<tir::PrimFunc>(ptr));
+        if (!updated.same_as(base_func)) {
+          made_change = true;
+          base_func = updated;
+        }
+      }
+      functions.Set(gvar, base_func);
+    }
+    if (made_change) {
+      mod.CopyOnWrite()->functions = std::move(functions);
+    }
+    return mod;
+  };
+  return tvm::transform::CreateModulePass(pass_func, 0, "tir.ConvertSSA", {});
+}
+
+}  // namespace transform
 }  // namespace tir
 }  // namespace tvm
diff --git a/src/tir/transforms/split_host_device.cc 
b/src/tir/transforms/split_host_device.cc
index 3696ff84e5..4f47b8ce2b 100644
--- a/src/tir/transforms/split_host_device.cc
+++ b/src/tir/transforms/split_host_device.cc
@@ -282,7 +282,7 @@ Pass SplitHostDevice() {
       }
     }
     mod->Update(device_mod);
-    return mod;
+    return ConvertSSA()(mod);
   };
 
   return tvm::transform::CreateModulePass(pass_func, 0, "tir.SplitHostDevice", 
{});
diff --git a/tests/python/unittest/test_tir_transform_split_host_device.py 
b/tests/python/unittest/test_tir_transform_split_host_device.py
index f4adac9cf7..680f23e07a 100644
--- a/tests/python/unittest/test_tir_transform_split_host_device.py
+++ b/tests/python/unittest/test_tir_transform_split_host_device.py
@@ -17,6 +17,7 @@
 import tvm
 from tvm import te
 import tvm.testing
+from tvm.script import tir as T, ir as I
 
 
 @tvm.testing.requires_cuda
@@ -48,5 +49,29 @@ def test_split_host_device_func_attr():
     assert fdevice.attrs["tir.is_global_func"].value
 
 
+def test_ssa_across_entire_module():
+    """The host and device functions should not share TIR vars
+
+    Any arguments that are passed from the host to the device should
+    be in terms of independent TIR variables.
+    """
+
+    @I.ir_module
+    class before:
+        @T.prim_func
+        def main():
+            T.func_attr({"global_symbol": "main", "target": T.target("cuda")})
+            for i in range(16):
+                T.attr(0, "device_scope", 0)
+                for j in range(16):
+                    T.evaluate(i)
+
+    after = tvm.tir.transform.SplitHostDevice()(before)
+    loop_var = after["main"].body.loop_var
+    param_var = after["main_kernel0"].params[0]
+
+    assert not loop_var.same_as(param_var)
+
+
 if __name__ == "__main__":
     test_split_host_device_func_attr()

Reply via email to