https://github.com/ergawy updated https://github.com/llvm/llvm-project/pull/138506
>From 33312c3ed87c5d299673be5e4831ea1583d536c6 Mon Sep 17 00:00:00 2001 From: ergawy <kareem.erg...@amd.com> Date: Mon, 5 May 2025 03:25:19 -0500 Subject: [PATCH] [flang][fir] Add locality specifiers modeling to `fir.do_concurrent.loop` Extends `fir.do_concurrent.loop` ops to model locality specifiers. This follows the same pattern used in OpenMP where an op of type `fir.local` (in OpenMP it is `omp.private`) is referenced from the `do concurrent` locality specifier. This PR adds the MLIR op changes as well as printing and parsing logic. --- .../include/flang/Optimizer/Dialect/FIROps.td | 33 +++++- flang/lib/Lower/Bridge.cpp | 2 +- flang/lib/Optimizer/Dialect/FIROps.cpp | 112 +++++++++++++++--- flang/test/Fir/do_concurrent.fir | 64 +++++++++- flang/test/Fir/invalid.fir | 10 +- 5 files changed, 195 insertions(+), 26 deletions(-) diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td index f87ffce72192d..01248aa0095ec 100644 --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -3647,6 +3647,13 @@ def fir_DoConcurrentOp : fir_Op<"do_concurrent", let hasVerifier = 1; } +def fir_LocalSpecifier { + dag arguments = (ins + Variadic<AnyType>:$local_vars, + OptionalAttr<SymbolRefArrayAttr>:$local_syms + ); +} + def fir_DoConcurrentLoopOp : fir_Op<"do_concurrent.loop", [AttrSizedOperandSegments, DeclareOpInterfaceMethods<LoopLikeOpInterface, ["getLoopInductionVars"]>, @@ -3700,7 +3707,7 @@ def fir_DoConcurrentLoopOp : fir_Op<"do_concurrent.loop", LLVM. }]; - let arguments = (ins + defvar opArgs = (ins Variadic<Index>:$lowerBound, Variadic<Index>:$upperBound, Variadic<Index>:$step, @@ -3709,16 +3716,40 @@ def fir_DoConcurrentLoopOp : fir_Op<"do_concurrent.loop", OptionalAttr<LoopAnnotationAttr>:$loopAnnotation ); + let arguments = !con(opArgs, fir_LocalSpecifier.arguments); + let regions = (region SizedRegion<1>:$region); let hasCustomAssemblyFormat = 1; let hasVerifier = 1; let extraClassDeclaration = [{ + unsigned getNumInductionVars() { return getLowerBound().size(); } + + unsigned getNumLocalOperands() { return getLocalVars().size(); } + + mlir::Block::BlockArgListType getInductionVars() { + return getBody()->getArguments().slice(0, getNumInductionVars()); + } + + mlir::Block::BlockArgListType getRegionLocalArgs() { + return getBody()->getArguments().slice(getNumInductionVars(), + getNumLocalOperands()); + } + + /// Number of operands controlling the loop + unsigned getNumControlOperands() { return getLowerBound().size() * 3; } + // Get Number of reduction operands unsigned getNumReduceOperands() { return getReduceOperands().size(); } + + mlir::Operation::operand_range getLocalOperands() { + return getOperands() + .slice(getNumControlOperands() + getNumReduceOperands(), + getNumLocalOperands()); + } }]; } diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp index 8da05255d5f41..0a61f61ab8f75 100644 --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -2460,7 +2460,7 @@ class FirConverter : public Fortran::lower::AbstractConverter { nestReduceAttrs.empty() ? nullptr : mlir::ArrayAttr::get(builder->getContext(), nestReduceAttrs), - nullptr); + nullptr, /*local_vars=*/std::nullopt, /*local_syms=*/nullptr); llvm::SmallVector<mlir::Type> loopBlockArgTypes( incrementLoopNestInfo.size(), builder->getIndexType()); diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp index 65ec730e134c2..c95655d7dcef6 100644 --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -5033,21 +5033,25 @@ mlir::ParseResult fir::DoConcurrentLoopOp::parse(mlir::OpAsmParser &parser, mlir::OperationState &result) { auto &builder = parser.getBuilder(); // Parse an opening `(` followed by induction variables followed by `)` - llvm::SmallVector<mlir::OpAsmParser::Argument, 4> ivs; - if (parser.parseArgumentList(ivs, mlir::OpAsmParser::Delimiter::Paren)) + llvm::SmallVector<mlir::OpAsmParser::Argument, 4> regionArgs; + + if (parser.parseArgumentList(regionArgs, mlir::OpAsmParser::Delimiter::Paren)) return mlir::failure(); + llvm::SmallVector<mlir::Type> argTypes(regionArgs.size(), + builder.getIndexType()); + // Parse loop bounds. llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> lower; if (parser.parseEqual() || - parser.parseOperandList(lower, ivs.size(), + parser.parseOperandList(lower, regionArgs.size(), mlir::OpAsmParser::Delimiter::Paren) || parser.resolveOperands(lower, builder.getIndexType(), result.operands)) return mlir::failure(); llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> upper; if (parser.parseKeyword("to") || - parser.parseOperandList(upper, ivs.size(), + parser.parseOperandList(upper, regionArgs.size(), mlir::OpAsmParser::Delimiter::Paren) || parser.resolveOperands(upper, builder.getIndexType(), result.operands)) return mlir::failure(); @@ -5055,7 +5059,7 @@ mlir::ParseResult fir::DoConcurrentLoopOp::parse(mlir::OpAsmParser &parser, // Parse step values. llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> steps; if (parser.parseKeyword("step") || - parser.parseOperandList(steps, ivs.size(), + parser.parseOperandList(steps, regionArgs.size(), mlir::OpAsmParser::Delimiter::Paren) || parser.resolveOperands(steps, builder.getIndexType(), result.operands)) return mlir::failure(); @@ -5086,12 +5090,55 @@ mlir::ParseResult fir::DoConcurrentLoopOp::parse(mlir::OpAsmParser &parser, builder.getArrayAttr(arrayAttr)); } - // Now parse the body. - mlir::Region *body = result.addRegion(); - for (auto &iv : ivs) - iv.type = builder.getIndexType(); - if (parser.parseRegion(*body, ivs)) - return mlir::failure(); + llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> localOperands; + if (succeeded(parser.parseOptionalKeyword("local"))) { + std::size_t oldArgTypesSize = argTypes.size(); + if (failed(parser.parseLParen())) + return mlir::failure(); + + llvm::SmallVector<mlir::SymbolRefAttr> localSymbolVec; + if (failed(parser.parseCommaSeparatedList([&]() { + if (failed(parser.parseAttribute(localSymbolVec.emplace_back()))) + return mlir::failure(); + + if (parser.parseOperand(localOperands.emplace_back()) || + parser.parseArrow() || + parser.parseArgument(regionArgs.emplace_back())) + return mlir::failure(); + + return mlir::success(); + }))) + return mlir::failure(); + + if (failed(parser.parseColon())) + return mlir::failure(); + + if (failed(parser.parseCommaSeparatedList([&]() { + if (failed(parser.parseType(argTypes.emplace_back()))) + return mlir::failure(); + + return mlir::success(); + }))) + return mlir::failure(); + + if (regionArgs.size() != argTypes.size()) + return parser.emitError(parser.getNameLoc(), + "mismatch in number of local arg and types"); + + if (failed(parser.parseRParen())) + return mlir::failure(); + + for (auto operandType : llvm::zip_equal( + localOperands, llvm::drop_begin(argTypes, oldArgTypesSize))) + if (parser.resolveOperand(std::get<0>(operandType), + std::get<1>(operandType), result.operands)) + return mlir::failure(); + + llvm::SmallVector<mlir::Attribute> symbolAttrs(localSymbolVec.begin(), + localSymbolVec.end()); + result.addAttribute(getLocalSymsAttrName(result.name), + builder.getArrayAttr(symbolAttrs)); + } // Set `operandSegmentSizes` attribute. result.addAttribute(DoConcurrentLoopOp::getOperandSegmentSizeAttr(), @@ -5099,7 +5146,16 @@ mlir::ParseResult fir::DoConcurrentLoopOp::parse(mlir::OpAsmParser &parser, {static_cast<int32_t>(lower.size()), static_cast<int32_t>(upper.size()), static_cast<int32_t>(steps.size()), - static_cast<int32_t>(reduceOperands.size())})); + static_cast<int32_t>(reduceOperands.size()), + static_cast<int32_t>(localOperands.size())})); + + // Now parse the body. + for (auto [arg, type] : llvm::zip_equal(regionArgs, argTypes)) + arg.type = type; + + mlir::Region *body = result.addRegion(); + if (parser.parseRegion(*body, regionArgs)) + return mlir::failure(); // Parse attributes. if (parser.parseOptionalAttrDict(result.attributes)) @@ -5109,8 +5165,9 @@ mlir::ParseResult fir::DoConcurrentLoopOp::parse(mlir::OpAsmParser &parser, } void fir::DoConcurrentLoopOp::print(mlir::OpAsmPrinter &p) { - p << " (" << getBody()->getArguments() << ") = (" << getLowerBound() - << ") to (" << getUpperBound() << ") step (" << getStep() << ")"; + p << " (" << getBody()->getArguments().slice(0, getNumInductionVars()) + << ") = (" << getLowerBound() << ") to (" << getUpperBound() << ") step (" + << getStep() << ")"; if (!getReduceOperands().empty()) { p << " reduce("; @@ -5123,12 +5180,27 @@ void fir::DoConcurrentLoopOp::print(mlir::OpAsmPrinter &p) { p << ')'; } + if (!getLocalVars().empty()) { + p << " local("; + llvm::interleaveComma(llvm::zip_equal(getLocalSymsAttr(), getLocalVars(), + getRegionLocalArgs()), + p, [&](auto it) { + p << std::get<0>(it) << " " << std::get<1>(it) + << " -> " << std::get<2>(it); + }); + p << " : "; + llvm::interleaveComma(getLocalVars(), p, + [&](auto it) { p << it.getType(); }); + p << ")"; + } + p << ' '; p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); p.printOptionalAttrDict( (*this)->getAttrs(), /*elidedAttrs=*/{DoConcurrentLoopOp::getOperandSegmentSizeAttr(), - DoConcurrentLoopOp::getReduceAttrsAttrName()}); + DoConcurrentLoopOp::getReduceAttrsAttrName(), + DoConcurrentLoopOp::getLocalSymsAttrName()}); } llvm::SmallVector<mlir::Region *> fir::DoConcurrentLoopOp::getLoopRegions() { @@ -5139,6 +5211,7 @@ llvm::LogicalResult fir::DoConcurrentLoopOp::verify() { mlir::Operation::operand_range lbValues = getLowerBound(); mlir::Operation::operand_range ubValues = getUpperBound(); mlir::Operation::operand_range stepValues = getStep(); + mlir::Operation::operand_range localVars = getLocalVars(); if (lbValues.empty()) return emitOpError( @@ -5152,11 +5225,13 @@ llvm::LogicalResult fir::DoConcurrentLoopOp::verify() { // Check that the body defines the same number of block arguments as the // number of tuple elements in step. mlir::Block *body = getBody(); - if (body->getNumArguments() != stepValues.size()) + unsigned numIndVarArgs = body->getNumArguments() - localVars.size(); + + if (numIndVarArgs != stepValues.size()) return emitOpError() << "expects the same number of induction variables: " << body->getNumArguments() << " as bound and step values: " << stepValues.size(); - for (auto arg : body->getArguments()) + for (auto arg : body->getArguments().slice(0, numIndVarArgs)) if (!arg.getType().isIndex()) return emitOpError( "expects arguments for the induction variable to be of index type"); @@ -5171,7 +5246,8 @@ llvm::LogicalResult fir::DoConcurrentLoopOp::verify() { std::optional<llvm::SmallVector<mlir::Value>> fir::DoConcurrentLoopOp::getLoopInductionVars() { - return llvm::SmallVector<mlir::Value>{getBody()->getArguments()}; + return llvm::SmallVector<mlir::Value>{ + getBody()->getArguments().slice(0, getLowerBound().size())}; } //===----------------------------------------------------------------------===// diff --git a/flang/test/Fir/do_concurrent.fir b/flang/test/Fir/do_concurrent.fir index 4e55777402428..cfb9a7abac15b 100644 --- a/flang/test/Fir/do_concurrent.fir +++ b/flang/test/Fir/do_concurrent.fir @@ -91,7 +91,6 @@ func.func @dc_2d_reduction(%i_lb: index, %i_ub: index, %i_st: index, // CHECK: } // CHECK: } - fir.local {type = local} @local_privatizer : i32 // CHECK: fir.local {type = local} @[[LOCAL_PRIV_SYM:local_privatizer]] : i32 @@ -109,3 +108,66 @@ fir.local {type = local_init} @local_init_privatizer : i32 copy { // CHECK: fir.store %[[ORIG_VAL_LD]] to %[[LOCAL_VAL]] : !fir.ref<i32> // CHECK: fir.yield(%[[LOCAL_VAL]] : !fir.ref<i32>) // CHECK: } + +func.func @_QPdo_concurrent() { + %3 = fir.alloca i32 {bindc_name = "local_init_var", uniq_name = "_QFdo_concurrentElocal_init_var"} + %4:2 = hlfir.declare %3 {uniq_name = "_QFdo_concurrentElocal_init_var"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>) + %5 = fir.alloca i32 {bindc_name = "local_var", uniq_name = "_QFdo_concurrentElocal_var"} + %6:2 = hlfir.declare %5 {uniq_name = "_QFdo_concurrentElocal_var"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>) + %c1 = arith.constant 1 : index + %c10 = arith.constant 1 : index + fir.do_concurrent { + %9 = fir.alloca i32 {bindc_name = "i"} + %10:2 = hlfir.declare %9 {uniq_name = "_QFdo_concurrentEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>) + fir.do_concurrent.loop (%arg0) = (%c1) to (%c10) step (%c1) local(@local_privatizer %6#0 -> %arg1, @local_init_privatizer %4#0 -> %arg2 : !fir.ref<i32>, !fir.ref<i32>) { + %11 = fir.convert %arg0 : (index) -> i32 + fir.store %11 to %10#0 : !fir.ref<i32> + %13:2 = hlfir.declare %arg1 {uniq_name = "_QFdo_concurrentElocal_var"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>) + %15:2 = hlfir.declare %arg2 {uniq_name = "_QFdo_concurrentElocal_init_var"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>) + %17 = fir.load %10#0 : !fir.ref<i32> + %c5_i32 = arith.constant 5 : i32 + %18 = arith.cmpi slt, %17, %c5_i32 : i32 + fir.if %18 { + %c42_i32 = arith.constant 42 : i32 + hlfir.assign %c42_i32 to %13#0 : i32, !fir.ref<i32> + } else { + %c84_i32 = arith.constant 84 : i32 + hlfir.assign %c84_i32 to %15#0 : i32, !fir.ref<i32> + } + } + } + return +} + +// CHECK-LABEL: func.func @_QPdo_concurrent() { +// CHECK: %[[LOC_INIT_ALLOC:.*]] = fir.alloca i32 {bindc_name = "local_init_var", {{.*}}} +// CHECK: %[[LOC_INIT_DECL:.*]]:2 = hlfir.declare %[[LOC_INIT_ALLOC]] + +// CHECK: %[[LOC_ALLOC:.*]] = fir.alloca i32 {bindc_name = "local_var", {{.*}}} +// CHECK: %[[LOC_DECL:.*]]:2 = hlfir.declare %[[LOC_ALLOC]] + +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[C10:.*]] = arith.constant 1 : index + +// CHECK: fir.do_concurrent { +// CHECK: %[[DC_I_ALLOC:.*]] = fir.alloca i32 {bindc_name = "i"} +// CHECK: %[[DC_I_DECL:.*]]:2 = hlfir.declare %[[DC_I_ALLOC]] + +// CHECK: fir.do_concurrent.loop (%[[IV:.*]]) = (%[[C1]]) to (%[[C10]]) step (%[[C1]]) local(@[[LOCAL_PRIV_SYM]] %[[LOC_DECL]]#0 -> %[[LOC_ARG:.*]], @[[LOCAL_INIT_PRIV_SYM]] %[[LOC_INIT_DECL]]#0 -> %[[LOC_INIT_ARG:.*]] : !fir.ref<i32>, !fir.ref<i32>) { +// CHECK: %[[IV_CVT:.*]] = fir.convert %[[IV]] : (index) -> i32 +// CHECK: fir.store %[[IV_CVT]] to %[[DC_I_DECL]]#0 : !fir.ref<i32> + +// CHECK: %[[LOC_PRIV_DECL:.*]]:2 = hlfir.declare %[[LOC_ARG]] +// CHECK: %[[LOC_INIT_PRIV_DECL:.*]]:2 = hlfir.declare %[[LOC_INIT_ARG]] + +// CHECK: fir.if %{{.*}} { +// CHECK: %[[C42:.*]] = arith.constant 42 : i32 +// CHECK: hlfir.assign %[[C42]] to %[[LOC_PRIV_DECL]]#0 : i32, !fir.ref<i32> +// CHECK: } else { +// CHECK: %[[C84:.*]] = arith.constant 84 : i32 +// CHECK: hlfir.assign %[[C84]] to %[[LOC_INIT_PRIV_DECL]]#0 : i32, !fir.ref<i32> +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: return +// CHECK: } diff --git a/flang/test/Fir/invalid.fir b/flang/test/Fir/invalid.fir index f9f5e267dd9bc..3cd3ab439b0e9 100644 --- a/flang/test/Fir/invalid.fir +++ b/flang/test/Fir/invalid.fir @@ -1198,7 +1198,7 @@ func.func @dc_0d() { func.func @dc_invalid_parent(%arg0: index, %arg1: index) { // expected-error@+1 {{'fir.do_concurrent.loop' op expects parent op 'fir.do_concurrent'}} - "fir.do_concurrent.loop"(%arg0, %arg1) <{operandSegmentSizes = array<i32: 1, 1, 0, 0>}> ({ + "fir.do_concurrent.loop"(%arg0, %arg1) <{operandSegmentSizes = array<i32: 1, 1, 0, 0, 0>}> ({ ^bb0(%arg2: index): %tmp = "fir.alloca"() <{in_type = i32, operandSegmentSizes = array<i32: 0, 0>}> : () -> !fir.ref<i32> }) : (index, index) -> () @@ -1210,7 +1210,7 @@ func.func @dc_invalid_parent(%arg0: index, %arg1: index) { func.func @dc_invalid_control(%arg0: index, %arg1: index) { // expected-error@+2 {{'fir.do_concurrent.loop' op different number of tuple elements for lowerBound, upperBound or step}} fir.do_concurrent { - "fir.do_concurrent.loop"(%arg0, %arg1) <{operandSegmentSizes = array<i32: 1, 1, 0, 0>}> ({ + "fir.do_concurrent.loop"(%arg0, %arg1) <{operandSegmentSizes = array<i32: 1, 1, 0, 0, 0>}> ({ ^bb0(%arg2: index): %tmp = "fir.alloca"() <{in_type = i32, operandSegmentSizes = array<i32: 0, 0>}> : () -> !fir.ref<i32> }) : (index, index) -> () @@ -1223,7 +1223,7 @@ func.func @dc_invalid_control(%arg0: index, %arg1: index) { func.func @dc_invalid_ind_var(%arg0: index, %arg1: index) { // expected-error@+2 {{'fir.do_concurrent.loop' op expects the same number of induction variables: 2 as bound and step values: 1}} fir.do_concurrent { - "fir.do_concurrent.loop"(%arg0, %arg1, %arg0) <{operandSegmentSizes = array<i32: 1, 1, 1, 0>}> ({ + "fir.do_concurrent.loop"(%arg0, %arg1, %arg0) <{operandSegmentSizes = array<i32: 1, 1, 1, 0, 0>}> ({ ^bb0(%arg3: index, %arg4: index): %tmp = "fir.alloca"() <{in_type = i32, operandSegmentSizes = array<i32: 0, 0>}> : () -> !fir.ref<i32> }) : (index, index, index) -> () @@ -1236,7 +1236,7 @@ func.func @dc_invalid_ind_var(%arg0: index, %arg1: index) { func.func @dc_invalid_ind_var_type(%arg0: index, %arg1: index) { // expected-error@+2 {{'fir.do_concurrent.loop' op expects arguments for the induction variable to be of index type}} fir.do_concurrent { - "fir.do_concurrent.loop"(%arg0, %arg1, %arg0) <{operandSegmentSizes = array<i32: 1, 1, 1, 0>}> ({ + "fir.do_concurrent.loop"(%arg0, %arg1, %arg0) <{operandSegmentSizes = array<i32: 1, 1, 1, 0, 0>}> ({ ^bb0(%arg3: i32): %tmp = "fir.alloca"() <{in_type = i32, operandSegmentSizes = array<i32: 0, 0>}> : () -> !fir.ref<i32> }) : (index, index, index) -> () @@ -1250,7 +1250,7 @@ func.func @dc_invalid_reduction(%arg0: index, %arg1: index) { %sum = fir.alloca i32 // expected-error@+2 {{'fir.do_concurrent.loop' op mismatch in number of reduction variables and reduction attributes}} fir.do_concurrent { - "fir.do_concurrent.loop"(%arg0, %arg1, %arg0, %sum) <{operandSegmentSizes = array<i32: 1, 1, 1, 1>}> ({ + "fir.do_concurrent.loop"(%arg0, %arg1, %arg0, %sum) <{operandSegmentSizes = array<i32: 1, 1, 1, 1, 0>}> ({ ^bb0(%arg3: index): %tmp = "fir.alloca"() <{in_type = i32, operandSegmentSizes = array<i32: 0, 0>}> : () -> !fir.ref<i32> }) : (index, index, index, !fir.ref<i32>) -> () _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits