Author: MaheshRavishankar Date: 2020-12-17T14:45:51-08:00 New Revision: de031216bf1755e61418a1515f2b0db0a9cfeddc
URL: https://github.com/llvm/llvm-project/commit/de031216bf1755e61418a1515f2b0db0a9cfeddc DIFF: https://github.com/llvm/llvm-project/commit/de031216bf1755e61418a1515f2b0db0a9cfeddc.diff LOG: [mlir] Add canonicalization from `tensor_cast` to `dim` op. Fold a `tensor_cast` -> `dim` to take the `dim` of the original tensor. Differential Revision: https://reviews.llvm.org/D93492 Added: Modified: mlir/lib/Dialect/StandardOps/IR/Ops.cpp mlir/test/Dialect/Standard/canonicalize.mlir Removed: ################################################################################ diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp index b19264ec4972..7ed9cffa8806 100644 --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1472,11 +1472,29 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> { return success(); } }; + +/// Fold dim of a dim of a cast into the the dim of the source of the tensor +/// cast. +template <typename CastOpTy> +struct DimOfCastOp : public OpRewritePattern<DimOp> { + using OpRewritePattern<DimOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(DimOp dimOp, + PatternRewriter &rewriter) const override { + auto castOp = dimOp.memrefOrTensor().getDefiningOp<CastOpTy>(); + if (!castOp) + return failure(); + Value newSource = castOp.getOperand(); + rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.index()); + return success(); + } +}; + } // end anonymous namespace. void DimOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.insert<DimOfMemRefReshape>(context); + results.insert<DimOfMemRefReshape, DimOfCastOp<TensorCastOp>>(context); } // --------------------------------------------------------------------------- diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir index d3c781d1290f..af67453a1f3c 100644 --- a/mlir/test/Dialect/Standard/canonicalize.mlir +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -115,3 +115,19 @@ func @dim_of_memref_reshape(%arg0: memref<*xf32>, %arg1: memref<?xindex>) %1 = dim %0, %c3 : memref<*xf32> return %1 : index } + +// Test case: Folding dim(tensor_cast %0, %idx) -> dim %0, %idx +// CHECK-LABEL: func @fold_dim_of_tensor_cast +// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x?xf32> +// CHECK-DAG: %[[C1:.+]] = constant 1 : index +// CHECK-DAG: %[[C4:.+]] = constant 4 : index +// CHECK: %[[T0:.+]] = dim %[[ARG0]], %[[C1]] +// CHECK-NEXT: return %[[C4]], %[[T0]] +func @fold_dim_of_tensor_cast(%arg0 : tensor<4x?xf32>) -> (index, index) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %0 = tensor_cast %arg0 : tensor<4x?xf32> to tensor<?x?xf32> + %1 = dim %0, %c0 : tensor<?x?xf32> + %2 = dim %0, %c1 : tensor<?x?xf32> + return %1, %2: index, index +} _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits