================ @@ -289,6 +938,123 @@ FailureOr<Operation *> winogradConv2DHelper(RewriterBase &rewriter, return transformedOutput.getDefiningOp(); } +FailureOr<Operation *> +decomposeWinogradFilterTransformHelper(RewriterBase &rewriter, + linalg::WinogradFilterTransformOp op) { + Location loc = op.getLoc(); + Value filter = op.getFilter(); + auto filterType = cast<ShapedType>(filter.getType()); + auto filterShape = filterType.getShape(); + int64_t filterH = filterShape[1]; + int64_t filterW = filterShape[2]; + + // For F(m x 1, r x 1), we only need to do left side transform. + bool leftTransform = filterH != 1; + // For F(1 x m, 1 x r), we only need to do right side transform. + bool rightTransform = filterW != 1; + Value transformedFilter = + filterTransform(rewriter, loc, filter, op.getOutput(), op.getM(), + op.getR(), leftTransform, rightTransform); + if (!transformedFilter) + return failure(); + + rewriter.replaceOp(op, transformedFilter); + + return transformedFilter.getDefiningOp(); +} + +FailureOr<Operation *> +decomposeWinogradInputTransformHelper(RewriterBase &rewriter, + linalg::WinogradInputTransformOp op) { + Location loc = op.getLoc(); + Value input = op.getInput(); + auto inputType = cast<ShapedType>(input.getType()); + auto inputShape = inputType.getShape(); + int64_t inputH = inputShape[1]; + int64_t inputW = inputShape[2]; + + // For F(m x 1, r x 1), we only need to do left side transform. + bool leftTransform = inputH != 1; + // For F(1 x m, 1 x r), we only need to do right side transform. + bool rightTransform = inputW != 1; + Value transformedInput = + inputTransform(rewriter, loc, op.getInput(), op.getOutput(), op.getM(), + op.getR(), leftTransform, rightTransform); + if (!transformedInput) + return failure(); + + rewriter.replaceOp(op, transformedInput); + + return transformedInput.getDefiningOp(); +} + +FailureOr<Operation *> +decomposeWinogradOutputTransformHelper(RewriterBase &rewriter, + linalg::WinogradOutputTransformOp op) { + Location loc = op.getLoc(); + Value value = op.getValue(); + auto valueType = cast<ShapedType>(value.getType()); + auto valueShape = valueType.getShape(); + int64_t valueH = valueShape[2]; + int64_t valueW = valueShape[3]; + + // For F(m x 1, r x 1), we only need to do left side transform. + bool leftTransform = valueH != 1; + // For F(1 x m, 1 x r), we only need to do right side transform. + bool rightTransform = valueW != 1; + Value transformedOutput = + outputTransform(rewriter, loc, value, op.getOutput(), op.getM(), + op.getR(), leftTransform, rightTransform); + if (!transformedOutput) + return failure(); + + rewriter.replaceOp(op, transformedOutput); + + return transformedOutput.getDefiningOp(); +} + +class DecomposeWinogradFilterTransform final + : public OpRewritePattern<linalg::WinogradFilterTransformOp> { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::WinogradFilterTransformOp op, + PatternRewriter &rewriter) const override { + if (failed(decomposeWinogradFilterTransformHelper(rewriter, op))) + return failure(); + + return success(); + } +}; + +class DecomposeWinogradInputTransform final + : public OpRewritePattern<linalg::WinogradInputTransformOp> { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::WinogradInputTransformOp op, + PatternRewriter &rewriter) const override { + if (failed(decomposeWinogradInputTransformHelper(rewriter, op))) + return failure(); + + return success(); ---------------- Hsiangkai wrote:
Done. https://github.com/llvm/llvm-project/pull/96183 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits