This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 3e30a5f [Bugfix] Handled TransformNode in
PassUpBitMaskOr/PassDownBitMaskOr (#10620)
3e30a5f is described below
commit 3e30a5f64486b23e0c9659b3d3432ab983f2b572
Author: Eric Lunderberg <[email protected]>
AuthorDate: Tue Mar 15 21:49:32 2022 -0500
[Bugfix] Handled TransformNode in PassUpBitMaskOr/PassDownBitMaskOr (#10620)
Previously, a layout transformation applied to a te.compute whose
computation used a reduction axis would fail.
---
src/te/schedule/message_passing.cc | 21 +++++++++++++++++++++
tests/python/unittest/test_transform_layout.py | 17 +++++++++++++++++
2 files changed, 38 insertions(+)
diff --git a/src/te/schedule/message_passing.cc
b/src/te/schedule/message_passing.cc
index 361cdb1..7041b75 100644
--- a/src/te/schedule/message_passing.cc
+++ b/src/te/schedule/message_passing.cc
@@ -535,6 +535,17 @@ void PassUpBitMaskOr(const Stage& stage,
std::unordered_map<IterVar, int>* p_sta
} else {
state[s->parent] |= state[s->rebased];
}
+ } else if (const TransformNode* s = rel.as<TransformNode>()) {
+ for (const auto& original_var : s->original_variables) {
+ for (const auto& transformed_var : s->transformed_variables) {
+ if (!state.count(transformed_var)) {
+ ICHECK(allow_missing);
+ continue;
+ }
+ state[original_var] |= state[transformed_var];
+ }
+ }
+
} else if (rel.as<SingletonNode>()) {
} else {
LOG(FATAL) << "unknown relation type";
@@ -581,6 +592,16 @@ void PassDownBitMaskOr(const Stage& stage,
std::unordered_map<IterVar, int>* p_s
} else {
state[s->rebased] |= state.at(s->parent);
}
+ } else if (const TransformNode* s = rel.as<TransformNode>()) {
+ for (const auto& original_var : s->original_variables) {
+ for (const auto& transformed_var : s->transformed_variables) {
+ if (!state.count(original_var)) {
+ ICHECK(allow_missing);
+ continue;
+ }
+ state[transformed_var] |= state[original_var];
+ }
+ }
} else if (const SingletonNode* s = rel.as<SingletonNode>()) {
state[s->iter] = 0;
} else {
diff --git a/tests/python/unittest/test_transform_layout.py
b/tests/python/unittest/test_transform_layout.py
index 55266fd..a3c232d 100755
--- a/tests/python/unittest/test_transform_layout.py
+++ b/tests/python/unittest/test_transform_layout.py
@@ -522,5 +522,22 @@ class TestTransformCache:
tvm.testing.assert_allclose(b.numpy(), b_np)
+def test_transform_with_reduction():
+ # To trigger this failure mode, the computation must use a
+ # reduction axis,
+ A = te.placeholder([16, 32, 64], dtype="float32", name="A")
+ k = te.reduce_axis((0, A.shape[-1]), name="k")
+ B = te.compute(A.shape[:-1], lambda i, j: te.sum(A[i, j, k], axis=[k]))
+ s = te.create_schedule(B.op)
+
+ # And the output of the computation must have a layout
+ # transformation applied.
+ s[B].transform_layout(lambda i, j: [j, i])
+
+ # When present, the failure occurred during tvm.lower, during the
+ # call to `tvm::te::PassDownBitMaskOr`.
+ tvm.lower(s, [A, B])
+
+
if __name__ == "__main__":
sys.exit(pytest.main(sys.argv))