Author: Uday Bondhugula Date: 2021-09-23T05:56:13+05:30 New Revision: eff9b542c7e3cbc44faac86046461a982ac9dcdc
URL: https://github.com/llvm/llvm-project/commit/eff9b542c7e3cbc44faac86046461a982ac9dcdc DIFF: https://github.com/llvm/llvm-project/commit/eff9b542c7e3cbc44faac86046461a982ac9dcdc.diff LOG: [MLIR] Introduce IfOp in the xla_lhlo dialect Introduce LHLO IfOp to model conditionals on the memref form. This is lowered form form of HLO IfOp (the latter operates on tensors). Its design is similar to that of the LHLO WhileOp, taking in tuples of memrefs or elemental types. The true and the false bodies return a tuple of the same type. Added: Modified: mlir/include/mlir/Dialect/LHLO/IR/LHLOOps.td mlir/test/Dialect/LHLO/lhlo_ops.mlir Removed: ################################################################################ diff --git a/mlir/include/mlir/Dialect/LHLO/IR/LHLOOps.td b/mlir/include/mlir/Dialect/LHLO/IR/LHLOOps.td index 752992761485..711549940570 100644 --- a/mlir/include/mlir/Dialect/LHLO/IR/LHLOOps.td +++ b/mlir/include/mlir/Dialect/LHLO/IR/LHLOOps.td @@ -467,6 +467,53 @@ def LHLO_GatherOp: LHLO_Op<"gather", []>, BASE_HLO_GatherOp { ); } +def LHLO_IfOp : LHLO_Op<"if", [AffineScope, RecursiveSideEffects]> { + string summary = "If operator"; + + string description = [{ + Returns the result of executing either a true or false function depending on + the result of a condition function. In contrast to the HLO version, the + tuple operands for the true or false branch are a tuple of memrefs or int/fp + types. Both the true and false branches also return such a tuple type: they + both return the same type and this match the result type of the op. + + Example: + + ```mlir + func @lhlo_if(%arg0: memref<1x1x10xf32>, %arg1: memref<1x1x10xf32>, %arg2: memref<i1>) { + %0 = "xla_lhlo.tuple"(%arg0, %arg1) : (memref<1x1x10xf32>, memref<1x1x10xf32>) -> tuple<memref<1x1x10xf32>, memref<1x1x10xf32>> + %1 = "xla_lhlo.if"(%arg2, %0, %0) ( { + ^bb0(%arg3: tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>): + %2 = "xla_lhlo.get_tuple_element"(%arg3) {index = 0 : i32} : (tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>) -> memref<1x1x10xf32> + %3 = "xla_lhlo.tuple"(%2) : (memref<1x1x10xf32>) -> tuple<memref<1x1x10xf32>> + "xla_lhlo.yield"(%3) : (tuple<memref<1x1x10xf32>>) -> () + }, { + ^bb0(%arg3: tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>): // no predecessors + %2 = "xla_lhlo.get_tuple_element"(%arg3) {index = 1 : i32} : (tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>) -> memref<1x1x10xf32> + %3 = "xla_lhlo.tuple"(%2) : (memref<1x1x10xf32>) -> tuple<memref<1x1x10xf32>> + "xla_lhlo.yield"(%3) : (tuple<memref<1x1x10xf32>>) -> () + }) : (memref<i1>, tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>, tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>) -> tuple<memref<1x1x10xf32>> + "xla_lhlo.terminator"() : () -> () + } + ``` + + See https://www.tensorflow.org/xla/operation_semantics#conditional. + }]; + + let arguments = (ins Arg<LHLO_TupleOfBufferOrIntOrFP> : $init); + + let arguments = (ins + LHLO_PredBufferOrI1:$pred, + LHLO_TupleOfBufferOrIntOrFP:$true_arg, + LHLO_TupleOfBufferOrIntOrFP:$false_arg + ); + + let regions = (region AnyRegion:$true_branch, + AnyRegion:$false_branch); + + let results = (outs Arg<LHLO_TupleOfBufferOrIntOrFP>); +} + def LHLO_MapOp : LHLO_Op<"map", [RecursiveSideEffects, SameOperandsShape]>, BASE_HLO_MapOp { let description = [{ diff --git a/mlir/test/Dialect/LHLO/lhlo_ops.mlir b/mlir/test/Dialect/LHLO/lhlo_ops.mlir index 30a84c5b4bcb..59ccfc8113be 100644 --- a/mlir/test/Dialect/LHLO/lhlo_ops.mlir +++ b/mlir/test/Dialect/LHLO/lhlo_ops.mlir @@ -235,3 +235,40 @@ func @while_op(%arg0: memref<4x?x16xf32>, %arg1: memref<4x?x16xf32>) { }) : (tuple<i32, memref<4xi32>>) -> tuple<i32, memref<4xi32>> "xla_lhlo.terminator"() : () -> () } + +// ----- + +func @lhlo_if(%arg0: memref<1x1x10xf32>, %arg1: memref<1x1x10xf32>, %arg2: memref<i1>) { + %0 = "xla_lhlo.tuple"(%arg0, %arg1) : (memref<1x1x10xf32>, memref<1x1x10xf32>) -> tuple<memref<1x1x10xf32>, memref<1x1x10xf32>> + // CHECK: xla_lhlo.if + %1 = "xla_lhlo.if"(%arg2, %0, %0) ( { + ^bb0(%arg3: tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>): + %2 = "xla_lhlo.get_tuple_element"(%arg3) {index = 0 : i32} : (tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>) -> memref<1x1x10xf32> + %3 = "xla_lhlo.tuple"(%2) : (memref<1x1x10xf32>) -> tuple<memref<1x1x10xf32>> + "xla_lhlo.yield"(%3) : (tuple<memref<1x1x10xf32>>) -> () + }, { + ^bb0(%arg3: tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>): // no predecessors + %2 = "xla_lhlo.get_tuple_element"(%arg3) {index = 1 : i32} : (tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>) -> memref<1x1x10xf32> + %3 = "xla_lhlo.tuple"(%2) : (memref<1x1x10xf32>) -> tuple<memref<1x1x10xf32>> + "xla_lhlo.yield"(%3) : (tuple<memref<1x1x10xf32>>) -> () + }) : (memref<i1>, tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>, tuple<memref<1x1x10xf32>, memref<1x1x10xf32>>) -> tuple<memref<1x1x10xf32>> + "xla_lhlo.terminator"() : () -> () +} + +// CHECK-LABEL: func @lhlo_if_empty_arg +func @lhlo_if_empty_arg(%arg0: memref<i1>) { + %cst = constant 1.000000e+00 : f32 + %cst_0 = constant 0.000000e+00 : f32 + %0 = "xla_lhlo.tuple"() : () -> tuple<> + // CHECK: xla_lhlo.if + %1 = "xla_lhlo.if"(%arg0, %0, %0) ( { + ^bb0(%arg1: tuple<>): + %2 = "xla_lhlo.tuple"(%cst, %cst_0) : (f32, f32) -> tuple<f32, f32> + "xla_lhlo.yield"(%2) : (tuple<f32, f32>) -> () + }, { + ^bb0(%arg1: tuple<>): + %2 = "xla_lhlo.tuple"(%cst_0, %cst) : (f32, f32) -> tuple<f32, f32> + "xla_lhlo.yield"(%2) : (tuple<f32, f32>) -> () + }) : (memref<i1>, tuple<>, tuple<>) -> tuple<f32, f32> + "xla_lhlo.terminator"() : () -> () +} _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits