Author: Ian Wood
Date: 2025-06-03T09:46:37-07:00
New Revision: 8b407ad6fcb1e722d744edb41dac39bf94117b8a

URL: 
https://github.com/llvm/llvm-project/commit/8b407ad6fcb1e722d744edb41dac39bf94117b8a
DIFF: 
https://github.com/llvm/llvm-project/commit/8b407ad6fcb1e722d744edb41dac39bf94117b8a.diff

LOG: Revert "[mlir][tensor] Loosen restrictions on folding dynamic reshapes 
(#137963)"

This reverts commit cb4a407e5c2a8a5972781d2a3be362f437602fae.

Added: 
    

Modified: 
    mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
    mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir
    mlir/test/Dialect/Tensor/canonicalize.mlir
    mlir/unittests/Dialect/Utils/CMakeLists.txt

Removed: 
    mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp


################################################################################
diff  --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp 
b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
index 3b1fdb69e8ef1..ed40a080441bc 100644
--- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
+++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp
@@ -10,10 +10,6 @@
 
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinTypeInterfaces.h"
-#include "llvm/ADT/ArrayRef.h"
-#include "llvm/ADT/SmallVector.h"
-#include "llvm/Support/LogicalResult.h"
 
 #include <numeric>
 #include <optional>
@@ -32,329 +28,67 @@ mlir::getReassociationIndicesForReshape(ShapedType 
sourceType,
   return std::nullopt;
 }
 
-namespace {
-/// A simple struct to represent ReassociationIndices as an inclusive interval.
-/// It's designed to be feasibly minimal, so the call sites should manage the
-/// validity of the range manually.
-struct ReassociationIndexRange {
-  /// FIXME: Signed type is used for consistency with ReassociationIndices.
-  /// We should consider refactoring all reassociation utilities to use 
unsigned
-  /// types.
-  int64_t leftIdx = 0, rightIdx = 0;
-
-  /// Util for manual checks of the range's validity
-  LogicalResult verify() const {
-    return leftIdx >= 0 && (leftIdx <= rightIdx) ? success() : failure();
-  }
-
-  /// Checks range's containment within another range. Treats the edges
-  /// non-exclusively.
-  bool isInRange(const ReassociationIndexRange &outerRange) const {
-    return leftIdx >= outerRange.leftIdx && rightIdx <= outerRange.rightIdx;
-  }
-
-  unsigned size() const {
-    assert(succeeded(verify()));
-    return rightIdx - leftIdx + 1;
-  }
-  bool containsSingleIndex() const { return size() == 1; }
-
-  /// Collects indices that do not overlap between this and another range.
-  ReassociationIndices
-  getNonOverlappingIndicesWith(ReassociationIndexRange &rhs) const {
-    if (rightIdx < rhs.leftIdx) {
-      // The intervals do not overlap - concatenate the indices from both.
-      auto jointFullIndices = getFullIndices();
-      jointFullIndices.append(rhs.getFullIndices());
-      return jointFullIndices;
-    }
-    ReassociationIndices result;
-    // Handle the chunk left of the overlapping range.
-    int64_t leftStart = std::min(leftIdx, rhs.leftIdx);
-    int64_t leftEnd = std::max(leftIdx, rhs.leftIdx);
-    llvm::append_range(result, llvm::seq(leftStart, leftEnd));
-    // Handle the chunk right of the overlapping range. Symmetrically, we 
should
-    // skip the edge of the overlap AND include the rightmost index.
-    int64_t rightStart = std::min(rightIdx, rhs.rightIdx) + 1;
-    int64_t rightEnd = std::max(rightIdx, rhs.rightIdx);
-    if (rightStart < rightEnd)
-      llvm::append_range(result, llvm::seq_inclusive(rightStart, rightEnd));
-    return result;
-  }
-
-  /// Converts the range into ReassociationIndices.
-  ReassociationIndices getFullIndices() const {
-    ReassociationIndices result;
-    for (int64_t idx = leftIdx; idx <= rightIdx; ++idx) {
-      result.push_back(idx);
-    }
-    return result;
-  }
-};
-} // namespace
-
-/// Starting from `sourceStartIdx`, searches `sourceShape` for the first
-/// sequence that can be collapsed into a dynamic dimension (at least one must
-/// be present in the source).
-/// By default, lazily returns once the first dynamic dimension has been found.
-/// Setting `matchGreedily` as `true` will also mark all subsequent
-/// source dimensions for collapsing into the target.
-static FailureOr<ReassociationIndexRange>
-findReassociationRangeForDynamicDim(ArrayRef<int64_t> sourceShape,
-                                    int64_t sourceStartIdx,
-                                    bool matchGreedily = false) {
-  const unsigned numSourceDims = sourceShape.size();
-  ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};
-  std::optional<ReassociationIndexRange> resultRange = std::nullopt;
-
-  ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx};
-  for (; iterationRange.isInRange(sourceShapeAsRange);
-       iterationRange.rightIdx++) {
-    int64_t sourceSize = sourceShape[iterationRange.rightIdx];
-    if (sourceSize == ShapedType::kDynamic) {
-      resultRange = iterationRange;
-      break;
-    }
-  }
-  if (!resultRange)
-    return failure();
-  if (matchGreedily)
-    resultRange->rightIdx = sourceShapeAsRange.rightIdx;
-  return *resultRange;
-}
+std::optional<SmallVector<ReassociationIndices>>
+mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
+                                         ArrayRef<int64_t> targetShape) {
+  if (sourceShape.size() <= targetShape.size())
+    return std::nullopt;
+  unsigned sourceDim = 0;
+  SmallVector<ReassociationIndices> reassociationMap;
+  reassociationMap.reserve(targetShape.size());
 
-/// Starting from `sourceStartIdx`, searches `sourceShape` for the first
-/// sequence of static dimensions such that their product matches `targetSize`.
-/// By default, lazily returns once the product matches the target size. 
Setting
-/// `matchGreedily` as `true` will append all neighboring unit dimensions
-/// (dimensions of 1) to the match.
-static FailureOr<ReassociationIndexRange>
-findReassociationRangeForSize(ArrayRef<int64_t> sourceShape,
-                              int64_t sourceStartIdx, int64_t targetSize,
-                              bool matchGreedily = false) {
-  const unsigned numSourceDims = sourceShape.size();
-  ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};
-  std::optional<ReassociationIndexRange> resultRange = std::nullopt;
-
-  ReassociationIndexRange iterationRange{sourceStartIdx, sourceStartIdx};
+  ReassociationIndices currIndices;
   int64_t prodOfCollapsedDims = 1;
-  while (iterationRange.isInRange(sourceShapeAsRange)) {
-    int64_t sourceSize = sourceShape[iterationRange.rightIdx];
-    if (sourceSize == ShapedType::kDynamic) {
-      // Reassociation for a static dim cannot include a dynamic dim. Reset
-      // induction variables to essentially restart the loop from the next
-      // source dimension.
-      prodOfCollapsedDims = 1;
-      iterationRange = {iterationRange.rightIdx + 1,
-                        iterationRange.rightIdx + 1};
-      continue;
-    }
-    prodOfCollapsedDims *= sourceSize;
-    // If the target size has been exceeded without matching, we need to shift
-    // the range start right. From the start of the range, roll back the
-    // multiplication until the target size exceeds the product again.
-    while (prodOfCollapsedDims > targetSize &&
-           !iterationRange.containsSingleIndex()) {
-      int64_t frontSourceSize = sourceShape[iterationRange.leftIdx];
-      prodOfCollapsedDims /= frontSourceSize;
-      // Shrink the range rightwards
-      iterationRange.leftIdx++;
-    }
-    // We could've reached the target size with the current dimension,
-    // also as a result of the above shift to right.
-    if (prodOfCollapsedDims == targetSize) {
-      resultRange = iterationRange;
+  while (sourceDim < sourceShape.size()) {
+    unsigned targetDim = reassociationMap.size();
+    // If we have mapped all the target dimensions stop and handle the 
remaining
+    // tail of size-1 dimensions explicitly.
+    if (targetDim == targetShape.size())
       break;
-    }
-    // Increment the iteration range
-    iterationRange.rightIdx++;
-  }
-  if (!resultRange)
-    return failure();
-  if (matchGreedily) {
-    // We now want to collect all unit dimensions directly after the target
-    // product match. Advance the iterator to avoid OOB when the product match
-    // happens at the last element.
-    iterationRange.rightIdx++;
-    while (iterationRange.isInRange(sourceShapeAsRange) &&
-           sourceShape[iterationRange.rightIdx] == 1) {
-      resultRange = iterationRange;
-      iterationRange.rightIdx++;
-    }
-  }
-  return *resultRange;
-}
-
-/// Attempts to find a valid collapsing reassociation of `sourceShape` into
-/// `targetShape` through a simple traversal. If successful, an array of source
-/// index ranges is returned, correspondingly to each dimension in the target
-/// shape. The resulting indices shall fully cover the `sourceShape` without
-/// overlaps.
-///
-/// The algorithm is essentially a lazy one, searching for non-greedy matches -
-/// it will only yield a greedy match for the last target dimension.
-/// FIXME: The algorithm can only backtrack when it needs to append an offset
-/// for a static target dimension to the preceding dynamic one (this retains 
the
-/// linear complexity). As feasible, consider adding further backtracking
-/// routines to enable more reassociations, e.g.:
-/// - ?x2x?x2 into ?x2
-static FailureOr<SmallVector<ReassociationIndexRange>>
-findReassociationRangesForCollapse(ArrayRef<int64_t> sourceShape,
-                                   ArrayRef<int64_t> targetShape) {
-  unsigned numSourceDims = sourceShape.size(),
-           numTargetDims = targetShape.size();
-  assert(numSourceDims > numTargetDims);
-  ReassociationIndexRange sourceShapeAsRange{0, numSourceDims - 1};
-
-  SmallVector<ReassociationIndexRange> reassocRanges;
-  reassocRanges.reserve(numTargetDims);
-  // We'll iterate in strides of 2 to enable pseudo-backtracking for simple
-  // cases, e.g.:
-  // - ?x2x3x5 into ?x15
-  std::optional<int64_t> prevTargetSize = std::nullopt;
-  for (unsigned targetDimIdx = 0, sourceDimIdx = 0;
-       targetDimIdx < numTargetDims; ++targetDimIdx) {
-    int64_t targetSize = targetShape[targetDimIdx];
-    // Simply check if there are any subsequent target dimensions left - if 
not,
-    // the match must be made greedily.
-    bool shouldMatchGreedily = targetDimIdx == numTargetDims - 1;
-    FailureOr<ReassociationIndexRange> sourceRange;
-    if (targetSize == ShapedType::kDynamic) {
-      sourceRange = findReassociationRangeForDynamicDim(
-          sourceShape, sourceDimIdx, shouldMatchGreedily);
-    } else {
-      sourceRange = findReassociationRangeForSize(
-          sourceShape, sourceDimIdx, targetSize, shouldMatchGreedily);
-    }
 
-    // Run sanity checks on the returned index range.
-    if (failed(sourceRange) || failed(sourceRange->verify()) ||
-        !sourceRange->isInRange(sourceShapeAsRange))
-      return failure();
-    if (sourceRange->leftIdx > sourceDimIdx) {
-      // If some source dimensions had to be skipped in order to find a match,
-      // they must be collapsed into the directly preceding dynamic dimension.
-      if (!prevTargetSize || prevTargetSize != ShapedType::kDynamic)
-        return failure();
-      reassocRanges.back().rightIdx = sourceRange->leftIdx - 1;
+    int64_t currTargetShape = targetShape[targetDim];
+    while (sourceDim < (sourceShape.size() - 1) &&
+           sourceShape[sourceDim] != ShapedType::kDynamic &&
+           prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape) {
+      prodOfCollapsedDims *= sourceShape[sourceDim];
+      currIndices.push_back(sourceDim++);
     }
 
-    // Store the gathered information as required for the next iteration.
-    prevTargetSize = targetSize;
-    sourceDimIdx = sourceRange->rightIdx + 1;
-    reassocRanges.push_back(*sourceRange);
-  }
-  // Fail if the source shape wasn't a full match for the target shape. We only
-  // need to check the last recorded index - any other gaps should have been
-  // mended by the main loop.
-  if (reassocRanges.back().rightIdx < sourceShapeAsRange.rightIdx)
-    return failure();
-  return reassocRanges;
-}
-
-/// A variant of `findReassociationRangesForCollapse(...)` that can also scan
-/// the shapes right-to-left.
-static FailureOr<SmallVector<ReassociationIndexRange>>
-findReassociationRangesForCollapse(ArrayRef<int64_t> sourceShape,
-                                   ArrayRef<int64_t> targetShape,
-                                   bool iterateRightToLeft) {
-  if (!iterateRightToLeft)
-    return findReassociationRangesForCollapse(sourceShape, targetShape);
-  // NB: To iterate right-to-left, we currently reverse the shapes and then
-  // reverse the result back. The reversed shapes must not be temporary, as
-  // we're passing through an ArrayRef.
-  // FIXME: It would be preferable to avoid the expensive copies. At the 
moment,
-  // this approach is chosen for readability of the main implementation.
-  std::vector<int64_t> sourceToReverse = sourceShape.vec(),
-                       targetToReverse = targetShape.vec();
-  std::reverse(sourceToReverse.begin(), sourceToReverse.end());
-  std::reverse(targetToReverse.begin(), targetToReverse.end());
-  auto invertedRanges =
-      findReassociationRangesForCollapse(sourceToReverse, targetToReverse);
-  if (failed(invertedRanges))
-    return failure();
-  SmallVector<ReassociationIndexRange> &rangesToInvert = *invertedRanges;
-  unsigned numSourceDims = sourceShape.size();
-  // We have received the ranges for inverted shapes. Now we have to invert
-  // the ranges back to correspond with the original source shape.
-  for (auto &range : rangesToInvert) {
-    int64_t invLeftIdx = range.leftIdx, invRightIdx = range.rightIdx;
-    range.leftIdx = numSourceDims - 1 - invRightIdx;
-    range.rightIdx = numSourceDims - 1 - invLeftIdx;
+    // If the current expanded dimension is dynamic, then the collapsed
+    // dimensions should also be dynamic and product of all previous 
unprocessed
+    // dimensions of the expanded shape should be 1.
+    if (sourceShape[sourceDim] == ShapedType::kDynamic &&
+        (currTargetShape != ShapedType::kDynamic || prodOfCollapsedDims != 1))
+      return std::nullopt;
+
+    // If the collapsed dim is dynamic, the current expanded dim should also
+    // be dynamic.
+    if (currTargetShape == ShapedType::kDynamic &&
+        sourceShape[sourceDim] != ShapedType::kDynamic)
+      return std::nullopt;
+
+    // For static shapes, if the product of dimensions of the expanded shape
+    // should match the collapsed dimension shape.
+    if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape)
+      return std::nullopt;
+
+    currIndices.push_back(sourceDim++);
+    reassociationMap.emplace_back(ReassociationIndices{});
+    std::swap(reassociationMap.back(), currIndices);
+    prodOfCollapsedDims = 1;
   }
-  // Also invert the ordering of the ranges to correspond with the original
-  // target shape.
-  std::reverse(rangesToInvert.begin(), rangesToInvert.end());
-  return rangesToInvert;
-}
-
-std::optional<SmallVector<ReassociationIndices>>
-mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
-                                         ArrayRef<int64_t> targetShape) {
-  unsigned numSourceDims = sourceShape.size(),
-           numTargetDims = targetShape.size();
-  // We're supposed to search for a collapsing reassociation. If the sizes
-  // match, there's no actual collapsing taking place - it's either a no-op or 
a
-  // `tensor.reshape`-style reassociation (that would be beyond the scope of
-  // this utility).
-  if (numSourceDims <= numTargetDims)
+  // All the dimensions in the target must have been processed.
+  if (reassociationMap.size() != targetShape.size())
     return std::nullopt;
-  // Early handling for scalar target types.
-  if (numTargetDims == 0) {
-    ReassociationIndices allSourceIndices;
-    allSourceIndices.reserve(numSourceDims);
-    for (unsigned sourceDimIdx = 0; sourceDimIdx < numSourceDims;
-         ++sourceDimIdx) {
-      int64_t sourceSize = sourceShape[sourceDimIdx];
-      // All source dimensions must be unit or dynamic.
-      if (sourceSize != 1 && sourceSize != ShapedType::kDynamic)
-        return std::nullopt;
-      allSourceIndices.push_back(sourceDimIdx);
-    }
-    return SmallVector<ReassociationIndices>{allSourceIndices};
-  }
-
-  // Collect source ranges by iterating over the target shape left-to-right.
-  FailureOr<SmallVector<ReassociationIndexRange>> maybeForwardRanges =
-      findReassociationRangesForCollapse(sourceShape, targetShape);
-  if (failed(maybeForwardRanges))
-    return std::nullopt;
-  auto &ranges = *maybeForwardRanges;
-  // Now do the same in reverse. We need to get another valid reassociation
-  // through some other strategy, and then compare the results in order to
-  // disambiguate mixed subshapes, such as:
-  // ?x?x? into ?x?, ?x2x? into ?x?, ?x2x3x6x? into ?x6x?
-  // This leads us to lose some of the reassociation opportunities that can 
only
-  // be found by iterating in a certain direction, e.g. 2x2x? into 2x? - 
without
-  // backtracking, the algorithm will fail right-to-left. However, this is the
-  // best way to preserve correctness.
-  FailureOr<SmallVector<ReassociationIndexRange>> maybeReverseRanges =
-      findReassociationRangesForCollapse(sourceShape, targetShape,
-                                         /*iterateRightToLeft=*/true);
-  if (failed(maybeReverseRanges))
-    return std::nullopt;
-  auto &reverseRanges = *maybeReverseRanges;
-
-  if (ranges.size() != numTargetDims || reverseRanges.size() != numTargetDims)
-    return std::nullopt;
-  // Now we can check for ambiguity of each target dimension's reassociation. 
If
-  // successful, we put the full indices into our result map for the target
-  // shape.
-  SmallVector<ReassociationIndices> reassociationMap(numTargetDims);
-  for (unsigned targetDimIdx = 0; targetDimIdx < numTargetDims;
-       ++targetDimIdx) {
-    ReassociationIndexRange &range = ranges[targetDimIdx];
-    ReassociationIndexRange &reverseRange = reverseRanges[targetDimIdx];
-    // Get non-overlapping indices between the ranges
-    ReassociationIndices nonMatchingIndices =
-        range.getNonOverlappingIndicesWith(reverseRange);
-    // Unit dimensions can be collapsed wherever - this is the only ambiguity
-    // that we allow.
-    for (int64_t sourceDimIdx : nonMatchingIndices) {
-      if (sourceShape[sourceDimIdx] != 1)
-        return std::nullopt;
-    }
-    reassociationMap[targetDimIdx] = range.getFullIndices();
+  // Process any remaining entries in the source shape. They all need to be
+  // 1 or dynamic.
+  for (; sourceDim < sourceShape.size(); sourceDim++) {
+    if (sourceShape[sourceDim] != ShapedType::kDynamic &&
+        sourceShape[sourceDim] != 1)
+      return std::nullopt;
+    // The map is empty when the target type is a scalar.
+    if (!reassociationMap.empty())
+      reassociationMap.back().push_back(sourceDim);
   }
   return reassociationMap;
 }
@@ -581,11 +315,11 @@ SmallVector<Range> 
SliceFromCollapseHelper::getExtractSliceParams(
     // have proven that these are not sliced. In this case we just take
     // the full extent of each dimension in the reassociation list.
     if (linearizedDimensions[it.index()]) {
-      llvm::append_range(offsetsSizesAndStrides,
-                         llvm::map_range(it.value(), [&](int64_t idx) -> Range 
{
-                           return {zeroAttr, collapseShapeInputShape[idx],
-                                   oneAttr};
-                         }));
+      llvm::append_range(
+          offsetsSizesAndStrides,
+          llvm::map_range(it.value(), [&](int64_t idx) -> Range {
+            return {zeroAttr, collapseShapeInputShape[idx], oneAttr};
+          }));
       continue;
     }
 

diff  --git a/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir 
b/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir
index 6979770154bab..51350e5bc8498 100644
--- a/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir
+++ b/mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir
@@ -158,8 +158,8 @@ func.func @unpack_to_partial_slice(%arg0: tensor<8x32xf32>) 
-> tensor<255xf32> {
 // -----
 
 // CHECK-LABEL: func.func @unpack_dynamic
-// CHECK:     tensor.collapse
-// CHECK-NOT:         linalg.unpack
+// CHECK-NOT:     tensor.collapse
+// CHECK:         linalg.unpack
 func.func @unpack_dynamic(%arg0: tensor<?x32xf32>) -> tensor<?xf32> {
   %c32 = arith.constant 32 : index
   %c0 = arith.constant 0 : index

diff  --git a/mlir/test/Dialect/Tensor/canonicalize.mlir 
b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 646b2197d9aa6..0abec7e01d184 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1117,7 +1117,7 @@ func.func @fold_expand_of_collapse(%arg0 : 
tensor<3x4x4xf32>) -> tensor<3x4x4xf3
 
 // -----
 
-func.func @fold_expand_of_collapse_mixed_subshape(%arg0 : tensor<?x4x?xf32>, 
%arg1: index, %arg2: index)
+func.func @fold_expand_of_collapse_dynamic(%arg0 : tensor<?x4x?xf32>, %arg1: 
index, %arg2: index)
     -> tensor<?x4x?xf32> {
   %0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
       : tensor<?x4x?xf32> into tensor<?x?xf32>
@@ -1125,28 +1125,12 @@ func.func @fold_expand_of_collapse_mixed_subshape(%arg0 
: tensor<?x4x?xf32>, %ar
       : tensor<?x?xf32> into tensor<?x4x?xf32>
   return %1 : tensor<?x4x?xf32>
 }
-// CHECK-LABEL: @fold_expand_of_collapse_mixed_subshape
+// CHECK-LABEL: @fold_expand_of_collapse_dynamic
 //   CHECK-NOT:   tensor.{{.*}}_shape
 
 // -----
 
-func.func @fold_expand_of_collapse_mixed_target_subshape(%arg0 : 
tensor<?x4x?x2xf32>, %arg1: index, %arg2: index)
-    -> tensor<?x4x?xf32> {
-  %0 = tensor.collapse_shape %arg0 [[0, 1], [2, 3]]
-      : tensor<?x4x?x2xf32> into tensor<?x?xf32>
-  %1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, 4, %arg2]
-      : tensor<?x?xf32> into tensor<?x4x?xf32>
-  return %1 : tensor<?x4x?xf32>
-}
-// CHECK-LABEL: @fold_expand_of_collapse_mixed_target_subshape
-//   CHECK-NOT:   tensor.expand_shape
-//       CHECK:   %[[COLLAPSE:.+]] = tensor.collapse_shape %arg0 {{\[}}[0], 
[1], [2, 3]]
-//  CHECK-SAME:     : tensor<?x4x?x2xf32> into tensor<?x4x?xf32>
-//  CHECK-NEXT:   return %[[COLLAPSE]]
-
-// -----
-
-func.func @no_fold_expand_of_collapse_fully_dynamic(%arg0 : tensor<?x?x?xf32>, 
%arg1: index, %arg2: index, %arg3: index)
+func.func @no_fold_expand_of_collapse_dynamic(%arg0 : tensor<?x?x?xf32>, 
%arg1: index, %arg2: index, %arg3: index)
     -> tensor<?x?x?xf32> {
   %0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
       : tensor<?x?x?xf32> into tensor<?x?xf32>
@@ -1154,22 +1138,7 @@ func.func 
@no_fold_expand_of_collapse_fully_dynamic(%arg0 : tensor<?x?x?xf32>, %
       : tensor<?x?xf32> into tensor<?x?x?xf32>
   return %1 : tensor<?x?x?xf32>
 }
-// CHECK-LABEL: @no_fold_expand_of_collapse_fully_dynamic
-//       CHECK:   tensor.collapse_shape
-//       CHECK:   %[[EXPAND:.+]] = tensor.expand_shape
-//       CHECK:   return %[[EXPAND]]
-
-// -----
-
-func.func @no_fold_expand_of_collapse_adjacent_dynamic(%arg0 : 
tensor<?x?x?xf32>, %arg1: index, %arg2: index)
-    -> tensor<?x?xf32> {
-  %0 = tensor.collapse_shape %arg0 [[0, 1, 2]]
-      : tensor<?x?x?xf32> into tensor<?xf32>
-  %1 = tensor.expand_shape %0 [[0, 1]] output_shape [%arg1, %arg2]
-      : tensor<?xf32> into tensor<?x?xf32>
-  return %1 : tensor<?x?xf32>
-}
-// CHECK-LABEL: @no_fold_expand_of_collapse_adjacent_dynamic
+// CHECK-LABEL: @no_fold_expand_of_collapse_dynamic
 //       CHECK:   tensor.collapse_shape
 //       CHECK:   %[[EXPAND:.+]] = tensor.expand_shape
 //       CHECK:   return %[[EXPAND]]

diff  --git a/mlir/unittests/Dialect/Utils/CMakeLists.txt 
b/mlir/unittests/Dialect/Utils/CMakeLists.txt
index e921c8bcfb4e5..61b9cdcb3b8f3 100644
--- a/mlir/unittests/Dialect/Utils/CMakeLists.txt
+++ b/mlir/unittests/Dialect/Utils/CMakeLists.txt
@@ -1,6 +1,5 @@
 add_mlir_unittest(MLIRDialectUtilsTests
   StructuredOpsUtilsTest.cpp
-  ReshapeOpsUtilsTest.cpp
   IndexingUtilsTest.cpp
 )
 mlir_target_link_libraries(MLIRDialectUtilsTests

diff  --git a/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp 
b/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp
deleted file mode 100644
index db1a87a4de2d5..0000000000000
--- a/mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp
+++ /dev/null
@@ -1,203 +0,0 @@
-//===- ReshapeOpsUtilsTest.cpp - ReshapeOpsUtils unit tests 
---------------===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM 
Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
-#include "mlir/IR/BuiltinTypeInterfaces.h"
-#include "llvm/ADT/STLExtras.h"
-#include "gtest/gtest.h"
-#include <optional>
-
-using namespace mlir;
-
-/// Helper to make constructing
-/// `std::optional<SmallVector<ReassociationIndices>>` more readable.
-static std::optional<SmallVector<ReassociationIndices>>
-makeOptionalIndices(std::initializer_list<ReassociationIndices> list) {
-  return std::optional<SmallVector<ReassociationIndices>>(list);
-}
-
-TEST(ReassociationIndicesForCollapse, ScalarTest) {
-  EXPECT_EQ(getReassociationIndicesForCollapse({1}, {}),
-            makeOptionalIndices({{0}}));
-  EXPECT_EQ(getReassociationIndicesForCollapse({1, 1}, {}),
-            makeOptionalIndices({{0, 1}}));
-  EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic}, {}),
-            makeOptionalIndices({{0}}));
-  EXPECT_EQ(getReassociationIndicesForCollapse({1, ShapedType::kDynamic,
-                                                ShapedType::kDynamic, 1,
-                                                ShapedType::kDynamic},
-                                               {}),
-            makeOptionalIndices({{0, 1, 2, 3, 4}}));
-}
-
-TEST(ReassociationIndicesForCollapse, ScalarTestFailure) {
-  EXPECT_EQ(getReassociationIndicesForCollapse({}, {}), std::nullopt);
-  EXPECT_EQ(getReassociationIndicesForCollapse({}, {1}), std::nullopt);
-  EXPECT_EQ(getReassociationIndicesForCollapse({2}, {}), std::nullopt);
-  EXPECT_EQ(
-      getReassociationIndicesForCollapse({1, 2, ShapedType::kDynamic, 1}, {}),
-      std::nullopt);
-}
-
-TEST(ReassociationIndicesForCollapse, StaticTest) {
-  EXPECT_EQ(getReassociationIndicesForCollapse({10, 20}, {200}),
-            makeOptionalIndices({{0, 1}}));
-  EXPECT_EQ(getReassociationIndicesForCollapse({10, 20, 30}, {10, 600}),
-            makeOptionalIndices({{0}, {1, 2}}));
-  EXPECT_EQ(getReassociationIndicesForCollapse({10, 20, 30}, {200, 30}),
-            makeOptionalIndices({{0, 1}, {2}}));
-}
-
-TEST(ReassociationIndicesForCollapse, StaticTestFailure) {
-  // No-op reassociation
-  EXPECT_EQ(getReassociationIndicesForCollapse({10, 20}, {10, 20}),
-            std::nullopt);
-  // Invalid static reassociations
-  EXPECT_EQ(getReassociationIndicesForCollapse({10, 20}, {10}), std::nullopt);
-  EXPECT_EQ(getReassociationIndicesForCollapse({10, 20, 30}, {200, 300}),
-            std::nullopt);
-  // Non-collapsing (expanding) reassociation
-  EXPECT_EQ(getReassociationIndicesForCollapse({10, 20, 30}, {1, 10, 20, 30}),
-            std::nullopt);
-}
-
-TEST(ReassociationIndicesForCollapse, StaticTestUnitDims) {
-  EXPECT_EQ(getReassociationIndicesForCollapse({10, 1}, {10}),
-            makeOptionalIndices({{0, 1}}));
-  EXPECT_EQ(getReassociationIndicesForCollapse({1, 20, 30}, {600}),
-            makeOptionalIndices({{0, 1, 2}}));
-  EXPECT_EQ(getReassociationIndicesForCollapse({1, 1, 1}, {1}),
-            makeOptionalIndices({{0, 1, 2}}));
-  EXPECT_EQ(getReassociationIndicesForCollapse({1, 1, 1, 1}, {1, 1, 1}),
-            makeOptionalIndices({{0}, {1}, {2, 3}}));
-}
-
-TEST(ReassociationIndicesForCollapse, DynamicTest) {
-  EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 1},
-                                               {ShapedType::kDynamic}),
-            makeOptionalIndices({{0, 1}}));
-  EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 1, 1},
-                                               {ShapedType::kDynamic}),
-            makeOptionalIndices({{0, 1, 2}}));
-  EXPECT_EQ(getReassociationIndicesForCollapse(
-                {1, ShapedType::kDynamic, 1, ShapedType::kDynamic, 1},
-                {ShapedType::kDynamic, ShapedType::kDynamic}),
-            makeOptionalIndices({{0, 1}, {2, 3, 4}}));
-  EXPECT_EQ(
-      getReassociationIndicesForCollapse(
-          {ShapedType::kDynamic, ShapedType::kDynamic}, 
{ShapedType::kDynamic}),
-      makeOptionalIndices({{0, 1}}));
-  EXPECT_EQ(getReassociationIndicesForCollapse(
-                {1, ShapedType::kDynamic, ShapedType::kDynamic},
-                {1, ShapedType::kDynamic}),
-            makeOptionalIndices({{0}, {1, 2}}));
-
-  EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 10},
-                                               {ShapedType::kDynamic}),
-            makeOptionalIndices({{0, 1}}));
-  EXPECT_EQ(getReassociationIndicesForCollapse(
-                {1, ShapedType::kDynamic, ShapedType::kDynamic},
-                {ShapedType::kDynamic}),
-            makeOptionalIndices({{0, 1, 2}}));
-  EXPECT_EQ(getReassociationIndicesForCollapse({10, ShapedType::kDynamic},
-                                               {ShapedType::kDynamic}),
-            makeOptionalIndices({{0, 1}}));
-  EXPECT_EQ(getReassociationIndicesForCollapse(
-                {ShapedType::kDynamic, 1, 2, ShapedType::kDynamic, 10},
-                {ShapedType::kDynamic, 10}),
-            makeOptionalIndices({{0, 1, 2, 3}, {4}}));
-  EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 10, 20},
-                                               {ShapedType::kDynamic, 20}),
-            makeOptionalIndices({{0, 1}, {2}}));
-  EXPECT_EQ(getReassociationIndicesForCollapse({10, ShapedType::kDynamic, 20},
-                                               {ShapedType::kDynamic, 20}),
-            makeOptionalIndices({{0, 1}, {2}}));
-  EXPECT_EQ(getReassociationIndicesForCollapse(
-                {ShapedType::kDynamic, 3, 2, 5, 2}, {ShapedType::kDynamic, 
20}),
-            makeOptionalIndices({{0, 1}, {2, 3, 4}}));
-  EXPECT_EQ(getReassociationIndicesForCollapse(
-                {10, ShapedType::kDynamic, 20, ShapedType::kDynamic, 1},
-                {ShapedType::kDynamic, 20, ShapedType::kDynamic}),
-            makeOptionalIndices({{0, 1}, {2}, {3, 4}}));
-  EXPECT_EQ(getReassociationIndicesForCollapse({1, ShapedType::kDynamic, 1},
-                                               {ShapedType::kDynamic}),
-            makeOptionalIndices({{0, 1, 2}}));
-  EXPECT_EQ(getReassociationIndicesForCollapse(
-                {ShapedType::kDynamic, ShapedType::kDynamic, 1},
-                {ShapedType::kDynamic, ShapedType::kDynamic}),
-            makeOptionalIndices({{0}, {1, 2}}));
-  EXPECT_EQ(getReassociationIndicesForCollapse(
-                {1, ShapedType::kDynamic, ShapedType::kDynamic},
-                {ShapedType::kDynamic, ShapedType::kDynamic}),
-            makeOptionalIndices({{0, 1}, {2}}));
-  EXPECT_EQ(getReassociationIndicesForCollapse(
-                {ShapedType::kDynamic, 1, ShapedType::kDynamic},
-                {ShapedType::kDynamic, ShapedType::kDynamic}),
-            makeOptionalIndices({{0}, {1, 2}}));
-}
-
-TEST(ReassociationIndicesForCollapse, DynamicTestFailure) {
-  EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic, 10, 20},
-                                               {ShapedType::kDynamic, 10}),
-            std::nullopt);
-  EXPECT_EQ(getReassociationIndicesForCollapse(
-                {ShapedType::kDynamic, 10, ShapedType::kDynamic},
-                {ShapedType::kDynamic, ShapedType::kDynamic}),
-            std::nullopt);
-  EXPECT_EQ(getReassociationIndicesForCollapse(
-                {20, ShapedType::kDynamic, 10, ShapedType::kDynamic},
-                {ShapedType::kDynamic, ShapedType::kDynamic}),
-            std::nullopt);
-  EXPECT_EQ(getReassociationIndicesForCollapse(
-                {ShapedType::kDynamic, 5, 3, 2, 2}, {ShapedType::kDynamic, 
20}),
-            std::nullopt);
-  EXPECT_EQ(
-      getReassociationIndicesForCollapse(
-          {ShapedType::kDynamic, ShapedType::kDynamic, ShapedType::kDynamic},
-          {ShapedType::kDynamic, ShapedType::kDynamic}),
-      std::nullopt);
-  EXPECT_EQ(getReassociationIndicesForCollapse(
-                {ShapedType::kDynamic, ShapedType::kDynamic, 10, 1,
-                 ShapedType::kDynamic},
-                {ShapedType::kDynamic, ShapedType::kDynamic}),
-            std::nullopt);
-  EXPECT_EQ(getReassociationIndicesForCollapse(
-                {ShapedType::kDynamic, 10, 10, 10, ShapedType::kDynamic},
-                {ShapedType::kDynamic, 10, ShapedType::kDynamic}),
-            std::nullopt);
-  EXPECT_EQ(getReassociationIndicesForCollapse(
-                {ShapedType::kDynamic, 10, 10, 10, ShapedType::kDynamic},
-                {ShapedType::kDynamic, 2, 2, ShapedType::kDynamic}),
-            std::nullopt);
-  EXPECT_EQ(getReassociationIndicesForCollapse(
-                {ShapedType::kDynamic, 3, 4, 3, ShapedType::kDynamic},
-                {ShapedType::kDynamic, 12, ShapedType::kDynamic}),
-            std::nullopt);
-  EXPECT_EQ(getReassociationIndicesForCollapse(
-                {ShapedType::kDynamic, 8, 4, 2, 16, ShapedType::kDynamic},
-                {ShapedType::kDynamic, 32, ShapedType::kDynamic}),
-            std::nullopt);
-
-  
//===----------------------------------------------------------------------===//
-  // TODO: Reassociation for the following examples can be computed, but isn't
-  // supported by `getReassociationIndicesForCollapse`.
-  
//===----------------------------------------------------------------------===//
-
-  // TODO: Fails because there's no backtracking when some source dimensions
-  // remain unmatched at either edge.
-  EXPECT_EQ(getReassociationIndicesForCollapse(
-                {ShapedType::kDynamic, 10, ShapedType::kDynamic, 10},
-                {ShapedType::kDynamic, 10}),
-            std::nullopt);
-  EXPECT_EQ(getReassociationIndicesForCollapse({1, ShapedType::kDynamic, 2, 2},
-                                               {1, ShapedType::kDynamic, 2}),
-            std::nullopt);
-  EXPECT_EQ(getReassociationIndicesForCollapse({2, 2, ShapedType::kDynamic, 1},
-                                               {2, ShapedType::kDynamic}),
-            std::nullopt);
-}


        
_______________________________________________
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to