csigg updated this revision to Diff 432074.
csigg added a comment.

Rebase.


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D126158/new/

https://reviews.llvm.org/D126158

Files:
  mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
  mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
  mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
  mlir/test/Dialect/LLVMIR/nvvm.mlir
  mlir/test/Target/LLVMIR/nvvmir.mlir

Index: mlir/test/Target/LLVMIR/nvvmir.mlir
===================================================================
--- mlir/test/Target/LLVMIR/nvvmir.mlir
+++ mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -1,5 +1,6 @@
 // RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
 
+// CHECK-LABEL: @nvvm_special_regs
 llvm.func @nvvm_special_regs() -> i32 {
   // CHECK: %1 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
   %1 = nvvm.read.ptx.sreg.tid.x : i32
@@ -32,12 +33,21 @@
   llvm.return %1 : i32
 }
 
+// CHECK-LABEL: @nvvm_rcp
+llvm.func @nvvm_rcp(%0: f32) -> f32 {
+  // CHECK: call float @llvm.nvvm.rcp.approx.ftz.f
+  %1 = nvvm.rcp.approx.ftz.f %0 : f32
+  llvm.return %1 : f32
+}
+
+// CHECK-LABEL: @llvm_nvvm_barrier0
 llvm.func @llvm_nvvm_barrier0() {
   // CHECK: call void @llvm.nvvm.barrier0()
   nvvm.barrier0
   llvm.return
 }
 
+// CHECK-LABEL: @nvvm_shfl
 llvm.func @nvvm_shfl(
     %0 : i32, %1 : i32, %2 : i32,
     %3 : i32, %4 : f32) -> i32 {
@@ -60,6 +70,7 @@
   llvm.return %6 : i32
 }
 
+// CHECK-LABEL: @nvvm_shfl_pred
 llvm.func @nvvm_shfl_pred(
     %0 : i32, %1 : i32, %2 : i32,
     %3 : i32, %4 : f32) -> !llvm.struct<(i32, i1)> {
@@ -82,6 +93,7 @@
   llvm.return %6 : !llvm.struct<(i32, i1)>
 }
 
+// CHECK-LABEL: @nvvm_vote
 llvm.func @nvvm_vote(%0 : i32, %1 : i1) -> i32 {
   // CHECK: call i32 @llvm.nvvm.vote.ballot.sync(i32 %{{.*}}, i1 %{{.*}})
   %3 = nvvm.vote.ballot.sync %0, %1 : i32
@@ -99,6 +111,7 @@
   llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
 }
 
+// CHECK-LABEL: @nvvm_mma_m16n8k16_f16_f16
 llvm.func @nvvm_mma_m16n8k16_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
                                 %a2 : vector<2xf16>, %a3 : vector<2xf16>,
                                 %b0 : vector<2xf16>, %b1 : vector<2xf16>,
@@ -111,6 +124,7 @@
 }
 
 // f32 return type, f16 accumulate type
+// CHECK-LABEL: @nvvm_mma_m16n8k16_f32_f16
 llvm.func @nvvm_mma_m16n8k16_f32_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
                                 %a2 : vector<2xf16>, %a3 : vector<2xf16>,
                                 %b0 : vector<2xf16>, %b1 : vector<2xf16>,
@@ -123,6 +137,7 @@
 }
 
 // f16 return type, f32 accumulate type
+// CHECK-LABEL: @nvvm_mma_m16n8k16_f16_f32
 llvm.func @nvvm_mma_m16n8k16_f16_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
                                 %a2 : vector<2xf16>, %a3 : vector<2xf16>,
                                 %b0 : vector<2xf16>, %b1 : vector<2xf16>,
@@ -135,6 +150,7 @@
 }
 
 // f32 return type, f32 accumulate type
+// CHECK-LABEL: @nvvm_mma_m16n8k16_f32_f32
 llvm.func @nvvm_mma_m16n8k16_f32_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
                                 %a2 : vector<2xf16>, %a3 : vector<2xf16>,
                                 %b0 : vector<2xf16>, %b1 : vector<2xf16>,
@@ -146,7 +162,8 @@
   llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
 }
 
-llvm.func @nvvm_mma_m16n8k16_s8_s8(%a0 : i32, %a1 : i32,                                
+// CHECK-LABEL: @nvvm_mma_m16n8k16_s8_s8
+llvm.func @nvvm_mma_m16n8k16_s8_s8(%a0 : i32, %a1 : i32,
                                 %b0 : i32, 
                                 %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) -> !llvm.struct<(i32, i32, i32, i32)> {
   // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.mma.m16n8k16.row.col.s8
@@ -158,7 +175,8 @@
   llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
 }
 
-llvm.func @nvvm_mma_m16n8k16_s8_u8(%a0 : i32, %a1 : i32,                                
+// CHECK-LABEL: @nvvm_mma_m16n8k16_s8_u8
+llvm.func @nvvm_mma_m16n8k16_s8_u8(%a0 : i32, %a1 : i32,
                                 %b0 : i32, 
                                 %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) -> !llvm.struct<(i32, i32, i32, i32)> {
   // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.mma.m16n8k16.row.col.satfinite.s8.u8
@@ -170,7 +188,8 @@
   llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
 }
 
-llvm.func @nvvm_mma_m16n8k128_b1_b1(%a0 : i32, %a1 : i32, 
+// CHECK-LABEL: @nvvm_mma_m16n8k128_b1_b1
+llvm.func @nvvm_mma_m16n8k128_b1_b1(%a0 : i32, %a1 : i32,
                                     %b0 : i32,
                                     %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) -> !llvm.struct<(i32,i32,i32,i32)> {  
   // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.mma.xor.popc.m16n8k128.row.col.b1
@@ -181,6 +200,7 @@
   llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
 }
 
+// CHECK-LABEL: @nvvm_mma_m16n8k32_s4_s4
 llvm.func @nvvm_mma_m16n8k32_s4_s4(%a0 : i32, %a1 : i32,
                                %b0 : i32,
                                %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) -> !llvm.struct<(i32,i32,i32,i32)> {  
@@ -193,6 +213,7 @@
   llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
 }
 
+// CHECK-LABEL: @nvvm_mma_m8n8k4_f64_f64
 llvm.func @nvvm_mma_m8n8k4_f64_f64(%a0 : f64,
                                    %b0 : f64, 
                                    %c0 : f64, %c1 : f64) -> !llvm.struct<(f64, f64)> {
@@ -203,6 +224,7 @@
   llvm.return %0 : !llvm.struct<(f64, f64)>
 }
 
+// CHECK-LABEL: @nvvm_mma_m16n8k4_tf32_f32
 llvm.func @nvvm_mma_m16n8k4_tf32_f32(%a0 : i32, %a1 : i32,
                                      %b0 : i32,
                                      %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) -> !llvm.struct<(f32, f32, f32, f32)> {
@@ -228,6 +250,7 @@
 
 // The test below checks the correct mapping of the nvvm.wmma.*.store.* op to the correct intrinsic
 // in the LLVM NVPTX backend.
+// CHECK-LABEL: @gpu_wmma_store_op
 llvm.func @gpu_wmma_store_op(%arg0: !llvm.ptr<i32, 3>, %arg1: i32,
                             %arg2: vector<2 x f16>, %arg3: vector<2 x f16>,
                             %arg4: vector<2 xf16>, %arg5: vector<2 x f16>) {
@@ -240,6 +263,7 @@
 
 // The test below checks the correct mapping of the nvvm.wmma.*.mma.* op to the correct intrinsic
 // in the LLVM NVPTX backend.
+// CHECK-LABEL: @gpu_wmma_mma_op
 llvm.func @gpu_wmma_mma_op(%arg0: vector<2 x f16>, %arg1: vector<2 x f16>,
                         %arg2: vector<2 x f16>, %arg3: vector<2 x f16>,
                         %arg4: vector<2 x f16>, %arg5: vector<2 x f16>,
@@ -261,6 +285,7 @@
   llvm.return
 }
 
+// CHECK-LABEL: @nvvm_wmma_load_tf32
 llvm.func @nvvm_wmma_load_tf32(%arg0: !llvm.ptr<i32>, %arg1 : i32) {
   // CHECK: call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k8.load.a.row.stride.tf32.p0i32(i32* %{{.*}}, i32 %{{.*}})
   %0 = nvvm.wmma.load %arg0, %arg1
@@ -269,6 +294,7 @@
   llvm.return
 }
 
+// CHECK-LABEL: @nvvm_wmma_mma
 llvm.func @nvvm_wmma_mma(%0 : i32, %1 : i32, %2 : i32, %3 : i32, %4 : i32, %5 : i32,
                     %6 : i32, %7 : i32, %8 : f32, %9 : f32, %10 : f32,
                     %11 : f32, %12 : f32, %13 : f32, %14 : f32, %15 : f32) {
@@ -280,6 +306,7 @@
   llvm.return
 }
 
+// CHECK-LABEL: @cp_async
 llvm.func @cp_async(%arg0: !llvm.ptr<i8, 3>, %arg1: !llvm.ptr<i8, 1>) {
 // CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.4(i8 addrspace(3)* %{{.*}}, i8 addrspace(1)* %{{.*}})
   nvvm.cp.async.shared.global %arg0, %arg1, 4
@@ -296,7 +323,7 @@
   llvm.return
 }
 
-// CHECK-LABEL: @ld_matrix(
+// CHECK-LABEL: @ld_matrix
 llvm.func @ld_matrix(%arg0: !llvm.ptr<i32, 3>) {
   // CHECK: call i32 @llvm.nvvm.ldmatrix.sync.aligned.m8n8.x1.b16.p3i32(i32 addrspace(3)* %{{.*}})
   %l1 = nvvm.ldmatrix %arg0 {num = 1 : i32, layout = #nvvm.mma_layout<row>} : (!llvm.ptr<i32, 3>) -> i32
Index: mlir/test/Dialect/LLVMIR/nvvm.mlir
===================================================================
--- mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -1,5 +1,6 @@
 // RUN: mlir-opt %s -split-input-file -verify-diagnostics | FileCheck %s
 
+// CHECK-LABEL: @nvvm_special_regs
 func.func @nvvm_special_regs() -> i32 {
   // CHECK: nvvm.read.ptx.sreg.tid.x : i32
   %0 = nvvm.read.ptx.sreg.tid.x : i32
@@ -28,12 +29,21 @@
   llvm.return %0 : i32
 }
 
-func.func @llvm.nvvm.barrier0() {
+// CHECK-LABEL: @nvvm_rcp
+func.func @nvvm_rcp(%arg0: f32) -> f32 {
+  // CHECK: nvvm.rcp.approx.ftz.f %arg0 : f32
+  %0 = nvvm.rcp.approx.ftz.f %arg0 : f32
+  llvm.return %0 : f32
+}
+
+// CHECK-LABEL: @llvm_nvvm_barrier0
+func.func @llvm_nvvm_barrier0() {
   // CHECK: nvvm.barrier0
   nvvm.barrier0
   llvm.return
 }
 
+// CHECK-LABEL: @nvvm_shfl
 func.func @nvvm_shfl(
     %arg0 : i32, %arg1 : i32, %arg2 : i32,
     %arg3 : i32, %arg4 : f32) -> i32 {
@@ -50,6 +60,7 @@
   llvm.return %0 : i32
 }
 
+// CHECK-LABEL: @nvvm_shfl_pred
 func.func @nvvm_shfl_pred(
     %arg0 : i32, %arg1 : i32, %arg2 : i32,
     %arg3 : i32, %arg4 : f32) -> !llvm.struct<(i32, i1)> {
@@ -60,6 +71,7 @@
   llvm.return %0 : !llvm.struct<(i32, i1)>
 }
 
+// CHECK-LABEL: @nvvm_vote(
 func.func @nvvm_vote(%arg0 : i32, %arg1 : i1) -> i32 {
   // CHECK: nvvm.vote.ballot.sync %{{.*}}, %{{.*}} : i32
   %0 = nvvm.vote.ballot.sync %arg0, %arg1 : i32
@@ -77,6 +89,7 @@
   llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
 }
 
+// CHECK-LABEL: @nvvm_mma_m8n8k4_f16_f16
 func.func @nvvm_mma_m8n8k4_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
                               %b0 : vector<2xf16>, %b1 : vector<2xf16>,
                               %c0 : vector<2xf16>, %c1 : vector<2xf16>, %c2 : vector<2xf16>, %c3 : vector<2xf16>) {  
@@ -87,6 +100,7 @@
   llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>
 }
 
+// CHECK-LABEL: @nvvm_mma_m8n8k16_s8_s8
 func.func @nvvm_mma_m8n8k16_s8_s8(%a0 : i32, %b0 : i32,
                              %c0 : i32, %c1 : i32) {                             
   // CHECK: nvvm.mma.sync A[{{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>, shape = {k = 16 : i32, m = 8 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32)> 
@@ -98,7 +112,8 @@
   llvm.return %0 : !llvm.struct<(i32, i32)>
 }
 
-func.func @nvvm_mma_m16n8k8_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,                                
+// CHECK-LABEL: @nvvm_mma_m16n8k8_f16_f16
+func.func @nvvm_mma_m16n8k8_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
                                %b0 : vector<2xf16>,
                                %c0 : vector<2xf16>, %c1 : vector<2xf16>) {
   // CHECK: nvvm.mma.sync A[%{{.*}}, %{{.*}}] B[%{{.*}}] C[%{{.*}}, %{{.*}}] {{{.*}}} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
@@ -108,6 +123,7 @@
   llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
 }
 
+// CHECK-LABEL: @nvvm_mma_m16n8k16_f16_f16
 func.func @nvvm_mma_m16n8k16_f16_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
                                 %a2 : vector<2xf16>, %a3 : vector<2xf16>,
                                 %b0 : vector<2xf16>, %b1 : vector<2xf16>,
@@ -119,6 +135,7 @@
   llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
 }
 
+// CHECK-LABEL: @nvvm_mma_m16n8k16_f32_f16
 func.func @nvvm_mma_m16n8k16_f32_f16(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
                                 %a2 : vector<2xf16>, %a3 : vector<2xf16>,
                                 %b0 : vector<2xf16>, %b1 : vector<2xf16>,
@@ -130,6 +147,7 @@
   llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
 }
 
+// CHECK-LABEL: @nvvm_mma_m16n8k16_f16_f32
 func.func @nvvm_mma_m16n8k16_f16_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
                                 %a2 : vector<2xf16>, %a3 : vector<2xf16>,
                                 %b0 : vector<2xf16>, %b1 : vector<2xf16>,
@@ -141,6 +159,7 @@
   llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
 }
 
+// CHECK-LABEL: @nvvm_mma_m16n8k16_f32_f32
 func.func @nvvm_mma_m16n8k16_f32_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
                                 %a2 : vector<2xf16>, %a3 : vector<2xf16>,
                                 %b0 : vector<2xf16>, %b1 : vector<2xf16>,
@@ -152,7 +171,8 @@
   llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
 }
 
-func.func @nvvm_mma_m16n8k4_tf32_f32(%a0 : i32, %a1 : i32,                                
+// CHECK-LABEL: @nvvm_mma_m16n8k4_tf32_f32
+func.func @nvvm_mma_m16n8k4_tf32_f32(%a0 : i32, %a1 : i32,
                                      %b0 : i32,
                                      %c0 : f32, %c1 : f32, %c2 : f32, %c3 : f32) {
   // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<tf32>, multiplicandBPtxType = #nvvm.mma_type<tf32>, shape = {k = 4 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, f32) -> !llvm.struct<(f32, f32, f32, f32)>
@@ -163,7 +183,8 @@
   llvm.return %0 : !llvm.struct<(f32, f32, f32, f32)>
 }
 
-func.func @nvvm_mma_m16n8k16_s8_s8(%a0 : i32, %a1 : i32, %b0 : i32, 
+// CHECK-LABEL: @nvvm_mma_m16n8k16_s8_s8
+func.func @nvvm_mma_m16n8k16_s8_s8(%a0 : i32, %a1 : i32, %b0 : i32,
                               %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) {  
   // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<wrapped>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<s8>, shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
   %0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
@@ -174,7 +195,8 @@
   llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
 }
 
-func.func @nvvm_mma_m16n8k16_s8_u8(%a0 : i32, %a1 : i32,                                
+// CHECK-LABEL: @nvvm_mma_m16n8k16_s8_u8
+func.func @nvvm_mma_m16n8k16_s8_u8(%a0 : i32, %a1 : i32,
                                 %b0 : i32, 
                                 %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) {  
   // CHECK: nvvm.mma.sync A[{{.*}}, {{.*}}] B[{{.*}}] C[{{.*}}, {{.*}}, {{.*}}, {{.*}}] {intOverflowBehavior = #nvvm.mma_int_overflow<satfinite>, layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>, multiplicandAPtxType = #nvvm.mma_type<s8>, multiplicandBPtxType = #nvvm.mma_type<u8>, shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32, i32, i32, i32)>
@@ -186,6 +208,7 @@
   llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
 }
 
+// CHECK-LABEL: @nvvm_mma_m16n8k256_b1_b1
 func.func @nvvm_mma_m16n8k256_b1_b1(%a0 : i32, %a1 : i32, %a2 : i32, %a3 : i32,
                                %b0 : i32, %b1 : i32,
                                %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) {  
@@ -197,6 +220,7 @@
   llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
 }
 
+// CHECK-LABEL: @nvvm_mma_m16n8k128_b1_b1
 func.func @nvvm_mma_m16n8k128_b1_b1(%a0 : i32, %a1 : i32,
                                %b0 : i32,
                                %c0 : i32, %c1 : i32, %c2 : i32, %c3 : i32) {  
@@ -243,6 +267,7 @@
   llvm.return %0 : !llvm.struct<(i32, i32, i32, i32)>
 }
 
+// CHECK-LABEL: @nvvm_wmma_mma
 func.func @nvvm_wmma_mma(%0 : i32, %1 : i32, %2 : i32, %3 : i32, %4 : i32, %5 : i32,
                     %6 : i32, %7 : i32, %8 : f32, %9 : f32, %10 : f32,
                     %11 : f32, %12 : f32, %13 : f32, %14 : f32, %15 : f32)
@@ -255,6 +280,7 @@
   llvm.return %r : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
 }
 
+// CHECK-LABEL: @cp_async
 llvm.func @cp_async(%arg0: !llvm.ptr<i8, 3>, %arg1: !llvm.ptr<i8, 1>) {
 // CHECK:  nvvm.cp.async.shared.global %{{.*}}, %{{.*}}, 16
   nvvm.cp.async.shared.global %arg0, %arg1, 16
Index: mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
===================================================================
--- mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -488,3 +488,30 @@
   }
 }
 
+// -----
+
+gpu.module @test_module {
+  // CHECK-LABEL: func @gpu_divf_fp16
+  func.func @gpu_divf_fp16(%arg0 : f16, %arg1 : f16) -> f16 {
+    // CHECK: %[[lhs:.*]]     = llvm.fpext %arg0 : f16 to f32
+    // CHECK: %[[rhs:.*]]     = llvm.fpext %arg1 : f16 to f32
+    // CHECK: %[[rcp:.*]]     = nvvm.rcp.approx.ftz.f %1 : f32
+    // CHECK: %[[approx:.*]]  = llvm.fmul %[[lhs]], %[[rcp]] : f32
+    // CHECK: %[[neg:.*]]     = llvm.fneg %[[rhs]] : f32
+    // CHECK: %[[err:.*]]     = "llvm.intr.fma"(%[[approx]], %[[neg]], %[[lhs]]) : (f32, f32, f32) -> f32
+    // CHECK: %[[refined:.*]] = "llvm.intr.fma"(%[[err]], %[[rcp]], %[[approx]]) : (f32, f32, f32) -> f32
+    // CHECK: %[[mask:.*]]    = llvm.mlir.constant(2139095040 : ui32) : i32
+    // CHECK: %[[cast:.*]]    = llvm.bitcast %[[approx]] : f32 to i32
+    // CHECK: %[[exp:.*]]     = llvm.and %[[cast]], %[[mask]] : i32
+    // CHECK: %[[c0:.*]]      = llvm.mlir.constant(0 : ui32) : i32
+    // CHECK: %[[is_zero:.*]] = llvm.icmp "eq" %[[exp]], %[[c0]] : i32
+    // CHECK: %[[is_mask:.*]] = llvm.icmp "eq" %[[exp]], %[[mask]] : i32
+    // CHECK: %[[pred:.*]]    = llvm.or %[[is_zero]], %[[is_mask]] : i1
+    // CHECK: %[[select:.*]]  = llvm.select %[[pred]], %[[approx]], %[[refined]] : i1, f32
+    // CHECK: %[[result:.*]]  = llvm.fptrunc %[[select]] : f32 to f16
+    %result = arith.divf %arg0, %arg1 : f16
+    // CHECK: llvm.return %[[result]] : f16
+    func.return %result : f16
+  }
+}
+
Index: mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
===================================================================
--- mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -148,6 +148,62 @@
   }
 };
 
+// Replaces fdiv on fp16 with fp32 multiplication with reciprocal plus one
+// (conditional) Newton iteration.
+//
+// This as accurate as promoting the division to fp32 in the NVPTX backend, but
+// faster because it performs less Newton iterations, avoids the slow path
+// for e.g. denormals, and allows reuse of the reciprocal for multiple divisions
+// by the same divisor.
+struct ExpandDivF16 : public ConvertOpToLLVMPattern<LLVM::FDivOp> {
+  using ConvertOpToLLVMPattern<LLVM::FDivOp>::ConvertOpToLLVMPattern;
+
+private:
+  LogicalResult
+  matchAndRewrite(LLVM::FDivOp op, LLVM::FDivOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (!op.getType().isF16())
+      return rewriter.notifyMatchFailure(op, "not f16");
+    Location loc = op.getLoc();
+
+    Type f32Type = rewriter.getF32Type();
+    Type i32Type = rewriter.getI32Type();
+
+    // Extend lhs and rhs to fp32.
+    Value lhs = rewriter.create<LLVM::FPExtOp>(loc, f32Type, adaptor.getLhs());
+    Value rhs = rewriter.create<LLVM::FPExtOp>(loc, f32Type, adaptor.getRhs());
+
+    // float rcp = rcp.approx.ftz.f32(rhs), approx = lhs * rcp.
+    Value rcp = rewriter.create<NVVM::RcpApproxFtzF32Op>(loc, f32Type, rhs);
+    Value approx = rewriter.create<LLVM::FMulOp>(loc, lhs, rcp);
+
+    // Refine the approximation with one Newton iteration:
+    // float refined = approx + (lhs - approx * rhs) * rcp;
+    Value err = rewriter.create<LLVM::FMAOp>(
+        loc, approx, rewriter.create<LLVM::FNegOp>(loc, rhs), lhs);
+    Value refined = rewriter.create<LLVM::FMAOp>(loc, err, rcp, approx);
+
+    // Use refined value if approx is normal (exponent neither all 0 or all 1).
+    Value mask = rewriter.create<LLVM::ConstantOp>(
+        loc, i32Type, rewriter.getUI32IntegerAttr(0x7f800000));
+    Value cast = rewriter.create<LLVM::BitcastOp>(loc, i32Type, approx);
+    Value exp = rewriter.create<LLVM::AndOp>(loc, i32Type, cast, mask);
+    Value zero = rewriter.create<LLVM::ConstantOp>(
+        loc, i32Type, rewriter.getUI32IntegerAttr(0));
+    Value pred = rewriter.create<LLVM::OrOp>(
+        loc,
+        rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, exp, zero),
+        rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, exp, mask));
+    Value result =
+        rewriter.create<LLVM::SelectOp>(loc, f32Type, pred, approx, refined);
+
+    // Replace with trucation back to fp16.
+    rewriter.replaceOpWithNewOp<LLVM::FPTruncOp>(op, op.getType(), result);
+
+    return success();
+  }
+};
+
 /// Import the GPU Ops to NVVM Patterns.
 #include "GPUToNVVM.cpp.inc"
 
@@ -222,6 +278,10 @@
                       LLVM::FCeilOp, LLVM::FFloorOp, LLVM::LogOp, LLVM::Log10Op,
                       LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp, LLVM::SqrtOp>();
 
+  // Expand fdiv on fp16 to faster code than NVPTX backend's fp32 promotion.
+  target.addDynamicallyLegalOp<LLVM::FDivOp>(
+      [&](LLVM::FDivOp op) { return !op.getType().isF16(); });
+
   // TODO: Remove once we support replacing non-root ops.
   target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp, gpu::ModuleEndOp>();
 }
@@ -241,6 +301,8 @@
            GPULaneIdOpToNVVM, GPUShuffleOpLowering, GPUReturnOpLowering>(
           converter);
 
+  patterns.add<ExpandDivF16>(converter);
+
   // Explicitly drop memory space when lowering private memory
   // attributions since NVVM models it as `alloca`s in the default
   // memory space and does not support `alloca`s with addrspace(5).
Index: mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
===================================================================
--- mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -51,21 +51,21 @@
 // NVVM intrinsic operations
 //===----------------------------------------------------------------------===//
 
-class NVVM_IntrOp<string mnem, list<int> overloadedResults,
-                  list<int> overloadedOperands, list<Trait> traits,
+class NVVM_IntrOp<string mnem, list<Trait> traits,
                   int numResults>
   : LLVM_IntrOpBase<NVVM_Dialect, mnem, "nvvm_" # !subst(".", "_", mnem),
-                    overloadedResults, overloadedOperands, traits, numResults>;
+                    /*list<int> overloadedResults=*/[],
+                    /*list<int> overloadedOperands=*/[],
+                    traits, numResults>;
 
 
 //===----------------------------------------------------------------------===//
 // NVVM special register op definitions
 //===----------------------------------------------------------------------===//
 
-class NVVM_SpecialRegisterOp<string mnemonic,
-    list<Trait> traits = []> :
-  NVVM_IntrOp<mnemonic, [], [], !listconcat(traits, [NoSideEffect]), 1>,
-  Arguments<(ins)> {
+class NVVM_SpecialRegisterOp<string mnemonic, list<Trait> traits = []> :
+  NVVM_IntrOp<mnemonic, !listconcat(traits, [NoSideEffect]), 1> {
+  let arguments = (ins);
   let assemblyFormat = "attr-dict `:` type($res)";
 }
 
@@ -92,6 +92,16 @@
 def NVVM_GridDimYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.y">;
 def NVVM_GridDimZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.z">;
 
+//===----------------------------------------------------------------------===//
+// NVVM approximate op definitions
+//===----------------------------------------------------------------------===//
+
+def NVVM_RcpApproxFtzF32Op : NVVM_IntrOp<"rcp.approx.ftz.f", [NoSideEffect], 1> {
+  let arguments = (ins F32:$arg);
+  let results = (outs F32:$res);
+  let assemblyFormat = "$arg attr-dict `:` type($res)";
+}
+
 //===----------------------------------------------------------------------===//
 // NVVM synchronization op definitions
 //===----------------------------------------------------------------------===//
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to