Author: MaheshRavishankar Date: 2021-01-22T14:39:27-08:00 New Revision: 6e8ef3b76ab65960edd6ee99f387e75564d8d9db
URL: https://github.com/llvm/llvm-project/commit/6e8ef3b76ab65960edd6ee99f387e75564d8d9db DIFF: https://github.com/llvm/llvm-project/commit/6e8ef3b76ab65960edd6ee99f387e75564d8d9db.diff LOG: [mlir][Linalg] Make Fill operation work on tensors. Depends on D95109 Added: Modified: mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp mlir/test/Dialect/Linalg/invalid.mlir mlir/test/Dialect/Linalg/roundtrip.mlir mlir/test/Dialect/Linalg/tile-tensors.mlir Removed: ################################################################################ diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index 26db4c2f6735..436dab1ade2b 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -148,8 +148,9 @@ def CopyOp : LinalgStructured_Op<"copy", [CopyOpInterface]> { } def FillOp : LinalgStructured_Op<"fill", []> { - let arguments = (ins AnyStridedMemRef:$output, + let arguments = (ins AnyShaped:$output, AnyTypeOf<[AnyFloat, AnySignlessInteger, AnyVector]>:$value); + let results = (outs Optional<AnyRankedTensor>:$result); let extraClassDeclaration = libraryCallName # [{ ValueRange inputs() { return {}; } ValueRange outputs() { return getOperands().take_front(); } @@ -174,6 +175,14 @@ def FillOp : LinalgStructured_Op<"fill", []> { } }]; + let assemblyFormat = [{ + `(` operands `)` attr-dict `:` type(operands) (`->` type($result)^)? + }]; + + let builders = [ + OpBuilderDAG<(ins "Value":$output, "Value":$value)> + ]; + let verifier = [{ return ::verify(*this); }]; let hasFolder = 1; diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index b500eefa9d0c..a6f3576c4240 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -220,6 +220,16 @@ static LogicalResult foldMemRefCast(Operation *op) { // LinalgOps.td), we define an overloaded `print` function and a // parse`className` function. +//===----------------------------------------------------------------------===// +// FillOp +//===----------------------------------------------------------------------===// + +void FillOp::build(OpBuilder &builder, OperationState &result, Value output, + Value value) { + build(builder, result, output.getType().dyn_cast<RankedTensorType>(), output, + value); +} + //===----------------------------------------------------------------------===// // GenericOps //===----------------------------------------------------------------------===// @@ -1726,6 +1736,10 @@ static LogicalResult verify(FillOp op) { auto fillType = op.value().getType(); if (viewType.getElementType() != fillType) return op.emitOpError("expects fill type to match view elemental type"); + if (!op.getNumResults() && !viewType.isa<MemRefType>()) { + return op.emitOpError( + "expected fill op with no result value to use memref type"); + } return success(); } diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir index a3ef242c29f9..6579add14c50 100644 --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -659,3 +659,41 @@ func @pad_block_args(%arg0: tensor<?x4xi32>, %arg1: i32) -> tensor<?x9xi32> { } : tensor<?x4xi32> to tensor<?x9xi32> return %0 : tensor<?x9xi32> } + +// ----- + +func @illegal_fill_tensor_no_return(%arg0 : index, %arg1 : index, %arg2 : f32) +{ + %0 = linalg.init_tensor [%arg0, %arg1] : tensor<?x?xf32> + // expected-error @+1 {{expected fill op with no result value to use memref type}} + linalg.fill(%0, %arg2) : tensor<?x?xf32>, f32 +} + +// ----- + +func @illegal_fill_memref_with_return(%arg0 : memref<?x?xf32>, %arg1 : f32) -> memref<?x?xf32> +{ + // expected-error @+1 {{unexpected #results > #outputs}} + %0 = linalg.fill(%arg0, %arg1) : memref<?x?xf32>, f32 -> memref<?x?xf32> + return %0 : memref<?x?xf32> +} + +// ----- + +func @illegal_fill_memref_with_tensor_return + (%arg0 : memref<?x?xf32>, %arg1 : f32) -> tensor<?x?xf32> +{ + // expected-error @+1 {{unexpected #results > #outputs}} + %0 = linalg.fill(%arg0, %arg1) : memref<?x?xf32>, f32 -> tensor<?x?xf32> + return %0 : tensor<?x?xf32> +} + +// ----- + +func @illegal_fill_tensor_with_memref_return + (%arg0 : tensor<?x?xf32>, %arg1 : f32) -> memref<?x?xf32> +{ + // expected-error @+1 {{expected type of operand #0 ('tensor<?x?xf32>') to match type of corresponding result ('memref<?x?xf32>')}} + %0 = linalg.fill(%arg0, %arg1) : tensor<?x?xf32>, f32 -> memref<?x?xf32> + return %0 : memref<?x?xf32> +} diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir index c4a3247fdc88..44743eaedc8c 100644 --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -805,3 +805,12 @@ func @legal_collapsing_reshape_dynamic_memref // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d4)> // CHECK: func @legal_collapsing_reshape_dynamic_memref // CHECK: linalg.reshape %{{.+}} [#[[MAP0]], #[[MAP1]], #[[MAP2]]] + +// ----- + +func @fill_tensor(%arg0 : index, %arg1 : index, %arg2 : f32) -> tensor<?x?xf32> { + %0 = linalg.init_tensor [%arg0, %arg1] : tensor<?x?xf32> + %1 = linalg.fill(%0, %arg2) : tensor<?x?xf32>, f32 -> tensor<?x?xf32> + return %1 : tensor<?x?xf32> +} +// CHECK: %{{.+}} = linalg.fill(%{{.+}}, %{{.+}}) : tensor<?x?xf32>, f32 -> tensor<?x?xf32> diff --git a/mlir/test/Dialect/Linalg/tile-tensors.mlir b/mlir/test/Dialect/Linalg/tile-tensors.mlir index f52d7fefa8be..f8b996e1ae05 100644 --- a/mlir/test/Dialect/Linalg/tile-tensors.mlir +++ b/mlir/test/Dialect/Linalg/tile-tensors.mlir @@ -41,7 +41,7 @@ func @generic_op_tensors( %4 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d2, d1)>, - affine_map<(d0, d1, d2) -> (d2, d1, d0)>], + affine_map<(d0, d1, d2) -> (d2, d1, d0)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%3 : tensor<?x?x?xf32>) { @@ -88,7 +88,7 @@ func @indexed_generic_op_tensors( %4 = linalg.indexed_generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d2, d1)>, - affine_map<(d0, d1, d2) -> (d2, d1, d0)>], + affine_map<(d0, d1, d2) -> (d2, d1, d0)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x?xf32>) outs(%3 : tensor<?x?x?xf32>) { @@ -120,3 +120,26 @@ func @indexed_generic_op_tensors( // CHECK: scf.yield %[[TD1]] // CHECK: } // CHECK: return %[[TD0]] + +// ----- + +func @fill_tensors(%arg0 : index, %arg1 : index, %arg2 : f32) -> tensor<?x?xf32> { + %0 = linalg.init_tensor [%arg0, %arg1] : tensor<?x?xf32> + %1 = linalg.fill(%0, %arg2) : tensor<?x?xf32>, f32 -> tensor<?x?xf32> + return %1 : tensor<?x?xf32> +} +// CHECK: func @fill_tensors +// CHECK: %[[INIT:.+]] = linalg.init_tensor +// CHECK: %[[RESULT:.+]] = scf.for %[[IV0:[a-zA-z0-9_]+]] +// CHECK-SAME: iter_args(%[[ARG4:.+]] = %[[INIT]]) -> (tensor<?x?xf32>) { +// CHECK: %[[YIELD_1:.+]] = scf.for %[[IV1:[a-zA-Z0-9_]+]] +// CHECK-SAME: iter_args(%[[ARG6:.+]] = %[[ARG4]]) -> (tensor<?x?xf32>) { +// CHECK: %[[FILL_TILE:.+]] = subtensor %[[ARG6]][%[[IV0]], %[[IV1]]] +// CHECK: %[[RESULT_TILE:.+]] = linalg.fill(%[[FILL_TILE]], %{{.+}}) +// CHECK: %[[YIELD_2:.+]] = subtensor_insert %[[RESULT_TILE]] +// CHECK-SAME: into %[[ARG6]][%[[IV0]], %[[IV1]]] +// CHECK: scf.yield %[[YIELD_2]] +// CHECK: } +// CHECK: scf.yield %[[YIELD_1]] +// CHECK: } +// CHECK: return %[[RESULT]] _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits