https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/199535
>From c8b3672ab238f0bd1ee20d2772ad28c0b5500770 Mon Sep 17 00:00:00 2001 From: Matthias Springer <[email protected]> Date: Mon, 25 May 2026 14:40:25 +0000 Subject: [PATCH] [mlir][SCF] Add `scf.loop` op and terminators --- mlir/include/mlir/Dialect/SCF/IR/SCFOps.td | 150 +++++++++++++++++++ mlir/lib/Dialect/SCF/IR/SCF.cpp | 162 +++++++++++++++++++++ mlir/test/Dialect/SCF/invalid.mlir | 101 +++++++++++++ mlir/test/Dialect/SCF/ops.mlir | 73 ++++++++++ 4 files changed, 486 insertions(+) diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td index 0b33ecb48b7f2..57c07fa0a50fc 100644 --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -147,6 +147,156 @@ def ExecuteRegionOp : SCF_Op<"execute_region", [ let hasVerifier = 1; } +//===----------------------------------------------------------------------===// +// LoopOp +//===----------------------------------------------------------------------===// + +def LoopOp : SCF_Op<"loop", [ + AutomaticAllocationScope, + RecursiveMemoryEffects, + SingleBlock + ]> { + let summary = "Loop until a break operation"; + let description = [{ + The `scf.loop` operation represents an infinite loop that executes until an + `scf.break` is reached. + + The loop consists of (1) a set of loop-carried values which are initialized + by `initValues` and updated by each iteration of the loop, and (2) a region + which represents the loop body. + + The loop body must end with an explicit terminator, which must be one of: + + - `scf.continue`: re-enters the loop, supplying the next iteration's value + for each loop-carried variable. Terminator operand types and loop operand + types must match. If the loop has op results, its values are undefined. + - `scf.break`: terminates the loop, supplying the final values for the + `scf.loop` results. Terminator operand types and loop op result types + must match. + + Note: This operation will be extended in the future to support breaking and + continuing from nested regions. For now, `scf.break` and `scf.continue` + must be terminators of the loop body. In practice this means that an + `scf.loop` either runs forever (terminator is `scf.continue`) or executes + exactly one iteration (terminator is `scf.break`). + + Examples: + + ```mlir + // Loop with iteration-carried values updated by `scf.continue`. + scf.loop iter_args(%i = %init) : i32 { + %v = "some.compute"(%i) : (i32) -> (i32) + scf.continue %v : i32 + } + ``` + + ```mlir + // Loop with both an iteration-carried value and a result. The iter_arg + // and result types may differ. + %r = scf.loop iter_args(%i = %init) : i32 -> i64 { + %v = "some.compute"(%i) : (i32) -> (i64) + scf.break %v : i64 + } + ``` + }]; + + let arguments = (ins Variadic<AnyType>:$initValues); + let results = (outs Variadic<AnyType>:$resultValues); + let regions = (region SizedRegion<1>:$region); + + let builders = [ + OpBuilder<(ins + CArg<"::mlir::TypeRange", "{}">:$resultTypes, + CArg<"::mlir::ValueRange", "{}">:$initValues, + CArg<"::llvm::function_ref<void(::mlir::OpBuilder &, ::mlir::Location, " + "::mlir::ValueRange)>", "nullptr">:$bodyBuilder)> + ]; + + let extraClassDeclaration = [{ + /// Return the iteration values of the loop region. + Block::BlockArgListType getRegionIterValues() { + return getRegion().getArguments(); + } + + /// Return the `index`-th region iteration value. + BlockArgument getRegionIterValue(unsigned index) { + return getRegionIterValues()[index]; + } + + /// Returns the number of region arguments for loop-carried values. + unsigned getNumRegionIterValues() { return getRegion().getNumArguments(); } + + /// Returns the loop body block. + Block *getBody() { return &getRegion().front(); } + }]; + + let hasCustomAssemblyFormat = 1; + let hasRegionVerifier = 1; +} + +//===----------------------------------------------------------------------===// +// BreakOp +//===----------------------------------------------------------------------===// + +def BreakOp : SCF_Op<"break", [ + Pure, ReturnLike, Terminator, ParentOneOf<["LoopOp"]> + ]> { + let summary = "Break from an `scf.loop`"; + let description = [{ + The `scf.break` operation terminates the immediately enclosing `scf.loop`. + Its operands become the loop's result values; their types must match the + result types of the enclosing `scf.loop` (verified by the loop). + + Example: + + ```mlir + %r = scf.loop -> i32 { + ... + scf.break %v : i32 + } + ``` + }]; + + let arguments = (ins Variadic<AnyType>:$operands); + let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>]; + + let assemblyFormat = [{ + attr-dict ($operands^ `:` type($operands))? + }]; +} + +//===----------------------------------------------------------------------===// +// ContinueOp +//===----------------------------------------------------------------------===// + +def ContinueOp : SCF_Op<"continue", [ + Pure, Terminator, ParentOneOf<["LoopOp"]> + ]> { + let summary = "Continue to the next iteration of an `scf.loop`"; + let description = [{ + The `scf.continue` operation re-enters the immediately enclosing `scf.loop` + for its next iteration. Its operands become the loop-carried values + (`iter_args`) for the next iteration; their types must match the loop's + iter_arg types (verified by the loop). + + Example: + + ```mlir + scf.loop iter_args(%i = %init) : i32 { + %next = arith.addi %i, %one : i32 + scf.continue %next : i32 + } + ``` + }]; + + let arguments = (ins Variadic<AnyType>:$operands); + let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>]; + + let assemblyFormat = [{ + attr-dict ($operands^ `:` type($operands))? + }]; +} + //===----------------------------------------------------------------------===// // ForOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 9f4f4dc9f58e6..60e5975f4ec48 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -282,6 +282,168 @@ ValueRange ExecuteRegionOp::getSuccessorInputs(RegionSuccessor successor) { : ValueRange(); } +//===----------------------------------------------------------------------===// +// LoopOp +//===----------------------------------------------------------------------===// + +void LoopOp::build( + OpBuilder &builder, OperationState &result, TypeRange resultTypes, + ValueRange initValues, + function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) { + result.addOperands(initValues); + result.addTypes(resultTypes); + + // Build the body region with a single entry block, one argument per init + // value. The caller-supplied `bodyBuilder` is responsible for terminating + // the block with either `scf.continue` or `scf.break`. + Region *bodyRegion = result.addRegion(); + Block *bodyBlock = builder.createBlock(bodyRegion); + SmallVector<Type> argTypes(initValues.getTypes()); + SmallVector<Location> argLocs(initValues.size(), result.location); + bodyBlock->addArguments(argTypes, argLocs); + + if (bodyBuilder) { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(bodyBlock); + bodyBuilder(builder, result.location, bodyBlock->getArguments()); + } +} + +LogicalResult LoopOp::verifyRegions() { + if (getRegion().empty()) + return emitOpError("region cannot be empty"); + Block &body = getRegion().front(); + if (body.getNumArguments() != getNumOperands()) + return emitOpError( + "mismatch in number of loop-carried values and defined values"); + for (auto [index, regionArg, initOperand] : + llvm::enumerate(body.getArguments(), getOperands())) { + if (regionArg.getType() != initOperand.getType()) + return emitOpError() << "type mismatch between " << index + << "th iter operand (" << initOperand.getType() + << ") and region argument (" << regionArg.getType() + << ")"; + } + + // The loop body must end with an explicit `scf.break` or `scf.continue`. + Operation *terminator = body.getTerminator(); + if (auto breakOp = dyn_cast<BreakOp>(terminator)) { + if (breakOp.getNumOperands() != getNumResults()) + return breakOp.emitOpError() + << "has " << breakOp.getNumOperands() + << " operands, but enclosing scf.loop returns " << getNumResults() + << " result(s)"; + for (auto [index, operandType, resultType] : + llvm::enumerate(breakOp.getOperandTypes(), getResultTypes())) { + if (operandType != resultType) + return breakOp.emitOpError() + << "type mismatch between " << index << "th operand (" + << operandType << ") and " << index + << "th result of enclosing scf.loop (" << resultType << ")"; + } + } else if (auto continueOp = dyn_cast<ContinueOp>(terminator)) { + if (continueOp.getNumOperands() != getNumRegionIterValues()) + return continueOp.emitOpError() + << "has " << continueOp.getNumOperands() + << " operands, but enclosing scf.loop has " + << getNumRegionIterValues() << " iter_args"; + for (auto [index, operandType, iterArgType] : llvm::enumerate( + continueOp.getOperandTypes(), body.getArgumentTypes())) { + if (operandType != iterArgType) + return continueOp.emitOpError() + << "type mismatch between " << index << "th operand (" + << operandType << ") and " << index + << "th iter_arg of enclosing scf.loop (" << iterArgType << ")"; + } + } else { + return emitOpError("body must be terminated by 'scf.break' or " + "'scf.continue', got '") + << terminator->getName() << "'"; + } + return success(); +} + +/// Print a type list in functional-return-type style: a single bare type or +/// a parenthesized comma-separated list. +static void printFunctionalTypeList(OpAsmPrinter &p, TypeRange types) { + if (types.size() == 1) { + p << types.front(); + return; + } + p << "("; + llvm::interleaveComma(types, p); + p << ")"; +} + +void LoopOp::print(OpAsmPrinter &p) { + p << " "; + if (!getInitValues().empty()) { + p << "iter_args("; + llvm::interleaveComma( + llvm::zip(getRegionIterValues(), getInitValues()), p, [&](auto it) { + p.printRegionArgument(std::get<0>(it), /*argAttrs=*/{}, + /*omitType=*/true); + p << " = " << std::get<1>(it); + }); + p << ") : "; + printFunctionalTypeList(p, getInitValues().getTypes()); + p << " "; + } + if (!getResultTypes().empty()) { + p << "-> "; + printFunctionalTypeList(p, getResultTypes()); + p << " "; + } + + p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/true); + p.printOptionalAttrDict((*this)->getAttrs()); +} + +ParseResult LoopOp::parse(OpAsmParser &parser, OperationState &result) { + SmallVector<OpAsmParser::Argument, 4> regionArgs; + SmallVector<OpAsmParser::UnresolvedOperand, 4> iterOperands; + SmallVector<Type, 4> iterTypes; + + if (failed(parser.parseOptionalKeyword("iter_args"))) { + // No iter_args, but may still have a result type list. + if (parser.parseOptionalArrowTypeList(result.types)) + return failure(); + } else { + if (parser.parseAssignmentList(regionArgs, iterOperands) || + parser.parseColon()) + return failure(); + if (parser.parseOptionalLParen()) { + // Single iter_arg type, no parens. + Type type; + if (parser.parseType(type)) + return failure(); + iterTypes.push_back(type); + } else { + if (parser.parseTypeList(iterTypes) || parser.parseRParen()) + return failure(); + } + if (regionArgs.size() != iterTypes.size()) + return parser.emitError(parser.getCurrentLocation(), + "found different number of iter_args and types"); + if (parser.parseOptionalArrowTypeList(result.types)) + return failure(); + for (auto [regionArg, type] : llvm::zip_equal(regionArgs, iterTypes)) + regionArg.type = type; + } + + Region *body = result.addRegion(); + if (parser.parseRegion(*body, regionArgs)) + return failure(); + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + if (parser.resolveOperands(iterOperands, iterTypes, parser.getNameLoc(), + result.operands)) + return failure(); + return success(); +} + //===----------------------------------------------------------------------===// // ConditionOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir index 33a8921eeb993..099c02631804f 100644 --- a/mlir/test/Dialect/SCF/invalid.mlir +++ b/mlir/test/Dialect/SCF/invalid.mlir @@ -852,3 +852,104 @@ func.func @for_missing_induction_var(%arg0: index, %arg1: index) { }) : (index, index, index) -> () return } + +// ----- + +func.func @break_outside_loop(%v: i32) { + // expected-error@+1 {{'scf.break' op expects parent op 'scf.loop'}} + scf.break %v : i32 +} + +// ----- + +func.func @continue_outside_loop() { + // expected-error@+1 {{'scf.continue' op expects parent op 'scf.loop'}} + scf.continue +} + +// ----- + +func.func @loop_bad_terminator() { + // expected-error@+1 {{'scf.loop' op body must be terminated by 'scf.break' or 'scf.continue'}} + "scf.loop"() ({ + ^bb0: + "test.foo"() : () -> () + "test.terminator"() : () -> () + }) : () -> () + return +} + +// ----- + +func.func @loop_init_arg_count_mismatch(%init: i32) { + // expected-error@+1 {{'scf.loop' op mismatch in number of loop-carried values and defined values}} + "scf.loop"(%init) ({ + ^bb0: + scf.continue + }) : (i32) -> () + return +} + +// ----- + +func.func @loop_init_arg_type_mismatch(%init: i32) { + // expected-error@+1 {{'scf.loop' op type mismatch between 0th iter operand ('i32') and region argument ('i64')}} + "scf.loop"(%init) ({ + ^bb0(%i: i64): + scf.continue %i : i64 + }) : (i32) -> () + return +} + +// ----- + +func.func @loop_break_count_mismatch(%v: i32) -> (i32, i32) { + // expected-error@+2 {{'scf.break' op has 1 operands, but enclosing scf.loop returns 2 result(s)}} + %r:2 = scf.loop -> (i32, i32) { + scf.break %v : i32 + } + return %r#0, %r#1 : i32, i32 +} + +// ----- + +func.func @loop_break_type_mismatch(%v: i32) -> i64 { + // expected-error@+2 {{'scf.break' op type mismatch between 0th operand ('i32') and 0th result of enclosing scf.loop ('i64')}} + %r = scf.loop -> i64 { + scf.break %v : i32 + } + return %r : i64 +} + +// ----- + +func.func @loop_continue_count_mismatch(%init: i32) { + // expected-error@+2 {{'scf.continue' op has 0 operands, but enclosing scf.loop has 1 iter_args}} + scf.loop iter_args(%i = %init) : i32 { + scf.continue + } + return +} + +// ----- + +func.func @loop_continue_type_mismatch(%init: i32, %v: i64) { + // expected-error@+2 {{'scf.continue' op type mismatch between 0th operand ('i64') and 0th iter_arg of enclosing scf.loop ('i32')}} + scf.loop iter_args(%i = %init) : i32 { + scf.continue %v : i64 + } + return +} + +// ----- + +func.func @loop_more_than_one_block(%v: i32) -> i32 { + // expected-error@+1 {{'scf.loop' op expects region #0 to have 0 or 1 blocks}} + %r = "scf.loop"() ({ + ^bb0: + "test.unreachable"() [^bb1] : () -> () + ^bb1: + scf.break %v : i32 + }) : () -> i32 + return %r : i32 +} diff --git a/mlir/test/Dialect/SCF/ops.mlir b/mlir/test/Dialect/SCF/ops.mlir index 5930a1df04266..e8f5294b40a4d 100644 --- a/mlir/test/Dialect/SCF/ops.mlir +++ b/mlir/test/Dialect/SCF/ops.mlir @@ -441,3 +441,76 @@ func.func @switch(%arg0: index) -> i32 { return %0 : i32 } + +// CHECK-LABEL: @loop_infinite +func.func @loop_infinite() { + // CHECK: scf.loop { + scf.loop { + // CHECK-NEXT: "test.foo" + "test.foo"() : () -> () + // CHECK-NEXT: scf.continue + scf.continue + } + return +} + +// CHECK-LABEL: @loop_break_no_operands +func.func @loop_break_no_operands() { + // CHECK: scf.loop { + scf.loop { + // CHECK-NEXT: scf.break + scf.break + } + return +} + +// CHECK-LABEL: @loop_break_single +func.func @loop_break_single(%v: i32) -> i32 { + // CHECK: %{{.*}} = scf.loop -> i32 { + %r = scf.loop -> i32 { + // CHECK-NEXT: scf.break %{{.*}} : i32 + scf.break %v : i32 + } + return %r : i32 +} + +// CHECK-LABEL: @loop_break_multi +func.func @loop_break_multi(%v: i32, %w: i64) -> (i32, i64) { + // CHECK: %{{.*}}:2 = scf.loop -> (i32, i64) { + %r:2 = scf.loop -> (i32, i64) { + // CHECK-NEXT: scf.break %{{.*}}, %{{.*}} : i32, i64 + scf.break %v, %w : i32, i64 + } + return %r#0, %r#1 : i32, i64 +} + +// CHECK-LABEL: @loop_iter_single +func.func @loop_iter_single(%init: i32) { + // CHECK: scf.loop iter_args(%{{.*}} = %{{.*}}) : i32 { + scf.loop iter_args(%i = %init) : i32 { + // CHECK: scf.continue %{{.*}} : i32 + scf.continue %i : i32 + } + return +} + +// CHECK-LABEL: @loop_iter_multi +func.func @loop_iter_multi(%init0: i32, %init1: i64) { + // CHECK: scf.loop iter_args(%{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}) : (i32, i64) { + scf.loop iter_args(%i = %init0, %j = %init1) : (i32, i64) { + // CHECK: scf.continue %{{.*}}, %{{.*}} : i32, i64 + scf.continue %i, %j : i32, i64 + } + return +} + +// Loop with iter_args of one type and a single result of another type. +// CHECK-LABEL: @loop_iter_and_result +func.func @loop_iter_and_result(%init: i32, %v: i64) -> i64 { + // CHECK: %{{.*}} = scf.loop iter_args(%{{.*}} = %{{.*}}) : i32 -> i64 { + %r = scf.loop iter_args(%i = %init) : i32 -> i64 { + // CHECK: scf.break %{{.*}} : i64 + scf.break %v : i64 + } + return %r : i64 +} _______________________________________________ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
