This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 3e30a5f  [Bugfix] Handled TransformNode in 
PassUpBitMaskOr/PassDownBitMaskOr (#10620)
3e30a5f is described below

commit 3e30a5f64486b23e0c9659b3d3432ab983f2b572
Author: Eric Lunderberg <[email protected]>
AuthorDate: Tue Mar 15 21:49:32 2022 -0500

    [Bugfix] Handled TransformNode in PassUpBitMaskOr/PassDownBitMaskOr (#10620)
    
    Previously, a layout transformation applied to a te.compute whose
    computation used a reduction axis would fail.
---
 src/te/schedule/message_passing.cc             | 21 +++++++++++++++++++++
 tests/python/unittest/test_transform_layout.py | 17 +++++++++++++++++
 2 files changed, 38 insertions(+)

diff --git a/src/te/schedule/message_passing.cc 
b/src/te/schedule/message_passing.cc
index 361cdb1..7041b75 100644
--- a/src/te/schedule/message_passing.cc
+++ b/src/te/schedule/message_passing.cc
@@ -535,6 +535,17 @@ void PassUpBitMaskOr(const Stage& stage, 
std::unordered_map<IterVar, int>* p_sta
       } else {
         state[s->parent] |= state[s->rebased];
       }
+    } else if (const TransformNode* s = rel.as<TransformNode>()) {
+      for (const auto& original_var : s->original_variables) {
+        for (const auto& transformed_var : s->transformed_variables) {
+          if (!state.count(transformed_var)) {
+            ICHECK(allow_missing);
+            continue;
+          }
+          state[original_var] |= state[transformed_var];
+        }
+      }
+
     } else if (rel.as<SingletonNode>()) {
     } else {
       LOG(FATAL) << "unknown relation type";
@@ -581,6 +592,16 @@ void PassDownBitMaskOr(const Stage& stage, 
std::unordered_map<IterVar, int>* p_s
       } else {
         state[s->rebased] |= state.at(s->parent);
       }
+    } else if (const TransformNode* s = rel.as<TransformNode>()) {
+      for (const auto& original_var : s->original_variables) {
+        for (const auto& transformed_var : s->transformed_variables) {
+          if (!state.count(original_var)) {
+            ICHECK(allow_missing);
+            continue;
+          }
+          state[transformed_var] |= state[original_var];
+        }
+      }
     } else if (const SingletonNode* s = rel.as<SingletonNode>()) {
       state[s->iter] = 0;
     } else {
diff --git a/tests/python/unittest/test_transform_layout.py 
b/tests/python/unittest/test_transform_layout.py
index 55266fd..a3c232d 100755
--- a/tests/python/unittest/test_transform_layout.py
+++ b/tests/python/unittest/test_transform_layout.py
@@ -522,5 +522,22 @@ class TestTransformCache:
             tvm.testing.assert_allclose(b.numpy(), b_np)
 
 
+def test_transform_with_reduction():
+    # To trigger this failure mode, the computation must use a
+    # reduction axis,
+    A = te.placeholder([16, 32, 64], dtype="float32", name="A")
+    k = te.reduce_axis((0, A.shape[-1]), name="k")
+    B = te.compute(A.shape[:-1], lambda i, j: te.sum(A[i, j, k], axis=[k]))
+    s = te.create_schedule(B.op)
+
+    # And the output of the computation must have a layout
+    # transformation applied.
+    s[B].transform_layout(lambda i, j: [j, i])
+
+    # When present, the failure occurred during tvm.lower, during the
+    # call to `tvm::te::PassDownBitMaskOr`.
+    tvm.lower(s, [A, B])
+
+
 if __name__ == "__main__":
     sys.exit(pytest.main(sys.argv))

Reply via email to