llvmbot wrote:

<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Momchil Velikov (momchil-velikov)

<details>
<summary>Changes</summary>

Previously, slices were sometimes marked as non-contiguous when they were 
actually contiguous. This occurred when the vector type had leading unit 
dimensions, e.g., `vector&lt;1x1x...x1xd0xd1x...xdn-1xT&gt;`. In such cases, 
only the trailing `n` dimensions of the memref need to be contiguous, not the 
entire vector rank.

This affects how `FlattenContiguousRowMajorTransfer{Read,Write}Pattern` 
flattens `transfer_read` and `transfer_write` ops. The pattern used to collapse 
a number of dimensions equal the vector rank, which may be is incorrect when 
leading dimensions are unit-sized.

This patch fixes the issue by collapsing only as many trailing memref 
dimensions as are actually contiguous.

---
Full diff: https://github.com/llvm/llvm-project/pull/142422.diff


6 Files Affected:

- (modified) mlir/include/mlir/Dialect/Utils/IndexingUtils.h (+2-1) 
- (modified) mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h (+28-26) 
- (modified) mlir/lib/Dialect/Utils/IndexingUtils.cpp (+4-2) 
- (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp 
(+6-2) 
- (modified) mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp (+8-17) 
- (modified) mlir/test/Dialect/Vector/vector-transfer-flatten.mlir (+78-30) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h 
b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
index 99218f491ddef..a1c8ec2db056a 100644
--- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
@@ -40,7 +40,8 @@ class ArrayAttr;
 /// Assuming `sizes` is `[s0, .. sn]`, return the vector<int64_t>
 ///   `[s1 * ... * sn, s2 * ... * sn, ..., sn, 1]`.
 ///
-/// `sizes` elements are asserted to be non-negative.
+/// `sizes` element `s0` is asserted to be kDynamic or non-negative.
+/// `sizes` elements `s1` to `sn` are asserted to be non-negative.
 ///
 /// Return an empty vector if `sizes` is empty.
 SmallVector<int64_t> computeSuffixProduct(ArrayRef<int64_t> sizes);
diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h 
b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
index 6609b28d77b6c..ed06d7a029494 100644
--- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
+++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h
@@ -49,35 +49,37 @@ FailureOr<std::pair<int, int>> 
isTranspose2DSlice(vector::TransposeOp op);
 
 /// Return true if `vectorType` is a contiguous slice of `memrefType`.
 ///
-/// Only the N = vectorType.getRank() trailing dims of `memrefType` are
-/// checked (the other dims are not relevant). Note that for `vectorType` to be
-/// a contiguous slice of `memrefType`, the trailing dims of the latter have
-/// to be contiguous - this is checked by looking at the corresponding strides.
+/// The leading unit dimensions of the vector type are ignored as they
+/// are not relevant to the result. Let N be the number of the vector
+/// dimensions after ignoring a leading sequence of unit ones.
 ///
-/// There might be some restriction on the leading dim of `VectorType`:
+/// For `vectorType` to be a contiguous slice of `memrefType`
+///   a) the N trailing dimensions of the latter must be contiguous, and
+///   b) the trailing N dimensions of `vectorType` and `memrefType`,
+///      except the first of them, must match.
 ///
-/// Case 1. If all the trailing dims of `vectorType` match the trailing dims
-///         of `memrefType` then the leading dim of `vectorType` can be
-///         arbitrary.
-///
-///        Ex. 1.1 contiguous slice, perfect match
-///          vector<4x3x2xi32> from memref<5x4x3x2xi32>
-///        Ex. 1.2 contiguous slice, the leading dim does not match (2 != 4)
-///          vector<2x3x2xi32> from memref<5x4x3x2xi32>
-///
-/// Case 2. If an "internal" dim of `vectorType` does not match the
-///         corresponding trailing dim in `memrefType` then the remaining
-///         leading dims of `vectorType` have to be 1 (the first non-matching
-///         dim can be arbitrary).
+/// Examples:
 ///
-///        Ex. 2.1 non-contiguous slice, 2 != 3 and the leading dim != <1>
-///          vector<2x2x2xi32> from memref<5x4x3x2xi32>
-///        Ex. 2.2  contiguous slice, 2 != 3 and the leading dim == <1>
-///          vector<1x2x2xi32> from memref<5x4x3x2xi32>
-///        Ex. 2.3. contiguous slice, 2 != 3 and the leading dims == <1x1>
-///          vector<1x1x2x2xi32> from memref<5x4x3x2xi32>
-///        Ex. 2.4. non-contiguous slice, 2 != 3 and the leading dims != <1x1>
-///         vector<2x1x2x2xi32> from memref<5x4x3x2xi32>)
+///   Ex.1 contiguous slice, perfect match
+///     vector<4x3x2xi32> from memref<5x4x3x2xi32>
+///   Ex.2 contiguous slice, the leading dim does not match (2 != 4)
+///     vector<2x3x2xi32> from memref<5x4x3x2xi32>
+///   Ex.3 non-contiguous slice, 2 != 3
+///     vector<2x2x2xi32> from memref<5x4x3x2xi32>
+///   Ex.4 contiguous slice, leading unit dimension of the vector ignored,
+///        2 != 3 (allowed)
+///     vector<1x2x2xi32> from memref<5x4x3x2xi32>
+///   Ex.5. contiguous slice, leasing two unit dims of the vector ignored,
+///         2 != 3 (allowed)
+///     vector<1x1x2x2xi32> from memref<5x4x3x2xi32>
+///   Ex.6. non-contiguous slice, 2 != 3, no leading sequence of unit dims
+///     vector<2x1x2x2xi32> from memref<5x4x3x2xi32>)
+///   Ex.7 contiguous slice, memref needs to be contiguous only on the last
+///        dimension
+///     vector<1x1x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>>
+///   Ex.8 non-contiguous slice, memref needs to be contiguous one the last
+///        two dimensions, and it isn't
+///     vector<1x2x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>>
 bool isContiguousSlice(MemRefType memrefType, VectorType vectorType);
 
 /// Returns an iterator for all positions in the leading dimensions of `vType`
diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp 
b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
index 8de77e2c3cb08..bb719d46a215a 100644
--- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp
+++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
@@ -69,8 +69,10 @@ SmallVector<ExprType> delinearizeImpl(ExprType linearIndex,
 
//===----------------------------------------------------------------------===//
 
 SmallVector<int64_t> mlir::computeSuffixProduct(ArrayRef<int64_t> sizes) {
-  assert(llvm::all_of(sizes, [](int64_t s) { return s >= 0; }) &&
-         "sizes must be nonnegative");
+  assert(sizes.size() == 0 ||
+         ((sizes[0] == ShapedType::kDynamic || sizes[0] >= 0) &&
+          llvm::all_of(sizes.drop_front(), [](int64_t s) { return s >= 0; })) 
&&
+             "sizes must be nonnegative");
   int64_t unit = 1;
   return ::computeSuffixProductImpl(sizes, unit);
 }
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp 
b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 7dbb7a334fe62..709716365f825 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -630,7 +630,9 @@ class FlattenContiguousRowMajorTransferReadPattern
     if (transferReadOp.getMask())
       return failure();
 
-    int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
+    // Determinine the first memref dimension to collapse
+    int64_t firstDimToCollapse =
+        sourceType.getRank() - sourceType.getMaxCollapsableTrailingDims();
 
     // 1. Collapse the source memref
     Value collapsedSource =
@@ -722,7 +724,9 @@ class FlattenContiguousRowMajorTransferWritePattern
     if (transferWriteOp.getMask())
       return failure();
 
-    int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank();
+    // Determinine the first memref dimension to collapse
+    int64_t firstDimToCollapse =
+        sourceType.getRank() - sourceType.getMaxCollapsableTrailingDims();
 
     // 1. Collapse the source memref
     Value collapsedSource =
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp 
b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index d5dd6f2027be8..5f8f3b6adf9db 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -258,29 +258,20 @@ bool vector::isContiguousSlice(MemRefType memrefType, 
VectorType vectorType) {
   if (vectorType.isScalable())
     return false;
 
-  ArrayRef<int64_t> vectorShape = vectorType.getShape();
-  auto vecRank = vectorType.getRank();
+  // Ignore a leading contiguous sequence of unit dimensions in the vector.
+  ArrayRef<int64_t> vectorShape =
+      vectorType.getShape().drop_while([](auto v) { return v == 1; });
+  auto vecRank = vectorShape.size();
 
   if (!memrefType.areTrailingDimsContiguous(vecRank))
     return false;
 
-  // Extract the trailing dims and strides of the input memref
+  // Extract the trailing dims of the input memref
   auto memrefShape = memrefType.getShape().take_back(vecRank);
 
-  // Compare the dims of `vectorType` against `memrefType` (in reverse).
-  // In the most basic case, all dims will match.
-  auto firstNonMatchingDim =
-      std::mismatch(vectorShape.rbegin(), vectorShape.rend(),
-                    memrefShape.rbegin(), memrefShape.rend());
-  if (firstNonMatchingDim.first == vectorShape.rend())
-    return true;
-
-  // One non-matching dim is still fine, however the remaining leading dims of
-  // `vectorType` need to be 1.
-  SmallVector<int64_t> leadingDims(++firstNonMatchingDim.first,
-                                   vectorShape.rend());
-
-  return llvm::all_of(leadingDims, [](auto x) { return x == 1; });
+  // Compare the dims of `vectorType` against `memrefType`.
+  // All of the dimensions, except the first must match.
+  return llvm::equal(vectorShape.drop_front(), memrefShape.drop_front());
 }
 
 std::optional<StaticTileOffsetRange>
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir 
b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 5b2f2ab1f2cef..594f7ce371347 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -116,10 +116,11 @@ func.func @transfer_read_dims_mismatch_non_zero_indices(
 // CHECK-SAME:      %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index,
 // CHECK-SAME:      %[[MEM:.*]]: memref<1x43x4x6xi32>
 // CHECK:           %[[C_0:.*]] = arith.constant 0 : i32
-// CHECK:           %[[C_0_IDX:.*]] = arith.constant 0 : index
-// CHECK:           %[[COLLAPSED_IN:.*]] = memref.collapse_shape %[[MEM]] 
{{\[}}[0], [1, 2, 3]] : memref<1x43x4x6xi32> into memref<1x1032xi32>
+// CHECK:           %[[COLLAPSED_IN:.*]] = memref.collapse_shape %[[MEM]]
+// CHECK-SAME-LITERAL: [[0, 1, 2, 3]]
+// CHECK-SAME:         : memref<1x43x4x6xi32> into memref<1032xi32>
 // CHECK:           %[[COLLAPSED_IDX:.*]] = affine.apply 
#[[$ATTR_0]]()[%[[IDX_1]], %[[IDX_2]]]
-// CHECK:           %[[READ:.*]] = vector.transfer_read 
%[[COLLAPSED_IN]][%[[C_0_IDX]], %[[COLLAPSED_IDX]]], %[[C_0]] {in_bounds = 
[true]} : memref<1x1032xi32>, vector<12xi32>
+// CHECK:           %[[READ:.*]] = vector.transfer_read 
%[[COLLAPSED_IN]][%[[COLLAPSED_IDX]]], %[[C_0]] {in_bounds = [true]} : 
memref<1032xi32>, vector<12xi32>
 
 // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices(
 //   CHECK-128B-NOT:   memref.collapse_shape
@@ -170,16 +171,18 @@ func.func @transfer_read_leading_dynamic_dims(
   return %res : vector<8x4xi8>
 }
 
+// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 32)>
+
 // CHECK-LABEL: func @transfer_read_leading_dynamic_dims
 // CHECK-SAME:    %[[MEM:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[IDX_1:.+]]: 
index, %[[IDX_2:.+]]: index
 // CHECK:         %[[C0_I8:.+]] = arith.constant 0 : i8
-// CHECK:         %[[C0:.+]] = arith.constant 0 : index
-// CHECK:         %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] 
{{\[}}[0], [1], [2, 3]{{\]}}
-// CHECK-SAME:      : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
+// CHECK:         %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] 
{{\[}}[0], [1, 2, 3]{{\]}}
+// CHECK-SAME:      : memref<?x?x8x4xi8, {{.+}}> into memref<?x?xi8, {{.+}}>
+// CHECK: %[[COLLAPSED_IDX:.+]] = affine.apply #[[$MAP]]()[%[[IDX_2]]]
 // CHECK:         %[[VEC1D:.+]] = vector.transfer_read %[[COLLAPSED]]
-// CHECK-SAME:    [%[[IDX_1]], %[[IDX_2]], %[[C0]]], %[[C0_I8]]
+// CHECK-SAME:    [%[[IDX_1]], %[[COLLAPSED_IDX]]], %[[C0_I8]]
 // CHECK-SAME:    {in_bounds = [true]}
-// CHECK-SAME:      : memref<?x?x32xi8, {{.+}}>, vector<32xi8>
+// CHECK-SAME:      : memref<?x?xi8, {{.+}}>, vector<32xi8>
 // CHECK:         %[[RES:.+]] = vector.shape_cast %[[VEC1D]] : vector<32xi8> 
to vector<8x4xi8>
 // CHECK:         return %[[RES]] : vector<8x4xi8>
 
@@ -210,13 +213,12 @@ func.func @transfer_read_dynamic_dim_to_flatten(
 // CHECK-SAME:    %[[IDX_2:arg1]]
 // CHECK-SAME:    %[[MEM:arg2]]
 // CHECK:              %[[C0_I32:.*]] = arith.constant 0 : i32
-// CHECK:              %[[C0:.*]] = arith.constant 0 : index
 // CHECK:              %[[COLLAPSED:.*]] = memref.collapse_shape %[[MEM]]
-// CHECK-SAME-LITERAL:   [[0], [1, 2, 3]]
-// CHECK-SAME:           memref<1x?x4x6xi32> into memref<1x?xi32>
+// CHECK-SAME-LITERAL:   [[0, 1, 2, 3]]
+// CHECK-SAME:           memref<1x?x4x6xi32> into memref<?xi32>
 // CHECK:              %[[COLLAPSED_IDX:.*]] = affine.apply 
#[[$MAP]]()[%[[IDX_1]], %[[IDX_2]]]
-// CHECK:              %[[VEC_1D:.*]] = vector.transfer_read 
%[[COLLAPSED]][%[[C0]], %[[COLLAPSED_IDX]]],
-// CHECK-SAME:           %[[C0_I32]] {in_bounds = [true]} : memref<1x?xi32>, 
vector<12xi32>
+// CHECK:              %[[VEC_1D:.*]] = vector.transfer_read 
%[[COLLAPSED]][%[[COLLAPSED_IDX]]],
+// CHECK-SAME:           %[[C0_I32]] {in_bounds = [true]} : memref<?xi32>, 
vector<12xi32>
 // CHECK:              %[[RESULT:.*]] = vector.shape_cast %[[VEC_1D]] : 
vector<12xi32> to vector<1x2x6xi32>
 // CHECK:              return %[[RESULT]] : vector<1x2x6xi32>
 
@@ -397,11 +399,10 @@ func.func @transfer_write_dims_mismatch_non_zero_indices(
 // CHECK-SAME:      %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index,
 // CHECK-SAME:      %[[MEM:.*]]: memref<1x43x4x6xi32>,
 // CHECK-SAME:      %[[VEC:.*]]: vector<1x2x6xi32>) {
-// CHECK-DAG:       %[[C0:.*]] = arith.constant 0 : index
 // CHECK-DAG:       %[[IDX:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[IDX_1]], 
%[[IDX_2]]]
-// CHECK-DAG:       %[[CS:.*]] = memref.collapse_shape %[[MEM]] {{\[\[}}0], 
[1, 2, 3]] : memref<1x43x4x6xi32> into memref<1x1032xi32>
+// CHECK-DAG:       %[[CS:.*]] = memref.collapse_shape %[[MEM]] {{\[\[}}0, 1, 
2, 3]] : memref<1x43x4x6xi32> into memref<1032xi32>
 // CHECK:           %[[SC:.*]] = vector.shape_cast %[[VEC]] : 
vector<1x2x6xi32> to vector<12xi32>
-// CHECK:           vector.transfer_write %[[SC]], %[[CS]]{{\[}}%[[C0]], 
%[[IDX]]] {in_bounds = [true]} : vector<12xi32>, memref<1x1032xi32>
+// CHECK:           vector.transfer_write %[[SC]], %[[CS]][%[[IDX]]] 
{in_bounds = [true]} : vector<12xi32>, memref<1032xi32>
 
 // CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_zero_indices(
 //   CHECK-128B-NOT:   memref.collapse_shape
@@ -449,16 +450,18 @@ func.func @transfer_write_leading_dynamic_dims(
   return
 }
 
+// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 32)>
+
 // CHECK-LABEL: func @transfer_write_leading_dynamic_dims
 // CHECK-SAME:    %[[VEC:.+]]: vector<8x4xi8>, %[[MEM:.+]]: memref<?x?x8x4xi8, 
{{.+}}>, %[[ARG2:.+]]: index, %[[ARG3:.+]]: index
-// CHECK:         %[[C0:.+]] = arith.constant 0 : index
-// CHECK:         %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] 
{{\[}}[0], [1], [2, 3]{{\]}}
-// CHECK-SAME:      : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}>
+// CHECK:         %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] 
{{\[}}[0], [1, 2, 3]{{\]}}
+// CHECK-SAME:      : memref<?x?x8x4xi8, {{.+}}> into memref<?x?xi8, {{.+}}>
+// CHECK: %[[COLLAPSED_IDX:.+]] = affine.apply #[[$MAP]]()[%[[ARG3]]]
 // CHECK:         %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<8x4xi8> 
to vector<32xi8>
 // CHECK:         vector.transfer_write %[[VEC1D]], %[[COLLAPSED]]
-// CHECK-SAME:      [%[[ARG2]], %[[ARG3]], %[[C0]]]
+// CHECK-SAME:      [%[[ARG2]], %[[COLLAPSED_IDX]]]
 // CHECK-SAME:      {in_bounds = [true]}
-// CHECK-SAME:      : vector<32xi8>, memref<?x?x32xi8, {{.+}}>
+// CHECK-SAME:      : vector<32xi8>, memref<?x?xi8, {{.+}}>
 
 // CHECK-128B-LABEL: func @transfer_write_leading_dynamic_dims
 //       CHECK-128B:   memref.collapse_shape
@@ -488,14 +491,13 @@ func.func @transfer_write_dynamic_to_flatten(
 // CHECK-SAME:    %[[VEC:arg2]]: vector<1x2x6xi32>
 // CHECK-SAME:    %[[MEM:arg3]]: memref<1x?x4x6xi32>
 
-// CHECK:              %[[C0:.*]] = arith.constant 0 : index
 // CHECK:              %[[COLLAPSED_MEM:.*]] = memref.collapse_shape %[[MEM]]
-// CHECK-SAME-LITERAL:   [[0], [1, 2, 3]]
-// CHECK-SAME:           : memref<1x?x4x6xi32> into memref<1x?xi32>
+// CHECK-SAME-LITERAL:   [[0, 1, 2, 3]]
+// CHECK-SAME:           : memref<1x?x4x6xi32> into memref<?xi32>
 // CHECK:              %[[COLLAPSED_IDX:.*]] = affine.apply 
#[[$MAP]]()[%[[IDX_1]], %[[IDX_2]]]
 // CHECK:              %[[VEC_1D:.*]] = vector.shape_cast %[[VEC]] : 
vector<1x2x6xi32> to vector<12xi32>
-// CHECK:              vector.transfer_write %[[VEC_1D]], 
%[[COLLAPSED_MEM]][%[[C0]], %[[COLLAPSED_IDX]]]
-// CHECK-SAME:           {in_bounds = [true]} : vector<12xi32>, memref<1x?xi32>
+// CHECK:              vector.transfer_write %[[VEC_1D]], 
%[[COLLAPSED_MEM]][%[[COLLAPSED_IDX]]]
+// CHECK-SAME:           {in_bounds = [true]} : vector<12xi32>, memref<?xi32>
 
 // CHECK-128B-LABEL: func @transfer_write_dynamic_to_flatten
 //   CHECK-128B-NOT:   memref.collapse_shape
@@ -573,8 +575,12 @@ func.func @negative_out_of_bound_transfer_read(
     memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<5x4x3x2xi8>
   return %res : vector<5x4x3x2xi8>
 }
-// CHECK:     func.func @negative_out_of_bound_transfer_read
-// CHECK-NOT:   memref.collapse_shape
+// CHECK-LABEL: func.func @negative_out_of_bound_transfer_read
+// CHECK-NOT:     memref.collapse_shape
+
+// CHECK-128B-LABEL: func.func @negative_out_of_bound_transfer_read
+//   CHECK-128B-NOT:   memref.collapse_shape
+//   CHECK-128B-NOT:   vector.shape_cast
 
 // -----
 
@@ -585,5 +591,47 @@ func.func @negative_out_of_bound_transfer_write(
     vector<1x1x3x2xi8>, memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>
   return
 }
-// CHECK:     func.func @negative_out_of_bound_transfer_write
-// CHECK-NOT:   memref.collapse_shape
+// CHECK-LABEL: func.func @negative_out_of_bound_transfer_write
+// CHECK-NOT:     memref.collapse_shape
+
+// CHECK-128B-LABEL: func.func @negative_out_of_bound_transfer_write
+//   CHECK-128B-NOT:   memref.collapse_shape
+//   CHECK-128B-NOT:   vector.shape_cast
+
+// -----
+
+func.func @discontig_mem_contig_slice(
+    %mem : memref<8x8x8xi32, strided<[128, 16, 1]>>, %vec : vector<1x1x8xi32>) 
{
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vec, %mem [%c0, %c0, %c0] {in_bounds = [true, true, 
true]} :
+    vector<1x1x8xi32>, memref<8x8x8xi32, strided<[128, 16, 1]>>
+  return
+}
+
+// CHECK-LABEL: func.func @discontig_mem_contig_slice
+// CHECK-SAME:    %[[MEM:.+]]: memref<8x8x8xi32, strided<[128, 16, 1]>>
+// CHECK-SAME:    %[[VEC:.+]]: vector<1x1x8xi32>
+// CHECK:       %[[C0:.+]] = arith.constant 0 : index
+// CHECK:       %[[VEC_1D:.+]] = vector.shape_cast %[[VEC]] : 
vector<1x1x8xi32> to vector<8xi32>
+// CHECK:       vector.transfer_write %[[VEC_1D]], %[[MEM]][%[[C0]], %[[C0]], 
%[[C0]]] {in_bounds = [true]}
+// CHECK-SAME:    : vector<8xi32>, memref<8x8x8xi32, strided<[128, 16, 1]>>
+
+// CHECK-128B-LABEL: func.func @discontig_mem_contig_slice
+//   CHECK-128B-NOT:   vector.shape_cast
+
+// -----
+
+func.func @discontig_mem_discontig_slice(
+    %mem : memref<8x8x8xi32, strided<[128, 16, 1]>>, %vec : vector<1x2x8xi32>) 
{
+  %c0 = arith.constant 0 : index
+  vector.transfer_write %vec, %mem [%c0, %c0, %c0] {in_bounds = [true, true, 
true]} :
+    vector<1x2x8xi32>, memref<8x8x8xi32, strided<[128, 16, 1]>>
+  return
+}
+
+// CHECK-LABEL: func.func @discontig_mem_discontig_slice
+// CHECK-NOT:    vector.shape_cast
+
+// CHECK-128B-LABEL: func.func @discontig_mem_discontig_slice
+//   CHECK-128B-NOT:   vector.shape_cast
+

``````````

</details>


https://github.com/llvm/llvm-project/pull/142422
_______________________________________________
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to