Author: Jeff Niu Date: 2025-09-12T13:53:32-07:00 New Revision: 86bcd1c2b256cd6aa5e65e1a54b63f929d616464
URL: https://github.com/llvm/llvm-project/commit/86bcd1c2b256cd6aa5e65e1a54b63f929d616464 DIFF: https://github.com/llvm/llvm-project/commit/86bcd1c2b256cd6aa5e65e1a54b63f929d616464.diff LOG: [mlir][Intrange] Fix materializing ShapedType constant values (#158359) When materializing integer ranges of splat tensors or vector as constants, they should use DenseElementsAttr of the shaped type, not IntegerAttrs of the element types, since this can violate the invariants of tensor/vector ops. Co-authored-by: Jeff Niu <jeff...@openai.com> Added: Modified: mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp mlir/test/Dialect/Arith/int-range-opts.mlir Removed: ################################################################################ diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp index e79f6a8aec1cf..70b56ca77b2da 100644 --- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp @@ -26,6 +26,7 @@ #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/InferIntRangeInterface.h" #include "mlir/Interfaces/LoopLikeInterface.h" +#include "mlir/Support/DebugStringHelper.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Casting.h" @@ -76,9 +77,17 @@ void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const { else dialect = value.getParentBlock()->getParentOp()->getDialect(); - Type type = getElementTypeOrSelf(value); - solver->propagateIfChanged( - cv, cv->join(ConstantValue(IntegerAttr::get(type, *constant), dialect))); + Attribute cstAttr; + if (isa<IntegerType, IndexType>(value.getType())) { + cstAttr = IntegerAttr::get(value.getType(), *constant); + } else if (auto shapedTy = dyn_cast<ShapedType>(value.getType())) { + cstAttr = SplatElementsAttr::get(shapedTy, *constant); + } else { + llvm::report_fatal_error( + Twine("FIXME: Don't know how to create a constant for this type: ") + + mlir::debugString(value.getType())); + } + solver->propagateIfChanged(cv, cv->join(ConstantValue(cstAttr, dialect))); } LogicalResult IntegerRangeAnalysis::visitOperation( diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp index 777ff0ecaa314..2017905587b26 100644 --- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp @@ -8,6 +8,7 @@ #include <utility> +#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" #include "mlir/Analysis/DataFlowFramework.h" #include "mlir/Dialect/Arith/Transforms/Passes.h" @@ -485,6 +486,7 @@ struct IntRangeOptimizationsPass final MLIRContext *ctx = op->getContext(); DataFlowSolver solver; solver.load<DeadCodeAnalysis>(); + solver.load<SparseConstantPropagation>(); solver.load<IntegerRangeAnalysis>(); if (failed(solver.initializeAndRun(op))) return signalPassFailure(); diff --git a/mlir/test/Dialect/Arith/int-range-opts.mlir b/mlir/test/Dialect/Arith/int-range-opts.mlir index ea5969a100258..e6e48d30cece5 100644 --- a/mlir/test/Dialect/Arith/int-range-opts.mlir +++ b/mlir/test/Dialect/Arith/int-range-opts.mlir @@ -132,3 +132,19 @@ func.func @wraps() -> i8 { %mod = arith.remsi %val, %c64 : i8 return %mod : i8 } + +// ----- + +// CHECK-LABEL: @analysis_crash +func.func @analysis_crash(%arg0: i32, %arg1: tensor<128xi1>) -> tensor<128xi64> { + %c0_i32 = arith.constant 0 : i32 + %cst = arith.constant dense<-1> : tensor<128xi32> + %splat = tensor.splat %arg0 : tensor<128xi32> + %0 = scf.for %arg2 = %c0_i32 to %arg0 step %arg0 iter_args(%arg3 = %splat) -> (tensor<128xi32>) : i32 { + scf.yield %arg3 : tensor<128xi32> + } + %1 = arith.select %arg1, %0#0, %cst : tensor<128xi1>, tensor<128xi32> + // Make sure the analysis doesn't crash when materializing the range as a tensor constant. + %2 = arith.extsi %1 : tensor<128xi32> to tensor<128xi64> + return %2 : tensor<128xi64> +} _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits