llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-mlir-scf Author: Matthias Springer (matthias-springer) <details> <summary>Changes</summary> This commit adds a new public API to `ValueBoundsOpInterface` to compare values/dims. Supported comparison operators are: LT, LE, EQ, GE, GT. The new `ValueBoundsOpInterface::compare` API replaces and generalizes `ValueBoundsOpInterface::areEqual`. Not only does it provide additional comparison operators, it also works in cases where the difference between the two values/dims is non-constant. The previous implementation of `areEqual` used to compute a constant bound of `val1 - val2`. Note: This commit refactors, generalizes and adds a public API for value/dim comparison. The comparison functionality itself was introduced in #<!-- -->85895 and is already in use for analyzing `scf.if`. In the long term, this improvement will allow for a more powerful analysis of subset ops. A future commit will update `areOverlappingSlices` to use the new comparison API. (`areEquivalentSlices` is already using the new API.) This will improve subset equivalence/disjointness checks with non-constant offsets/sizes/strides. --- Patch is 31.17 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/86915.diff 7 Files Affected: - (modified) mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h (+51-10) - (modified) mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp (+9-22) - (modified) mlir/lib/Interfaces/ValueBoundsOpInterface.cpp (+181-58) - (modified) mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir (+42-1) - (modified) mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir (+12) - (modified) mlir/test/Dialect/Tensor/value-bounds-op-interface-impl.mlir (+8-8) - (modified) mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp (+66-13) ``````````diff diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h index f35432ca0136f3..d27081fad8c6c0 100644 --- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h +++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h @@ -211,7 +211,8 @@ class ValueBoundsConstraintSet /// Comparison operator for `ValueBoundsConstraintSet::compare`. enum ComparisonOperator { LT, LE, EQ, GT, GE }; - /// Try to prove that, based on the current state of this constraint set + /// Populate constraints for lhs/rhs (until the stop condition is met). Then, + /// try to prove that, based on the current state of this constraint set /// (i.e., without analyzing additional IR or adding new constraints), the /// "lhs" value/dim is LE/LT/EQ/GT/GE than the "rhs" value/dim. /// @@ -220,24 +221,37 @@ class ValueBoundsConstraintSet /// proven. This could be because the specified relation does in fact not hold /// or because there is not enough information in the constraint set. In other /// words, if we do not know for sure, this function returns "false". - bool compare(Value lhs, std::optional<int64_t> lhsDim, ComparisonOperator cmp, - Value rhs, std::optional<int64_t> rhsDim); + bool populateAndCompare(OpFoldResult lhs, std::optional<int64_t> lhsDim, + ComparisonOperator cmp, OpFoldResult rhs, + std::optional<int64_t> rhsDim); + + /// Return "true" if "lhs cmp rhs" was proven to hold. Return "false" if the + /// specified relation could not be proven. This could be because the + /// specified relation does in fact not hold or because there is not enough + /// information in the constraint set. In other words, if we do not know for + /// sure, this function returns "false". + /// + /// This function keeps traversing the backward slice of lhs/rhs until could + /// prove the relation or until it ran out of IR. + static bool compare(OpFoldResult lhs, std::optional<int64_t> lhsDim, + ComparisonOperator cmp, OpFoldResult rhs, + std::optional<int64_t> rhsDim); + static bool compare(AffineMap lhs, ValueDimList lhsOperands, + ComparisonOperator cmp, AffineMap rhs, + ValueDimList rhsOperands); + static bool compare(AffineMap lhs, ArrayRef<Value> lhsOperands, + ComparisonOperator cmp, AffineMap rhs, + ArrayRef<Value> rhsOperands); /// Compute whether the given values/dimensions are equal. Return "failure" if /// equality could not be determined. /// /// `dim1`/`dim2` must be `nullopt` if and only if `value1`/`value2` are /// index-typed. - static FailureOr<bool> areEqual(Value value1, Value value2, + static FailureOr<bool> areEqual(OpFoldResult value1, OpFoldResult value2, std::optional<int64_t> dim1 = std::nullopt, std::optional<int64_t> dim2 = std::nullopt); - /// Compute whether the given values/attributes are equal. Return "failure" if - /// equality could not be determined. - /// - /// `ofr1`/`ofr2` must be of index type. - static FailureOr<bool> areEqual(OpFoldResult ofr1, OpFoldResult ofr2); - /// Return "true" if the given slices are guaranteed to be overlapping. /// Return "false" if the given slices are guaranteed to be non-overlapping. /// Return "failure" if unknown. @@ -290,6 +304,20 @@ class ValueBoundsConstraintSet ValueBoundsConstraintSet(MLIRContext *ctx, StopConditionFn stopCondition); + /// Return "true" if, based on the current state of the constraint system, + /// "lhs cmp rhs" was proven to hold. Return "false" if the specified relation + /// could not be proven. This could be because the specified relation does in + /// fact not hold or because there is not enough information in the constraint + /// set. In other words, if we do not know for sure, this function returns + /// "false". + /// + /// This function does not analyze any IR and does not populate any additional + /// constraints. + bool compareValueDims(OpFoldResult lhs, std::optional<int64_t> lhsDim, + ComparisonOperator cmp, OpFoldResult rhs, + std::optional<int64_t> rhsDim); + bool comparePos(int64_t lhsPos, ComparisonOperator cmp, int64_t rhsPos); + /// Given an affine map with a single result (and map operands), add a new /// column to the constraint set that represents the result of the map. /// Traverse additional IR starting from the map operands as needed (as long @@ -311,6 +339,14 @@ class ValueBoundsConstraintSet /// value/dimension exists in the constraint set. int64_t getPos(Value value, std::optional<int64_t> dim = std::nullopt) const; + /// Return an affine expression that represents column `pos` in the constraint + /// set. + AffineExpr getPosExpr(int64_t pos); + + /// Return "true" if the given value/dim is mapped (i.e., has a corresponding + /// column in the constraint system). + bool isMapped(Value value, std::optional<int64_t> dim = std::nullopt) const; + /// Insert a value/dimension into the constraint set. If `isSymbol` is set to /// "false", a dimension is added. The value/dimension is added to the /// worklist if `addToWorklist` is set. @@ -330,6 +366,11 @@ class ValueBoundsConstraintSet /// dimensions but not for symbols. int64_t insert(bool isSymbol = true); + /// Insert the given affine map and its bound operands as a new column in the + /// constraint system. Return the position of the new column. Any operands + /// that were not analyzed yet are put on the worklist. + int64_t insert(AffineMap map, ValueDimList operands, bool isSymbol = true); + /// Project out the given column in the constraint set. void projectOut(int64_t pos); diff --git a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp index 72c5aaa2306783..087ffc438a830a 100644 --- a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp @@ -58,20 +58,11 @@ struct ForOpInterface Value iterArg = forOp.getRegionIterArg(iterArgIdx); Value initArg = forOp.getInitArgs()[iterArgIdx]; - // Populate constraints for the yielded value. - cstr.populateConstraints(yieldedValue, dim); - // Populate constraints for the iter_arg. This is just to ensure that the - // iter_arg is mapped in the constraint set, which is a prerequisite for - // `compare`. It may lead to a recursive call to this function in case the - // iter_arg was not visited when the constraints for the yielded value were - // populated, but no additional work is done. - cstr.populateConstraints(iterArg, dim); - // An EQ constraint can be added if the yielded value (dimension size) // equals the corresponding block argument (dimension size). - if (cstr.compare(yieldedValue, dim, - ValueBoundsConstraintSet::ComparisonOperator::EQ, iterArg, - dim)) { + if (cstr.populateAndCompare( + yieldedValue, dim, ValueBoundsConstraintSet::ComparisonOperator::EQ, + iterArg, dim)) { if (dim.has_value()) { cstr.bound(value)[*dim] == cstr.getExpr(initArg, dim); } else { @@ -113,10 +104,6 @@ struct IfOpInterface Value thenValue = ifOp.thenYield().getResults()[resultNum]; Value elseValue = ifOp.elseYield().getResults()[resultNum]; - // Populate constraints for the yielded value (and all values on the - // backward slice, as long as the current stop condition is not satisfied). - cstr.populateConstraints(thenValue, dim); - cstr.populateConstraints(elseValue, dim); auto boundsBuilder = cstr.bound(value); if (dim) boundsBuilder[*dim]; @@ -125,9 +112,9 @@ struct IfOpInterface // If thenValue <= elseValue: // * result <= elseValue // * result >= thenValue - if (cstr.compare(thenValue, dim, - ValueBoundsConstraintSet::ComparisonOperator::LE, - elseValue, dim)) { + if (cstr.populateAndCompare( + thenValue, dim, ValueBoundsConstraintSet::ComparisonOperator::LE, + elseValue, dim)) { if (dim) { cstr.bound(value)[*dim] >= cstr.getExpr(thenValue, dim); cstr.bound(value)[*dim] <= cstr.getExpr(elseValue, dim); @@ -139,9 +126,9 @@ struct IfOpInterface // If elseValue <= thenValue: // * result <= thenValue // * result >= elseValue - if (cstr.compare(elseValue, dim, - ValueBoundsConstraintSet::ComparisonOperator::LE, - thenValue, dim)) { + if (cstr.populateAndCompare( + elseValue, dim, ValueBoundsConstraintSet::ComparisonOperator::LE, + thenValue, dim)) { if (dim) { cstr.bound(value)[*dim] >= cstr.getExpr(elseValue, dim); cstr.bound(value)[*dim] <= cstr.getExpr(thenValue, dim); diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp index dd98da9adc7d96..d7ffed14daccdd 100644 --- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp +++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp @@ -212,6 +212,28 @@ int64_t ValueBoundsConstraintSet::insert(bool isSymbol) { return pos; } +int64_t ValueBoundsConstraintSet::insert(AffineMap map, ValueDimList operands, + bool isSymbol) { + assert(map.getNumResults() == 1 && "expected affine map with one result"); + int64_t pos = insert(/*isSymbol=*/false); + + // Add map and operands to the constraint set. Dimensions are converted to + // symbols. All operands are added to the worklist (unless they were already + // processed). + auto mapper = [&](std::pair<Value, std::optional<int64_t>> v) { + return getExpr(v.first, v.second); + }; + SmallVector<AffineExpr> dimReplacements = llvm::to_vector( + llvm::map_range(ArrayRef(operands).take_front(map.getNumDims()), mapper)); + SmallVector<AffineExpr> symReplacements = llvm::to_vector( + llvm::map_range(ArrayRef(operands).drop_front(map.getNumDims()), mapper)); + addBound( + presburger::BoundType::EQ, pos, + map.getResult(0).replaceDimsAndSymbols(dimReplacements, symReplacements)); + + return pos; +} + int64_t ValueBoundsConstraintSet::getPos(Value value, std::optional<int64_t> dim) const { #ifndef NDEBUG @@ -227,6 +249,20 @@ int64_t ValueBoundsConstraintSet::getPos(Value value, return it->second; } +AffineExpr ValueBoundsConstraintSet::getPosExpr(int64_t pos) { + assert(pos >= 0 && pos < cstr.getNumDimAndSymbolVars() && "invalid position"); + return pos < cstr.getNumDimVars() + ? builder.getAffineDimExpr(pos) + : builder.getAffineSymbolExpr(pos - cstr.getNumDimVars()); +} + +bool ValueBoundsConstraintSet::isMapped(Value value, + std::optional<int64_t> dim) const { + auto it = + valueDimToPosition.find(std::make_pair(value, dim.value_or(kIndexValue))); + return it != valueDimToPosition.end(); +} + static Operation *getOwnerOfValue(Value value) { if (auto bbArg = dyn_cast<BlockArgument>(value)) return bbArg.getOwner()->getParentOp(); @@ -563,27 +599,10 @@ void ValueBoundsConstraintSet::populateConstraints(Value value, int64_t ValueBoundsConstraintSet::populateConstraints(AffineMap map, ValueDimList operands) { - assert(map.getNumResults() == 1 && "expected affine map with one result"); - int64_t pos = insert(/*isSymbol=*/false); - - // Add map and operands to the constraint set. Dimensions are converted to - // symbols. All operands are added to the worklist (unless they were already - // processed). - auto mapper = [&](std::pair<Value, std::optional<int64_t>> v) { - return getExpr(v.first, v.second); - }; - SmallVector<AffineExpr> dimReplacements = llvm::to_vector( - llvm::map_range(ArrayRef(operands).take_front(map.getNumDims()), mapper)); - SmallVector<AffineExpr> symReplacements = llvm::to_vector( - llvm::map_range(ArrayRef(operands).drop_front(map.getNumDims()), mapper)); - addBound( - presburger::BoundType::EQ, pos, - map.getResult(0).replaceDimsAndSymbols(dimReplacements, symReplacements)); - + int64_t pos = insert(map, operands, /*isSymbol=*/false); // Process the backward slice of `operands` (i.e., reverse use-def chain) // until `stopCondition` is met. processWorklist(); - return pos; } @@ -603,9 +622,18 @@ ValueBoundsConstraintSet::computeConstantDelta(Value value1, Value value2, {{value1, dim1}, {value2, dim2}}); } -bool ValueBoundsConstraintSet::compare(Value lhs, std::optional<int64_t> lhsDim, - ComparisonOperator cmp, Value rhs, - std::optional<int64_t> rhsDim) { +bool ValueBoundsConstraintSet::compareValueDims(OpFoldResult lhs, + std::optional<int64_t> lhsDim, + ComparisonOperator cmp, + OpFoldResult rhs, + std::optional<int64_t> rhsDim) { +#ifndef NDEBUG + if (auto lhsVal = dyn_cast<Value>(lhs)) + assertValidValueDim(lhsVal, lhsDim); + if (auto rhsVal = dyn_cast<Value>(rhs)) + assertValidValueDim(rhsVal, rhsDim); +#endif // NDEBUG + // This function returns "true" if "lhs CMP rhs" is proven to hold. // // Example for ComparisonOperator::LE and index-typed values: We would like to @@ -624,19 +652,61 @@ bool ValueBoundsConstraintSet::compare(Value lhs, std::optional<int64_t> lhsDim, // EQ can be expressed as LE and GE. if (cmp == EQ) - return compare(lhs, lhsDim, ComparisonOperator::LE, rhs, rhsDim) && - compare(lhs, lhsDim, ComparisonOperator::GE, rhs, rhsDim); + return compareValueDims(lhs, lhsDim, ComparisonOperator::LE, rhs, rhsDim) && + compareValueDims(lhs, lhsDim, ComparisonOperator::GE, rhs, rhsDim); // Construct inequality. For the above example: lhs > rhs. // `IntegerRelation` inequalities are expressed in the "flattened" form and // with ">= 0". I.e., lhs - rhs - 1 >= 0. + SmallVector<int64_t> eq(cstr.getNumCols(), 0); + auto addToEq = [&](OpFoldResult ofr, std::optional<int64_t> dim, + int64_t factor) { + if (auto constVal = ::getConstantIntValue(ofr)) { + eq[cstr.getNumCols() - 1] += *constVal * factor; + } else { + eq[getPos(cast<Value>(ofr), dim)] += factor; + } + }; + if (cmp == LT || cmp == LE) { + addToEq(lhs, lhsDim, 1); + addToEq(rhs, rhsDim, -1); + } else if (cmp == GT || cmp == GE) { + addToEq(lhs, lhsDim, -1); + addToEq(rhs, rhsDim, 1); + } else { + llvm_unreachable("unsupported comparison operator"); + } + if (cmp == LE || cmp == GE) + eq[cstr.getNumCols() - 1] -= 1; + + // Add inequality to the constraint set and check if it made the constraint + // set empty. + int64_t ineqPos = cstr.getNumInequalities(); + cstr.addInequality(eq); + bool isEmpty = cstr.isEmpty(); + cstr.removeInequality(ineqPos); + return isEmpty; +} + +bool ValueBoundsConstraintSet::comparePos(int64_t lhsPos, + ComparisonOperator cmp, + int64_t rhsPos) { + // This function returns "true" if "lhs CMP rhs" is proven to hold. For + // detailed documentation, see `compareValueDims`. + + // EQ can be expressed as LE and GE. + if (cmp == EQ) + return comparePos(lhsPos, ComparisonOperator::LE, rhsPos) && + comparePos(lhsPos, ComparisonOperator::GE, rhsPos); + + // Construct inequality. SmallVector<int64_t> eq(cstr.getNumDimAndSymbolVars() + 1, 0); if (cmp == LT || cmp == LE) { - ++eq[getPos(lhs, lhsDim)]; - --eq[getPos(rhs, rhsDim)]; + ++eq[lhsPos]; + --eq[rhsPos]; } else if (cmp == GT || cmp == GE) { - --eq[getPos(lhs, lhsDim)]; - ++eq[getPos(rhs, rhsDim)]; + --eq[lhsPos]; + ++eq[rhsPos]; } else { llvm_unreachable("unsupported comparison operator"); } @@ -652,40 +722,93 @@ bool ValueBoundsConstraintSet::compare(Value lhs, std::optional<int64_t> lhsDim, return isEmpty; } +bool ValueBoundsConstraintSet::populateAndCompare( + OpFoldResult lhs, std::optional<int64_t> lhsDim, ComparisonOperator cmp, + OpFoldResult rhs, std::optional<int64_t> rhsDim) { +#ifndef NDEBUG + if (auto lhsVal = dyn_cast<Value>(lhs)) + assertValidValueDim(lhsVal, lhsDim); + if (auto rhsVal = dyn_cast<Value>(rhs)) + assertValidValueDim(rhsVal, rhsDim); +#endif // NDEBUG + + if (auto lhsVal = dyn_cast<Value>(lhs)) + populateConstraints(lhsVal, lhsDim); + if (auto rhsVal = dyn_cast<Value>(rhs)) + populateConstraints(rhsVal, rhsDim); + + return compareValueDims(lhs, lhsDim, cmp, rhs, rhsDim); +} + +bool ValueBoundsConstraintSet::compare(OpFoldResult lhs, + std::optional<int64_t> lhsDim, + ComparisonOperator cmp, OpFoldResult rhs, + std::optional<int64_t> rhsDim) { + auto stopCondition = [&](Value v, std::optional<int64_t> dim, + ValueBoundsConstraintSet &cstr) { + // Keep processing as long as lhs/rhs are not mapped. + if (auto lhsVal = dyn_cast<Value>(lhs)) + if (!cstr.isMapped(lhsVal, dim)) + return false; + if (auto rhsVal = dyn_cast<Value>(rhs)) + if (!cstr.isMapped(rhsVal, dim)) + return false; + // Keep processing as long as the relation cannot be proven. + return cstr.compareValueDims(lhs, lhsDim, cmp, rhs, rhsDim); + }; + + ValueBoundsConstraintSet cstr(lhs.getContext(), stopCondition); + return cstr.populateAndCompare(lhs, lhsDim, cmp, rhs, rhsDim); +} + +bool ValueBoundsConstraintSet::compare(AffineMap lhs, ValueDimList lhsOperands, + ComparisonOperator cmp, AffineMap rhs, + ValueDimList rhsOperands) { + int64_t lhsPos = -1, rhsPos = -1; + auto stopCondition = [&](Value v, std::optional<int64_t> dim, + ValueBoundsConstraintSet &cstr) { + // Keep processing as long as lhs/rhs were not processed. + if (lhsPos >= cstr.positionToValueDim.size() || + rhsPos >= cstr.positionToValueDim.size()) + return false; + // Keep processing as long as the relation cannot be proven. + return cstr.comparePos(lhsPos, cmp, rhsPos); + }; + ValueBoundsConstraintSet cstr(lhs.getContext(), stopCondition); + lhsPos = cstr.insert(lhs, lhsOperands); + rhsPos = cstr.insert(rhs, rhsOperands); + return cstr.comparePos(lhsPos, cmp, rhsPos); +} + +bool ValueBoundsConstraintSet::compare(AffineMap lhs, + ArrayRef<Value> lhsOperands, + ComparisonOperator cmp, AffineMap rhs, + ArrayRef<Value> rhsOperands) { + ValueDimList lhsValueDimOperands = + llvm::map_to_vector(lhsOperands, [](Value v) { + return std::make_pair(v, std::optional<int64_t>()); + }); + ValueDimList rhsValueDimOperands = + llvm::map_to_vector(rhsOperands, [](Value v) { + return std::make_pair(v, std::optional<int64_t>()); + }); + return ValueBoundsConstraintSet::compare(lhs, lhsValueDimOperands, cmp, rhs, + rhsValueDimOperands); +} + FailureOr<bool> -ValueBoundsConstraintSet::areEqual(Value value1, Value value2, +ValueBoundsConstraintSet::areEqual(OpFoldResult value1, OpFoldResult value2, std::optional<int64_t> dim1, std::optional<int64_t> dim2) { - // Subtract the two values/dimensions from each other. If the result is 0, - // both are equal. - FailureOr<int64_t> delta = computeConstantDelta(value1, value2, dim1, dim2); - if (failed(delta)) - return failure(); - return *delta == 0; -} - -FailureOr<bool> ValueBoundsConstraintSet::areEqual(OpFoldResult ofr1, - OpFoldR... [truncated] `````````` </details> https://github.com/llvm/llvm-project/pull/86915 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits