llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) <details> <summary>Changes</summary> Implement runtime verification for `memref.atomic_rmw` and `memref.generic_atomic_rmw`. --- Full diff: https://github.com/llvm/llvm-project/pull/130414.diff 3 Files Affected: - (modified) mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp (+26-19) - (added) mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir (+45) - (added) mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir (+45) ``````````diff diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp index c8e7325d7ac89..ceea27a35a225 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp @@ -35,6 +35,26 @@ Value generateInBoundsCheck(OpBuilder &builder, Location loc, Value value, return inBounds; } +/// Generate a runtime check to see if the given indices are in-bounds with +/// respect to the given ranked memref. +Value generateIndicesInBoundsCheck(OpBuilder &builder, Location loc, + Value memref, ValueRange indices) { + auto memrefType = cast<MemRefType>(memref.getType()); + assert(memrefType.getRank() == static_cast<int64_t>(indices.size()) && + "rank mismatch"); + Value cond = builder.create<arith::ConstantOp>( + loc, builder.getIntegerAttr(builder.getI1Type(), 1)); + + auto zero = builder.create<arith::ConstantIndexOp>(loc, 0); + for (auto [dim, idx] : llvm::enumerate(indices)) { + Value dimOp = builder.createOrFold<memref::DimOp>(loc, memref, dim); + Value inBounds = generateInBoundsCheck(builder, loc, idx, zero, dimOp); + cond = builder.createOrFold<arith::AndIOp>(loc, cond, inBounds); + } + + return cond; +} + struct AssumeAlignmentOpInterface : public RuntimeVerifiableOpInterface::ExternalModel< AssumeAlignmentOpInterface, AssumeAlignmentOp> { @@ -186,26 +206,10 @@ struct LoadStoreOpInterface void generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc) const { auto loadStoreOp = cast<LoadStoreOp>(op); - - auto memref = loadStoreOp.getMemref(); - auto rank = memref.getType().getRank(); - if (rank == 0) { - return; - } - auto indices = loadStoreOp.getIndices(); - - auto zero = builder.create<arith::ConstantIndexOp>(loc, 0); - Value assertCond; - for (auto i : llvm::seq<int64_t>(0, rank)) { - Value dimOp = builder.createOrFold<memref::DimOp>(loc, memref, i); - Value inBounds = - generateInBoundsCheck(builder, loc, indices[i], zero, dimOp); - assertCond = - i > 0 ? builder.createOrFold<arith::AndIOp>(loc, assertCond, inBounds) - : inBounds; - } + Value cond = generateIndicesInBoundsCheck( + builder, loc, loadStoreOp.getMemref(), loadStoreOp.getIndices()); builder.create<cf::AssertOp>( - loc, assertCond, + loc, cond, RuntimeVerifiableOpInterface::generateErrorMessage( op, "out-of-bounds access")); } @@ -377,9 +381,12 @@ void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels( DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) { AssumeAlignmentOp::attachInterface<AssumeAlignmentOpInterface>(*ctx); + AtomicRMWOp::attachInterface<LoadStoreOpInterface<AtomicRMWOp>>(*ctx); CastOp::attachInterface<CastOpInterface>(*ctx); DimOp::attachInterface<DimOpInterface>(*ctx); ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx); + GenericAtomicRMWOp::attachInterface< + LoadStoreOpInterface<GenericAtomicRMWOp>>(*ctx); LoadOp::attachInterface<LoadStoreOpInterface<LoadOp>>(*ctx); ReinterpretCastOp::attachInterface<ReinterpretCastOpInterface>(*ctx); StoreOp::attachInterface<LoadStoreOpInterface<StoreOp>>(*ctx); diff --git a/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir new file mode 100644 index 0000000000000..9f70c5ca66f65 --- /dev/null +++ b/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir @@ -0,0 +1,45 @@ +// RUN: mlir-opt %s -generate-runtime-verification \ +// RUN: -test-cf-assert \ +// RUN: -expand-strided-metadata \ +// RUN: -lower-affine \ +// RUN: -convert-to-llvm | \ +// RUN: mlir-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ +// RUN: FileCheck %s + +func.func @store_dynamic(%memref: memref<?xf32>, %index: index) { + %cst = arith.constant 1.0 : f32 + memref.atomic_rmw addf %cst, %memref[%index] : (f32, memref<?xf32>) -> f32 + return +} + +func.func @main() { + // Allocate a memref<10xf32>, but disguise it as a memref<5xf32>. This is + // necessary because "-test-cf-assert" does not abort the program and we do + // not want to segfault when running the test case. + %alloc = memref.alloca() : memref<10xf32> + %ptr = memref.extract_aligned_pointer_as_index %alloc : memref<10xf32> -> index + %ptr_i64 = arith.index_cast %ptr : index to i64 + %ptr_llvm = llvm.inttoptr %ptr_i64 : i64 to !llvm.ptr + %c0 = llvm.mlir.constant(0 : index) : i64 + %c1 = llvm.mlir.constant(1 : index) : i64 + %c5 = llvm.mlir.constant(5 : index) : i64 + %4 = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %5 = llvm.insertvalue %ptr_llvm, %4[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %6 = llvm.insertvalue %ptr_llvm, %5[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %8 = llvm.insertvalue %c0, %6[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %9 = llvm.insertvalue %c5, %8[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %10 = llvm.insertvalue %c1, %9[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %buffer = builtin.unrealized_conversion_cast %10 : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> to memref<5xf32> + %cast = memref.cast %buffer : memref<5xf32> to memref<?xf32> + + // CHECK: ERROR: Runtime op verification failed + // CHECK-NEXT: "memref.atomic_rmw"(%{{.*}}, %{{.*}}, %{{.*}}) <{kind = 0 : i64}> : (f32, memref<?xf32>, index) -> f32 + // CHECK-NEXT: ^ out-of-bounds access + // CHECK-NEXT: Location: loc({{.*}}) + %c9 = arith.constant 9 : index + func.call @store_dynamic(%cast, %c9) : (memref<?xf32>, index) -> () + + return +} + diff --git a/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir new file mode 100644 index 0000000000000..58961ba31d93a --- /dev/null +++ b/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir @@ -0,0 +1,45 @@ +// RUN: mlir-opt %s -generate-runtime-verification \ +// RUN: -test-cf-assert \ +// RUN: -expand-strided-metadata \ +// RUN: -lower-affine \ +// RUN: -convert-to-llvm | \ +// RUN: mlir-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ +// RUN: FileCheck %s + +func.func @store_dynamic(%memref: memref<?xf32>, %index: index) { + %cst = arith.constant 1.0 : f32 + memref.store %cst, %memref[%index] : memref<?xf32> + return +} + +func.func @main() { + // Allocate a memref<10xf32>, but disguise it as a memref<5xf32>. This is + // necessary because "-test-cf-assert" does not abort the program and we do + // not want to segfault when running the test case. + %alloc = memref.alloca() : memref<10xf32> + %ptr = memref.extract_aligned_pointer_as_index %alloc : memref<10xf32> -> index + %ptr_i64 = arith.index_cast %ptr : index to i64 + %ptr_llvm = llvm.inttoptr %ptr_i64 : i64 to !llvm.ptr + %c0 = llvm.mlir.constant(0 : index) : i64 + %c1 = llvm.mlir.constant(1 : index) : i64 + %c5 = llvm.mlir.constant(5 : index) : i64 + %4 = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %5 = llvm.insertvalue %ptr_llvm, %4[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %6 = llvm.insertvalue %ptr_llvm, %5[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %8 = llvm.insertvalue %c0, %6[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %9 = llvm.insertvalue %c5, %8[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %10 = llvm.insertvalue %c1, %9[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + %buffer = builtin.unrealized_conversion_cast %10 : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> to memref<5xf32> + %cast = memref.cast %buffer : memref<5xf32> to memref<?xf32> + + // CHECK: ERROR: Runtime op verification failed + // CHECK-NEXT: "memref.store"(%{{.*}}, %{{.*}}, %{{.*}}) : (f32, memref<?xf32>, index) -> () + // CHECK-NEXT: ^ out-of-bounds access + // CHECK-NEXT: Location: loc({{.*}}) + %c9 = arith.constant 9 : index + func.call @store_dynamic(%cast, %c9) : (memref<?xf32>, index) -> () + + return +} + `````````` </details> https://github.com/llvm/llvm-project/pull/130414 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits