This is an automated email from the ASF dual-hosted git repository.

sanirudh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 81a850693d [TIR] Use constructor for new PrimFunc in TransformLayout 
(#16832)
81a850693d is described below

commit 81a850693d3afc3d056d119c1b1c68b4c1aec8a7
Author: Eric Lunderberg <[email protected]>
AuthorDate: Sun Apr 7 08:09:42 2024 -0500

    [TIR] Use constructor for new PrimFunc in TransformLayout (#16832)
    
    Using the constructor applies all initialization steps and
    error-checking, where using `CopyOnWrite()` does not.  This function
    is used as part of the legalization of `relax.op.layout_tranform`,
    which relies on the annotations produced in the `PrimFunc`
    constructor.
---
 .../schedule/primitive/layout_transformation.cc    | 17 +++---
 .../test_transform_legalize_ops_manipulate.py      | 64 ++++++++++++++++++++++
 2 files changed, 74 insertions(+), 7 deletions(-)

diff --git a/src/tir/schedule/primitive/layout_transformation.cc 
b/src/tir/schedule/primitive/layout_transformation.cc
index 6c6427a906..f1e9106a63 100644
--- a/src/tir/schedule/primitive/layout_transformation.cc
+++ b/src/tir/schedule/primitive/layout_transformation.cc
@@ -1207,17 +1207,20 @@ void TransformLayout(ScheduleState self, const 
StmtSRef& block_sref, int buffer_
   // Step 4: Rewrite buffer_map of the PrimFunc if necessary.
   if (!defining_site_sref.defined()) {
     GlobalVar g_var;
-    GetRootPrimFunc(self->mod, scope_block, &g_var);
+    const auto* old_func = GetRootPrimFunc(self->mod, scope_block, &g_var);
     IRModuleNode* new_mod = self->mod.CopyOnWrite();
     MapNode* new_map = new_mod->functions.CopyOnWrite();
-    PrimFunc ref_new_func = Downcast<PrimFunc>(std::move(new_map->at(g_var)));
-    PrimFuncNode* new_func = ref_new_func.CopyOnWrite();
-    MapNode* new_buffer_map = new_func->buffer_map.CopyOnWrite();
-    for (auto it = new_buffer_map->begin(); it != new_buffer_map->end(); ++it) 
{
-      if ((*it).second.same_as(old_buffer)) {
-        (*it).second = new_buffer;
+
+    Map<Var, Buffer> new_buffer_map;
+    for (auto [var, buffer] : old_func->buffer_map) {
+      if (buffer.same_as(old_buffer)) {
+        buffer = new_buffer;
       }
+      new_buffer_map.Set(var, buffer);
     }
+
+    PrimFunc ref_new_func(old_func->params, old_func->body, 
old_func->ret_type, new_buffer_map,
+                          old_func->attrs, old_func->span);
     new_map->at(g_var) = std::move(ref_new_func);
   }
 
diff --git a/tests/python/relax/test_transform_legalize_ops_manipulate.py 
b/tests/python/relax/test_transform_legalize_ops_manipulate.py
index 9b7a8f23c9..dd0208f5db 100644
--- a/tests/python/relax/test_transform_legalize_ops_manipulate.py
+++ b/tests/python/relax/test_transform_legalize_ops_manipulate.py
@@ -1666,5 +1666,69 @@ def test_layout_transform_with_pad_axis_sep():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_func_struct_info_of_legalized_layout_transform():
+    """PrimFunc shape information must be correct
+
+    This is a regression test.  Previously, the legalization of
+    `R.layout_transform` produced a PrimFunc with `FuncStructInfo`
+    different than its actual signature.  This resulted in errors
+    when later passes attempted to infer the StructInfo.
+    """
+
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(
+            x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,), 
dtype="float32")
+        ) -> R.Tensor((16,), dtype="float32"):
+            R.func_attr({"relax.force_pure": True})
+            with R.dataflow():
+                lv: R.Tensor((4, 4), dtype="float32") = R.layout_transform(
+                    x, index_map=lambda i: (i // 4, i % 4), pad_value=None
+                )
+                gv: R.Tensor((4, 4), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    After = tvm.ir.transform.Sequential(
+        [
+            relax.transform.LegalizeOps(),
+            relax.transform.ToNonDataflow(),
+            relax.transform.RemovePurityChecking(),
+            relax.transform.CallTIRRewrite(),
+        ]
+    )(Before)
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((16,), dtype="float32"),
+            y: R.Tensor((16,), dtype="float32"),
+        ):
+            R.func_attr({"relax.force_pure": True})
+            cls = Expected
+            alloc: R.Tensor((4, 4), dtype="float32") = R.builtin.alloc_tensor(
+                R.shape([4, 4]), R.dtype("float32"), R.prim_value(0), 
R.str("global")
+            )
+            cls.te_layout_transform(x, alloc)
+            lv = alloc
+            gv = lv
+            return gv
+
+        @T.prim_func(private=True)
+        def te_layout_transform(
+            A: T.Buffer((T.int64(16),), "float32"),
+            te_layout_transform: T.Buffer((T.int64(4), T.int64(4)), "float32"),
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            for i in range(T.int64(16)):
+                with T.block("te_layout_transform"):
+                    vi = T.axis.spatial(T.int64(16), i)
+                    te_layout_transform[vi // T.int64(4), vi % T.int64(4)] = 
A[vi]
+
+    tvm.ir.assert_structural_equal(Expected, After)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to