This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e375c311da [Arith][IndexMap] Correct MapShape result for small 
vectorized dims (#12927)
e375c311da is described below

commit e375c311dac3c4ec0636c1bd0e203c3cd70f7f23
Author: Eric Lunderberg <[email protected]>
AuthorDate: Thu Oct 6 13:51:02 2022 -0500

    [Arith][IndexMap] Correct MapShape result for small vectorized dims (#12927)
    
    Prior to this commit, `IndexMap::MapShape` could produce incorrect
    results when the split factor is greater than the size of the
    dimension being split.  For example, a buffer of shape `[N]` mapped
    transformed with `lambda i: [i//4, i%4]` should result in shape
    `[ceildiv(N,4), 4]`.  However, for `N<4`, the transformed shape was
    instead `[1, N%4]`.  This results in unexpected shapes when attempting
    to prepare a buffer for vectorized access.
    
    This commit preferentially uses the result of `arith::DetectIterMap`
    to determine the mapped buffer shape, similar to what is done when
    computing the inverse.  The old method of `MapShape`, which relied on
    `arith::EvalSet`, is maintained for transformations that aren't
    recognized by `arith::DetectIterMap`.
---
 src/tir/ir/index_map.cc                 | 45 +++++++++++++++++++++++++--------
 tests/python/unittest/test_index_map.py |  9 ++++++-
 2 files changed, 43 insertions(+), 11 deletions(-)

diff --git a/src/tir/ir/index_map.cc b/src/tir/ir/index_map.cc
index 6d982b510a..a25ecdd040 100644
--- a/src/tir/ir/index_map.cc
+++ b/src/tir/ir/index_map.cc
@@ -169,23 +169,48 @@ Array<Range> IndexMapNode::MapRanges(const Array<Range>& 
ranges, arith::Analyzer
     input_iters.Set(initial_indices[i], ranges[i]);
   }
 
-  std::unordered_map<const VarNode*, arith::IntSet> dom_map;
-  for (size_t i = 0; i < initial_indices.size(); i++) {
-    dom_map[initial_indices[i].get()] = arith::IntSet::FromRange(ranges[i]);
-  }
-
   arith::Analyzer local_analyzer;
   if (!analyzer) {
     analyzer = &local_analyzer;
   }
 
+  auto iter_map = DetectIterMap(final_indices, input_iters, /* predicate = */ 
1,
+                                /*check_level=*/arith::IterMapLevel::NoCheck, 
analyzer,
+                                /*simplify_trivial_iterators=*/false);
   Array<Range> output;
-  for (const auto& final_index : final_indices) {
-    auto int_set = arith::EvalSet(final_index, dom_map);
-    output.push_back(Range::FromMinExtent(analyzer->Simplify(int_set.min()),
-                                          analyzer->Simplify(int_set.max() - 
int_set.min() + 1)));
-  }
+  if (iter_map->indices.size()) {
+    // Preferred route, requires the map to be expressible as an
+    // affine sum.  Since the terms are orthogonal, the extent of the
+    // sum is the extent of the largest term.
+    for (const auto& index : iter_map->indices) {
+      Optional<PrimExpr> extent = NullOpt;
+      for (const auto& term : index->args) {
+        PrimExpr term_extent = term->extent * term->scale;
+        if (extent.defined()) {
+          extent = tvm::max(extent.value(), term_extent);
+        } else {
+          extent = term_extent;
+        }
+      }
+      output.push_back(Range::FromMinExtent(index->base, extent.value_or(1)));
+    }
 
+  } else {
+    // Fall-back method, more general but can ignore intended padding.
+    // For example, [N] mapped through i=>[i//4,i%4] should have shape
+    // [ceildiv(N,4), 4].  However, for N<4, this method instead
+    // results in a shape [1, N].
+    std::unordered_map<const VarNode*, arith::IntSet> dom_map;
+    for (size_t i = 0; i < initial_indices.size(); i++) {
+      dom_map[initial_indices[i].get()] = arith::IntSet::FromRange(ranges[i]);
+    }
+
+    for (const auto& final_index : final_indices) {
+      auto int_set = arith::EvalSet(final_index, dom_map);
+      output.push_back(Range::FromMinExtent(analyzer->Simplify(int_set.min()),
+                                            analyzer->Simplify(int_set.max() - 
int_set.min() + 1)));
+    }
+  }
   return output;
 }
 
diff --git a/tests/python/unittest/test_index_map.py 
b/tests/python/unittest/test_index_map.py
index 804d04d0b0..6882c2b426 100644
--- a/tests/python/unittest/test_index_map.py
+++ b/tests/python/unittest/test_index_map.py
@@ -104,7 +104,7 @@ padding_test_case = tvm.testing.parameter(
             forward=lambda i: [i // 4, i % 4],
             inverse=lambda i, j: [4 * i + j],
             pre_shape=[dynamic_N],
-            post_shape=[(dynamic_N - 1) // 4 + 1, 4],
+            post_shape=[(dynamic_N - dynamic_N % (-4)) // 4, 4],
             padding=lambda i, j: tvm.tir.And(
                 dynamic_N % (-4) != 0,
                 tvm.tir.And(i == dynamic_N // 4, j >= dynamic_N % 4),
@@ -162,6 +162,13 @@ padding_test_case = tvm.testing.parameter(
             post_shape=[8, 4, 4],
             padding=lambda j, i, k: tvm.tir.And(i == 0, j * 4 + k < 5),
         ),
+        "outer_loop_extent_one": dict(
+            forward=lambda i: [i // 4, i % 4],
+            inverse=lambda i, j: [i * 4 + j],
+            pre_shape=[3],
+            post_shape=[1, 4],
+            padding=lambda i, j: 3 <= j,
+        ),
     }
 )
 

Reply via email to