llvmbot wrote:

<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-clangir

Author: Erich Keane (erichkeane)

<details>
<summary>Changes</summary>

As a follow up to 3c4dff3ac6884b85fe93fe512c5bdaf014738c45 I audited all uses 
of 'process clause and use additive methods', and added explicit functions to 
the construct to make it easier for the next project to attempt to use this 
mechanism (vs construct all operands/etc in advance, then add all at once).

I've only done ones that I have attempted to use so far(as a catch-up, so no 
var-list clauses, and no constructs that can't be used without a var-list, and 
no loop, and no compound constructs).  I intend to do those "as I go" with the 
lowering of each of those things instead.

---

Patch is 33.12 KiB, truncated to 20.00 KiB below, full version: 
https://github.com/llvm/llvm-project/pull/137396.diff


3 Files Affected:

- (modified) clang/lib/CIR/CodeGen/CIRGenStmtOpenACC.cpp (+33-170) 
- (modified) mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td (+80) 
- (modified) mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp (+289) 


``````````diff
diff --git a/clang/lib/CIR/CodeGen/CIRGenStmtOpenACC.cpp 
b/clang/lib/CIR/CodeGen/CIRGenStmtOpenACC.cpp
index 6e65f94c78bed..6f86d2b681a1e 100644
--- a/clang/lib/CIR/CodeGen/CIRGenStmtOpenACC.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenStmtOpenACC.cpp
@@ -46,7 +46,17 @@ class OpenACCClauseCIREmitter final
   // diagnostics are gone.
   SourceLocation dirLoc;
 
-  const OpenACCDeviceTypeClause *lastDeviceTypeClause = nullptr;
+  llvm::SmallVector<mlir::acc::DeviceType> lastDeviceTypeValues;
+
+  void setLastDeviceTypeClause(const OpenACCDeviceTypeClause &clause) {
+    lastDeviceTypeValues.clear();
+
+    llvm::for_each(clause.getArchitectures(),
+                   [this](const DeviceTypeArgument &arg) {
+                     lastDeviceTypeValues.push_back(
+                         decodeDeviceType(arg.getIdentifierInfo()));
+                   });
+  }
 
   void clauseNotImplemented(const OpenACCClause &c) {
     cgf.cgm.errorNYI(c.getSourceRange(), "OpenACC Clause", c.getClauseKind());
@@ -95,114 +105,6 @@ class OpenACCClauseCIREmitter final
         .CaseLower("radeon", mlir::acc::DeviceType::Radeon);
   }
 
-  // Overload of this function that only returns the device-types list.
-  mlir::ArrayAttr
-  handleDeviceTypeAffectedClause(mlir::ArrayAttr existingDeviceTypes) {
-    mlir::ValueRange argument;
-    mlir::MutableOperandRange range{operation};
-
-    return handleDeviceTypeAffectedClause(existingDeviceTypes, argument, 
range);
-  }
-  // Overload of this function for when 'segments' aren't necessary.
-  mlir::ArrayAttr
-  handleDeviceTypeAffectedClause(mlir::ArrayAttr existingDeviceTypes,
-                                 mlir::ValueRange argument,
-                                 mlir::MutableOperandRange argCollection) {
-    llvm::SmallVector<int32_t> segments;
-    assert(argument.size() <= 1 &&
-           "Overload only for cases where segments don't need to be added");
-    return handleDeviceTypeAffectedClause(existingDeviceTypes, argument,
-                                          argCollection, segments);
-  }
-
-  // Handle a clause affected by the 'device_type' to the point that they need
-  // to have attributes added in the correct/corresponding order, such as
-  // 'num_workers' or 'vector_length' on a compute construct. The 'argument' is
-  // a collection of operands that need to be appended to the `argCollection` 
as
-  // we're adding a 'device_type' entry.  If there is more than 0 elements in
-  // the 'argument', the collection must be non-null, as it is needed to add to
-  // it.
-  // As some clauses, such as 'num_gangs' or 'wait' require a 'segments' list 
to
-  // be maintained, this takes a list of segments that will be updated with the
-  // proper counts as 'argument' elements are added.
-  //
-  // In MLIR, the 'operands' are stored as a large array, with a separate array
-  // of 'segments' that show which 'operand' applies to which 'operand-kind'.
-  // That is, a 'num_workers' operand-kind or 'num_vectors' operand-kind.
-  //
-  // So the operands array might have 4 elements, but the 'segments' array will
-  // be something like:
-  //
-  // {0, 0, 0, 2, 0, 1, 1, 0, 0...}
-  //
-  // Where each position belongs to a specific 'operand-kind'.  So that
-  // specifies that whichever operand-kind corresponds with index '3' has 2
-  // elements, and should take the 1st 2 operands off the list (since all
-  // preceding values are 0). operand-kinds corresponding to 5 and 6 each have
-  // 1 element.
-  //
-  // Fortunately, the `MutableOperandRange` append function actually takes care
-  // of that for us at the 'top level'.
-  //
-  // However, in cases like `num_gangs' or 'wait', where each individual
-  // 'element' might be itself array-like, there is a separate 'segments' array
-  // for them. So in the case of:
-  //
-  // device_type(nvidia, radeon) num_gangs(1, 2, 3)
-  //
-  // We have to emit that as TWO arrays into the IR (where the device_type is 
an
-  // attribute), so they look like:
-  //
-  // num_gangs({One : i32, Two : i32, Three : i32} [#acc.device_type<nvidia>],\
-  //           {One : i32, Two : i32, Three : i32} [#acc.device_type<radeon>])
-  //
-  // When stored in the 'operands' list, the top-level 'segment' for
-  // 'num_gangs' just shows 6 elements. In order to get the array-like
-  // apperance, the 'numGangsSegments' list is kept as well. In the above case,
-  // we've inserted 6 operands, so the 'numGangsSegments' must contain 2
-  // elements, 1 per array, and each will have a value of 3.  The verifier will
-  // ensure that the collections counts are correct.
-  mlir::ArrayAttr
-  handleDeviceTypeAffectedClause(mlir::ArrayAttr existingDeviceTypes,
-                                 mlir::ValueRange argument,
-                                 mlir::MutableOperandRange argCollection,
-                                 llvm::SmallVector<int32_t> &segments) {
-    llvm::SmallVector<mlir::Attribute> deviceTypes;
-
-    // Collect the 'existing' device-type attributes so we can re-create them
-    // and insert them.
-    if (existingDeviceTypes) {
-      for (const mlir::Attribute &Attr : existingDeviceTypes)
-        deviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
-            builder.getContext(),
-            cast<mlir::acc::DeviceTypeAttr>(Attr).getValue()));
-    }
-
-    // Insert 1 version of the 'expr' to the NumWorkers list per-current
-    // device type.
-    if (lastDeviceTypeClause) {
-      for (const DeviceTypeArgument &arch :
-           lastDeviceTypeClause->getArchitectures()) {
-        deviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
-            builder.getContext(), decodeDeviceType(arch.getIdentifierInfo())));
-        if (!argument.empty()) {
-          argCollection.append(argument);
-          segments.push_back(argument.size());
-        }
-      }
-    } else {
-      // Else, we just add a single for 'none'.
-      deviceTypes.push_back(mlir::acc::DeviceTypeAttr::get(
-          builder.getContext(), mlir::acc::DeviceType::None));
-      if (!argument.empty()) {
-        argCollection.append(argument);
-        segments.push_back(argument.size());
-      }
-    }
-
-    return mlir::ArrayAttr::get(builder.getContext(), deviceTypes);
-  }
-
 public:
   OpenACCClauseCIREmitter(OpTy &operation, CIRGenFunction &cgf,
                           CIRGenBuilderTy &builder,
@@ -236,7 +138,8 @@ class OpenACCClauseCIREmitter final
   }
 
   void VisitDeviceTypeClause(const OpenACCDeviceTypeClause &clause) {
-    lastDeviceTypeClause = &clause;
+    setLastDeviceTypeClause(clause);
+
     if constexpr (isOneOfTypes<OpTy, InitOp, ShutdownOp>) {
       llvm::for_each(
           clause.getArchitectures(), [this](const DeviceTypeArgument &arg) {
@@ -253,8 +156,8 @@ class OpenACCClauseCIREmitter final
     } else if constexpr (isOneOfTypes<OpTy, ParallelOp, SerialOp, KernelsOp,
                                       DataOp>) {
       // Nothing to do here, these constructs don't have any IR for these, as
-      // they just modify the other clauses IR.  So setting of `lastDeviceType`
-      // (done above) is all we need.
+      // they just modify the other clauses IR.  So setting of
+      // `lastDeviceTypeValues` (done above) is all we need.
     } else {
       // TODO: When we've implemented this for everything, switch this to an
       // unreachable. update, data, loop, routine, combined constructs remain.
@@ -264,10 +167,9 @@ class OpenACCClauseCIREmitter final
 
   void VisitNumWorkersClause(const OpenACCNumWorkersClause &clause) {
     if constexpr (isOneOfTypes<OpTy, ParallelOp, KernelsOp>) {
-      mlir::MutableOperandRange range = operation.getNumWorkersMutable();
-      operation.setNumWorkersDeviceTypeAttr(handleDeviceTypeAffectedClause(
-          operation.getNumWorkersDeviceTypeAttr(),
-          createIntExpr(clause.getIntExpr()), range));
+      operation.addNumWorkersOperand(builder.getContext(),
+                                     createIntExpr(clause.getIntExpr()),
+                                     lastDeviceTypeValues);
     } else if constexpr (isOneOfTypes<OpTy, SerialOp>) {
       llvm_unreachable("num_workers not valid on serial");
     } else {
@@ -279,10 +181,9 @@ class OpenACCClauseCIREmitter final
 
   void VisitVectorLengthClause(const OpenACCVectorLengthClause &clause) {
     if constexpr (isOneOfTypes<OpTy, ParallelOp, KernelsOp>) {
-      mlir::MutableOperandRange range = operation.getVectorLengthMutable();
-      operation.setVectorLengthDeviceTypeAttr(handleDeviceTypeAffectedClause(
-          operation.getVectorLengthDeviceTypeAttr(),
-          createIntExpr(clause.getIntExpr()), range));
+      operation.addVectorLengthOperand(builder.getContext(),
+                                       createIntExpr(clause.getIntExpr()),
+                                       lastDeviceTypeValues);
     } else if constexpr (isOneOfTypes<OpTy, SerialOp>) {
       llvm_unreachable("vector_length not valid on serial");
     } else {
@@ -294,15 +195,12 @@ class OpenACCClauseCIREmitter final
 
   void VisitAsyncClause(const OpenACCAsyncClause &clause) {
     if constexpr (isOneOfTypes<OpTy, ParallelOp, SerialOp, KernelsOp, DataOp>) 
{
-      if (!clause.hasIntExpr()) {
-        operation.setAsyncOnlyAttr(
-            handleDeviceTypeAffectedClause(operation.getAsyncOnlyAttr()));
-      } else {
-        mlir::MutableOperandRange range = operation.getAsyncOperandsMutable();
-        
operation.setAsyncOperandsDeviceTypeAttr(handleDeviceTypeAffectedClause(
-            operation.getAsyncOperandsDeviceTypeAttr(),
-            createIntExpr(clause.getIntExpr()), range));
-      }
+      if (!clause.hasIntExpr())
+        operation.addAsyncOnly(builder.getContext(), lastDeviceTypeValues);
+      else
+        operation.addAsyncOperand(builder.getContext(),
+                                  createIntExpr(clause.getIntExpr()),
+                                  lastDeviceTypeValues);
     } else if constexpr (isOneOfTypes<OpTy, WaitOp>) {
       // Wait doesn't have a device_type, so its handling here is slightly
       // different.
@@ -366,19 +264,11 @@ class OpenACCClauseCIREmitter final
   void VisitNumGangsClause(const OpenACCNumGangsClause &clause) {
     if constexpr (isOneOfTypes<OpTy, ParallelOp, KernelsOp>) {
       llvm::SmallVector<mlir::Value> values;
-
       for (const Expr *E : clause.getIntExprs())
         values.push_back(createIntExpr(E));
 
-      llvm::SmallVector<int32_t> segments;
-      if (operation.getNumGangsSegments())
-        llvm::copy(*operation.getNumGangsSegments(),
-                   std::back_inserter(segments));
-
-      mlir::MutableOperandRange range = operation.getNumGangsMutable();
-      operation.setNumGangsDeviceTypeAttr(handleDeviceTypeAffectedClause(
-          operation.getNumGangsDeviceTypeAttr(), values, range, segments));
-      operation.setNumGangsSegments(llvm::ArrayRef<int32_t>{segments});
+      operation.addNumGangsOperands(builder.getContext(), values,
+                                    lastDeviceTypeValues);
     } else {
       // TODO: When we've implemented this for everything, switch this to an
       // unreachable. Combined constructs remain.
@@ -389,42 +279,15 @@ class OpenACCClauseCIREmitter final
   void VisitWaitClause(const OpenACCWaitClause &clause) {
     if constexpr (isOneOfTypes<OpTy, ParallelOp, SerialOp, KernelsOp, DataOp>) 
{
       if (!clause.hasExprs()) {
-        operation.setWaitOnlyAttr(
-            handleDeviceTypeAffectedClause(operation.getWaitOnlyAttr()));
+        operation.addWaitOnly(builder.getContext(), lastDeviceTypeValues);
       } else {
         llvm::SmallVector<mlir::Value> values;
-
         if (clause.hasDevNumExpr())
           values.push_back(createIntExpr(clause.getDevNumExpr()));
         for (const Expr *E : clause.getQueueIdExprs())
           values.push_back(createIntExpr(E));
-
-        llvm::SmallVector<int32_t> segments;
-        if (operation.getWaitOperandsSegments())
-          llvm::copy(*operation.getWaitOperandsSegments(),
-                     std::back_inserter(segments));
-
-        unsigned beforeSegmentSize = segments.size();
-
-        mlir::MutableOperandRange range = operation.getWaitOperandsMutable();
-        operation.setWaitOperandsDeviceTypeAttr(handleDeviceTypeAffectedClause(
-            operation.getWaitOperandsDeviceTypeAttr(), values, range,
-            segments));
-        operation.setWaitOperandsSegments(segments);
-
-        // In addition to having to set the 'segments', wait also has a list of
-        // bool attributes whether it is annotated with 'devnum'.  We can use
-        // our knowledge of how much the 'segments' array grew to determine how
-        // many we need to add.
-        llvm::SmallVector<bool> hasDevNums;
-        if (operation.getHasWaitDevnumAttr())
-          for (mlir::Attribute A : operation.getHasWaitDevnumAttr())
-            hasDevNums.push_back(cast<mlir::BoolAttr>(A).getValue());
-
-        hasDevNums.insert(hasDevNums.end(), segments.size() - 
beforeSegmentSize,
-                          clause.hasDevNumExpr());
-
-        operation.setHasWaitDevnumAttr(builder.getBoolArrayAttr(hasDevNums));
+        operation.addWaitOperands(builder.getContext(), clause.hasDevNumExpr(),
+                                  values, lastDeviceTypeValues);
       }
     } else {
       // TODO: When we've implemented this for everything, switch this to an
@@ -589,7 +452,7 @@ CIRGenFunction::emitOpenACCWaitConstruct(const 
OpenACCWaitConstruct &s) {
     if (s.hasDevNumExpr())
       waitOp.getWaitDevnumMutable().append(createIntExpr(s.getDevNumExpr()));
 
-    for (Expr *QueueExpr  : s.getQueueIdExprs())
+    for (Expr *QueueExpr : s.getQueueIdExprs())
       waitOp.getWaitOperandsMutable().append(createIntExpr(QueueExpr));
   }
 
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td 
b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index 2167129e9e1c7..cd2dc86f7bf51 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -1408,6 +1408,31 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
     static mlir::acc::Construct getConstructId() {
       return mlir::acc::Construct::acc_construct_parallel;
     }
+    /// Add a value to 'num_workers' with the current list of device types.
+    void addNumWorkersOperand(MLIRContext *, mlir::Value,
+                              llvm::ArrayRef<DeviceType>);
+    /// Add a value to 'vector_length' with the current list of device types.
+    void addVectorLengthOperand(MLIRContext *, mlir::Value,
+                                llvm::ArrayRef<DeviceType>);
+    /// Add an entry to the 'async-only' attribute (clause spelled without
+    /// arguments)for each of the additional device types (or a none if it is
+    /// empty).
+    void addAsyncOnly(MLIRContext *, llvm::ArrayRef<DeviceType>);
+    /// Add a value to the 'async' with the current list of device types.
+    void addAsyncOperand(MLIRContext *, mlir::Value,
+                         llvm::ArrayRef<DeviceType>);
+    /// Add an array-like entry to the 'num_gangs' with the current list of
+    /// device types.
+    void addNumGangsOperands(MLIRContext *, mlir::ValueRange,
+                             llvm::ArrayRef<DeviceType>);
+    /// Add an entry to the 'wait-only' attribute (clause spelled without
+    /// arguments)for each of the additional device types (or a none if it is
+    /// empty).
+    void addWaitOnly(MLIRContext *, llvm::ArrayRef<DeviceType>);
+    /// Add an arrya-like entry  to the 'wait' with the current list of device
+    /// types.
+    void addWaitOperands(MLIRContext *, bool hasDevnum, mlir::ValueRange,
+                         llvm::ArrayRef<DeviceType>);
   }];
 
   let assemblyFormat = [{
@@ -1535,6 +1560,21 @@ def OpenACC_SerialOp : OpenACC_Op<"serial",
     static mlir::acc::Construct getConstructId() {
       return mlir::acc::Construct::acc_construct_serial;
     }
+    /// Add an entry to the 'async-only' attribute (clause spelled without
+    /// arguments)for each of the additional device types (or a none if it is
+    /// empty).
+    void addAsyncOnly(MLIRContext *, llvm::ArrayRef<DeviceType>);
+    /// Add a value to the 'async' with the current list of device types.
+    void addAsyncOperand(MLIRContext *, mlir::Value,
+                         llvm::ArrayRef<DeviceType>);
+    /// Add an entry to the 'wait-only' attribute (clause spelled without
+    /// arguments)for each of the additional device types (or a none if it is
+    /// empty).
+    void addWaitOnly(MLIRContext *, llvm::ArrayRef<DeviceType>);
+    /// Add an arrya-like entry  to the 'wait' with the current list of device
+    /// types.
+    void addWaitOperands(MLIRContext *, bool hasDevnum, mlir::ValueRange,
+                         llvm::ArrayRef<DeviceType>);
   }];
 
   let assemblyFormat = [{
@@ -1679,6 +1719,31 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels",
     static mlir::acc::Construct getConstructId() {
       return mlir::acc::Construct::acc_construct_kernels;
     }
+    /// Add a value to 'num_workers' with the current list of device types.
+    void addNumWorkersOperand(MLIRContext *, mlir::Value,
+                              llvm::ArrayRef<DeviceType>);
+    /// Add a value to 'vector_length' with the current list of device types.
+    void addVectorLengthOperand(MLIRContext *, mlir::Value,
+                                llvm::ArrayRef<DeviceType>);
+    /// Add an entry to the 'async-only' attribute (clause spelled without
+    /// arguments)for each of the additional device types (or a none if it is
+    /// empty).
+    void addAsyncOnly(MLIRContext *, llvm::ArrayRef<DeviceType>);
+    /// Add a value to the 'async' with the current list of device types.
+    void addAsyncOperand(MLIRContext *, mlir::Value,
+                         llvm::ArrayRef<DeviceType>);
+    /// Add an array-like entry to the 'num_gangs' with the current list of
+    /// device types.
+    void addNumGangsOperands(MLIRContext *, mlir::ValueRange,
+                             llvm::ArrayRef<DeviceType>);
+    /// Add an entry to the 'wait-only' attribute (clause spelled without
+    /// arguments)for each of the additional device types (or a none if it is
+    /// empty).
+    void addWaitOnly(MLIRContext *, llvm::ArrayRef<DeviceType>);
+    /// Add an arrya-like entry  to the 'wait' with the current list of device
+    /// types.
+    void addWaitOperands(MLIRContext *, bool hasDevnum, mlir::ValueRange,
+                         llvm::ArrayRef<DeviceType>);
   }];
 
   let assemblyFormat = [{
@@ -1785,6 +1850,21 @@ def OpenACC_DataOp : OpenACC_Op<"data",
     /// Return the wait devnum value clause for the given device_type if
     /// present.
     mlir::Value getWaitDevnum(mlir::acc::DeviceType deviceType);
+    /// Add an entry to the 'async-only' attribute (clause spelled without
+    /// arguments)for each of the additional device types (or a none if it is
+    /// empty).
+    void addAsyncOnly(MLIRContext *, llvm::ArrayRef<DeviceType>);
+    /// Add a value to the 'async' with the current list of device types.
+    void addAsyncOperand(MLIRContext *, mlir::Value,
+                         llvm::ArrayRef<DeviceType>);
+    /// Add an entry to the 'wait-only' attribute (clause spelled without
+    /// arguments)for each of the additional device types (or a none if it is
+    /// empty).
+    void addWaitOnly(MLIRContext *, llvm::ArrayRef<DeviceType>);
+    /// Add an arrya-like entry  to the 'wait' with the current list of device
+    /// types.
+    void addWaitOperands(MLIRContext *, bool hasDevnum, mlir::ValueRange,
+                         llvm::ArrayRef<DeviceType>);
   }];
 
   let assemblyFormat = [{
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp 
b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 04cbe200eafe9..fa684942f60d0 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -76,6 +76,69 @@ struct LLVMPointerPointerLikeModel
                                             LLVM::LLVMPointerType> {
   Type getElementType(Type pointer) const { return Type(); }
 };
+
+/// Helper function for any of the times we need to modify an ArrayAttr based 
on
+/// a device type list.  Returns a new ArrayAttr with all of the
+/// existingDeviceTypes, plus the effective new ones(or an added none if hte 
new
+/// lis...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/137396
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to