This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e8cd33b601 [TIR] Update SplitHostDevice to post-process with
ConvertSSA (#14496)
e8cd33b601 is described below
commit e8cd33b601f081d9c9b4dc86b5c3b36147b37834
Author: Eric Lunderberg <[email protected]>
AuthorDate: Fri Apr 7 20:35:35 2023 -0500
[TIR] Update SplitHostDevice to post-process with ConvertSSA (#14496)
* [TIR][Utils] Implemented ConvertSSA as IRModule transform
When passes create new PrimFuncs, such as when `tir.SplitHostDevice`
separates out a `tir::Stmt` into an independent function, the
parameters of these new function may alias existing variable
definitions. While this is well-defined, because variable definitions
are not shared across function boundaries, it can give false
discrepancies from `tvm.ir.assert_structural_equal`.
This commit implements `tvm::tir::transform::ConvertSSA`, which
ensures unique variable declaration locations across an entire module.
* [TIR] Update SplitHostDevice to post-process with ConvertSSA
Avoid duplicate variable defitions between the host and device
PrimFunc.
---
include/tvm/tir/transform.h | 13 ++
src/tir/transforms/ir_utils.cc | 186 ++++++++++++++++++---
src/tir/transforms/split_host_device.cc | 2 +-
.../test_tir_transform_split_host_device.py | 25 +++
4 files changed, 205 insertions(+), 21 deletions(-)
diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h
index d4f537ff31..35aa392db2 100644
--- a/include/tvm/tir/transform.h
+++ b/include/tvm/tir/transform.h
@@ -176,6 +176,19 @@ TVM_DLL Pass RewriteUnsafeSelect();
*/
TVM_DLL Pass Simplify();
+/*!
+ * \brief Convert an IRModule to be SSA form.
+ *
+ * This pass handles cases where the same tir::Var appears in
+ * multiple functions within the same module. For example, after
+ * extracting a fragment from one function into another, where the
+ * same `tir::Var` may be defined both as within the body of the
+ * original function, and as a parameter within the hoisted function.
+ *
+ * \return The pass.
+ */
+TVM_DLL Pass ConvertSSA();
+
/*!
* \brief Instruments bound checkers.
*
diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc
index f6e4ac45c6..e80772fda3 100644
--- a/src/tir/transforms/ir_utils.cc
+++ b/src/tir/transforms/ir_utils.cc
@@ -26,6 +26,7 @@
#include <tvm/arith/analyzer.h>
#include <tvm/arith/int_solver.h>
#include <tvm/tir/stmt_functor.h>
+#include <tvm/tir/transform.h>
#include <unordered_map>
#include <unordered_set>
@@ -90,13 +91,107 @@ Stmt MergeNest(const std::vector<std::vector<Stmt>>& nest,
Stmt body) {
class IRConvertSSA final : public StmtExprMutator {
public:
- PrimExpr VisitExpr_(const VarNode* op) final {
- if (scope_.count(op) && !scope_[op].empty()) {
- return scope_[op].back();
- } else {
- return GetRef<PrimExpr>(op);
+ PrimFunc VisitPrimFunc(PrimFunc func) {
+ std::vector<ScopedRedefine> redefines;
+
+ // Remap parameters, if they were used in another function
+ auto params = func->params.Map([&](const tir::Var& var) -> tir::Var {
+ if (defined_.count(var.get())) {
+ const ScopedRedefine& redefine = redefines.emplace_back(this, var);
+ return redefine.new_var;
+ } else {
+ defined_.insert(var.get());
+ return var;
+ }
+ });
+
+ // Remap implicitly defined buffer parameters
+ {
+ std::unordered_set<const VarNode*> defined_params;
+ for (const auto& var : func->params) {
+ defined_params.insert(var.get());
+ }
+ for (const auto& [var, buffer] : func->buffer_map) {
+ static_cast<void>(var); // gcc 7.x bug,
https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81767
+ auto check_expr = [&](const PrimExpr& expr) {
+ auto* var_ptr = expr.as<VarNode>();
+ if (!var_ptr) return;
+ if (defined_params.count(var_ptr)) return;
+
+ if (defined_.count(var_ptr)) {
+ auto var = GetRef<Var>(var_ptr);
+ redefines.emplace_back(this, var);
+ } else {
+ defined_.insert(var_ptr);
+ }
+ };
+ for (const auto& dim : buffer->shape) {
+ check_expr(dim);
+ }
+ for (const auto& stride : buffer->strides) {
+ check_expr(stride);
+ }
+ check_expr(buffer->elem_offset);
+ }
+ }
+
+ // Update the buffer map, based on the redefined parameters
+ auto buffer_map = [&]() {
+ Map<Var, Buffer> buffer_map;
+ bool made_change = false;
+ for (const auto& [var, buffer] : func->buffer_map) {
+ auto new_var = GetRemappedVar(var);
+ auto new_buf = GetRemappedBuffer(buffer);
+
+ made_change = made_change || !var.same_as(new_var) ||
!buffer.same_as(new_buf);
+ buffer_map.Set(new_var, new_buf);
+ }
+ if (made_change) {
+ return buffer_map;
+ } else {
+ return func->buffer_map;
+ }
+ }();
+
+ auto attrs = [&]() -> DictAttrs {
+ Map<String, ObjectRef> dict;
+ bool made_change = false;
+
+ for (const auto& [key, old_value] : func->attrs->dict) {
+ auto value = old_value;
+ if (auto* expr = value.as<PrimExprNode>()) {
+ value = VisitExpr(GetRef<PrimExpr>(expr));
+ } else if (auto* stmt = value.as<StmtNode>()) {
+ value = VisitStmt(GetRef<Stmt>(stmt));
+ }
+
+ made_change = made_change || !value.same_as(old_value);
+ dict.Set(key, value);
+ }
+
+ if (made_change) {
+ return DictAttrs(dict);
+ } else {
+ return func->attrs;
+ }
+ }();
+
+ auto body = VisitStmt(func->body);
+
+ // If anything changed, update the returned function
+ if (!params.same_as(func->params) || !buffer_map.same_as(func->buffer_map)
||
+ !attrs.same_as(func->attrs) || !body.same_as(func->body)) {
+ func = PrimFunc(params, body, func->ret_type, buffer_map, attrs);
+ }
+
+ // Pop the redefines in reverse order of creation
+ while (redefines.size()) {
+ redefines.pop_back();
}
+ return func;
}
+
+ PrimExpr VisitExpr_(const VarNode* op) final { return
GetRemappedVar(GetRef<Var>(op)); }
PrimExpr VisitExpr_(const LetNode* op) final {
const Var& v = op->var;
if (defined_.count(v.get())) {
@@ -142,18 +237,27 @@ class IRConvertSSA final : public StmtExprMutator {
return node;
}
+ Var GetRemappedVar(Var var) {
+ if (auto it = scope_.find(var.get()); it != scope_.end() &&
it->second.size()) {
+ return it->second.back();
+ } else {
+ return var;
+ }
+ }
+
Buffer GetRemappedBuffer(Buffer buf) {
// Determine the buffer var that should be in the updated buffer,
// given the current scope. If no redefines are present, then the
// buffer var is unchanged.
- Var new_buffer_var = buf->data;
- auto var_it = scope_.find(buf->data.get());
- if (var_it != scope_.end() && !var_it->second.empty()) {
- new_buffer_var = var_it->second.back();
- }
+ Var new_buffer_var = GetRemappedVar(buf->data);
+ PrimExpr elem_offset = VisitExpr(buf->elem_offset);
+ auto visit_expr = [this](const PrimExpr& expr) { return VisitExpr(expr); };
+ Array<PrimExpr> shape = buf->shape.Map(visit_expr);
+ Array<PrimExpr> strides = buf->strides.Map(visit_expr);
// If no mapping is required, return the original buffer.
- if (new_buffer_var.same_as(buf->data)) {
+ if (new_buffer_var.same_as(buf->data) &&
elem_offset.same_as(buf->elem_offset) &&
+ shape.same_as(buf->shape) && strides.same_as(buf->strides)) {
return buf;
}
@@ -169,9 +273,9 @@ class IRConvertSSA final : public StmtExprMutator {
// new buffer, pushing it onto the scoped stack of existing
// buffers. This will be popped when the new_buffer_var
// redefinition is popped.
- Buffer new_buf(new_buffer_var, buf->dtype, buf->shape, buf->strides,
buf->elem_offset,
- buf->name, buf->data_alignment, buf->offset_factor,
buf->buffer_type,
- buf->axis_separators, buf->span);
+ Buffer new_buf(new_buffer_var, buf->dtype, shape, strides, elem_offset,
buf->name,
+ buf->data_alignment, buf->offset_factor, buf->buffer_type,
buf->axis_separators,
+ buf->span);
buffers.push_back(new_buf);
return new_buf;
}
@@ -239,16 +343,33 @@ class IRConvertSSA final : public StmtExprMutator {
}
~ScopedRedefine() {
- parent->scope_[old_var.get()].pop_back();
- for (auto& kv : parent->buf_remap_) {
- std::vector<Buffer>& buffers = kv.second;
- if (buffers.size() && (buffers.back()->data.get() == new_var.get())) {
- buffers.pop_back();
+ if (parent) {
+ parent->scope_[old_var.get()].pop_back();
+ for (auto& kv : parent->buf_remap_) {
+ std::vector<Buffer>& buffers = kv.second;
+ if (buffers.size() && (buffers.back()->data.get() == new_var.get()))
{
+ buffers.pop_back();
+ }
}
}
}
- IRConvertSSA* parent;
+ ScopedRedefine& operator=(const ScopedRedefine&) = delete;
+ ScopedRedefine(const ScopedRedefine&) = delete;
+
+ ScopedRedefine& operator=(ScopedRedefine&& other) {
+ swap(other);
+ return *this;
+ }
+ ScopedRedefine(ScopedRedefine&& other) { swap(other); }
+
+ void swap(ScopedRedefine& other) {
+ std::swap(parent, other.parent);
+ std::swap(old_var, other.old_var);
+ std::swap(new_var, other.new_var);
+ }
+
+ IRConvertSSA* parent{nullptr};
Var old_var;
Var new_var;
};
@@ -447,5 +568,30 @@ std::pair<PrimExpr, PrimExpr> GetAsyncWaitAttributes(const
AttrStmtNode* op) {
return std::make_pair(op->value, inner->value);
}
+namespace transform {
+Pass ConvertSSA() {
+ auto pass_func = [](IRModule mod, PassContext ctx) {
+ tir::IRConvertSSA converter;
+ Map<GlobalVar, BaseFunc> functions;
+ bool made_change = false;
+ for (auto [gvar, base_func] : mod->functions) {
+ if (auto* ptr = base_func.as<tir::PrimFuncNode>()) {
+ auto updated = converter.VisitPrimFunc(GetRef<tir::PrimFunc>(ptr));
+ if (!updated.same_as(base_func)) {
+ made_change = true;
+ base_func = updated;
+ }
+ }
+ functions.Set(gvar, base_func);
+ }
+ if (made_change) {
+ mod.CopyOnWrite()->functions = std::move(functions);
+ }
+ return mod;
+ };
+ return tvm::transform::CreateModulePass(pass_func, 0, "tir.ConvertSSA", {});
+}
+
+} // namespace transform
} // namespace tir
} // namespace tvm
diff --git a/src/tir/transforms/split_host_device.cc
b/src/tir/transforms/split_host_device.cc
index 3696ff84e5..4f47b8ce2b 100644
--- a/src/tir/transforms/split_host_device.cc
+++ b/src/tir/transforms/split_host_device.cc
@@ -282,7 +282,7 @@ Pass SplitHostDevice() {
}
}
mod->Update(device_mod);
- return mod;
+ return ConvertSSA()(mod);
};
return tvm::transform::CreateModulePass(pass_func, 0, "tir.SplitHostDevice",
{});
diff --git a/tests/python/unittest/test_tir_transform_split_host_device.py
b/tests/python/unittest/test_tir_transform_split_host_device.py
index f4adac9cf7..680f23e07a 100644
--- a/tests/python/unittest/test_tir_transform_split_host_device.py
+++ b/tests/python/unittest/test_tir_transform_split_host_device.py
@@ -17,6 +17,7 @@
import tvm
from tvm import te
import tvm.testing
+from tvm.script import tir as T, ir as I
@tvm.testing.requires_cuda
@@ -48,5 +49,29 @@ def test_split_host_device_func_attr():
assert fdevice.attrs["tir.is_global_func"].value
+def test_ssa_across_entire_module():
+ """The host and device functions should not share TIR vars
+
+ Any arguments that are passed from the host to the device should
+ be in terms of independent TIR variables.
+ """
+
+ @I.ir_module
+ class before:
+ @T.prim_func
+ def main():
+ T.func_attr({"global_symbol": "main", "target": T.target("cuda")})
+ for i in range(16):
+ T.attr(0, "device_scope", 0)
+ for j in range(16):
+ T.evaluate(i)
+
+ after = tvm.tir.transform.SplitHostDevice()(before)
+ loop_var = after["main"].body.loop_var
+ param_var = after["main_kernel0"].params[0]
+
+ assert not loop_var.same_as(param_var)
+
+
if __name__ == "__main__":
test_split_host_device_func_attr()