This is an automated email from the ASF dual-hosted git repository.
sanirudh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 81a850693d [TIR] Use constructor for new PrimFunc in TransformLayout
(#16832)
81a850693d is described below
commit 81a850693d3afc3d056d119c1b1c68b4c1aec8a7
Author: Eric Lunderberg <[email protected]>
AuthorDate: Sun Apr 7 08:09:42 2024 -0500
[TIR] Use constructor for new PrimFunc in TransformLayout (#16832)
Using the constructor applies all initialization steps and
error-checking, where using `CopyOnWrite()` does not. This function
is used as part of the legalization of `relax.op.layout_tranform`,
which relies on the annotations produced in the `PrimFunc`
constructor.
---
.../schedule/primitive/layout_transformation.cc | 17 +++---
.../test_transform_legalize_ops_manipulate.py | 64 ++++++++++++++++++++++
2 files changed, 74 insertions(+), 7 deletions(-)
diff --git a/src/tir/schedule/primitive/layout_transformation.cc
b/src/tir/schedule/primitive/layout_transformation.cc
index 6c6427a906..f1e9106a63 100644
--- a/src/tir/schedule/primitive/layout_transformation.cc
+++ b/src/tir/schedule/primitive/layout_transformation.cc
@@ -1207,17 +1207,20 @@ void TransformLayout(ScheduleState self, const
StmtSRef& block_sref, int buffer_
// Step 4: Rewrite buffer_map of the PrimFunc if necessary.
if (!defining_site_sref.defined()) {
GlobalVar g_var;
- GetRootPrimFunc(self->mod, scope_block, &g_var);
+ const auto* old_func = GetRootPrimFunc(self->mod, scope_block, &g_var);
IRModuleNode* new_mod = self->mod.CopyOnWrite();
MapNode* new_map = new_mod->functions.CopyOnWrite();
- PrimFunc ref_new_func = Downcast<PrimFunc>(std::move(new_map->at(g_var)));
- PrimFuncNode* new_func = ref_new_func.CopyOnWrite();
- MapNode* new_buffer_map = new_func->buffer_map.CopyOnWrite();
- for (auto it = new_buffer_map->begin(); it != new_buffer_map->end(); ++it)
{
- if ((*it).second.same_as(old_buffer)) {
- (*it).second = new_buffer;
+
+ Map<Var, Buffer> new_buffer_map;
+ for (auto [var, buffer] : old_func->buffer_map) {
+ if (buffer.same_as(old_buffer)) {
+ buffer = new_buffer;
}
+ new_buffer_map.Set(var, buffer);
}
+
+ PrimFunc ref_new_func(old_func->params, old_func->body,
old_func->ret_type, new_buffer_map,
+ old_func->attrs, old_func->span);
new_map->at(g_var) = std::move(ref_new_func);
}
diff --git a/tests/python/relax/test_transform_legalize_ops_manipulate.py
b/tests/python/relax/test_transform_legalize_ops_manipulate.py
index 9b7a8f23c9..dd0208f5db 100644
--- a/tests/python/relax/test_transform_legalize_ops_manipulate.py
+++ b/tests/python/relax/test_transform_legalize_ops_manipulate.py
@@ -1666,5 +1666,69 @@ def test_layout_transform_with_pad_axis_sep():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_func_struct_info_of_legalized_layout_transform():
+ """PrimFunc shape information must be correct
+
+ This is a regression test. Previously, the legalization of
+ `R.layout_transform` produced a PrimFunc with `FuncStructInfo`
+ different than its actual signature. This resulted in errors
+ when later passes attempted to infer the StructInfo.
+ """
+
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(
+ x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,),
dtype="float32")
+ ) -> R.Tensor((16,), dtype="float32"):
+ R.func_attr({"relax.force_pure": True})
+ with R.dataflow():
+ lv: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
+ x, index_map=lambda i: (i // 4, i % 4), pad_value=None
+ )
+ gv: R.Tensor((4, 4), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ After = tvm.ir.transform.Sequential(
+ [
+ relax.transform.LegalizeOps(),
+ relax.transform.ToNonDataflow(),
+ relax.transform.RemovePurityChecking(),
+ relax.transform.CallTIRRewrite(),
+ ]
+ )(Before)
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((16,), dtype="float32"),
+ y: R.Tensor((16,), dtype="float32"),
+ ):
+ R.func_attr({"relax.force_pure": True})
+ cls = Expected
+ alloc: R.Tensor((4, 4), dtype="float32") = R.builtin.alloc_tensor(
+ R.shape([4, 4]), R.dtype("float32"), R.prim_value(0),
R.str("global")
+ )
+ cls.te_layout_transform(x, alloc)
+ lv = alloc
+ gv = lv
+ return gv
+
+ @T.prim_func(private=True)
+ def te_layout_transform(
+ A: T.Buffer((T.int64(16),), "float32"),
+ te_layout_transform: T.Buffer((T.int64(4), T.int64(4)), "float32"),
+ ):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ for i in range(T.int64(16)):
+ with T.block("te_layout_transform"):
+ vi = T.axis.spatial(T.int64(16), i)
+ te_layout_transform[vi // T.int64(4), vi % T.int64(4)] =
A[vi]
+
+ tvm.ir.assert_structural_equal(Expected, After)
+
+
if __name__ == "__main__":
tvm.testing.main()