Author: Ian Wood Date: 2025-06-03T09:46:37-07:00 New Revision: 8b407ad6fcb1e722d744edb41dac39bf94117b8a
URL: https://github.com/llvm/llvm-project/commit/8b407ad6fcb1e722d744edb41dac39bf94117b8a DIFF: https://github.com/llvm/llvm-project/commit/8b407ad6fcb1e722d744edb41dac39bf94117b8a.diff LOG: Revert "[mlir][tensor] Loosen restrictions on folding dynamic reshapes (#137963)" This reverts commit cb4a407e5c2a8a5972781d2a3be362f437602fae. Added: Modified: mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir mlir/test/Dialect/Tensor/canonicalize.mlir mlir/unittests/Dialect/Utils/CMakeLists.txt Removed: mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp ################################################################################ diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp index 3b1fdb69e8ef1..ed40a080441bc 100644 --- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp +++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp @@ -10,10 +10,6 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinTypeInterfaces.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/Support/LogicalResult.h" #include <numeric> #include <optional> @@ -32,329 +28,67 @@ mlir::getReassociationIndicesForReshape(ShapedType sourceType, return std::nullopt; } -namespace { -/// A simple struct to represent ReassociationIndices as an inclusive interval. -/// It's designed to be feasibly minimal, so the call sites should manage the -/// validity of the range manually. -struct ReassociationIndexRange { - /// FIXME: Signed type is used for consistency with ReassociationIndices. - /// We should consider refactoring all reassociation utilities to use unsigned - /// types. - int64_t leftIdx = 0, rightIdx = 0; - - /// Util for manual checks of the range's validity - LogicalResult verify() const { - return leftIdx >= 0 && (leftIdx <= rightIdx) ? success() : failure(); - } - - /// Checks range's containment within another range. Treats the edges - /// non-exclusively. - bool isInRange(const ReassociationIndexRange &outerRange) const { - return leftIdx >= outerRange.leftIdx && rightIdx <= outerRange.rightIdx; - } - - unsigned size() const { - assert(succeeded(verify())); - return rightIdx - leftIdx + 1; - } - bool containsSingleIndex() const { return size() == 1; } - - /// Collects indices that do not overlap between this and another range. - ReassociationIndices - getNonOverlappingIndicesWith(ReassociationIndexRange &rhs) const { - if (rightIdx < rhs.leftIdx) { - // The intervals do not overlap - concatenate the indices from both. - auto jointFullIndices = getFullIndices(); - jointFullIndices.append(rhs.getFullIndices()); - return jointFullIndices; - } - ReassociationIndices result; - // Handle the chunk left of the overlapping range. - int64_t leftStart = std::min(leftIdx, rhs.leftIdx); - int64_t leftEnd = std::max(leftIdx, rhs.leftIdx); - llvm::append_range(result, llvm::seq(leftStart, leftEnd)); - // Handle the chunk right of the overlapping range. Symmetrically, we should - // skip the edge of the overlap AND include the rightmost index. - int64_t rightStart = std::min(rightIdx, rhs.rightIdx) + 1; - int64_t rightEnd = std::max(rightIdx, rhs.rightIdx); - if (rightStart < rightEnd) - llvm::append_range(result, llvm::seq_inclusive(rightStart, rightEnd)); - return result; - } - - /// Converts the range into ReassociationIndices. - ReassociationIndices getFullIndices() const { - ReassociationIndices result; - for (int64_t idx = leftIdx; idx <= rightIdx; ++idx) { - result.push_back(idx); - } - return result; - } -}; -} // namespace - -/// Starting from `sourceStartIdx`, searches `sourceShape` for the first -/// sequence that can be collapsed into a dynamic dimension (at least one must -/// be present in the source). -/// By default, lazily returns once the first dynamic dimension has been found. -/// Setting `matchGreedily` as `true` will also mark all subsequent -/// source dimensions for collapsing into the target. -static FailureOr<ReassociationIndexRange> -findReassociationRangeForDynamicDim(ArrayRef<int64_t> sourceShape, - int64_t sourceStartIdx, - bool matchGreedily = false) { - const unsigned numSourceDims = sourceShape.size(); - ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1}; - std::optional<ReassociationIndexRange> resultRange = std::nullopt; - - ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx}; - for (; iterationRange.isInRange(sourceShapeAsRange); - iterationRange.rightIdx++) { - int64_t sourceSize = sourceShape[iterationRange.rightIdx]; - if (sourceSize == ShapedType::kDynamic) { - resultRange = iterationRange; - break; - } - } - if (!resultRange) - return failure(); - if (matchGreedily) - resultRange->rightIdx = sourceShapeAsRange.rightIdx; - return *resultRange; -} +std::optional<SmallVector<ReassociationIndices>> +mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape, + ArrayRef<int64_t> targetShape) { + if (sourceShape.size() <= targetShape.size()) + return std::nullopt; + unsigned sourceDim = 0; + SmallVector<ReassociationIndices> reassociationMap; + reassociationMap.reserve(targetShape.size()); -/// Starting from `sourceStartIdx`, searches `sourceShape` for the first -/// sequence of static dimensions such that their product matches `targetSize`. -/// By default, lazily returns once the product matches the target size. Setting -/// `matchGreedily` as `true` will append all neighboring unit dimensions -/// (dimensions of 1) to the match. -static FailureOr<ReassociationIndexRange> -findReassociationRangeForSize(ArrayRef<int64_t> sourceShape, - int64_t sourceStartIdx, int64_t targetSize, - bool matchGreedily = false) { - const unsigned numSourceDims = sourceShape.size(); - ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1}; - std::optional<ReassociationIndexRange> resultRange = std::nullopt; - - ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx}; + ReassociationIndices currIndices; int64_t prodOfCollapsedDims = 1; - while (iterationRange.isInRange(sourceShapeAsRange)) { - int64_t sourceSize = sourceShape[iterationRange.rightIdx]; - if (sourceSize == ShapedType::kDynamic) { - // Reassociation for a static dim cannot include a dynamic dim. Reset - // induction variables to essentially restart the loop from the next - // source dimension. - prodOfCollapsedDims = 1; - iterationRange = {iterationRange.rightIdx + 1, - iterationRange.rightIdx + 1}; - continue; - } - prodOfCollapsedDims *= sourceSize; - // If the target size has been exceeded without matching, we need to shift - // the range start right. From the start of the range, roll back the - // multiplication until the target size exceeds the product again. - while (prodOfCollapsedDims > targetSize && - !iterationRange.containsSingleIndex()) { - int64_t frontSourceSize = sourceShape[iterationRange.leftIdx]; - prodOfCollapsedDims /= frontSourceSize; - // Shrink the range rightwards - iterationRange.leftIdx++; - } - // We could've reached the target size with the current dimension, - // also as a result of the above shift to right. - if (prodOfCollapsedDims == targetSize) { - resultRange = iterationRange; + while (sourceDim < sourceShape.size()) { + unsigned targetDim = reassociationMap.size(); + // If we have mapped all the target dimensions stop and handle the remaining + // tail of size-1 dimensions explicitly. + if (targetDim == targetShape.size()) break; - } - // Increment the iteration range - iterationRange.rightIdx++; - } - if (!resultRange) - return failure(); - if (matchGreedily) { - // We now want to collect all unit dimensions directly after the target - // product match. Advance the iterator to avoid OOB when the product match - // happens at the last element. - iterationRange.rightIdx++; - while (iterationRange.isInRange(sourceShapeAsRange) && - sourceShape[iterationRange.rightIdx] == 1) { - resultRange = iterationRange; - iterationRange.rightIdx++; - } - } - return *resultRange; -} - -/// Attempts to find a valid collapsing reassociation of `sourceShape` into -/// `targetShape` through a simple traversal. If successful, an array of source -/// index ranges is returned, correspondingly to each dimension in the target -/// shape. The resulting indices shall fully cover the `sourceShape` without -/// overlaps. -/// -/// The algorithm is essentially a lazy one, searching for non-greedy matches - -/// it will only yield a greedy match for the last target dimension. -/// FIXME: The algorithm can only backtrack when it needs to append an offset -/// for a static target dimension to the preceding dynamic one (this retains the -/// linear complexity). As feasible, consider adding further backtracking -/// routines to enable more reassociations, e.g.: -/// - ?x2x?x2 into ?x2 -static FailureOr<SmallVector<ReassociationIndexRange>> -findReassociationRangesForCollapse(ArrayRef<int64_t> sourceShape, - ArrayRef<int64_t> targetShape) { - unsigned numSourceDims = sourceShape.size(), - numTargetDims = targetShape.size(); - assert(numSourceDims > numTargetDims); - ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1}; - - SmallVector<ReassociationIndexRange> reassocRanges; - reassocRanges.reserve(numTargetDims); - // We'll iterate in strides of 2 to enable pseudo-backtracking for simple - // cases, e.g.: - // - ?x2x3x5 into ?x15 - std::optional<int64_t> prevTargetSize = std::nullopt; - for (unsigned targetDimIdx = 0, sourceDimIdx = 0; - targetDimIdx < numTargetDims; ++targetDimIdx) { - int64_t targetSize = targetShape[targetDimIdx]; - // Simply check if there are any subsequent target dimensions left - if not, - // the match must be made greedily. - bool shouldMatchGreedily = targetDimIdx == numTargetDims - 1; - FailureOr<ReassociationIndexRange> sourceRange; - if (targetSize == ShapedType::kDynamic) { - sourceRange = findReassociationRangeForDynamicDim( - sourceShape, sourceDimIdx, shouldMatchGreedily); - } else { - sourceRange = findReassociationRangeForSize( - sourceShape, sourceDimIdx, targetSize, shouldMatchGreedily); - } - // Run sanity checks on the returned index range. - if (failed(sourceRange) || failed(sourceRange->verify()) || - !sourceRange->isInRange(sourceShapeAsRange)) - return failure(); - if (sourceRange->leftIdx > sourceDimIdx) { - // If some source dimensions had to be skipped in order to find a match, - // they must be collapsed into the directly preceding dynamic dimension. - if (!prevTargetSize || prevTargetSize != ShapedType::kDynamic) - return failure(); - reassocRanges.back().rightIdx = sourceRange->leftIdx - 1; + int64_t currTargetShape = targetShape[targetDim]; + while (sourceDim < (sourceShape.size() - 1) && + sourceShape[sourceDim] != ShapedType::kDynamic && + prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape) { + prodOfCollapsedDims *= sourceShape[sourceDim]; + currIndices.push_back(sourceDim++); } - // Store the gathered information as required for the next iteration. - prevTargetSize = targetSize; - sourceDimIdx = sourceRange->rightIdx + 1; - reassocRanges.push_back(*sourceRange); - } - // Fail if the source shape wasn't a full match for the target shape. We only - // need to check the last recorded index - any other gaps should have been - // mended by the main loop. - if (reassocRanges.back().rightIdx < sourceShapeAsRange.rightIdx) - return failure(); - return reassocRanges; -} - -/// A variant of `findReassociationRangesForCollapse(...)` that can also scan -/// the shapes right-to-left. -static FailureOr<SmallVector<ReassociationIndexRange>> -findReassociationRangesForCollapse(ArrayRef<int64_t> sourceShape, - ArrayRef<int64_t> targetShape, - bool iterateRightToLeft) { - if (!iterateRightToLeft) - return findReassociationRangesForCollapse(sourceShape, targetShape); - // NB: To iterate right-to-left, we currently reverse the shapes and then - // reverse the result back. The reversed shapes must not be temporary, as - // we're passing through an ArrayRef. - // FIXME: It would be preferable to avoid the expensive copies. At the moment, - // this approach is chosen for readability of the main implementation. - std::vector<int64_t> sourceToReverse = sourceShape.vec(), - targetToReverse = targetShape.vec(); - std::reverse(sourceToReverse.begin(), sourceToReverse.end()); - std::reverse(targetToReverse.begin(), targetToReverse.end()); - auto invertedRanges = - findReassociationRangesForCollapse(sourceToReverse, targetToReverse); - if (failed(invertedRanges)) - return failure(); - SmallVector<ReassociationIndexRange> &rangesToInvert = *invertedRanges; - unsigned numSourceDims = sourceShape.size(); - // We have received the ranges for inverted shapes. Now we have to invert - // the ranges back to correspond with the original source shape. - for (auto &range : rangesToInvert) { - int64_t invLeftIdx = range.leftIdx, invRightIdx = range.rightIdx; - range.leftIdx = numSourceDims - 1 - invRightIdx; - range.rightIdx = numSourceDims - 1 - invLeftIdx; + // If the current expanded dimension is dynamic, then the collapsed + // dimensions should also be dynamic and product of all previous unprocessed + // dimensions of the expanded shape should be 1. + if (sourceShape[sourceDim] == ShapedType::kDynamic && + (currTargetShape != ShapedType::kDynamic || prodOfCollapsedDims != 1)) + return std::nullopt; + + // If the collapsed dim is dynamic, the current expanded dim should also + // be dynamic. + if (currTargetShape == ShapedType::kDynamic && + sourceShape[sourceDim] != ShapedType::kDynamic) + return std::nullopt; + + // For static shapes, if the product of dimensions of the expanded shape + // should match the collapsed dimension shape. + if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape) + return std::nullopt; + + currIndices.push_back(sourceDim++); + reassociationMap.emplace_back(ReassociationIndices{}); + std::swap(reassociationMap.back(), currIndices); + prodOfCollapsedDims = 1; } - // Also invert the ordering of the ranges to correspond with the original - // target shape. - std::reverse(rangesToInvert.begin(), rangesToInvert.end()); - return rangesToInvert; -} - -std::optional<SmallVector<ReassociationIndices>> -mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape, - ArrayRef<int64_t> targetShape) { - unsigned numSourceDims = sourceShape.size(), - numTargetDims = targetShape.size(); - // We're supposed to search for a collapsing reassociation. If the sizes - // match, there's no actual collapsing taking place - it's either a no-op or a - // `tensor.reshape`-style reassociation (that would be beyond the scope of - // this utility). - if (numSourceDims <= numTargetDims) + // All the dimensions in the target must have been processed. + if (reassociationMap.size() != targetShape.size()) return std::nullopt; - // Early handling for scalar target types. - if (numTargetDims == 0) { - ReassociationIndices allSourceIndices; - allSourceIndices.reserve(numSourceDims); - for (unsigned sourceDimIdx = 0; sourceDimIdx < numSourceDims; - ++sourceDimIdx) { - int64_t sourceSize = sourceShape[sourceDimIdx]; - // All source dimensions must be unit or dynamic. - if (sourceSize != 1 && sourceSize != ShapedType::kDynamic) - return std::nullopt; - allSourceIndices.push_back(sourceDimIdx); - } - return SmallVector<ReassociationIndices>{allSourceIndices}; - } - - // Collect source ranges by iterating over the target shape left-to-right. - FailureOr<SmallVector<ReassociationIndexRange>> maybeForwardRanges = - findReassociationRangesForCollapse(sourceShape, targetShape); - if (failed(maybeForwardRanges)) - return std::nullopt; - auto &ranges = *maybeForwardRanges; - // Now do the same in reverse. We need to get another valid reassociation - // through some other strategy, and then compare the results in order to - // disambiguate mixed subshapes, such as: - // ?x?x? into ?x?, ?x2x? into ?x?, ?x2x3x6x? into ?x6x? - // This leads us to lose some of the reassociation opportunities that can only - // be found by iterating in a certain direction, e.g. 2x2x? into 2x? - without - // backtracking, the algorithm will fail right-to-left. However, this is the - // best way to preserve correctness. - FailureOr<SmallVector<ReassociationIndexRange>> maybeReverseRanges = - findReassociationRangesForCollapse(sourceShape, targetShape, - /*iterateRightToLeft=*/true); - if (failed(maybeReverseRanges)) - return std::nullopt; - auto &reverseRanges = *maybeReverseRanges; - - if (ranges.size() != numTargetDims || reverseRanges.size() != numTargetDims) - return std::nullopt; - // Now we can check for ambiguity of each target dimension's reassociation. If - // successful, we put the full indices into our result map for the target - // shape. - SmallVector<ReassociationIndices> reassociationMap(numTargetDims); - for (unsigned targetDimIdx = 0; targetDimIdx < numTargetDims; - ++targetDimIdx) { - ReassociationIndexRange &range = ranges[targetDimIdx]; - ReassociationIndexRange &reverseRange = reverseRanges[targetDimIdx]; - // Get non-overlapping indices between the ranges - ReassociationIndices nonMatchingIndices = - range.getNonOverlappingIndicesWith(reverseRange); - // Unit dimensions can be collapsed wherever - this is the only ambiguity - // that we allow. - for (int64_t sourceDimIdx : nonMatchingIndices) { - if (sourceShape[sourceDimIdx] != 1) - return std::nullopt; - } - reassociationMap[targetDimIdx] = range.getFullIndices(); + // Process any remaining entries in the source shape. They all need to be + // 1 or dynamic. + for (; sourceDim < sourceShape.size(); sourceDim++) { + if (sourceShape[sourceDim] != ShapedType::kDynamic && + sourceShape[sourceDim] != 1) + return std::nullopt; + // The map is empty when the target type is a scalar. + if (!reassociationMap.empty()) + reassociationMap.back().push_back(sourceDim); } return reassociationMap; } @@ -581,11 +315,11 @@ SmallVector<Range> SliceFromCollapseHelper::getExtractSliceParams( // have proven that these are not sliced. In this case we just take // the full extent of each dimension in the reassociation list. if (linearizedDimensions[it.index()]) { - llvm::append_range(offsetsSizesAndStrides, - llvm::map_range(it.value(), [&](int64_t idx) -> Range { - return {zeroAttr, collapseShapeInputShape[idx], - oneAttr}; - })); + llvm::append_range( + offsetsSizesAndStrides, + llvm::map_range(it.value(), [&](int64_t idx) -> Range { + return {zeroAttr, collapseShapeInputShape[idx], oneAttr}; + })); continue; } diff --git a/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir b/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir index 6979770154bab..51350e5bc8498 100644 --- a/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir +++ b/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir @@ -158,8 +158,8 @@ func.func @unpack_to_partial_slice(%arg0: tensor<8x32xf32>) -> tensor<255xf32> { // ----- // CHECK-LABEL: func.func @unpack_dynamic -// CHECK: tensor.collapse -// CHECK-NOT: linalg.unpack +// CHECK-NOT: tensor.collapse +// CHECK: linalg.unpack func.func @unpack_dynamic(%arg0: tensor<?x32xf32>) -> tensor<?xf32> { %c32 = arith.constant 32 : index %c0 = arith.constant 0 : index diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index 646b2197d9aa6..0abec7e01d184 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -1117,7 +1117,7 @@ func.func @fold_expand_of_collapse(%arg0 : tensor<3x4x4xf32>) -> tensor<3x4x4xf3 // ----- -func.func @fold_expand_of_collapse_mixed_subshape(%arg0 : tensor<?x4x?xf32>, %arg1: index, %arg2: index) +func.func @fold_expand_of_collapse_dynamic(%arg0 : tensor<?x4x?xf32>, %arg1: index, %arg2: index) -> tensor<?x4x?xf32> { %0 = tensor.collapse_shape %arg0 [[0, 1], [2]] : tensor<?x4x?xf32> into tensor<?x?xf32> @@ -1125,28 +1125,12 @@ func.func @fold_expand_of_collapse_mixed_subshape(%arg0 : tensor<?x4x?xf32>, %ar : tensor<?x?xf32> into tensor<?x4x?xf32> return %1 : tensor<?x4x?xf32> } -// CHECK-LABEL: @fold_expand_of_collapse_mixed_subshape +// CHECK-LABEL: @fold_expand_of_collapse_dynamic // CHECK-NOT: tensor.{{.*}}_shape // ----- -func.func @fold_expand_of_collapse_mixed_target_subshape(%arg0 : tensor<?x4x?x2xf32>, %arg1: index, %arg2: index) - -> tensor<?x4x?xf32> { - %0 = tensor.collapse_shape %arg0 [[0, 1], [2, 3]] - : tensor<?x4x?x2xf32> into tensor<?x?xf32> - %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, 4, %arg2] - : tensor<?x?xf32> into tensor<?x4x?xf32> - return %1 : tensor<?x4x?xf32> -} -// CHECK-LABEL: @fold_expand_of_collapse_mixed_target_subshape -// CHECK-NOT: tensor.expand_shape -// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %arg0 {{\[}}[0], [1], [2, 3]] -// CHECK-SAME: : tensor<?x4x?x2xf32> into tensor<?x4x?xf32> -// CHECK-NEXT: return %[[COLLAPSE]] - -// ----- - -func.func @no_fold_expand_of_collapse_fully_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index) +func.func @no_fold_expand_of_collapse_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index) -> tensor<?x?x?xf32> { %0 = tensor.collapse_shape %arg0 [[0, 1], [2]] : tensor<?x?x?xf32> into tensor<?x?xf32> @@ -1154,22 +1138,7 @@ func.func @no_fold_expand_of_collapse_fully_dynamic(%arg0 : tensor<?x?x?xf32>, % : tensor<?x?xf32> into tensor<?x?x?xf32> return %1 : tensor<?x?x?xf32> } -// CHECK-LABEL: @no_fold_expand_of_collapse_fully_dynamic -// CHECK: tensor.collapse_shape -// CHECK: %[[EXPAND:.+]] = tensor.expand_shape -// CHECK: return %[[EXPAND]] - -// ----- - -func.func @no_fold_expand_of_collapse_adjacent_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index) - -> tensor<?x?xf32> { - %0 = tensor.collapse_shape %arg0 [[0, 1, 2]] - : tensor<?x?x?xf32> into tensor<?xf32> - %1 = tensor.expand_shape %0 [[0, 1]] output_shape [%arg1, %arg2] - : tensor<?xf32> into tensor<?x?xf32> - return %1 : tensor<?x?xf32> -} -// CHECK-LABEL: @no_fold_expand_of_collapse_adjacent_dynamic +// CHECK-LABEL: @no_fold_expand_of_collapse_dynamic // CHECK: tensor.collapse_shape // CHECK: %[[EXPAND:.+]] = tensor.expand_shape // CHECK: return %[[EXPAND]] diff --git a/mlir/unittests/Dialect/Utils/CMakeLists.txt b/mlir/unittests/Dialect/Utils/CMakeLists.txt index e921c8bcfb4e5..61b9cdcb3b8f3 100644 --- a/mlir/unittests/Dialect/Utils/CMakeLists.txt +++ b/mlir/unittests/Dialect/Utils/CMakeLists.txt @@ -1,6 +1,5 @@ add_mlir_unittest(MLIRDialectUtilsTests StructuredOpsUtilsTest.cpp - ReshapeOpsUtilsTest.cpp IndexingUtilsTest.cpp ) mlir_target_link_libraries(MLIRDialectUtilsTests diff --git a/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp b/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp deleted file mode 100644 index db1a87a4de2d5..0000000000000 --- a/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp +++ /dev/null @@ -1,203 +0,0 @@ -//===- ReshapeOpsUtilsTest.cpp - ReshapeOpsUtils unit tests ---------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" -#include "mlir/IR/BuiltinTypeInterfaces.h" -#include "llvm/ADT/STLExtras.h" -#include "gtest/gtest.h" -#include <optional> - -using namespace mlir; - -/// Helper to make constructing -/// `std::optional<SmallVector<ReassociationIndices>>` more readable. -static std::optional<SmallVector<ReassociationIndices>> -makeOptionalIndices(std::initializer_list<ReassociationIndices> list) { - return std::optional<SmallVector<ReassociationIndices>>(list); -} - -TEST(ReassociationIndicesForCollapse, ScalarTest) { - EXPECT_EQ(getReassociationIndicesForCollapse({1}, {}), - makeOptionalIndices({{0}})); - EXPECT_EQ(getReassociationIndicesForCollapse({1, 1}, {}), - makeOptionalIndices({{0, 1}})); - EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic}, {}), - makeOptionalIndices({{0}})); - EXPECT_EQ(getReassociationIndicesForCollapse({1, ShapedType::kDynamic, - ShapedType::kDynamic, 1, - ShapedType::kDynamic}, - {}), - makeOptionalIndices({{0, 1, 2, 3, 4}})); -} - -TEST(ReassociationIndicesForCollapse, ScalarTestFailure) { - EXPECT_EQ(getReassociationIndicesForCollapse({}, {}), std::nullopt); - EXPECT_EQ(getReassociationIndicesForCollapse({}, {1}), std::nullopt); - EXPECT_EQ(getReassociationIndicesForCollapse({2}, {}), std::nullopt); - EXPECT_EQ( - getReassociationIndicesForCollapse({1, 2, ShapedType::kDynamic, 1}, {}), - std::nullopt); -} - -TEST(ReassociationIndicesForCollapse, StaticTest) { - EXPECT_EQ(getReassociationIndicesForCollapse({10, 20}, {200}), - makeOptionalIndices({{0, 1}})); - EXPECT_EQ(getReassociationIndicesForCollapse({10, 20, 30}, {10, 600}), - makeOptionalIndices({{0}, {1, 2}})); - EXPECT_EQ(getReassociationIndicesForCollapse({10, 20, 30}, {200, 30}), - makeOptionalIndices({{0, 1}, {2}})); -} - -TEST(ReassociationIndicesForCollapse, StaticTestFailure) { - // No-op reassociation - EXPECT_EQ(getReassociationIndicesForCollapse({10, 20}, {10, 20}), - std::nullopt); - // Invalid static reassociations - EXPECT_EQ(getReassociationIndicesForCollapse({10, 20}, {10}), std::nullopt); - EXPECT_EQ(getReassociationIndicesForCollapse({10, 20, 30}, {200, 300}), - std::nullopt); - // Non-collapsing (expanding) reassociation - EXPECT_EQ(getReassociationIndicesForCollapse({10, 20, 30}, {1, 10, 20, 30}), - std::nullopt); -} - -TEST(ReassociationIndicesForCollapse, StaticTestUnitDims) { - EXPECT_EQ(getReassociationIndicesForCollapse({10, 1}, {10}), - makeOptionalIndices({{0, 1}})); - EXPECT_EQ(getReassociationIndicesForCollapse({1, 20, 30}, {600}), - makeOptionalIndices({{0, 1, 2}})); - EXPECT_EQ(getReassociationIndicesForCollapse({1, 1, 1}, {1}), - makeOptionalIndices({{0, 1, 2}})); - EXPECT_EQ(getReassociationIndicesForCollapse({1, 1, 1, 1}, {1, 1, 1}), - makeOptionalIndices({{0}, {1}, {2, 3}})); -} - -TEST(ReassociationIndicesForCollapse, DynamicTest) { - EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 1}, - {ShapedType::kDynamic}), - makeOptionalIndices({{0, 1}})); - EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 1, 1}, - {ShapedType::kDynamic}), - makeOptionalIndices({{0, 1, 2}})); - EXPECT_EQ(getReassociationIndicesForCollapse( - {1, ShapedType::kDynamic, 1, ShapedType::kDynamic, 1}, - {ShapedType::kDynamic, ShapedType::kDynamic}), - makeOptionalIndices({{0, 1}, {2, 3, 4}})); - EXPECT_EQ( - getReassociationIndicesForCollapse( - {ShapedType::kDynamic, ShapedType::kDynamic}, {ShapedType::kDynamic}), - makeOptionalIndices({{0, 1}})); - EXPECT_EQ(getReassociationIndicesForCollapse( - {1, ShapedType::kDynamic, ShapedType::kDynamic}, - {1, ShapedType::kDynamic}), - makeOptionalIndices({{0}, {1, 2}})); - - EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 10}, - {ShapedType::kDynamic}), - makeOptionalIndices({{0, 1}})); - EXPECT_EQ(getReassociationIndicesForCollapse( - {1, ShapedType::kDynamic, ShapedType::kDynamic}, - {ShapedType::kDynamic}), - makeOptionalIndices({{0, 1, 2}})); - EXPECT_EQ(getReassociationIndicesForCollapse({10, ShapedType::kDynamic}, - {ShapedType::kDynamic}), - makeOptionalIndices({{0, 1}})); - EXPECT_EQ(getReassociationIndicesForCollapse( - {ShapedType::kDynamic, 1, 2, ShapedType::kDynamic, 10}, - {ShapedType::kDynamic, 10}), - makeOptionalIndices({{0, 1, 2, 3}, {4}})); - EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 10, 20}, - {ShapedType::kDynamic, 20}), - makeOptionalIndices({{0, 1}, {2}})); - EXPECT_EQ(getReassociationIndicesForCollapse({10, ShapedType::kDynamic, 20}, - {ShapedType::kDynamic, 20}), - makeOptionalIndices({{0, 1}, {2}})); - EXPECT_EQ(getReassociationIndicesForCollapse( - {ShapedType::kDynamic, 3, 2, 5, 2}, {ShapedType::kDynamic, 20}), - makeOptionalIndices({{0, 1}, {2, 3, 4}})); - EXPECT_EQ(getReassociationIndicesForCollapse( - {10, ShapedType::kDynamic, 20, ShapedType::kDynamic, 1}, - {ShapedType::kDynamic, 20, ShapedType::kDynamic}), - makeOptionalIndices({{0, 1}, {2}, {3, 4}})); - EXPECT_EQ(getReassociationIndicesForCollapse({1, ShapedType::kDynamic, 1}, - {ShapedType::kDynamic}), - makeOptionalIndices({{0, 1, 2}})); - EXPECT_EQ(getReassociationIndicesForCollapse( - {ShapedType::kDynamic, ShapedType::kDynamic, 1}, - {ShapedType::kDynamic, ShapedType::kDynamic}), - makeOptionalIndices({{0}, {1, 2}})); - EXPECT_EQ(getReassociationIndicesForCollapse( - {1, ShapedType::kDynamic, ShapedType::kDynamic}, - {ShapedType::kDynamic, ShapedType::kDynamic}), - makeOptionalIndices({{0, 1}, {2}})); - EXPECT_EQ(getReassociationIndicesForCollapse( - {ShapedType::kDynamic, 1, ShapedType::kDynamic}, - {ShapedType::kDynamic, ShapedType::kDynamic}), - makeOptionalIndices({{0}, {1, 2}})); -} - -TEST(ReassociationIndicesForCollapse, DynamicTestFailure) { - EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 10, 20}, - {ShapedType::kDynamic, 10}), - std::nullopt); - EXPECT_EQ(getReassociationIndicesForCollapse( - {ShapedType::kDynamic, 10, ShapedType::kDynamic}, - {ShapedType::kDynamic, ShapedType::kDynamic}), - std::nullopt); - EXPECT_EQ(getReassociationIndicesForCollapse( - {20, ShapedType::kDynamic, 10, ShapedType::kDynamic}, - {ShapedType::kDynamic, ShapedType::kDynamic}), - std::nullopt); - EXPECT_EQ(getReassociationIndicesForCollapse( - {ShapedType::kDynamic, 5, 3, 2, 2}, {ShapedType::kDynamic, 20}), - std::nullopt); - EXPECT_EQ( - getReassociationIndicesForCollapse( - {ShapedType::kDynamic, ShapedType::kDynamic, ShapedType::kDynamic}, - {ShapedType::kDynamic, ShapedType::kDynamic}), - std::nullopt); - EXPECT_EQ(getReassociationIndicesForCollapse( - {ShapedType::kDynamic, ShapedType::kDynamic, 10, 1, - ShapedType::kDynamic}, - {ShapedType::kDynamic, ShapedType::kDynamic}), - std::nullopt); - EXPECT_EQ(getReassociationIndicesForCollapse( - {ShapedType::kDynamic, 10, 10, 10, ShapedType::kDynamic}, - {ShapedType::kDynamic, 10, ShapedType::kDynamic}), - std::nullopt); - EXPECT_EQ(getReassociationIndicesForCollapse( - {ShapedType::kDynamic, 10, 10, 10, ShapedType::kDynamic}, - {ShapedType::kDynamic, 2, 2, ShapedType::kDynamic}), - std::nullopt); - EXPECT_EQ(getReassociationIndicesForCollapse( - {ShapedType::kDynamic, 3, 4, 3, ShapedType::kDynamic}, - {ShapedType::kDynamic, 12, ShapedType::kDynamic}), - std::nullopt); - EXPECT_EQ(getReassociationIndicesForCollapse( - {ShapedType::kDynamic, 8, 4, 2, 16, ShapedType::kDynamic}, - {ShapedType::kDynamic, 32, ShapedType::kDynamic}), - std::nullopt); - - //===----------------------------------------------------------------------===// - // TODO: Reassociation for the following examples can be computed, but isn't - // supported by `getReassociationIndicesForCollapse`. - //===----------------------------------------------------------------------===// - - // TODO: Fails because there's no backtracking when some source dimensions - // remain unmatched at either edge. - EXPECT_EQ(getReassociationIndicesForCollapse( - {ShapedType::kDynamic, 10, ShapedType::kDynamic, 10}, - {ShapedType::kDynamic, 10}), - std::nullopt); - EXPECT_EQ(getReassociationIndicesForCollapse({1, ShapedType::kDynamic, 2, 2}, - {1, ShapedType::kDynamic, 2}), - std::nullopt); - EXPECT_EQ(getReassociationIndicesForCollapse({2, 2, ShapedType::kDynamic, 1}, - {2, ShapedType::kDynamic}), - std::nullopt); -} _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits