================
@@ -583,3 +602,480 @@ Region *mlir::getEnclosingRepetitiveRegion(Value value) {
   LDBG() << "No enclosing repetitive region found for value";
   return nullptr;
 }
+
+/// Return "true" if `a` can be used in lieu of `b`, where `b` is a region
+/// successor input and `a` is a "possible value" of `b`. Possible values are
+/// successor operand values that are (maybe transitively) forwarded to `b`.
+static bool isDefinedBefore(Operation *regionBranchOp, Value a, Value b) {
+  assert((b.getDefiningOp() == regionBranchOp ||
+          b.getParentRegion()->getParentOp() == regionBranchOp) &&
+         "b must be a region successor input");
+
+  // Case 1: `a` is defined inside of the region branch op. `a` must be
+  // directly nested in the region branch op. Otherwise, it could not have
+  // been among the possible values for a region successor input.
+  if (a.getParentRegion()->getParentOp() == regionBranchOp) {
+    // Case 1.1: If `b` is a result of the region branch op, `a` is not in
+    // scope for `b`.
+    // Example:
+    // %b = region_op({
+    // ^bb0(%a1: ...):
+    //   %a2 = ...
+    // })
+    if (isa<OpResult>(b))
+      return false;
+
+    // Case 1.2: `b` is an entry block argument of a region. `a` is in scope
+    // for `b` only if it is also an entry block argument of the same region.
+    // Example:
+    // region_op({
+    // ^bb0(%b: ..., %a: ...):
+    //   ...
+    // })
+    assert(isa<BlockArgument>(b) && "b must be a block argument");
+    return isa<BlockArgument>(a) && cast<BlockArgument>(a).getOwner() ==
+                                        cast<BlockArgument>(b).getOwner();
+  }
+
+  // Case 2: `a` is defined outside of the region branch op. In that case, we
+  // can safely assume that `a` was defined before `b`. Otherwise, it could not
+  // be among the possible values for a region successor input.
+  // Example:
+  // {   <- %a1 parent region begins here.
+  // ^bb0(%a1: ...):
+  //   %a2 = ...
+  //   %b1 = reigon_op({
+  //   ^bb1(%b2: ...):
+  //     ...
+  //   })
+  // }
+  return true;
+}
+
+/// Compute all non-successor-input values that a successor input could have
+/// based on the given successor input to successor operand mapping.
+///
+/// Example 1:
+/// %r = scf.if ... {
+///   scf.yield %a : ...
+/// } else {
+///   scf.yield %b : ...
+/// }
+/// possibleValues(%r) = {%a, %b}
+///
+/// Example 2:
+/// %r = scf.for ... iter_args(%arg0 = %0) -> ... {
+///   scf.yield %arg0 : ...
+/// }
+/// possibleValues(%arg0) = {%0}
+/// possibleValues(%r) = {%0}
+///
+/// Example 3:
+/// %r = scf.for ... iter_args(%arg0 = %0) -> ... {
+///   ...
+///   scf.yield %1 : ...
+/// }
+/// possibleValues(%arg0) = {%0, %1}
+/// possibleValues(%r) = {%0, %1}
+static llvm::SmallDenseSet<Value> computePossibleValuesOfSuccessorInput(
+    Value value, const RegionBranchInverseSuccessorMapping &inputToOperands) {
+  assert(inputToOperands.contains(value) && "value must be a successor input");
+  // Starting with the given value, trace back all predecessor values (i.e.,
+  // preceding successor operands) and add them to the set of possible values.
+  // If the successor operand is again a successor input, do not add it to
+  // result set, but instead continue the traversal.
+  llvm::SmallDenseSet<Value> possibleValues;
+  llvm::SmallDenseSet<Value> visited;
+  SmallVector<Value> worklist;
+  worklist.push_back(value);
+  while (!worklist.empty()) {
+    Value next = worklist.pop_back_val();
+    auto it = inputToOperands.find(next);
+    if (it == inputToOperands.end()) {
+      possibleValues.insert(next);
+      continue;
+    }
+    for (OpOperand *operand : it->second)
+      if (visited.insert(operand->get()).second)
+        worklist.push_back(operand->get());
+  }
+  // Note: The result does not contain any successor inputs. (Therefore,
+  // `value` is also guaranteed to be excluded.)
+  return possibleValues;
+}
+
+namespace {
+/// Try to make successor inputs dead by replacing their uses with values that
+/// are not successor inputs. This pattern enables additional canonicalization
+/// opportunities for RemoveDeadRegionBranchOpSuccessorInputs.
+///
+/// Example:
+///
+/// %r0, %r1 = scf.for ... iter_args(%arg0 = %0, %arg1 = %1) -> ... {
+///   scf.yield %arg1, %arg1 : ...
+/// }
+/// use(%r0, %r1)
+///
+/// possibleValues(%r0) = {%0, %1}
+/// possibleValues(%r1) = {%1} ==> replace uses of %r1 with %1.
+/// possibleValues(%arg0) = {%0, %1}
+/// possibleValues(%arg1) = {%1} ==> replace uses of %arg1 with %1.
+///
+/// IR after pattern application:
+///
+/// %r0, %r1 = scf.for ... iter_args(%arg0 = %0, %arg1 = %1) -> ... {
+///   scf.yield %1, %1 : ...
+/// }
+/// use(%r0, %1)
+///
+/// Note that %r1 and %arg1 are dead now. The IR can now be further
+/// canonicalized by RemoveDeadRegionBranchOpSuccessorInputs.
+struct MakeRegionBranchOpSuccessorInputsDead : public RewritePattern {
+  MakeRegionBranchOpSuccessorInputsDead(MLIRContext *context, StringRef name,
+                                        PatternBenefit benefit = 1)
+      : RewritePattern(name, benefit, context) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    assert(!op->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
+           "isolated-from-above ops are not supported");
+
+    // Compute the mapping of successor inputs to successor operands.
+    auto regionBranchOp = cast<RegionBranchOpInterface>(op);
+    RegionBranchInverseSuccessorMapping inputToOperands;
+    regionBranchOp.getSuccessorInputOperandMapping(inputToOperands);
+
+    // Try to replace the uses of each successor input one-by-one.
+    bool changed = false;
+    for (Value value : inputToOperands.keys()) {
+      // Nothing to do for successor inputs that are already dead.
+      if (value.use_empty())
+        continue;
+      // Nothing to do for successor inputs that may have multiple possible
+      // values.
+      llvm::SmallDenseSet<Value> possibleValues =
+          computePossibleValuesOfSuccessorInput(value, inputToOperands);
+      if (possibleValues.size() != 1)
+        continue;
+      assert(*possibleValues.begin() != value &&
+             "successor inputs are supposed to be excluded");
+      // Do not replace `value` with the found possible value if doing so would
+      // violate dominance. Example:
+      // %r = scf.execute_region ... {
+      //   %a = ...
+      //   scf.yield %a : ...
+      // }
+      // use(%r)
+      // In the above example, possibleValues(%r) = {%a}, but %a cannot be used
+      // as a replacement for %r due to dominance / scope.
+      if (!isDefinedBefore(regionBranchOp, *possibleValues.begin(), value))
+        continue;
+      rewriter.replaceAllUsesWith(value, *possibleValues.begin());
+      changed = true;
+    }
+    return success(changed);
+  }
+};
+
+/// Lookup a bit vector in the given mapping (DenseMap). If the key was not
+/// found, create a new bit vector with the given size and initialize it with
+/// false.
+template <typename MappingTy, typename KeyTy>
+static BitVector &lookupOrCreateBitVector(MappingTy &mapping, KeyTy key,
+                                          unsigned size) {
+  return mapping.try_emplace(key, size, false).first->second;
+}
+
+/// Compute tied successor inputs. Tied successor inputs are successor inputs
+/// that come as a set. If you erase one value from a set, you must erase all
+/// values from the set. Otherwise, the op would become structurally invalid.
+/// Each successor input appears in exactly one set.
+///
+/// Example:
+/// %r0, %r1 = scf.for ... iter_args(%arg0 = %0, %arg1 = %1) -> ... {
+///   ...
+/// }
+/// There are two sets: {{%r0, %arg0}, {%r1, %arg1}}.
+static llvm::EquivalenceClasses<Value> computeTiedSuccessorInputs(
+    const RegionBranchSuccessorMapping &operandToInputs) {
+  llvm::EquivalenceClasses<Value> tiedSuccessorInputs;
+  for (const auto &[operand, inputs] : operandToInputs) {
+    assert(!inputs.empty() && "expected non-empty inputs");
+    Value firstInput = inputs.front();
+    tiedSuccessorInputs.insert(firstInput);
+    for (Value nextInput : llvm::drop_begin(inputs)) {
+      // As we explore more successor operand to successor input mappings,
+      // existing sets may get merged.
+      tiedSuccessorInputs.unionSets(firstInput, nextInput);
+    }
+  }
+  return tiedSuccessorInputs;
+}
+
+/// Remove dead successor inputs from region branch ops. A successor input is
+/// dead if it has no uses. Successor inputs come in sets of tied values: if
+/// you remove one value from a set, you must remove all values from the set.
+/// Furthermore, successor operands must also be removed. (Op operands are not
+/// part of the set, but the set is built based on the successor operand to
+/// successor input mapping.)
+///
+/// Example 1:
+/// %r0, %r1 = scf.for ... iter_args(%arg0 = %0, %arg1 = %1) -> ... {
+///   scf.yield %0, %arg1 : ...
+/// }
+/// use(%0, %1)
+///
+/// There are two sets: {{%r0, %arg0}, {%r1, %arg1}}. All values in the first
+/// set are dead, so %arg0 and %r0 can be removed, but not %r1 and %arg1. The
+/// resulting IR is as follows:
+///
+/// %r1 = scf.for ... iter_args(%arg1 = %1) -> ... {
+///   scf.yield %arg1 : ...
+/// }
+/// use(%0, %1)
+///
+/// Example 2:
+/// %r0, %r1 = scf.while (%arg0 = %0) {
+///   scf.condition(...) %arg0, %arg0 : ...
+/// } do {
+/// ^bb0(%arg1: ..., %arg2: ...):
+///   scf.yield %arg1 : ...
+/// }
+/// There are three sets: {{%r0, %arg1}, {%r1, %arg2}, {%r0}}.
+///
+/// Example 3:
+/// %r1, %r2 = scf.if ... {
+///   scf.yield %0, %1 : ...
+/// } else {
+///   scf.yield %2, %3 : ...
+/// }
+/// There are two sets: {{%r1}, {%r2}}. Each set has one value, so there each
+/// value can be removed independently of the other values.
+struct RemoveDeadRegionBranchOpSuccessorInputs : public RewritePattern {
+  RemoveDeadRegionBranchOpSuccessorInputs(MLIRContext *context, StringRef name,
+                                          PatternBenefit benefit = 1)
+      : RewritePattern(name, benefit, context) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    assert(!op->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
+           "isolated-from-above ops are not supported");
+
+    // Compute tied values: values that must come as a set. If you remove one,
+    // you must remove all. If a successor op operand is forwarded to two
+    // successor inputs %a and %b, both %a and %b are in the same set.
+    auto regionBranchOp = cast<RegionBranchOpInterface>(op);
+    RegionBranchSuccessorMapping operandToInputs;
+    regionBranchOp.getSuccessorOperandInputMapping(operandToInputs);
+    llvm::EquivalenceClasses<Value> tiedSuccessorInputs =
+        computeTiedSuccessorInputs(operandToInputs);
+
+    // Determine which values to remove and group them by block and operation.
+    SmallVector<Value> valuesToRemove;
+    DenseMap<Block *, BitVector> blockArgsToRemove;
+    DenseMap<Operation *, BitVector> resultsToRemove;
----------------
linuxlonelyeagle wrote:

"The results are the successor input, and they should only be related to the 
regionOp via the results.

https://github.com/llvm/llvm-project/pull/174094
_______________________________________________
llvm-branch-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to