================ @@ -48,6 +287,261 @@ Value collapse2DData(RewriterBase &rewriter, Location loc, Value data) { reassociation); } +// This function transforms the filter. The data layout of the filter is FHWC. +// The transformation matrix is 2-dimension. We need to extract H x W from +// FHWC first. We need to generate 2 levels of loops to iterate on F and C. +// After the transformation, we get +// +// scf.for %f = lo_f to hi_f step 1 +// scf.for %c = lo_c to hi_c step 1 +// %extracted = extract filter<h x w> from filter<f x h x w x c> +// %ret = linalg.matmul G, %extracted +// %ret = linalg.matmul %ret, GT +// %inserted = insert %ret into filter<tile_h x tile_w x h x w x c x f> +// +Value filterTransform(RewriterBase &rewriter, Location loc, Value filter, + Value retValue, int64_t m, int64_t r, + bool leftTransform = true, bool rightTransform = true) { + // Map from (m, r) to G transform matrix. + static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix> + GMatrices = { + {F_2_3, TransformMatrix(G_2x2_3x3, 4, 3)}, + {F_4_3, TransformMatrix(G_4x4_3x3, 6, 3)}, + {F_2_5, TransformMatrix(G_2x2_5x5, 6, 5)}, + }; + + // Map from (m, r) to GT transform matrix. + static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix> + GTMatrices = { + {F_2_3, TransformMatrix(GT_2x2_3x3, 3, 4)}, + {F_4_3, TransformMatrix(GT_4x4_3x3, 3, 6)}, + {F_2_5, TransformMatrix(GT_2x2_5x5, 5, 6)}, + }; + + auto filterType = cast<ShapedType>(filter.getType()); + Type elementType = filterType.getElementType(); + auto filterShape = filterType.getShape(); // F, H, W, C + int64_t filterF = filterShape[0]; + int64_t filterH = filterShape[1]; + int64_t filterW = filterShape[2]; + int64_t filterC = filterShape[3]; + + if (filterH != r && filterH != 1) + return Value(); + if (filterW != r && filterW != 1) + return Value(); + + // Return shape is <H x W x C x F> + auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0); + auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterF); + auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterC); + auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1); + auto outerForOp = + rewriter.create<scf::ForOp>(loc, zeroIdx, fUpperBound, oneStep, retValue); + Block *outerForBody = outerForOp.getBody(); + rewriter.setInsertionPointToStart(outerForBody); + Value FIter = outerForBody->getArgument(0); + + auto innerForOp = rewriter.create<scf::ForOp>( + loc, zeroIdx, cUpperBound, oneStep, outerForOp.getRegionIterArgs()[0]); + Block *innerForBody = innerForOp.getBody(); + rewriter.setInsertionPointToStart(innerForBody); + Value CIter = innerForBody->getArgument(0); + + // Extract (H, W) from (F, H, W, C) + auto extractFilter = extract2DData( + rewriter, loc, filter, FIter, CIter, /*outLoopIdx=*/0, + /*inLoopIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2, /*srcSize=*/4); + + TransformMapKeyTy key = {m, r}; + int64_t retRows = 1; + Value matmulRetValue = extractFilter; + if (leftTransform) { + // Get constant transform matrix G + auto it = GMatrices.find(key); + if (it == GMatrices.end()) + return Value(); + const TransformMatrix &GMatrix = it->second; + + retRows = GMatrix.rows; + auto matmulType = RankedTensorType::get({retRows, filterW}, elementType); + auto init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(), + elementType); + + Value G = create2DTransformMatrix(rewriter, loc, GMatrix, elementType); ---------------- Hsiangkai wrote:
There is a `ConstantOpInterface` that can convert `arith.constant` to `memref.get_global` after bufferization. https://github.com/llvm/llvm-project/pull/96183 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits