llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-mlir-memref Author: Matthias Springer (matthias-springer) <details> <summary>Changes</summary> Implement runtime op verification for `memref.copy`. Only ranked memrefs are verified at the moment. --- Full diff: https://github.com/llvm/llvm-project/pull/130437.diff 2 Files Affected: - (modified) mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp (+48) - (added) mlir/test/Integration/Dialect/MemRef/copy-runtime-verification.mlir (+28) ``````````diff diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp index ceea27a35a225..c604af249ba2e 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp @@ -182,6 +182,53 @@ struct CastOpInterface } }; +struct CopyOpInterface + : public RuntimeVerifiableOpInterface::ExternalModel<CopyOpInterface, + CopyOp> { + void generateRuntimeVerification(Operation *op, OpBuilder &builder, + Location loc) const { + auto copyOp = cast<CopyOp>(op); + BaseMemRefType sourceType = copyOp.getSource().getType(); + BaseMemRefType targetType = copyOp.getTarget().getType(); + auto rankedSourceType = dyn_cast<MemRefType>(sourceType); + auto rankedTargetType = dyn_cast<MemRefType>(targetType); + + // TODO: Verification for unranked memrefs is not supported yet. + if (!rankedSourceType || !rankedTargetType) + return; + + assert(sourceType.getRank() == targetType.getRank() && "rank mismatch"); + for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) { + // Fully static dimensions in both source and target operand are already + // verified by the op verifier. + if (!rankedSourceType.isDynamicDim(i) && + !rankedTargetType.isDynamicDim(i)) + continue; + Value sourceDim; + if (rankedSourceType.isDynamicDim(i)) { + sourceDim = builder.create<DimOp>(loc, copyOp.getSource(), i); + } else { + sourceDim = builder.create<arith::ConstantIndexOp>( + loc, rankedSourceType.getDimSize(i)); + } + Value targetDim; + if (rankedTargetType.isDynamicDim(i)) { + targetDim = builder.create<DimOp>(loc, copyOp.getTarget(), i); + } else { + targetDim = builder.create<arith::ConstantIndexOp>( + loc, rankedTargetType.getDimSize(i)); + } + Value sameDimSize = builder.create<arith::CmpIOp>( + loc, arith::CmpIPredicate::eq, sourceDim, targetDim); + builder.create<cf::AssertOp>( + loc, sameDimSize, + RuntimeVerifiableOpInterface::generateErrorMessage( + op, "size of " + std::to_string(i) + + "-th source/target dim does not match")); + } + } +}; + struct DimOpInterface : public RuntimeVerifiableOpInterface::ExternalModel<DimOpInterface, DimOp> { @@ -383,6 +430,7 @@ void mlir::memref::registerRuntimeVerifiableOpInterfaceExternalModels( AssumeAlignmentOp::attachInterface<AssumeAlignmentOpInterface>(*ctx); AtomicRMWOp::attachInterface<LoadStoreOpInterface<AtomicRMWOp>>(*ctx); CastOp::attachInterface<CastOpInterface>(*ctx); + CopyOp::attachInterface<CopyOpInterface>(*ctx); DimOp::attachInterface<DimOpInterface>(*ctx); ExpandShapeOp::attachInterface<ExpandShapeOpInterface>(*ctx); GenericAtomicRMWOp::attachInterface< diff --git a/mlir/test/Integration/Dialect/MemRef/copy-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/copy-runtime-verification.mlir new file mode 100644 index 0000000000000..95b9db2832cee --- /dev/null +++ b/mlir/test/Integration/Dialect/MemRef/copy-runtime-verification.mlir @@ -0,0 +1,28 @@ +// RUN: mlir-opt %s -generate-runtime-verification \ +// RUN: -expand-strided-metadata \ +// RUN: -test-cf-assert \ +// RUN: -convert-to-llvm | \ +// RUN: mlir-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ +// RUN: FileCheck %s + +// Put memref.copy in a function, otherwise the memref.cast may fold. +func.func @memcpy_helper(%src: memref<?xf32>, %dest: memref<?xf32>) { + memref.copy %src, %dest : memref<?xf32> to memref<?xf32> + return +} + +func.func @main() { + %alloca1 = memref.alloca() : memref<4xf32> + %alloca2 = memref.alloca() : memref<5xf32> + %cast1 = memref.cast %alloca1 : memref<4xf32> to memref<?xf32> + %cast2 = memref.cast %alloca2 : memref<5xf32> to memref<?xf32> + + // CHECK: ERROR: Runtime op verification failed + // CHECK-NEXT: "memref.copy"(%{{.*}}, %{{.*}}) : (memref<?xf32>, memref<?xf32>) -> () + // CHECK-NEXT: ^ size of 0-th source/target dim does not match + // CHECK-NEXT: Location: loc({{.*}}) + call @memcpy_helper(%cast1, %cast2) : (memref<?xf32>, memref<?xf32>) -> () + + return +} `````````` </details> https://github.com/llvm/llvm-project/pull/130437 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits