https://github.com/luporl updated 
https://github.com/llvm/llvm-project/pull/183800

>From 878f7ca04216944e97b602a689af911a715b3285 Mon Sep 17 00:00:00 2001
From: Leandro Lupori <[email protected]>
Date: Fri, 27 Feb 2026 15:32:58 -0300
Subject: [PATCH] [mlir][OpenMP] Fix update of linear iteration variables

The final value of a linear iteration variable must be the loop
limit_value + step. Before this patch it was limit_value.

This fixes the second issue reported in #170784.
---
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      | 129 ++++++++++++------
 mlir/test/Target/LLVMIR/openmp-llvm.mlir      |  15 +-
 2 files changed, 103 insertions(+), 41 deletions(-)

diff --git 
a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp 
b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index fb7a5116bad74..6510e126c710d 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -150,6 +150,37 @@ class LinearClauseProcessor {
   llvm::BasicBlock *linearFinalizationBB;
   llvm::BasicBlock *linearExitBB;
   llvm::BasicBlock *linearLastIterExitBB;
+  Value linearLoopIV;
+  Value linearLoopIVStart;
+
+  void updateLinearVar(llvm::IRBuilderBase &builder, llvm::Type *varType,
+                       llvm::Value *var, llvm::Value *varStart,
+                       llvm::Value *step, llvm::Value *iv) {
+    if (!iv->getType()->isIntegerTy())
+      llvm_unreachable("OpenMP loop induction variable must be an integer "
+                       "type");
+
+    if (varType->isIntegerTy()) {
+      // Integer path: normalize all arithmetic to linearVarType
+      iv = builder.CreateSExtOrTrunc(iv, varType);
+      step = builder.CreateSExtOrTrunc(step, varType);
+
+      llvm::Value *mulInst = builder.CreateMul(iv, step);
+      llvm::Value *addInst = builder.CreateAdd(varStart, mulInst);
+      builder.CreateStore(addInst, var);
+    } else if (varType->isFloatingPointTy()) {
+      // Float path: perform multiply in integer, then convert to float
+      step = builder.CreateSExtOrTrunc(step, iv->getType());
+
+      llvm::Value *mulInst = builder.CreateMul(iv, step);
+      llvm::Value *mulFp = builder.CreateSIToFP(mulInst, varType);
+      llvm::Value *addInst = builder.CreateFAdd(varStart, mulFp);
+      builder.CreateStore(addInst, var);
+    } else {
+      llvm_unreachable(
+          "Linear variable must be of integer or floating-point type");
+    }
+  }
 
 public:
   // Register type for the linear variables
@@ -189,46 +220,63 @@ class LinearClauseProcessor {
     }
   }
 
+  // Find linear iteration variable and save it for later updates
+  void initLinearIV(omp::SimdOp simdOp) {
+    auto loopOp = cast<omp::LoopNestOp>(simdOp.getWrappedLoop());
+    // NOTE iteration variables can only be linear in non-nested loops.
+    if (loopOp.getIVs().size() != 1)
+      return;
+    // The linear IV is the loop IV's store address.
+    BlockArgument arg = loopOp.getIVs().front();
+    for (const Operation *user : arg.getUsers()) {
+      if (auto storeOp = dyn_cast<LLVM::StoreOp>(user)) {
+        for (Value linearVar : simdOp.getLinearVars()) {
+          if (linearVar == storeOp.getAddr()) {
+            linearLoopIV = linearVar;
+            linearLoopIVStart = loopOp.getLoopLowerBounds().front();
+            break;
+          }
+        }
+      }
+    }
+  }
+
   // Emit IR for updating Linear variables
-  void updateLinearVar(llvm::IRBuilderBase &builder, llvm::BasicBlock 
*loopBody,
-                       llvm::Value *loopInductionVar) {
+  void updateLinearVars(llvm::IRBuilderBase &builder,
+                        llvm::BasicBlock *loopBody,
+                        llvm::Value *loopInductionVar) {
     builder.SetInsertPoint(loopBody->getTerminator());
     for (size_t index = 0; index < linearPreconditionVars.size(); index++) {
-      llvm::Type *linearVarType = linearVarTypes[index];
-      llvm::Value *iv = loopInductionVar;
-      llvm::Value *step = linearSteps[index];
-
-      if (!iv->getType()->isIntegerTy())
-        llvm_unreachable("OpenMP loop induction variable must be an integer "
-                         "type");
-
-      if (linearVarType->isIntegerTy()) {
-        // Integer path: normalize all arithmetic to linearVarType
-        iv = builder.CreateSExtOrTrunc(iv, linearVarType);
-        step = builder.CreateSExtOrTrunc(step, linearVarType);
-
-        llvm::LoadInst *linearVarStart =
-            builder.CreateLoad(linearVarType, linearPreconditionVars[index]);
-        llvm::Value *mulInst = builder.CreateMul(iv, step);
-        llvm::Value *addInst = builder.CreateAdd(linearVarStart, mulInst);
-        builder.CreateStore(addInst, linearLoopBodyTemps[index]);
-      } else if (linearVarType->isFloatingPointTy()) {
-        // Float path: perform multiply in integer, then convert to float
-        step = builder.CreateSExtOrTrunc(step, iv->getType());
-        llvm::Value *mulInst = builder.CreateMul(iv, step);
-
-        llvm::LoadInst *linearVarStart =
-            builder.CreateLoad(linearVarType, linearPreconditionVars[index]);
-        llvm::Value *mulFp = builder.CreateSIToFP(mulInst, linearVarType);
-        llvm::Value *addInst = builder.CreateFAdd(linearVarStart, mulFp);
-        builder.CreateStore(addInst, linearLoopBodyTemps[index]);
-      } else {
-        llvm_unreachable(
-            "Linear variable must be of integer or floating-point type");
-      }
+      llvm::LoadInst *linearVarStart = builder.CreateLoad(
+          linearVarTypes[index], linearPreconditionVars[index]);
+      updateLinearVar(builder, linearVarTypes[index],
+                      linearLoopBodyTemps[index], linearVarStart,
+                      linearSteps[index], loopInductionVar);
     }
   }
 
+  // Emit IR for updating linear iteration variables on loop exit
+  void updateLinearIV(llvm::IRBuilderBase &builder,
+                      LLVM::ModuleTranslation &moduleTranslation,
+                      llvm::Value *loopIV) {
+    if (!linearLoopIV)
+      return;
+    llvm::Value *linearIV = moduleTranslation.lookupValue(linearLoopIV);
+    llvm::Value *linearIVStart =
+        moduleTranslation.lookupValue(linearLoopIVStart);
+
+    // Find linearIV's index
+    size_t index;
+    for (index = 0; index < linearOrigVal.size(); index++)
+      if (linearIV == linearOrigVal[index])
+        break;
+    if (index == linearOrigVal.size())
+      return;
+
+    updateLinearVar(builder, linearVarTypes[index], linearLoopBodyTemps[index],
+                    linearIVStart, linearSteps[index], loopIV);
+  }
+
   // Linear variable finalization is conditional on the last logical iteration.
   // Create BB splits to manage the same.
   void splitLinearFiniBB(llvm::IRBuilderBase &builder,
@@ -3838,8 +3886,8 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase 
&builder,
     if (failed(handleError(afterBarrierIP, *loopOp)))
       return failure();
     builder.restoreIP(*afterBarrierIP);
-    linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
-                                          loopInfo->getIndVar());
+    linearClauseProcessor.updateLinearVars(builder, loopInfo->getBody(),
+                                           loopInfo->getIndVar());
     linearClauseProcessor.splitLinearFiniBB(builder, loopInfo->getExit());
   }
 
@@ -4139,6 +4187,8 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase 
&builder,
   // Initialize linear variables and linear step
   LinearClauseProcessor linearClauseProcessor;
 
+  linearClauseProcessor.initLinearIV(simdOp);
+
   if (!simdOp.getLinearVars().empty()) {
     auto linearVarTypes = simdOp.getLinearVarTypes().value();
     for (mlir::Attribute linearVarType : linearVarTypes)
@@ -4236,8 +4286,8 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase 
&builder,
     linearClauseProcessor.initLinearVar(builder, moduleTranslation,
                                         loopInfo->getPreheader());
 
-    linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
-                                          loopInfo->getIndVar());
+    linearClauseProcessor.updateLinearVars(builder, loopInfo->getBody(),
+                                           loopInfo->getIndVar());
   }
   builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
 
@@ -4247,6 +4297,9 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase 
&builder,
                             : nullptr,
                         order, simdlen, safelen);
 
+  linearClauseProcessor.updateLinearIV(builder, moduleTranslation,
+                                       loopInfo->getIndVar());
+
   linearClauseProcessor.emitStoresForLinearVar(builder);
 
   // Check if this SIMD loop contains ordered regions
diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir 
b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
index f002a64a593a0..5d1cb9407d5d2 100644
--- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
@@ -743,6 +743,7 @@ llvm.func @simd_simple(%lb : i64, %ub : i64, %step : i64, 
%arg0: !llvm.ptr) {
 llvm.func @simd_linear(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) {
 
 // CHECK-LABEL: @simd_linear
+// CHECK-SAME: (i32 %[[LB:.*]], i32 %{{.*}}, i32 %[[STEP:.*]], ptr %[[X:.*]])
 
 // CHECK: %[[LINEAR_VAR:.*]] = alloca i32, align 4
 // CHECK: %[[LINEAR_RESULT:.*]] = alloca i32, align 4
@@ -758,8 +759,16 @@ llvm.func @simd_linear(%lb : i32, %ub : i32, %step : i32, 
%x : !llvm.ptr) {
 // CHECK: %[[MUL:.*]] = mul i32 %omp_loop.iv, {{.*}}
 // CHECK: %[[ADD:.*]] = add i32 %[[LOAD]], %[[MUL]]
 // CHECK: store i32 %[[ADD]], ptr %[[LINEAR_RESULT]], align 4, 
!llvm.access.group !1
+
+// CHECK: omp.region.cont:
+// CHECK: %[[MUL:.*]] = mul i32 %omp_loop.iv, %[[STEP]]
+// CHECK: %[[ADD:.*]] = add i32 %[[LB]], %[[MUL]]
+// CHECK: store i32 %[[ADD]], ptr %[[LINEAR_RESULT]], align 4
+// CHECK: %[[LOAD:.*]] = load i32, ptr %[[LINEAR_RESULT]], align 4
+// CHECK: store i32 %[[LOAD]], ptr %[[X]], align 4
   omp.simd linear(%x : !llvm.ptr = %step : i32) {
     omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
+      llvm.store %iv, %x : i32, !llvm.ptr
       omp.yield
     }
   } {linear_var_types = [i32]}
@@ -784,8 +793,8 @@ llvm.func @simd_linear_i64_var_i32_step(%lb : i32, %ub : 
i32, %x : !llvm.ptr) {
 
 // CHECK: omp_loop.body:
 // Verify type conversions: iv (i32) is extended to i64 before multiplication
-// CHECK: %[[IV_I64:.*]] = sext i32 %omp_loop.iv to i64
 // CHECK: %[[LOAD:.*]] = load i64, ptr %[[LINEAR_VAR]], 
{{.*}}!llvm.access.group
+// CHECK: %[[IV_I64:.*]] = sext i32 %omp_loop.iv to i64
 // Verify multiplication and addition use consistent i64 types
 // CHECK: %[[MUL:.*]] = mul i64 %[[IV_I64]], {{.*}}
 // CHECK: %[[ADD:.*]] = add i64 %[[LOAD]], %[[MUL]]
@@ -817,8 +826,8 @@ llvm.func @simd_linear_f64_var_i32_step(%lb : i32, %ub : 
i32, %x : !llvm.ptr) {
 // CHECK: omp_loop.body:
 // Verify integer multiplication, load, and conversion to float
 // CHECK: mul i32 %omp_loop.iv
-// CHECK: %[[MUL_INT:.*]] = mul i32 %omp_loop.iv, {{.*}}
-// CHECK-NEXT: %[[LOAD:.*]] = load double, ptr %[[LINEAR_VAR]], 
{{.*}}!llvm.access.group
+// CHECK: %[[LOAD:.*]] = load double, ptr %[[LINEAR_VAR]], 
{{.*}}!llvm.access.group
+// CHECK-NEXT: %[[MUL_INT:.*]] = mul i32 %omp_loop.iv, {{.*}}
 // CHECK-NEXT: %[[MUL_FP:.*]] = sitofp i32 %[[MUL_INT]] to double
 // CHECK-NEXT: %[[ADD:.*]] = fadd double %[[LOAD]], %[[MUL_FP]]
 // CHECK-NEXT: store double %[[ADD]], ptr %[[LINEAR_RESULT]], 
{{.*}}!llvm.access.group

_______________________________________________
llvm-branch-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to