This is an automated email from the ASF dual-hosted git repository.

moreau pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new a57397e  [Bugfix] Simultaneous layout transform and axis separators. 
(#10553)
a57397e is described below

commit a57397e9a282a3faedf843b67a9b119a5d7ce975
Author: Eric Lunderberg <[email protected]>
AuthorDate: Thu Mar 10 10:01:21 2022 -0600

    [Bugfix] Simultaneous layout transform and axis separators. (#10553)
    
    Previously, SchedulePostProcToPrimFunc would first generate the map
    from buffer object to layout transformation, then would update buffers
    with the axis separators.  However, it failed to replace the buffer
    objects in the layout transformation map, so the transformation
    wasn't applied.
    
    This PR correctly updates the layout transformation map, and
    adds a unit test to catch this failure mode.
---
 src/te/schedule/schedule_postproc_to_primfunc.cc | 11 ++++
 tests/python/unittest/test_transform_layout.py   | 70 +++++++++++++++++-------
 2 files changed, 60 insertions(+), 21 deletions(-)

diff --git a/src/te/schedule/schedule_postproc_to_primfunc.cc 
b/src/te/schedule/schedule_postproc_to_primfunc.cc
index 0cf6e54..c7d5d7a 100644
--- a/src/te/schedule/schedule_postproc_to_primfunc.cc
+++ b/src/te/schedule/schedule_postproc_to_primfunc.cc
@@ -256,6 +256,9 @@ class AxisSeparatorsAttrUnwrapper : StmtExprMutator {
       auto pass = AxisSeparatorsAttrUnwrapper(axis_separators_map);
       write_ptr->buffer_map = pass.UpdateExternBufferMap(func->buffer_map);
       write_ptr->body = pass(func->body);
+      if (auto map = func->attrs.GetAttr<Map<Buffer, 
Array<IndexMap>>>("layout_transform_map")) {
+        func = WithAttr(std::move(func), "layout_transform_map", 
pass.UpdateIndexMap(map.value()));
+      }
     }
 
     return func;
@@ -272,6 +275,14 @@ class AxisSeparatorsAttrUnwrapper : StmtExprMutator {
     return output;
   }
 
+  Map<Buffer, Array<IndexMap>> UpdateIndexMap(const Map<Buffer, 
Array<IndexMap>>& orig) {
+    Map<Buffer, Array<IndexMap>> output;
+    for (const auto& kv : orig) {
+      output.Set(GetRemappedBuffer(kv.first), kv.second);
+    }
+    return output;
+  }
+
   Stmt VisitStmt_(const AttrStmtNode* op) final {
     auto ret = StmtExprMutator::VisitStmt_(op);
     op = ret.as<AttrStmtNode>();
diff --git a/tests/python/unittest/test_transform_layout.py 
b/tests/python/unittest/test_transform_layout.py
index 5cac01d..55266fd 100755
--- a/tests/python/unittest/test_transform_layout.py
+++ b/tests/python/unittest/test_transform_layout.py
@@ -222,28 +222,54 @@ class TestCompareAgainstExplicitReshape:
 
 class Test2DPhysicalLayout:
     transform_A = tvm.testing.parameter(
-        by_dict={
-            "2d_A": True,
-            "1d_A": False,
-        }
+        "1d_A",
+        "2d_A",
+        "2d_rev_A",
     )
     transform_B = tvm.testing.parameter(
-        by_dict={
-            "2d_B": True,
-            "1d_B": False,
-        }
+        "1d_B",
+        "2d_B",
+        "2d_rev_B",
     )
 
     @staticmethod
-    def extract_loop_vars(stmt):
-        output = []
+    def extract_logical_indices(stmt):
+        output = {}
 
+        # Since the for loops can be reordered by the layout
+        # transformation, identify the loop corresponding to each
+        # pre-transformation axis based on the iteration extent.
         def callback(node):
             if isinstance(node, tvm.tir.For):
-                output.append(node.loop_var)
+                output[node.loop_var] = node.extent.value
 
         post_order_visit(stmt, callback)
-        return output[::-1]
+        return sorted(output, key=output.get)
+
+    def get_transform(self, name):
+        name = name[:-2]
+        if name == "1d":
+            return None
+        elif name == "2d":
+            return lambda i, j, k: [i, j, te.AXIS_SEPARATOR, k]
+        elif name == "2d_rev":
+            return lambda i, j, k: [k, j, te.AXIS_SEPARATOR, i]
+        else:
+            raise ValueError(f"Unknown transformation: {name}")
+
+    def transform_indices(self, name, logical_shape, logical_index_vars):
+        name = name[:-2]
+
+        i, j, k = logical_index_vars
+
+        if name == "1d":
+            return [i * (logical_shape[1] * logical_shape[2]) + j * 
logical_shape[2] + k]
+        elif name == "2d":
+            return [i * logical_shape[1] + j, k]
+        elif name == "2d_rev":
+            return [k * logical_shape[1] + j, i]
+        else:
+            raise ValueError(f"Unknown transformation: {name}")
 
     def test_2d_physical(self, dtype, transform_A, transform_B):
         logical_shape = (2, 3, 4)
@@ -252,11 +278,13 @@ class Test2DPhysicalLayout:
 
         s = te.create_schedule(B.op)
 
-        if transform_A:
-            s[A].transform_layout(lambda i, j, k: [i, j, te.AXIS_SEPARATOR, k])
+        func = self.get_transform(transform_A)
+        if func:
+            s[A].transform_layout(func)
 
-        if transform_B:
-            s[B].transform_layout(lambda i, j, k: [i, j, te.AXIS_SEPARATOR, k])
+        func = self.get_transform(transform_B)
+        if func:
+            s[B].transform_layout(func)
 
         # If the two buffers are accessed with the same indices, CSE
         # will replace them with a Let binding.  Since this makes it
@@ -265,17 +293,17 @@ class Test2DPhysicalLayout:
         with 
tvm.transform.PassContext(disabled_pass=["tir.CommonSubexprElimTIR"]):
             mod = tvm.lower(s, [A, B])
 
-        i, j, k = self.extract_loop_vars(mod["main"].body)
-        indices_1d = [i * (logical_shape[1] * logical_shape[2]) + j * 
logical_shape[2] + k]
-        indices_2d = [i * logical_shape[1] + j, k]
+        logical_index_vars = self.extract_logical_indices(mod["main"].body)
+        expected_indices_A = self.transform_indices(transform_A, 
logical_shape, logical_index_vars)
+        expected_indices_B = self.transform_indices(transform_B, 
logical_shape, logical_index_vars)
 
         def callback(node):
             if type(node) in [tvm.tir.BufferLoad, tvm.tir.BufferStore]:
                 name = node.buffer.name
                 if name == "A":
-                    expected_indices = indices_2d if transform_A else 
indices_1d
+                    expected_indices = expected_indices_A
                 elif name == "B":
-                    expected_indices = indices_2d if transform_B else 
indices_1d
+                    expected_indices = expected_indices_B
                 else:
                     raise RuntimeError(f"Unexpected buffer: {name}")
 

Reply via email to