https://github.com/momchil-velikov created https://github.com/llvm/llvm-project/pull/142422
Previously, slices were sometimes marked as non-contiguous when they were actually contiguous. This occurred when the vector type had leading unit dimensions, e.g., `vector<1x1x...x1xd0xd1x...xdn-1xT>`. In such cases, only the trailing `n` dimensions of the memref need to be contiguous, not the entire vector rank. This affects how `FlattenContiguousRowMajorTransfer{Read,Write}Pattern` flattens `transfer_read` and `transfer_write` ops. The pattern used to collapse a number of dimensions equal the vector rank, which may be is incorrect when leading dimensions are unit-sized. This patch fixes the issue by collapsing only as many trailing memref dimensions as are actually contiguous. >From d9a470e098553dbac74e81f98e0077718f6d9ed1 Mon Sep 17 00:00:00 2001 From: Momchil Velikov <momchil.veli...@arm.com> Date: Mon, 2 Jun 2025 15:13:13 +0000 Subject: [PATCH] [MLIR] Fix incorrect slice contiguity inference in `vector::isContiguousSlice` Previously, slices were sometimes marked as non-contiguous when they were actually contiguous. This occurred when the vector type had leading unit dimensions, e.g., `vector<1x1x...x1xd0xd1x...xdn-1xT>``. In such cases, only the trailing n dimensions of the memref need to be contiguous, not the entire vector rank. This affects how `FlattenContiguousRowMajorTransfer{Read,Write}Pattern` flattens `transfer_read` and `transfer_write`` ops. The pattern used to collapse a number of dimensions equal the vector rank, which may be is incorrect when leading dimensions are unit-sized. This patch fixes the issue by collapsing only as many trailing memref dimensions as are actually contiguous. --- .../mlir/Dialect/Utils/IndexingUtils.h | 3 +- .../mlir/Dialect/Vector/Utils/VectorUtils.h | 54 ++++----- mlir/lib/Dialect/Utils/IndexingUtils.cpp | 6 +- .../Transforms/VectorTransferOpTransforms.cpp | 8 +- mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp | 25 ++-- .../Vector/vector-transfer-flatten.mlir | 108 +++++++++++++----- 6 files changed, 126 insertions(+), 78 deletions(-) diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h index 99218f491ddef..a1c8ec2db056a 100644 --- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h +++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h @@ -40,7 +40,8 @@ class ArrayAttr; /// Assuming `sizes` is `[s0, .. sn]`, return the vector<int64_t> /// `[s1 * ... * sn, s2 * ... * sn, ..., sn, 1]`. /// -/// `sizes` elements are asserted to be non-negative. +/// `sizes` element `s0` is asserted to be kDynamic or non-negative. +/// `sizes` elements `s1` to `sn` are asserted to be non-negative. /// /// Return an empty vector if `sizes` is empty. SmallVector<int64_t> computeSuffixProduct(ArrayRef<int64_t> sizes); diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h index 6609b28d77b6c..ed06d7a029494 100644 --- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h +++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h @@ -49,35 +49,37 @@ FailureOr<std::pair<int, int>> isTranspose2DSlice(vector::TransposeOp op); /// Return true if `vectorType` is a contiguous slice of `memrefType`. /// -/// Only the N = vectorType.getRank() trailing dims of `memrefType` are -/// checked (the other dims are not relevant). Note that for `vectorType` to be -/// a contiguous slice of `memrefType`, the trailing dims of the latter have -/// to be contiguous - this is checked by looking at the corresponding strides. +/// The leading unit dimensions of the vector type are ignored as they +/// are not relevant to the result. Let N be the number of the vector +/// dimensions after ignoring a leading sequence of unit ones. /// -/// There might be some restriction on the leading dim of `VectorType`: +/// For `vectorType` to be a contiguous slice of `memrefType` +/// a) the N trailing dimensions of the latter must be contiguous, and +/// b) the trailing N dimensions of `vectorType` and `memrefType`, +/// except the first of them, must match. /// -/// Case 1. If all the trailing dims of `vectorType` match the trailing dims -/// of `memrefType` then the leading dim of `vectorType` can be -/// arbitrary. -/// -/// Ex. 1.1 contiguous slice, perfect match -/// vector<4x3x2xi32> from memref<5x4x3x2xi32> -/// Ex. 1.2 contiguous slice, the leading dim does not match (2 != 4) -/// vector<2x3x2xi32> from memref<5x4x3x2xi32> -/// -/// Case 2. If an "internal" dim of `vectorType` does not match the -/// corresponding trailing dim in `memrefType` then the remaining -/// leading dims of `vectorType` have to be 1 (the first non-matching -/// dim can be arbitrary). +/// Examples: /// -/// Ex. 2.1 non-contiguous slice, 2 != 3 and the leading dim != <1> -/// vector<2x2x2xi32> from memref<5x4x3x2xi32> -/// Ex. 2.2 contiguous slice, 2 != 3 and the leading dim == <1> -/// vector<1x2x2xi32> from memref<5x4x3x2xi32> -/// Ex. 2.3. contiguous slice, 2 != 3 and the leading dims == <1x1> -/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32> -/// Ex. 2.4. non-contiguous slice, 2 != 3 and the leading dims != <1x1> -/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>) +/// Ex.1 contiguous slice, perfect match +/// vector<4x3x2xi32> from memref<5x4x3x2xi32> +/// Ex.2 contiguous slice, the leading dim does not match (2 != 4) +/// vector<2x3x2xi32> from memref<5x4x3x2xi32> +/// Ex.3 non-contiguous slice, 2 != 3 +/// vector<2x2x2xi32> from memref<5x4x3x2xi32> +/// Ex.4 contiguous slice, leading unit dimension of the vector ignored, +/// 2 != 3 (allowed) +/// vector<1x2x2xi32> from memref<5x4x3x2xi32> +/// Ex.5. contiguous slice, leasing two unit dims of the vector ignored, +/// 2 != 3 (allowed) +/// vector<1x1x2x2xi32> from memref<5x4x3x2xi32> +/// Ex.6. non-contiguous slice, 2 != 3, no leading sequence of unit dims +/// vector<2x1x2x2xi32> from memref<5x4x3x2xi32>) +/// Ex.7 contiguous slice, memref needs to be contiguous only on the last +/// dimension +/// vector<1x1x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>> +/// Ex.8 non-contiguous slice, memref needs to be contiguous one the last +/// two dimensions, and it isn't +/// vector<1x2x2xi32> from memref<2x2x2xi32, strided<[8, 4, 1]>> bool isContiguousSlice(MemRefType memrefType, VectorType vectorType); /// Returns an iterator for all positions in the leading dimensions of `vType` diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp index 8de77e2c3cb08..bb719d46a215a 100644 --- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp +++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp @@ -69,8 +69,10 @@ SmallVector<ExprType> delinearizeImpl(ExprType linearIndex, //===----------------------------------------------------------------------===// SmallVector<int64_t> mlir::computeSuffixProduct(ArrayRef<int64_t> sizes) { - assert(llvm::all_of(sizes, [](int64_t s) { return s >= 0; }) && - "sizes must be nonnegative"); + assert(sizes.size() == 0 || + ((sizes[0] == ShapedType::kDynamic || sizes[0] >= 0) && + llvm::all_of(sizes.drop_front(), [](int64_t s) { return s >= 0; })) && + "sizes must be nonnegative"); int64_t unit = 1; return ::computeSuffixProductImpl(sizes, unit); } diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp index 7dbb7a334fe62..709716365f825 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -630,7 +630,9 @@ class FlattenContiguousRowMajorTransferReadPattern if (transferReadOp.getMask()) return failure(); - int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank(); + // Determinine the first memref dimension to collapse + int64_t firstDimToCollapse = + sourceType.getRank() - sourceType.getMaxCollapsableTrailingDims(); // 1. Collapse the source memref Value collapsedSource = @@ -722,7 +724,9 @@ class FlattenContiguousRowMajorTransferWritePattern if (transferWriteOp.getMask()) return failure(); - int64_t firstDimToCollapse = sourceType.getRank() - vectorType.getRank(); + // Determinine the first memref dimension to collapse + int64_t firstDimToCollapse = + sourceType.getRank() - sourceType.getMaxCollapsableTrailingDims(); // 1. Collapse the source memref Value collapsedSource = diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp index d5dd6f2027be8..5f8f3b6adf9db 100644 --- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp +++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp @@ -258,29 +258,20 @@ bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) { if (vectorType.isScalable()) return false; - ArrayRef<int64_t> vectorShape = vectorType.getShape(); - auto vecRank = vectorType.getRank(); + // Ignore a leading contiguous sequence of unit dimensions in the vector. + ArrayRef<int64_t> vectorShape = + vectorType.getShape().drop_while([](auto v) { return v == 1; }); + auto vecRank = vectorShape.size(); if (!memrefType.areTrailingDimsContiguous(vecRank)) return false; - // Extract the trailing dims and strides of the input memref + // Extract the trailing dims of the input memref auto memrefShape = memrefType.getShape().take_back(vecRank); - // Compare the dims of `vectorType` against `memrefType` (in reverse). - // In the most basic case, all dims will match. - auto firstNonMatchingDim = - std::mismatch(vectorShape.rbegin(), vectorShape.rend(), - memrefShape.rbegin(), memrefShape.rend()); - if (firstNonMatchingDim.first == vectorShape.rend()) - return true; - - // One non-matching dim is still fine, however the remaining leading dims of - // `vectorType` need to be 1. - SmallVector<int64_t> leadingDims(++firstNonMatchingDim.first, - vectorShape.rend()); - - return llvm::all_of(leadingDims, [](auto x) { return x == 1; }); + // Compare the dims of `vectorType` against `memrefType`. + // All of the dimensions, except the first must match. + return llvm::equal(vectorShape.drop_front(), memrefShape.drop_front()); } std::optional<StaticTileOffsetRange> diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir index 5b2f2ab1f2cef..594f7ce371347 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir @@ -116,10 +116,11 @@ func.func @transfer_read_dims_mismatch_non_zero_indices( // CHECK-SAME: %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index, // CHECK-SAME: %[[MEM:.*]]: memref<1x43x4x6xi32> // CHECK: %[[C_0:.*]] = arith.constant 0 : i32 -// CHECK: %[[C_0_IDX:.*]] = arith.constant 0 : index -// CHECK: %[[COLLAPSED_IN:.*]] = memref.collapse_shape %[[MEM]] {{\[}}[0], [1, 2, 3]] : memref<1x43x4x6xi32> into memref<1x1032xi32> +// CHECK: %[[COLLAPSED_IN:.*]] = memref.collapse_shape %[[MEM]] +// CHECK-SAME-LITERAL: [[0, 1, 2, 3]] +// CHECK-SAME: : memref<1x43x4x6xi32> into memref<1032xi32> // CHECK: %[[COLLAPSED_IDX:.*]] = affine.apply #[[$ATTR_0]]()[%[[IDX_1]], %[[IDX_2]]] -// CHECK: %[[READ:.*]] = vector.transfer_read %[[COLLAPSED_IN]][%[[C_0_IDX]], %[[COLLAPSED_IDX]]], %[[C_0]] {in_bounds = [true]} : memref<1x1032xi32>, vector<12xi32> +// CHECK: %[[READ:.*]] = vector.transfer_read %[[COLLAPSED_IN]][%[[COLLAPSED_IDX]]], %[[C_0]] {in_bounds = [true]} : memref<1032xi32>, vector<12xi32> // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices( // CHECK-128B-NOT: memref.collapse_shape @@ -170,16 +171,18 @@ func.func @transfer_read_leading_dynamic_dims( return %res : vector<8x4xi8> } +// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 32)> + // CHECK-LABEL: func @transfer_read_leading_dynamic_dims // CHECK-SAME: %[[MEM:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[IDX_1:.+]]: index, %[[IDX_2:.+]]: index // CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8 -// CHECK: %[[C0:.+]] = arith.constant 0 : index -// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{\[}}[0], [1], [2, 3]{{\]}} -// CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}> +// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{\[}}[0], [1, 2, 3]{{\]}} +// CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?xi8, {{.+}}> +// CHECK: %[[COLLAPSED_IDX:.+]] = affine.apply #[[$MAP]]()[%[[IDX_2]]] // CHECK: %[[VEC1D:.+]] = vector.transfer_read %[[COLLAPSED]] -// CHECK-SAME: [%[[IDX_1]], %[[IDX_2]], %[[C0]]], %[[C0_I8]] +// CHECK-SAME: [%[[IDX_1]], %[[COLLAPSED_IDX]]], %[[C0_I8]] // CHECK-SAME: {in_bounds = [true]} -// CHECK-SAME: : memref<?x?x32xi8, {{.+}}>, vector<32xi8> +// CHECK-SAME: : memref<?x?xi8, {{.+}}>, vector<32xi8> // CHECK: %[[RES:.+]] = vector.shape_cast %[[VEC1D]] : vector<32xi8> to vector<8x4xi8> // CHECK: return %[[RES]] : vector<8x4xi8> @@ -210,13 +213,12 @@ func.func @transfer_read_dynamic_dim_to_flatten( // CHECK-SAME: %[[IDX_2:arg1]] // CHECK-SAME: %[[MEM:arg2]] // CHECK: %[[C0_I32:.*]] = arith.constant 0 : i32 -// CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[MEM]] -// CHECK-SAME-LITERAL: [[0], [1, 2, 3]] -// CHECK-SAME: memref<1x?x4x6xi32> into memref<1x?xi32> +// CHECK-SAME-LITERAL: [[0, 1, 2, 3]] +// CHECK-SAME: memref<1x?x4x6xi32> into memref<?xi32> // CHECK: %[[COLLAPSED_IDX:.*]] = affine.apply #[[$MAP]]()[%[[IDX_1]], %[[IDX_2]]] -// CHECK: %[[VEC_1D:.*]] = vector.transfer_read %[[COLLAPSED]][%[[C0]], %[[COLLAPSED_IDX]]], -// CHECK-SAME: %[[C0_I32]] {in_bounds = [true]} : memref<1x?xi32>, vector<12xi32> +// CHECK: %[[VEC_1D:.*]] = vector.transfer_read %[[COLLAPSED]][%[[COLLAPSED_IDX]]], +// CHECK-SAME: %[[C0_I32]] {in_bounds = [true]} : memref<?xi32>, vector<12xi32> // CHECK: %[[RESULT:.*]] = vector.shape_cast %[[VEC_1D]] : vector<12xi32> to vector<1x2x6xi32> // CHECK: return %[[RESULT]] : vector<1x2x6xi32> @@ -397,11 +399,10 @@ func.func @transfer_write_dims_mismatch_non_zero_indices( // CHECK-SAME: %[[IDX_1:.*]]: index, %[[IDX_2:.*]]: index, // CHECK-SAME: %[[MEM:.*]]: memref<1x43x4x6xi32>, // CHECK-SAME: %[[VEC:.*]]: vector<1x2x6xi32>) { -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[IDX:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[IDX_1]], %[[IDX_2]]] -// CHECK-DAG: %[[CS:.*]] = memref.collapse_shape %[[MEM]] {{\[\[}}0], [1, 2, 3]] : memref<1x43x4x6xi32> into memref<1x1032xi32> +// CHECK-DAG: %[[CS:.*]] = memref.collapse_shape %[[MEM]] {{\[\[}}0, 1, 2, 3]] : memref<1x43x4x6xi32> into memref<1032xi32> // CHECK: %[[SC:.*]] = vector.shape_cast %[[VEC]] : vector<1x2x6xi32> to vector<12xi32> -// CHECK: vector.transfer_write %[[SC]], %[[CS]]{{\[}}%[[C0]], %[[IDX]]] {in_bounds = [true]} : vector<12xi32>, memref<1x1032xi32> +// CHECK: vector.transfer_write %[[SC]], %[[CS]][%[[IDX]]] {in_bounds = [true]} : vector<12xi32>, memref<1032xi32> // CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_zero_indices( // CHECK-128B-NOT: memref.collapse_shape @@ -449,16 +450,18 @@ func.func @transfer_write_leading_dynamic_dims( return } +// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 * 32)> + // CHECK-LABEL: func @transfer_write_leading_dynamic_dims // CHECK-SAME: %[[VEC:.+]]: vector<8x4xi8>, %[[MEM:.+]]: memref<?x?x8x4xi8, {{.+}}>, %[[ARG2:.+]]: index, %[[ARG3:.+]]: index -// CHECK: %[[C0:.+]] = arith.constant 0 : index -// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{\[}}[0], [1], [2, 3]{{\]}} -// CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?x32xi8, {{.+}}> +// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]] {{\[}}[0], [1, 2, 3]{{\]}} +// CHECK-SAME: : memref<?x?x8x4xi8, {{.+}}> into memref<?x?xi8, {{.+}}> +// CHECK: %[[COLLAPSED_IDX:.+]] = affine.apply #[[$MAP]]()[%[[ARG3]]] // CHECK: %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<8x4xi8> to vector<32xi8> // CHECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]] -// CHECK-SAME: [%[[ARG2]], %[[ARG3]], %[[C0]]] +// CHECK-SAME: [%[[ARG2]], %[[COLLAPSED_IDX]]] // CHECK-SAME: {in_bounds = [true]} -// CHECK-SAME: : vector<32xi8>, memref<?x?x32xi8, {{.+}}> +// CHECK-SAME: : vector<32xi8>, memref<?x?xi8, {{.+}}> // CHECK-128B-LABEL: func @transfer_write_leading_dynamic_dims // CHECK-128B: memref.collapse_shape @@ -488,14 +491,13 @@ func.func @transfer_write_dynamic_to_flatten( // CHECK-SAME: %[[VEC:arg2]]: vector<1x2x6xi32> // CHECK-SAME: %[[MEM:arg3]]: memref<1x?x4x6xi32> -// CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[COLLAPSED_MEM:.*]] = memref.collapse_shape %[[MEM]] -// CHECK-SAME-LITERAL: [[0], [1, 2, 3]] -// CHECK-SAME: : memref<1x?x4x6xi32> into memref<1x?xi32> +// CHECK-SAME-LITERAL: [[0, 1, 2, 3]] +// CHECK-SAME: : memref<1x?x4x6xi32> into memref<?xi32> // CHECK: %[[COLLAPSED_IDX:.*]] = affine.apply #[[$MAP]]()[%[[IDX_1]], %[[IDX_2]]] // CHECK: %[[VEC_1D:.*]] = vector.shape_cast %[[VEC]] : vector<1x2x6xi32> to vector<12xi32> -// CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM]][%[[C0]], %[[COLLAPSED_IDX]]] -// CHECK-SAME: {in_bounds = [true]} : vector<12xi32>, memref<1x?xi32> +// CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM]][%[[COLLAPSED_IDX]]] +// CHECK-SAME: {in_bounds = [true]} : vector<12xi32>, memref<?xi32> // CHECK-128B-LABEL: func @transfer_write_dynamic_to_flatten // CHECK-128B-NOT: memref.collapse_shape @@ -573,8 +575,12 @@ func.func @negative_out_of_bound_transfer_read( memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>, vector<5x4x3x2xi8> return %res : vector<5x4x3x2xi8> } -// CHECK: func.func @negative_out_of_bound_transfer_read -// CHECK-NOT: memref.collapse_shape +// CHECK-LABEL: func.func @negative_out_of_bound_transfer_read +// CHECK-NOT: memref.collapse_shape + +// CHECK-128B-LABEL: func.func @negative_out_of_bound_transfer_read +// CHECK-128B-NOT: memref.collapse_shape +// CHECK-128B-NOT: vector.shape_cast // ----- @@ -585,5 +591,47 @@ func.func @negative_out_of_bound_transfer_write( vector<1x1x3x2xi8>, memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>> return } -// CHECK: func.func @negative_out_of_bound_transfer_write -// CHECK-NOT: memref.collapse_shape +// CHECK-LABEL: func.func @negative_out_of_bound_transfer_write +// CHECK-NOT: memref.collapse_shape + +// CHECK-128B-LABEL: func.func @negative_out_of_bound_transfer_write +// CHECK-128B-NOT: memref.collapse_shape +// CHECK-128B-NOT: vector.shape_cast + +// ----- + +func.func @discontig_mem_contig_slice( + %mem : memref<8x8x8xi32, strided<[128, 16, 1]>>, %vec : vector<1x1x8xi32>) { + %c0 = arith.constant 0 : index + vector.transfer_write %vec, %mem [%c0, %c0, %c0] {in_bounds = [true, true, true]} : + vector<1x1x8xi32>, memref<8x8x8xi32, strided<[128, 16, 1]>> + return +} + +// CHECK-LABEL: func.func @discontig_mem_contig_slice +// CHECK-SAME: %[[MEM:.+]]: memref<8x8x8xi32, strided<[128, 16, 1]>> +// CHECK-SAME: %[[VEC:.+]]: vector<1x1x8xi32> +// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[VEC_1D:.+]] = vector.shape_cast %[[VEC]] : vector<1x1x8xi32> to vector<8xi32> +// CHECK: vector.transfer_write %[[VEC_1D]], %[[MEM]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true]} +// CHECK-SAME: : vector<8xi32>, memref<8x8x8xi32, strided<[128, 16, 1]>> + +// CHECK-128B-LABEL: func.func @discontig_mem_contig_slice +// CHECK-128B-NOT: vector.shape_cast + +// ----- + +func.func @discontig_mem_discontig_slice( + %mem : memref<8x8x8xi32, strided<[128, 16, 1]>>, %vec : vector<1x2x8xi32>) { + %c0 = arith.constant 0 : index + vector.transfer_write %vec, %mem [%c0, %c0, %c0] {in_bounds = [true, true, true]} : + vector<1x2x8xi32>, memref<8x8x8xi32, strided<[128, 16, 1]>> + return +} + +// CHECK-LABEL: func.func @discontig_mem_discontig_slice +// CHECK-NOT: vector.shape_cast + +// CHECK-128B-LABEL: func.func @discontig_mem_discontig_slice +// CHECK-128B-NOT: vector.shape_cast + _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits