Author: Aart Bik Date: 2021-01-11T13:32:39-08:00 New Revision: 046612d29d7894783e8fcecbc62ebd6b4a78499f
URL: https://github.com/llvm/llvm-project/commit/046612d29d7894783e8fcecbc62ebd6b4a78499f DIFF: https://github.com/llvm/llvm-project/commit/046612d29d7894783e8fcecbc62ebd6b4a78499f.diff LOG: [mlir][vector] verify memref of vector memory ops This ensures the memref base + indices expression is well-formed Reviewed By: ThomasRaoux, ftynse Differential Revision: https://reviews.llvm.org/D94441 Added: Modified: mlir/lib/Dialect/Vector/VectorOps.cpp mlir/test/Dialect/Vector/invalid.mlir Removed: ################################################################################ diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp index 731ddae85ead..54e5e008e56f 100644 --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -2365,10 +2365,12 @@ static LogicalResult verify(MaskedLoadOp op) { VectorType maskVType = op.getMaskVectorType(); VectorType passVType = op.getPassThruVectorType(); VectorType resVType = op.getResultVectorType(); + MemRefType memType = op.getMemRefType(); - if (resVType.getElementType() != op.getMemRefType().getElementType()) + if (resVType.getElementType() != memType.getElementType()) return op.emitOpError("base and result element type should match"); - + if (llvm::size(op.indices()) != memType.getRank()) + return op.emitOpError("requires ") << memType.getRank() << " indices"; if (resVType.getDimSize(0) != maskVType.getDimSize(0)) return op.emitOpError("expected result dim to match mask dim"); if (resVType != passVType) @@ -2410,10 +2412,12 @@ void MaskedLoadOp::getCanonicalizationPatterns( static LogicalResult verify(MaskedStoreOp op) { VectorType maskVType = op.getMaskVectorType(); VectorType valueVType = op.getValueVectorType(); + MemRefType memType = op.getMemRefType(); - if (valueVType.getElementType() != op.getMemRefType().getElementType()) + if (valueVType.getElementType() != memType.getElementType()) return op.emitOpError("base and value element type should match"); - + if (llvm::size(op.indices()) != memType.getRank()) + return op.emitOpError("requires ") << memType.getRank() << " indices"; if (valueVType.getDimSize(0) != maskVType.getDimSize(0)) return op.emitOpError("expected value dim to match mask dim"); return success(); @@ -2454,10 +2458,10 @@ static LogicalResult verify(GatherOp op) { VectorType indicesVType = op.getIndicesVectorType(); VectorType maskVType = op.getMaskVectorType(); VectorType resVType = op.getResultVectorType(); + MemRefType memType = op.getMemRefType(); - if (resVType.getElementType() != op.getMemRefType().getElementType()) + if (resVType.getElementType() != memType.getElementType()) return op.emitOpError("base and result element type should match"); - if (resVType.getDimSize(0) != indicesVType.getDimSize(0)) return op.emitOpError("expected result dim to match indices dim"); if (resVType.getDimSize(0) != maskVType.getDimSize(0)) @@ -2500,10 +2504,10 @@ static LogicalResult verify(ScatterOp op) { VectorType indicesVType = op.getIndicesVectorType(); VectorType maskVType = op.getMaskVectorType(); VectorType valueVType = op.getValueVectorType(); + MemRefType memType = op.getMemRefType(); - if (valueVType.getElementType() != op.getMemRefType().getElementType()) + if (valueVType.getElementType() != memType.getElementType()) return op.emitOpError("base and value element type should match"); - if (valueVType.getDimSize(0) != indicesVType.getDimSize(0)) return op.emitOpError("expected value dim to match indices dim"); if (valueVType.getDimSize(0) != maskVType.getDimSize(0)) @@ -2544,10 +2548,12 @@ static LogicalResult verify(ExpandLoadOp op) { VectorType maskVType = op.getMaskVectorType(); VectorType passVType = op.getPassThruVectorType(); VectorType resVType = op.getResultVectorType(); + MemRefType memType = op.getMemRefType(); - if (resVType.getElementType() != op.getMemRefType().getElementType()) + if (resVType.getElementType() != memType.getElementType()) return op.emitOpError("base and result element type should match"); - + if (llvm::size(op.indices()) != memType.getRank()) + return op.emitOpError("requires ") << memType.getRank() << " indices"; if (resVType.getDimSize(0) != maskVType.getDimSize(0)) return op.emitOpError("expected result dim to match mask dim"); if (resVType != passVType) @@ -2589,10 +2595,12 @@ void ExpandLoadOp::getCanonicalizationPatterns( static LogicalResult verify(CompressStoreOp op) { VectorType maskVType = op.getMaskVectorType(); VectorType valueVType = op.getValueVectorType(); + MemRefType memType = op.getMemRefType(); - if (valueVType.getElementType() != op.getMemRefType().getElementType()) + if (valueVType.getElementType() != memType.getElementType()) return op.emitOpError("base and value element type should match"); - + if (llvm::size(op.indices()) != memType.getRank()) + return op.emitOpError("requires ") << memType.getRank() << " indices"; if (valueVType.getDimSize(0) != maskVType.getDimSize(0)) return op.emitOpError("expected value dim to match mask dim"); return success(); diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index 11100c4e615e..099dad7eada4 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1222,6 +1222,13 @@ func @maskedload_pass_thru_type_mask_mismatch(%base: memref<?xf32>, %mask: vecto // ----- +func @maskedload_memref_mismatch(%base: memref<?xf32>, %mask: vector<16xi1>, %pass: vector<16xf32>) { + // expected-error@+1 {{'vector.maskedload' op requires 1 indices}} + %0 = vector.maskedload %base[], %mask, %pass : memref<?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32> +} + +// ----- + func @maskedstore_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %value: vector<16xf32>) { %c0 = constant 0 : index // expected-error@+1 {{'vector.maskedstore' op base and value element type should match}} @@ -1238,6 +1245,14 @@ func @maskedstore_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<15xi1>, // ----- +func @maskedstore_memref_mismatch(%base: memref<?xf32>, %mask: vector<16xi1>, %value: vector<16xf32>) { + %c0 = constant 0 : index + // expected-error@+1 {{'vector.maskedstore' op requires 1 indices}} + vector.maskedstore %base[%c0, %c0], %mask, %value : memref<?xf32>, vector<16xi1>, vector<16xf32> +} + +// ----- + func @gather_base_type_mismatch(%base: memref<?xf64>, %indices: vector<16xi32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) { // expected-error@+1 {{'vector.gather' op base and result element type should match}} @@ -1343,6 +1358,14 @@ func @expand_pass_thru_mismatch(%base: memref<?xf32>, %mask: vector<16xi1>, %pas // ----- +func @expand_memref_mismatch(%base: memref<?x?xf32>, %mask: vector<16xi1>, %pass_thru: vector<16xf32>) { + %c0 = constant 0 : index + // expected-error@+1 {{'vector.expandload' op requires 2 indices}} + %0 = vector.expandload %base[%c0], %mask, %pass_thru : memref<?x?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32> +} + +// ----- + func @compress_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %value: vector<16xf32>) { %c0 = constant 0 : index // expected-error@+1 {{'vector.compressstore' op base and value element type should match}} @@ -1359,6 +1382,14 @@ func @compress_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<17xi1>, %va // ----- +func @compress_memref_mismatch(%base: memref<?x?xf32>, %mask: vector<16xi1>, %value: vector<16xf32>) { + %c0 = constant 0 : index + // expected-error@+1 {{'vector.compressstore' op requires 2 indices}} + vector.compressstore %base[%c0, %c0, %c0], %mask, %value : memref<?x?xf32>, vector<16xi1>, vector<16xf32> +} + +// ----- + func @extract_map_rank(%v: vector<32xf32>, %id : index) { // expected-error@+1 {{'vector.extract_map' op expected source and destination vectors of same rank}} %0 = vector.extract_map %v[%id] : vector<32xf32> to vector<2x1xf32> _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits