Author: Max191 Date: 2023-11-27T16:28:22-08:00 New Revision: b823f8469b5364411cde31a215c9bcbe0d3c08f7
URL: https://github.com/llvm/llvm-project/commit/b823f8469b5364411cde31a215c9bcbe0d3c08f7 DIFF: https://github.com/llvm/llvm-project/commit/b823f8469b5364411cde31a215c9bcbe0d3c08f7.diff LOG: [mlir] Add support for `memref.alloca` sub-byte emulation (#73138) Adds a similar case to `memref.alloc` for `memref.alloca` in EmulateNarrowTypes. Fixes https://github.com/openxla/iree/issues/15515 Added: Modified: mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp mlir/test/Dialect/MemRef/emulate-narrow-type.mlir Removed: ################################################################################ diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp index dec5936fa7e83ce..e5801c3733ed5a8 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp @@ -112,18 +112,22 @@ static Value getOffsetForBitwidth(Location loc, OpFoldResult srcIdx, namespace { //===----------------------------------------------------------------------===// -// ConvertMemRefAlloc +// ConvertMemRefAllocation //===----------------------------------------------------------------------===// -struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> { - using OpConversionPattern::OpConversionPattern; +template <typename OpTy> +struct ConvertMemRefAllocation final : OpConversionPattern<OpTy> { + using OpConversionPattern<OpTy>::OpConversionPattern; LogicalResult - matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor, + matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto currentType = op.getMemref().getType().cast<MemRefType>(); - auto newResultType = - getTypeConverter()->convertType(op.getType()).dyn_cast<MemRefType>(); + static_assert(std::is_same<OpTy, memref::AllocOp>() || + std::is_same<OpTy, memref::AllocaOp>(), + "expected only memref::AllocOp or memref::AllocaOp"); + auto currentType = cast<MemRefType>(op.getMemref().getType()); + auto newResultType = dyn_cast<MemRefType>( + this->getTypeConverter()->convertType(op.getType())); if (!newResultType) { return rewriter.notifyMatchFailure( op->getLoc(), @@ -132,9 +136,9 @@ struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> { // Special case zero-rank memrefs. if (currentType.getRank() == 0) { - rewriter.replaceOpWithNewOp<memref::AllocOp>( - op, newResultType, ValueRange{}, adaptor.getSymbolOperands(), - adaptor.getAlignmentAttr()); + rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, ValueRange{}, + adaptor.getSymbolOperands(), + adaptor.getAlignmentAttr()); return success(); } @@ -156,9 +160,9 @@ struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> { rewriter, loc, linearizedMemRefInfo.linearizedSize)); } - rewriter.replaceOpWithNewOp<memref::AllocOp>( - op, newResultType, dynamicLinearizedSize, adaptor.getSymbolOperands(), - adaptor.getAlignmentAttr()); + rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, dynamicLinearizedSize, + adaptor.getSymbolOperands(), + adaptor.getAlignmentAttr()); return success(); } }; @@ -344,10 +348,11 @@ void memref::populateMemRefNarrowTypeEmulationPatterns( RewritePatternSet &patterns) { // Populate `memref.*` conversion patterns. - patterns - .add<ConvertMemRefAlloc, ConvertMemRefLoad, ConvertMemRefAssumeAlignment, - ConvertMemRefSubview, ConvertMemRefReinterpretCast>( - typeConverter, patterns.getContext()); + patterns.add<ConvertMemRefAllocation<memref::AllocOp>, + ConvertMemRefAllocation<memref::AllocaOp>, ConvertMemRefLoad, + ConvertMemRefAssumeAlignment, ConvertMemRefSubview, + ConvertMemRefReinterpretCast>(typeConverter, + patterns.getContext()); memref::populateResolveExtractStridedMetadataPatterns(patterns); } diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir index 2c411defb47e3ba..dc32a59a1a14931 100644 --- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir +++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir @@ -232,3 +232,36 @@ func.func @reinterpret_cast_memref_load_1D(%arg0: index) -> i4 { // CHECK32: %[[SHR:.+]] = arith.shrsi %[[LOAD]], %[[CAST]] : i32 // CHECK32: %[[TRUNC:.+]] = arith.trunci %[[SHR]] : i32 to i4 // CHECK32: return %[[TRUNC]] + +// ----- + +func.func @memref_alloca_load_i4(%arg0: index) -> i4 { + %0 = memref.alloca() : memref<5xi4> + %1 = memref.load %0[%arg0] : memref<5xi4> + return %1 : i4 +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 2) * 8) +// CHECK: func @memref_alloca_load_i4( +// CHECK-SAME: %[[ARG0:.+]]: index +// CHECK: %[[ALLOCA:.+]] = memref.alloca() : memref<3xi8> +// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]] +// CHECK: %[[LOADVAL:.+]] = memref.load %[[ALLOCA]][%[[INDEX]]] +// CHECK: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]] +// CHECK: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8 +// CHECK: %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]] +// CHECK: %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i8 to i4 +// CHECK: return %[[TRUNC]] + +// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 8)> +// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 8) * 32) +// CHECK32: func @memref_alloca_load_i4( +// CHECK32-SAME: %[[ARG0:.+]]: index +// CHECK32: %[[ALLOCA:.+]] = memref.alloca() : memref<1xi32> +// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]] +// CHECK32: %[[LOADVAL:.+]] = memref.load %[[ALLOCA]][%[[INDEX]]] +// CHECK32: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]] +// CHECK32: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32 +// CHECK32: %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]] +// CHECK32: %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i32 to i4 +// CHECK32: return %[[TRUNC]] _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits