https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/86239
>From 3c4adb5458f054634d51e1502736bb3dbebad106 Mon Sep 17 00:00:00 2001 From: Matthias Springer <spring...@google.com> Date: Sat, 23 Mar 2024 06:02:28 +0000 Subject: [PATCH] [mlir][SCF][NFC] `ValueBoundsConstraintSet`: Simplify `scf.for` implementation This commit simplifies the implementation of the `ValueBoundsOpInterface` for `scf.for` based on the newly added `ValueBoundsConstraintSet::compare` API and adds additional documentation. Previously, the interface implementation created a new constraint set just to check if the yielded value and iter_arg are equal. This was inefficient because constraints were added multiple times (to two different constraint sets) for ops that are inside the loop. --- .../SCF/IR/ValueBoundsOpInterfaceImpl.cpp | 80 +++++++++---------- 1 file changed, 36 insertions(+), 44 deletions(-) diff --git a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp index 8e9d1021f93e4b..72c5aaa2306783 100644 --- a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp @@ -12,7 +12,6 @@ #include "mlir/Interfaces/ValueBoundsOpInterface.h" using namespace mlir; -using presburger::BoundType; namespace mlir { namespace scf { @@ -21,7 +20,28 @@ namespace { struct ForOpInterface : public ValueBoundsOpInterface::ExternalModel<ForOpInterface, ForOp> { - /// Populate bounds of values/dimensions for iter_args/OpResults. + /// Populate bounds of values/dimensions for iter_args/OpResults. If the + /// value/dimension size does not change in an iteration, we can deduce that + /// it the same as the initial value/dimension. + /// + /// Example 1: + /// %0 = scf.for ... iter_args(%arg0 = %t) -> tensor<?xf32> { + /// ... + /// %1 = tensor.insert %f into %arg0[...] : tensor<?xf32> + /// scf.yield %1 : tensor<?xf32> + /// } + /// --> bound(%0)[0] == bound(%t)[0] + /// --> bound(%arg0)[0] == bound(%t)[0] + /// + /// Example 2: + /// %0 = scf.for ... iter_args(%arg0 = %t) -> tensor<?xf32> { + /// %sz = tensor.dim %arg0 : tensor<?xf32> + /// %incr = arith.addi %sz, %c1 : index + /// %1 = tensor.empty(%incr) : tensor<?xf32> + /// scf.yield %1 : tensor<?xf32> + /// } + /// --> The yielded tensor dimension size changes with each iteration. Such + /// loops are not supported and no constraints are added. static void populateIterArgBounds(scf::ForOp forOp, Value value, std::optional<int64_t> dim, ValueBoundsConstraintSet &cstr) { @@ -33,59 +53,31 @@ struct ForOpInterface iterArgIdx = llvm::cast<OpResult>(value).getResultNumber(); } - // An EQ constraint can be added if the yielded value (dimension size) - // equals the corresponding block argument (dimension size). Value yieldedValue = cast<scf::YieldOp>(forOp.getBody()->getTerminator()) .getOperand(iterArgIdx); Value iterArg = forOp.getRegionIterArg(iterArgIdx); Value initArg = forOp.getInitArgs()[iterArgIdx]; - auto addEqBound = [&]() { + // Populate constraints for the yielded value. + cstr.populateConstraints(yieldedValue, dim); + // Populate constraints for the iter_arg. This is just to ensure that the + // iter_arg is mapped in the constraint set, which is a prerequisite for + // `compare`. It may lead to a recursive call to this function in case the + // iter_arg was not visited when the constraints for the yielded value were + // populated, but no additional work is done. + cstr.populateConstraints(iterArg, dim); + + // An EQ constraint can be added if the yielded value (dimension size) + // equals the corresponding block argument (dimension size). + if (cstr.compare(yieldedValue, dim, + ValueBoundsConstraintSet::ComparisonOperator::EQ, iterArg, + dim)) { if (dim.has_value()) { cstr.bound(value)[*dim] == cstr.getExpr(initArg, dim); } else { cstr.bound(value) == initArg; } - }; - - if (yieldedValue == iterArg) { - addEqBound(); - return; - } - - // Compute EQ bound for yielded value. - AffineMap bound; - ValueDimList boundOperands; - LogicalResult status = ValueBoundsConstraintSet::computeBound( - bound, boundOperands, BoundType::EQ, yieldedValue, dim, - [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) { - // Stop when reaching a block argument of the loop body. - if (auto bbArg = llvm::dyn_cast<BlockArgument>(v)) - return bbArg.getOwner()->getParentOp() == forOp; - // Stop when reaching a value that is defined outside of the loop. It - // is impossible to reach an iter_arg from there. - Operation *op = v.getDefiningOp(); - return forOp.getRegion().findAncestorOpInRegion(*op) == nullptr; - }); - if (failed(status)) - return; - if (bound.getNumResults() != 1) - return; - - // Check if computed bound equals the corresponding iter_arg. - Value singleValue = nullptr; - std::optional<int64_t> singleDim; - if (auto dimExpr = dyn_cast<AffineDimExpr>(bound.getResult(0))) { - int64_t idx = dimExpr.getPosition(); - singleValue = boundOperands[idx].first; - singleDim = boundOperands[idx].second; - } else if (auto symExpr = dyn_cast<AffineSymbolExpr>(bound.getResult(0))) { - int64_t idx = symExpr.getPosition() + bound.getNumDims(); - singleValue = boundOperands[idx].first; - singleDim = boundOperands[idx].second; } - if (singleValue == iterArg && singleDim == dim) - addEqBound(); } void populateBoundsForIndexValue(Operation *op, Value value, _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits