Author: Aniket Singh Date: 2026-01-15T09:49:53Z New Revision: fe68b17f46d470c2aa5223bb3cc4fec0d14801f9
URL: https://github.com/llvm/llvm-project/commit/fe68b17f46d470c2aa5223bb3cc4fec0d14801f9 DIFF: https://github.com/llvm/llvm-project/commit/fe68b17f46d470c2aa5223bb3cc4fec0d14801f9.diff LOG: [MLIR][SCFToOpenMP] Fix crash when lowering vector reductions (#173978) This patch fixes a crash in the SCF to OpenMP conversion pass when encountering scf.parallel with vector reductions. - Extracts scalar element types for bitwidth calculations. - Uses DenseElementsAttr for vector splat initializers. - Bypasses llvm.atomicrmw for vector types (not supported in LLVM IR). Fixes #173860 --------- Co-authored-by: Aniket Singh <[email protected]> Added: mlir/test/Conversion/SCFToOpenMP/vector-reduction.mlir Modified: mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp Removed: ################################################################################ diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp index 6423d49859c97..5fcaea7f39c3c 100644 --- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp +++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp @@ -150,32 +150,48 @@ static const llvm::fltSemantics &fltSemanticsForType(FloatType type) { llvm_unreachable("unknown float type"); } +/// Helper to create a splat attribute for vector types, or return the scalar +/// attribute for scalar types. +static Attribute getSplatOrScalarAttr(Type type, Attribute val) { + if (auto vecType = dyn_cast<VectorType>(type)) + return DenseElementsAttr::get(vecType, val); + return val; +} + /// Returns an attribute with the minimum (if `min` is set) or the maximum value /// (otherwise) for the given float type. static Attribute minMaxValueForFloat(Type type, bool min) { - auto fltType = cast<FloatType>(type); - return FloatAttr::get( - type, llvm::APFloat::getLargest(fltSemanticsForType(fltType), min)); + Type elType = getElementTypeOrSelf(type); + auto fltType = cast<FloatType>(elType); + auto val = llvm::APFloat::getLargest(fltSemanticsForType(fltType), min); + + return getSplatOrScalarAttr(type, FloatAttr::get(elType, val)); } /// Returns an attribute with the signed integer minimum (if `min` is set) or /// the maximum value (otherwise) for the given integer type, regardless of its /// signedness semantics (only the width is considered). static Attribute minMaxValueForSignedInt(Type type, bool min) { - auto intType = cast<IntegerType>(type); + Type elType = getElementTypeOrSelf(type); + auto intType = cast<IntegerType>(elType); unsigned bitwidth = intType.getWidth(); - return IntegerAttr::get(type, min ? llvm::APInt::getSignedMinValue(bitwidth) - : llvm::APInt::getSignedMaxValue(bitwidth)); + auto val = min ? llvm::APInt::getSignedMinValue(bitwidth) + : llvm::APInt::getSignedMaxValue(bitwidth); + + return getSplatOrScalarAttr(type, IntegerAttr::get(elType, val)); } /// Returns an attribute with the unsigned integer minimum (if `min` is set) or /// the maximum value (otherwise) for the given integer type, regardless of its /// signedness semantics (only the width is considered). static Attribute minMaxValueForUnsignedInt(Type type, bool min) { - auto intType = cast<IntegerType>(type); + Type elType = getElementTypeOrSelf(type); + auto intType = cast<IntegerType>(elType); unsigned bitwidth = intType.getWidth(); - return IntegerAttr::get(type, min ? llvm::APInt::getZero(bitwidth) - : llvm::APInt::getAllOnes(bitwidth)); + auto val = + min ? llvm::APInt::getZero(bitwidth) : llvm::APInt::getAllOnes(bitwidth); + + return getSplatOrScalarAttr(type, IntegerAttr::get(elType, val)); } /// Creates an OpenMP reduction declaration and inserts it into the provided @@ -203,7 +219,7 @@ createDecl(PatternRewriter &builder, SymbolTable &symbolTable, Operation *terminator = &reduce.getReductions()[reductionIndex].front().back(); assert(isa<scf::ReduceReturnOp>(terminator) && - "expected reduce op to be terminated by redure return"); + "expected reduce op to be terminated by reduce return"); builder.setInsertionPoint(terminator); builder.replaceOpWithNewOp<omp::YieldOp>(terminator, terminator->getOperands()); @@ -237,6 +253,11 @@ static omp::DeclareReductionOp addAtomicRMW(OpBuilder &builder, return decl; } +/// Returns true if the type is supported by llvm.atomicrmw. +/// LLVM IR currently does not support atomic operations on vector types. +/// See LLVM Language Reference Manual on 'atomicrmw'. +static bool supportsAtomic(Type type) { return !isa<VectorType>(type); } + /// Creates an OpenMP reduction declaration that corresponds to the given SCF /// reduction and returns it. Recognizes common reductions in order to identify /// the neutral value, necessary for the OpenMP declaration. If the reduction @@ -261,91 +282,119 @@ static omp::DeclareReductionOp declareReduction(PatternRewriter &builder, // Match simple binary reductions that can be expressed with atomicrmw. Type type = reduce.getOperands()[reductionIndex].getType(); Block &reduction = reduce.getReductions()[reductionIndex].front(); + + // Handle scalar element type extraction for vector bitwidth safety. + Type elType = getElementTypeOrSelf(type); + + // Arithmetic Reductions if (matchSimpleReduction<arith::AddFOp, LLVM::FAddOp>(reduction)) { - omp::DeclareReductionOp decl = - createDecl(builder, symbolTable, reduce, reductionIndex, - builder.getFloatAttr(type, 0.0)); - return addAtomicRMW(builder, LLVM::AtomicBinOp::fadd, decl, reduce, - reductionIndex); + omp::DeclareReductionOp decl = createDecl( + builder, symbolTable, reduce, reductionIndex, + getSplatOrScalarAttr(type, builder.getFloatAttr(elType, 0.0))); + return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::fadd, + decl, reduce, reductionIndex) + : decl; } if (matchSimpleReduction<arith::AddIOp, LLVM::AddOp>(reduction)) { - omp::DeclareReductionOp decl = - createDecl(builder, symbolTable, reduce, reductionIndex, - builder.getIntegerAttr(type, 0)); - return addAtomicRMW(builder, LLVM::AtomicBinOp::add, decl, reduce, - reductionIndex); + omp::DeclareReductionOp decl = createDecl( + builder, symbolTable, reduce, reductionIndex, + getSplatOrScalarAttr(type, builder.getIntegerAttr(elType, 0))); + return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::add, + decl, reduce, reductionIndex) + : decl; } if (matchSimpleReduction<arith::OrIOp, LLVM::OrOp>(reduction)) { - omp::DeclareReductionOp decl = - createDecl(builder, symbolTable, reduce, reductionIndex, - builder.getIntegerAttr(type, 0)); - return addAtomicRMW(builder, LLVM::AtomicBinOp::_or, decl, reduce, - reductionIndex); + omp::DeclareReductionOp decl = createDecl( + builder, symbolTable, reduce, reductionIndex, + getSplatOrScalarAttr(type, builder.getIntegerAttr(elType, 0))); + return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::_or, + decl, reduce, reductionIndex) + : decl; } if (matchSimpleReduction<arith::XOrIOp, LLVM::XOrOp>(reduction)) { - omp::DeclareReductionOp decl = - createDecl(builder, symbolTable, reduce, reductionIndex, - builder.getIntegerAttr(type, 0)); - return addAtomicRMW(builder, LLVM::AtomicBinOp::_xor, decl, reduce, - reductionIndex); + omp::DeclareReductionOp decl = createDecl( + builder, symbolTable, reduce, reductionIndex, + getSplatOrScalarAttr(type, builder.getIntegerAttr(elType, 0))); + return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::_xor, + decl, reduce, reductionIndex) + : decl; } if (matchSimpleReduction<arith::AndIOp, LLVM::AndOp>(reduction)) { + APInt allOnes = llvm::APInt::getAllOnes(elType.getIntOrFloatBitWidth()); omp::DeclareReductionOp decl = createDecl( builder, symbolTable, reduce, reductionIndex, - builder.getIntegerAttr( - type, llvm::APInt::getAllOnes(type.getIntOrFloatBitWidth()))); - return addAtomicRMW(builder, LLVM::AtomicBinOp::_and, decl, reduce, - reductionIndex); + getSplatOrScalarAttr(type, builder.getIntegerAttr(elType, allOnes))); + return supportsAtomic(type) ? addAtomicRMW(builder, LLVM::AtomicBinOp::_and, + decl, reduce, reductionIndex) + : decl; } // Match simple binary reductions that cannot be expressed with atomicrmw. // TODO: add atomic region using cmpxchg (which needs atomic load to be // available as an op). if (matchSimpleReduction<arith::MulFOp, LLVM::FMulOp>(reduction)) { - return createDecl(builder, symbolTable, reduce, reductionIndex, - builder.getFloatAttr(type, 1.0)); + return createDecl( + builder, symbolTable, reduce, reductionIndex, + getSplatOrScalarAttr(type, builder.getFloatAttr(elType, 1.0))); } + if (matchSimpleReduction<arith::MulIOp, LLVM::MulOp>(reduction)) { - return createDecl(builder, symbolTable, reduce, reductionIndex, - builder.getIntegerAttr(type, 1)); + return createDecl( + builder, symbolTable, reduce, reductionIndex, + getSplatOrScalarAttr(type, builder.getIntegerAttr(elType, 1))); } // Match select-based min/max reductions. bool isMin; - if (matchSelectReduction<arith::CmpFOp, arith::SelectOp>( + // Floating Point Min/Max + if (matchSelectReduction<arith::CmpFOp, arith::SelectOp, + arith::CmpFPredicate>( reduction, {arith::CmpFPredicate::OLT, arith::CmpFPredicate::OLE}, {arith::CmpFPredicate::OGT, arith::CmpFPredicate::OGE}, isMin) || - matchSelectReduction<LLVM::FCmpOp, LLVM::SelectOp>( - reduction, {LLVM::FCmpPredicate::olt, LLVM::FCmpPredicate::ole}, - {LLVM::FCmpPredicate::ogt, LLVM::FCmpPredicate::oge}, isMin)) { + matchSelectReduction<arith::CmpFOp, arith::SelectOp, + arith::CmpFPredicate>( + reduction, {arith::CmpFPredicate::OGT, arith::CmpFPredicate::OGE}, + {arith::CmpFPredicate::OLT, arith::CmpFPredicate::OLE}, isMin)) { return createDecl(builder, symbolTable, reduce, reductionIndex, minMaxValueForFloat(type, !isMin)); } - if (matchSelectReduction<arith::CmpIOp, arith::SelectOp>( + + // Integer Min/Max + if (matchSelectReduction<arith::CmpIOp, arith::SelectOp, + arith::CmpIPredicate>( reduction, {arith::CmpIPredicate::slt, arith::CmpIPredicate::sle}, {arith::CmpIPredicate::sgt, arith::CmpIPredicate::sge}, isMin) || - matchSelectReduction<LLVM::ICmpOp, LLVM::SelectOp>( - reduction, {LLVM::ICmpPredicate::slt, LLVM::ICmpPredicate::sle}, - {LLVM::ICmpPredicate::sgt, LLVM::ICmpPredicate::sge}, isMin)) { + matchSelectReduction<arith::CmpIOp, arith::SelectOp, + arith::CmpIPredicate>( + reduction, {arith::CmpIPredicate::sgt, arith::CmpIPredicate::sge}, + {arith::CmpIPredicate::slt, arith::CmpIPredicate::sle}, isMin)) { omp::DeclareReductionOp decl = createDecl(builder, symbolTable, reduce, reductionIndex, minMaxValueForSignedInt(type, !isMin)); - return addAtomicRMW(builder, - isMin ? LLVM::AtomicBinOp::min : LLVM::AtomicBinOp::max, - decl, reduce, reductionIndex); + return supportsAtomic(type) ? addAtomicRMW(builder, + isMin ? LLVM::AtomicBinOp::min + : LLVM::AtomicBinOp::max, + decl, reduce, reductionIndex) + : decl; } - if (matchSelectReduction<arith::CmpIOp, arith::SelectOp>( + + // Unsigned Integer Min/Max + if (matchSelectReduction<arith::CmpIOp, arith::SelectOp, + arith::CmpIPredicate>( reduction, {arith::CmpIPredicate::ult, arith::CmpIPredicate::ule}, {arith::CmpIPredicate::ugt, arith::CmpIPredicate::uge}, isMin) || - matchSelectReduction<LLVM::ICmpOp, LLVM::SelectOp>( - reduction, {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::ule}, - {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::uge}, isMin)) { + matchSelectReduction<arith::CmpIOp, arith::SelectOp, + arith::CmpIPredicate>( + reduction, {arith::CmpIPredicate::ugt, arith::CmpIPredicate::uge}, + {arith::CmpIPredicate::ult, arith::CmpIPredicate::ule}, isMin)) { omp::DeclareReductionOp decl = createDecl(builder, symbolTable, reduce, reductionIndex, minMaxValueForUnsignedInt(type, !isMin)); - return addAtomicRMW( - builder, isMin ? LLVM::AtomicBinOp::umin : LLVM::AtomicBinOp::umax, - decl, reduce, reductionIndex); + return supportsAtomic(type) ? addAtomicRMW(builder, + isMin ? LLVM::AtomicBinOp::umin + : LLVM::AtomicBinOp::umax, + decl, reduce, reductionIndex) + : decl; } return nullptr; diff --git a/mlir/test/Conversion/SCFToOpenMP/vector-reduction.mlir b/mlir/test/Conversion/SCFToOpenMP/vector-reduction.mlir new file mode 100644 index 0000000000000..018f8a03c8e34 --- /dev/null +++ b/mlir/test/Conversion/SCFToOpenMP/vector-reduction.mlir @@ -0,0 +1,29 @@ +// RUN: mlir-opt %s --convert-scf-to-openmp | FileCheck %s + +// CHECK-LABEL: omp.declare_reduction @__scf_reduction : vector<2xi1> +// CHECK: init { +// CHECK: %[[INIT:.*]] = llvm.mlir.constant(dense<true> : vector<2xi1>) : vector<2xi1> +// CHECK: omp.yield(%[[INIT]] : vector<2xi1>) +// CHECK: } +// CHECK: combiner { +// CHECK: ^bb0(%[[ARG0:.*]]: vector<2xi1>, %[[ARG1:.*]]: vector<2xi1>): +// CHECK: %[[RES:.*]] = arith.andi %[[ARG0]], %[[ARG1]] : vector<2xi1> +// CHECK: omp.yield(%[[RES]] : vector<2xi1>) +// CHECK: } +// CHECK-NOT: atomic + +func.func @vector_and_reduction() { + %v_mask = vector.constant_mask [1] : vector<2xi1> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %result = scf.parallel (%i) = (%c0) to (%c2) step (%c1) init(%v_mask) -> vector<2xi1> { + %val = vector.constant_mask [1] : vector<2xi1> + scf.reduce (%val : vector<2xi1>) { + ^bb0(%lhs: vector<2xi1>, %rhs: vector<2xi1>): + %0 = arith.andi %lhs, %rhs : vector<2xi1> + scf.reduce.return %0 : vector<2xi1> + } + } + return +} \ No newline at end of file _______________________________________________ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
