llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-clangir Author: Erich Keane (erichkeane) <details> <summary>Changes</summary> As a follow up to 3c4dff3ac6884b85fe93fe512c5bdaf014738c45 I audited all uses of 'process clause and use additive methods', and added explicit functions to the construct to make it easier for the next project to attempt to use this mechanism (vs construct all operands/etc in advance, then add all at once). I've only done ones that I have attempted to use so far(as a catch-up, so no var-list clauses, and no constructs that can't be used without a var-list, and no loop, and no compound constructs). I intend to do those "as I go" with the lowering of each of those things instead. --- Patch is 33.12 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/137396.diff 3 Files Affected: - (modified) clang/lib/CIR/CodeGen/CIRGenStmtOpenACC.cpp (+33-170) - (modified) mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td (+80) - (modified) mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp (+289) ``````````diff diff --git a/clang/lib/CIR/CodeGen/CIRGenStmtOpenACC.cpp b/clang/lib/CIR/CodeGen/CIRGenStmtOpenACC.cpp index 6e65f94c78bed..6f86d2b681a1e 100644 --- a/clang/lib/CIR/CodeGen/CIRGenStmtOpenACC.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenStmtOpenACC.cpp @@ -46,7 +46,17 @@ class OpenACCClauseCIREmitter final // diagnostics are gone. SourceLocation dirLoc; - const OpenACCDeviceTypeClause *lastDeviceTypeClause = nullptr; + llvm::SmallVector<mlir::acc::DeviceType> lastDeviceTypeValues; + + void setLastDeviceTypeClause(const OpenACCDeviceTypeClause &clause) { + lastDeviceTypeValues.clear(); + + llvm::for_each(clause.getArchitectures(), + [this](const DeviceTypeArgument &arg) { + lastDeviceTypeValues.push_back( + decodeDeviceType(arg.getIdentifierInfo())); + }); + } void clauseNotImplemented(const OpenACCClause &c) { cgf.cgm.errorNYI(c.getSourceRange(), "OpenACC Clause", c.getClauseKind()); @@ -95,114 +105,6 @@ class OpenACCClauseCIREmitter final .CaseLower("radeon", mlir::acc::DeviceType::Radeon); } - // Overload of this function that only returns the device-types list. - mlir::ArrayAttr - handleDeviceTypeAffectedClause(mlir::ArrayAttr existingDeviceTypes) { - mlir::ValueRange argument; - mlir::MutableOperandRange range{operation}; - - return handleDeviceTypeAffectedClause(existingDeviceTypes, argument, range); - } - // Overload of this function for when 'segments' aren't necessary. - mlir::ArrayAttr - handleDeviceTypeAffectedClause(mlir::ArrayAttr existingDeviceTypes, - mlir::ValueRange argument, - mlir::MutableOperandRange argCollection) { - llvm::SmallVector<int32_t> segments; - assert(argument.size() <= 1 && - "Overload only for cases where segments don't need to be added"); - return handleDeviceTypeAffectedClause(existingDeviceTypes, argument, - argCollection, segments); - } - - // Handle a clause affected by the 'device_type' to the point that they need - // to have attributes added in the correct/corresponding order, such as - // 'num_workers' or 'vector_length' on a compute construct. The 'argument' is - // a collection of operands that need to be appended to the `argCollection` as - // we're adding a 'device_type' entry. If there is more than 0 elements in - // the 'argument', the collection must be non-null, as it is needed to add to - // it. - // As some clauses, such as 'num_gangs' or 'wait' require a 'segments' list to - // be maintained, this takes a list of segments that will be updated with the - // proper counts as 'argument' elements are added. - // - // In MLIR, the 'operands' are stored as a large array, with a separate array - // of 'segments' that show which 'operand' applies to which 'operand-kind'. - // That is, a 'num_workers' operand-kind or 'num_vectors' operand-kind. - // - // So the operands array might have 4 elements, but the 'segments' array will - // be something like: - // - // {0, 0, 0, 2, 0, 1, 1, 0, 0...} - // - // Where each position belongs to a specific 'operand-kind'. So that - // specifies that whichever operand-kind corresponds with index '3' has 2 - // elements, and should take the 1st 2 operands off the list (since all - // preceding values are 0). operand-kinds corresponding to 5 and 6 each have - // 1 element. - // - // Fortunately, the `MutableOperandRange` append function actually takes care - // of that for us at the 'top level'. - // - // However, in cases like `num_gangs' or 'wait', where each individual - // 'element' might be itself array-like, there is a separate 'segments' array - // for them. So in the case of: - // - // device_type(nvidia, radeon) num_gangs(1, 2, 3) - // - // We have to emit that as TWO arrays into the IR (where the device_type is an - // attribute), so they look like: - // - // num_gangs({One : i32, Two : i32, Three : i32} [#acc.device_type<nvidia>],\ - // {One : i32, Two : i32, Three : i32} [#acc.device_type<radeon>]) - // - // When stored in the 'operands' list, the top-level 'segment' for - // 'num_gangs' just shows 6 elements. In order to get the array-like - // apperance, the 'numGangsSegments' list is kept as well. In the above case, - // we've inserted 6 operands, so the 'numGangsSegments' must contain 2 - // elements, 1 per array, and each will have a value of 3. The verifier will - // ensure that the collections counts are correct. - mlir::ArrayAttr - handleDeviceTypeAffectedClause(mlir::ArrayAttr existingDeviceTypes, - mlir::ValueRange argument, - mlir::MutableOperandRange argCollection, - llvm::SmallVector<int32_t> &segments) { - llvm::SmallVector<mlir::Attribute> deviceTypes; - - // Collect the 'existing' device-type attributes so we can re-create them - // and insert them. - if (existingDeviceTypes) { - for (const mlir::Attribute &Attr : existingDeviceTypes) - deviceTypes.push_back(mlir::acc::DeviceTypeAttr::get( - builder.getContext(), - cast<mlir::acc::DeviceTypeAttr>(Attr).getValue())); - } - - // Insert 1 version of the 'expr' to the NumWorkers list per-current - // device type. - if (lastDeviceTypeClause) { - for (const DeviceTypeArgument &arch : - lastDeviceTypeClause->getArchitectures()) { - deviceTypes.push_back(mlir::acc::DeviceTypeAttr::get( - builder.getContext(), decodeDeviceType(arch.getIdentifierInfo()))); - if (!argument.empty()) { - argCollection.append(argument); - segments.push_back(argument.size()); - } - } - } else { - // Else, we just add a single for 'none'. - deviceTypes.push_back(mlir::acc::DeviceTypeAttr::get( - builder.getContext(), mlir::acc::DeviceType::None)); - if (!argument.empty()) { - argCollection.append(argument); - segments.push_back(argument.size()); - } - } - - return mlir::ArrayAttr::get(builder.getContext(), deviceTypes); - } - public: OpenACCClauseCIREmitter(OpTy &operation, CIRGenFunction &cgf, CIRGenBuilderTy &builder, @@ -236,7 +138,8 @@ class OpenACCClauseCIREmitter final } void VisitDeviceTypeClause(const OpenACCDeviceTypeClause &clause) { - lastDeviceTypeClause = &clause; + setLastDeviceTypeClause(clause); + if constexpr (isOneOfTypes<OpTy, InitOp, ShutdownOp>) { llvm::for_each( clause.getArchitectures(), [this](const DeviceTypeArgument &arg) { @@ -253,8 +156,8 @@ class OpenACCClauseCIREmitter final } else if constexpr (isOneOfTypes<OpTy, ParallelOp, SerialOp, KernelsOp, DataOp>) { // Nothing to do here, these constructs don't have any IR for these, as - // they just modify the other clauses IR. So setting of `lastDeviceType` - // (done above) is all we need. + // they just modify the other clauses IR. So setting of + // `lastDeviceTypeValues` (done above) is all we need. } else { // TODO: When we've implemented this for everything, switch this to an // unreachable. update, data, loop, routine, combined constructs remain. @@ -264,10 +167,9 @@ class OpenACCClauseCIREmitter final void VisitNumWorkersClause(const OpenACCNumWorkersClause &clause) { if constexpr (isOneOfTypes<OpTy, ParallelOp, KernelsOp>) { - mlir::MutableOperandRange range = operation.getNumWorkersMutable(); - operation.setNumWorkersDeviceTypeAttr(handleDeviceTypeAffectedClause( - operation.getNumWorkersDeviceTypeAttr(), - createIntExpr(clause.getIntExpr()), range)); + operation.addNumWorkersOperand(builder.getContext(), + createIntExpr(clause.getIntExpr()), + lastDeviceTypeValues); } else if constexpr (isOneOfTypes<OpTy, SerialOp>) { llvm_unreachable("num_workers not valid on serial"); } else { @@ -279,10 +181,9 @@ class OpenACCClauseCIREmitter final void VisitVectorLengthClause(const OpenACCVectorLengthClause &clause) { if constexpr (isOneOfTypes<OpTy, ParallelOp, KernelsOp>) { - mlir::MutableOperandRange range = operation.getVectorLengthMutable(); - operation.setVectorLengthDeviceTypeAttr(handleDeviceTypeAffectedClause( - operation.getVectorLengthDeviceTypeAttr(), - createIntExpr(clause.getIntExpr()), range)); + operation.addVectorLengthOperand(builder.getContext(), + createIntExpr(clause.getIntExpr()), + lastDeviceTypeValues); } else if constexpr (isOneOfTypes<OpTy, SerialOp>) { llvm_unreachable("vector_length not valid on serial"); } else { @@ -294,15 +195,12 @@ class OpenACCClauseCIREmitter final void VisitAsyncClause(const OpenACCAsyncClause &clause) { if constexpr (isOneOfTypes<OpTy, ParallelOp, SerialOp, KernelsOp, DataOp>) { - if (!clause.hasIntExpr()) { - operation.setAsyncOnlyAttr( - handleDeviceTypeAffectedClause(operation.getAsyncOnlyAttr())); - } else { - mlir::MutableOperandRange range = operation.getAsyncOperandsMutable(); - operation.setAsyncOperandsDeviceTypeAttr(handleDeviceTypeAffectedClause( - operation.getAsyncOperandsDeviceTypeAttr(), - createIntExpr(clause.getIntExpr()), range)); - } + if (!clause.hasIntExpr()) + operation.addAsyncOnly(builder.getContext(), lastDeviceTypeValues); + else + operation.addAsyncOperand(builder.getContext(), + createIntExpr(clause.getIntExpr()), + lastDeviceTypeValues); } else if constexpr (isOneOfTypes<OpTy, WaitOp>) { // Wait doesn't have a device_type, so its handling here is slightly // different. @@ -366,19 +264,11 @@ class OpenACCClauseCIREmitter final void VisitNumGangsClause(const OpenACCNumGangsClause &clause) { if constexpr (isOneOfTypes<OpTy, ParallelOp, KernelsOp>) { llvm::SmallVector<mlir::Value> values; - for (const Expr *E : clause.getIntExprs()) values.push_back(createIntExpr(E)); - llvm::SmallVector<int32_t> segments; - if (operation.getNumGangsSegments()) - llvm::copy(*operation.getNumGangsSegments(), - std::back_inserter(segments)); - - mlir::MutableOperandRange range = operation.getNumGangsMutable(); - operation.setNumGangsDeviceTypeAttr(handleDeviceTypeAffectedClause( - operation.getNumGangsDeviceTypeAttr(), values, range, segments)); - operation.setNumGangsSegments(llvm::ArrayRef<int32_t>{segments}); + operation.addNumGangsOperands(builder.getContext(), values, + lastDeviceTypeValues); } else { // TODO: When we've implemented this for everything, switch this to an // unreachable. Combined constructs remain. @@ -389,42 +279,15 @@ class OpenACCClauseCIREmitter final void VisitWaitClause(const OpenACCWaitClause &clause) { if constexpr (isOneOfTypes<OpTy, ParallelOp, SerialOp, KernelsOp, DataOp>) { if (!clause.hasExprs()) { - operation.setWaitOnlyAttr( - handleDeviceTypeAffectedClause(operation.getWaitOnlyAttr())); + operation.addWaitOnly(builder.getContext(), lastDeviceTypeValues); } else { llvm::SmallVector<mlir::Value> values; - if (clause.hasDevNumExpr()) values.push_back(createIntExpr(clause.getDevNumExpr())); for (const Expr *E : clause.getQueueIdExprs()) values.push_back(createIntExpr(E)); - - llvm::SmallVector<int32_t> segments; - if (operation.getWaitOperandsSegments()) - llvm::copy(*operation.getWaitOperandsSegments(), - std::back_inserter(segments)); - - unsigned beforeSegmentSize = segments.size(); - - mlir::MutableOperandRange range = operation.getWaitOperandsMutable(); - operation.setWaitOperandsDeviceTypeAttr(handleDeviceTypeAffectedClause( - operation.getWaitOperandsDeviceTypeAttr(), values, range, - segments)); - operation.setWaitOperandsSegments(segments); - - // In addition to having to set the 'segments', wait also has a list of - // bool attributes whether it is annotated with 'devnum'. We can use - // our knowledge of how much the 'segments' array grew to determine how - // many we need to add. - llvm::SmallVector<bool> hasDevNums; - if (operation.getHasWaitDevnumAttr()) - for (mlir::Attribute A : operation.getHasWaitDevnumAttr()) - hasDevNums.push_back(cast<mlir::BoolAttr>(A).getValue()); - - hasDevNums.insert(hasDevNums.end(), segments.size() - beforeSegmentSize, - clause.hasDevNumExpr()); - - operation.setHasWaitDevnumAttr(builder.getBoolArrayAttr(hasDevNums)); + operation.addWaitOperands(builder.getContext(), clause.hasDevNumExpr(), + values, lastDeviceTypeValues); } } else { // TODO: When we've implemented this for everything, switch this to an @@ -589,7 +452,7 @@ CIRGenFunction::emitOpenACCWaitConstruct(const OpenACCWaitConstruct &s) { if (s.hasDevNumExpr()) waitOp.getWaitDevnumMutable().append(createIntExpr(s.getDevNumExpr())); - for (Expr *QueueExpr : s.getQueueIdExprs()) + for (Expr *QueueExpr : s.getQueueIdExprs()) waitOp.getWaitOperandsMutable().append(createIntExpr(QueueExpr)); } diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td index 2167129e9e1c7..cd2dc86f7bf51 100644 --- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td +++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td @@ -1408,6 +1408,31 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel", static mlir::acc::Construct getConstructId() { return mlir::acc::Construct::acc_construct_parallel; } + /// Add a value to 'num_workers' with the current list of device types. + void addNumWorkersOperand(MLIRContext *, mlir::Value, + llvm::ArrayRef<DeviceType>); + /// Add a value to 'vector_length' with the current list of device types. + void addVectorLengthOperand(MLIRContext *, mlir::Value, + llvm::ArrayRef<DeviceType>); + /// Add an entry to the 'async-only' attribute (clause spelled without + /// arguments)for each of the additional device types (or a none if it is + /// empty). + void addAsyncOnly(MLIRContext *, llvm::ArrayRef<DeviceType>); + /// Add a value to the 'async' with the current list of device types. + void addAsyncOperand(MLIRContext *, mlir::Value, + llvm::ArrayRef<DeviceType>); + /// Add an array-like entry to the 'num_gangs' with the current list of + /// device types. + void addNumGangsOperands(MLIRContext *, mlir::ValueRange, + llvm::ArrayRef<DeviceType>); + /// Add an entry to the 'wait-only' attribute (clause spelled without + /// arguments)for each of the additional device types (or a none if it is + /// empty). + void addWaitOnly(MLIRContext *, llvm::ArrayRef<DeviceType>); + /// Add an arrya-like entry to the 'wait' with the current list of device + /// types. + void addWaitOperands(MLIRContext *, bool hasDevnum, mlir::ValueRange, + llvm::ArrayRef<DeviceType>); }]; let assemblyFormat = [{ @@ -1535,6 +1560,21 @@ def OpenACC_SerialOp : OpenACC_Op<"serial", static mlir::acc::Construct getConstructId() { return mlir::acc::Construct::acc_construct_serial; } + /// Add an entry to the 'async-only' attribute (clause spelled without + /// arguments)for each of the additional device types (or a none if it is + /// empty). + void addAsyncOnly(MLIRContext *, llvm::ArrayRef<DeviceType>); + /// Add a value to the 'async' with the current list of device types. + void addAsyncOperand(MLIRContext *, mlir::Value, + llvm::ArrayRef<DeviceType>); + /// Add an entry to the 'wait-only' attribute (clause spelled without + /// arguments)for each of the additional device types (or a none if it is + /// empty). + void addWaitOnly(MLIRContext *, llvm::ArrayRef<DeviceType>); + /// Add an arrya-like entry to the 'wait' with the current list of device + /// types. + void addWaitOperands(MLIRContext *, bool hasDevnum, mlir::ValueRange, + llvm::ArrayRef<DeviceType>); }]; let assemblyFormat = [{ @@ -1679,6 +1719,31 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels", static mlir::acc::Construct getConstructId() { return mlir::acc::Construct::acc_construct_kernels; } + /// Add a value to 'num_workers' with the current list of device types. + void addNumWorkersOperand(MLIRContext *, mlir::Value, + llvm::ArrayRef<DeviceType>); + /// Add a value to 'vector_length' with the current list of device types. + void addVectorLengthOperand(MLIRContext *, mlir::Value, + llvm::ArrayRef<DeviceType>); + /// Add an entry to the 'async-only' attribute (clause spelled without + /// arguments)for each of the additional device types (or a none if it is + /// empty). + void addAsyncOnly(MLIRContext *, llvm::ArrayRef<DeviceType>); + /// Add a value to the 'async' with the current list of device types. + void addAsyncOperand(MLIRContext *, mlir::Value, + llvm::ArrayRef<DeviceType>); + /// Add an array-like entry to the 'num_gangs' with the current list of + /// device types. + void addNumGangsOperands(MLIRContext *, mlir::ValueRange, + llvm::ArrayRef<DeviceType>); + /// Add an entry to the 'wait-only' attribute (clause spelled without + /// arguments)for each of the additional device types (or a none if it is + /// empty). + void addWaitOnly(MLIRContext *, llvm::ArrayRef<DeviceType>); + /// Add an arrya-like entry to the 'wait' with the current list of device + /// types. + void addWaitOperands(MLIRContext *, bool hasDevnum, mlir::ValueRange, + llvm::ArrayRef<DeviceType>); }]; let assemblyFormat = [{ @@ -1785,6 +1850,21 @@ def OpenACC_DataOp : OpenACC_Op<"data", /// Return the wait devnum value clause for the given device_type if /// present. mlir::Value getWaitDevnum(mlir::acc::DeviceType deviceType); + /// Add an entry to the 'async-only' attribute (clause spelled without + /// arguments)for each of the additional device types (or a none if it is + /// empty). + void addAsyncOnly(MLIRContext *, llvm::ArrayRef<DeviceType>); + /// Add a value to the 'async' with the current list of device types. + void addAsyncOperand(MLIRContext *, mlir::Value, + llvm::ArrayRef<DeviceType>); + /// Add an entry to the 'wait-only' attribute (clause spelled without + /// arguments)for each of the additional device types (or a none if it is + /// empty). + void addWaitOnly(MLIRContext *, llvm::ArrayRef<DeviceType>); + /// Add an arrya-like entry to the 'wait' with the current list of device + /// types. + void addWaitOperands(MLIRContext *, bool hasDevnum, mlir::ValueRange, + llvm::ArrayRef<DeviceType>); }]; let assemblyFormat = [{ diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index 04cbe200eafe9..fa684942f60d0 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -76,6 +76,69 @@ struct LLVMPointerPointerLikeModel LLVM::LLVMPointerType> { Type getElementType(Type pointer) const { return Type(); } }; + +/// Helper function for any of the times we need to modify an ArrayAttr based on +/// a device type list. Returns a new ArrayAttr with all of the +/// existingDeviceTypes, plus the effective new ones(or an added none if hte new +/// lis... [truncated] `````````` </details> https://github.com/llvm/llvm-project/pull/137396 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits