This revision was automatically updated to reflect the committed changes.
Closed by commit rGbef3eb83442a: [Clang][NVPTX]Add NVPTX intrinsics and 
builtins for CUDA PTX cvt sm80… (authored by JackAKirk, committed by tra).

Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D116673/new/

https://reviews.llvm.org/D116673

Files:
  clang/include/clang/Basic/BuiltinsNVPTX.def
  clang/test/CodeGen/builtins-nvptx.c
  llvm/include/llvm/IR/IntrinsicsNVVM.td
  llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
  llvm/lib/Target/NVPTX/NVPTX.h
  llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
  llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
  llvm/test/CodeGen/NVPTX/convert-sm80.ll

Index: llvm/test/CodeGen/NVPTX/convert-sm80.ll
===================================================================
--- /dev/null
+++ llvm/test/CodeGen/NVPTX/convert-sm80.ll
@@ -0,0 +1,136 @@
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_80 -mattr=+ptx70 | FileCheck %s
+
+
+; CHECK-LABEL: cvt_rn_bf16x2_f32
+define i32 @cvt_rn_bf16x2_f32(float %f1, float %f2) {
+
+; CHECK: cvt.rn.bf16x2.f32
+  %val = call i32 @llvm.nvvm.ff2bf16x2.rn(float %f1, float %f2);
+
+ret i32 %val
+}
+
+; CHECK-LABEL: cvt_rn_relu_bf16x2_f32
+define i32 @cvt_rn_relu_bf16x2_f32(float %f1, float %f2) {
+
+; CHECK: cvt.rn.relu.bf16x2.f32
+%val = call i32 @llvm.nvvm.ff2bf16x2.rn.relu(float %f1, float %f2);
+
+ret i32 %val
+}
+
+; CHECK-LABEL: cvt_rz_bf16x2_f32
+define i32 @cvt_rz_bf16x2_f32(float %f1, float %f2) {
+
+; CHECK: cvt.rz.bf16x2.f32
+  %val = call i32 @llvm.nvvm.ff2bf16x2.rz(float %f1, float %f2);
+
+ret i32 %val
+}
+
+; CHECK-LABEL: cvt_rz_relu_bf16x2_f32
+define i32 @cvt_rz_relu_bf16x2_f32(float %f1, float %f2) {
+
+; CHECK: cvt.rz.relu.bf16x2.f32
+%val = call i32 @llvm.nvvm.ff2bf16x2.rz.relu(float %f1, float %f2);
+
+ret i32 %val
+}
+
+declare i32 @llvm.nvvm.ff2bf16x2.rn(float, float)
+declare i32 @llvm.nvvm.ff2bf16x2.rn.relu(float, float)
+declare i32 @llvm.nvvm.ff2bf16x2.rz(float, float)
+declare i32 @llvm.nvvm.ff2bf16x2.rz.relu(float, float)
+
+; CHECK-LABEL: cvt_rn_f16x2_f32
+define <2 x half> @cvt_rn_f16x2_f32(float %f1, float %f2) {
+
+; CHECK: cvt.rn.f16x2.f32
+  %val = call <2 x half> @llvm.nvvm.ff2f16x2.rn(float %f1, float %f2);
+
+ret <2 x half> %val
+}
+
+; CHECK-LABEL: cvt_rn_relu_f16x2_f32
+define <2 x half> @cvt_rn_relu_f16x2_f32(float %f1, float %f2) {
+
+; CHECK: cvt.rn.relu.f16x2.f32
+%val = call <2 x half> @llvm.nvvm.ff2f16x2.rn.relu(float %f1, float %f2);
+
+ret <2 x half> %val
+}
+
+; CHECK-LABEL: cvt_rz_f16x2_f32
+define <2 x half> @cvt_rz_f16x2_f32(float %f1, float %f2) {
+
+; CHECK: cvt.rz.f16x2.f32
+  %val = call <2 x half> @llvm.nvvm.ff2f16x2.rz(float %f1, float %f2);
+
+ret <2 x half> %val
+}
+
+; CHECK-LABEL: cvt_rz_relu_f16x2_f32
+define <2 x half> @cvt_rz_relu_f16x2_f32(float %f1, float %f2) {
+
+; CHECK: cvt.rz.relu.f16x2.f32
+%val = call <2 x half> @llvm.nvvm.ff2f16x2.rz.relu(float %f1, float %f2);
+
+ret <2 x half> %val
+}
+
+declare <2 x half> @llvm.nvvm.ff2f16x2.rn(float, float)
+declare <2 x half> @llvm.nvvm.ff2f16x2.rn.relu(float, float)
+declare <2 x half> @llvm.nvvm.ff2f16x2.rz(float, float)
+declare <2 x half> @llvm.nvvm.ff2f16x2.rz.relu(float, float)
+
+; CHECK-LABEL: cvt_rn_bf16_f32
+define i16 @cvt_rn_bf16_f32(float %f1) {
+
+; CHECK: cvt.rn.bf16.f32
+  %val = call i16 @llvm.nvvm.f2bf16.rn(float %f1);
+
+ret i16 %val
+}
+
+; CHECK-LABEL: cvt_rn_relu_bf16_f32
+define i16 @cvt_rn_relu_bf16_f32(float %f1) {
+
+; CHECK: cvt.rn.relu.bf16.f32
+%val = call i16 @llvm.nvvm.f2bf16.rn.relu(float %f1);
+
+ret i16 %val
+}
+
+; CHECK-LABEL: cvt_rz_bf16_f32
+define i16 @cvt_rz_bf16_f32(float %f1) {
+
+; CHECK: cvt.rz.bf16.f32
+  %val = call i16 @llvm.nvvm.f2bf16.rz(float %f1);
+
+ret i16 %val
+}
+
+; CHECK-LABEL: cvt_rz_relu_bf16_f32
+define i16 @cvt_rz_relu_bf16_f32(float %f1) {
+
+; CHECK: cvt.rz.relu.bf16.f32
+%val = call i16 @llvm.nvvm.f2bf16.rz.relu(float %f1);
+
+ret i16 %val
+}
+
+declare i16 @llvm.nvvm.f2bf16.rn(float)
+declare i16 @llvm.nvvm.f2bf16.rn.relu(float)
+declare i16 @llvm.nvvm.f2bf16.rz(float)
+declare i16 @llvm.nvvm.f2bf16.rz.relu(float)
+
+; CHECK-LABEL: cvt_rna_tf32_f32
+define i32 @cvt_rna_tf32_f32(float %f1) {
+
+; CHECK: cvt.rna.tf32.f32
+  %val = call i32 @llvm.nvvm.f2tf32.rna(float %f1);
+
+ret i32 %val
+}
+
+declare i32 @llvm.nvvm.f2tf32.rna(float)
Index: llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
===================================================================
--- llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -1046,6 +1046,38 @@
 def : Pat<(int_nvvm_ui2f_rp Int32Regs:$a),
           (CVT_f32_u32 Int32Regs:$a, CvtRP)>;
 
+def : Pat<(int_nvvm_ff2bf16x2_rn Float32Regs:$a, Float32Regs:$b),
+          (CVT_bf16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN)>;
+def : Pat<(int_nvvm_ff2bf16x2_rn_relu Float32Regs:$a, Float32Regs:$b),
+          (CVT_bf16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN_RELU)>;
+def : Pat<(int_nvvm_ff2bf16x2_rz Float32Regs:$a, Float32Regs:$b),
+          (CVT_bf16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRZ)>;
+def : Pat<(int_nvvm_ff2bf16x2_rz_relu Float32Regs:$a, Float32Regs:$b),
+          (CVT_bf16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRZ_RELU)>;
+
+def : Pat<(int_nvvm_ff2f16x2_rn Float32Regs:$a, Float32Regs:$b),
+          (CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN)>;
+def : Pat<(int_nvvm_ff2f16x2_rn_relu Float32Regs:$a, Float32Regs:$b),
+          (CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN_RELU)>;
+def : Pat<(int_nvvm_ff2f16x2_rz Float32Regs:$a, Float32Regs:$b),
+          (CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRZ)>;
+def : Pat<(int_nvvm_ff2f16x2_rz_relu Float32Regs:$a, Float32Regs:$b),
+          (CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRZ_RELU)>;
+
+def : Pat<(int_nvvm_f2bf16_rn Float32Regs:$a),
+          (CVT_bf16_f32 Float32Regs:$a, CvtRN)>;
+def : Pat<(int_nvvm_f2bf16_rn_relu Float32Regs:$a),
+          (CVT_bf16_f32 Float32Regs:$a, CvtRN_RELU)>;
+def : Pat<(int_nvvm_f2bf16_rz Float32Regs:$a),
+          (CVT_bf16_f32 Float32Regs:$a, CvtRZ)>;
+def : Pat<(int_nvvm_f2bf16_rz_relu Float32Regs:$a),
+          (CVT_bf16_f32 Float32Regs:$a, CvtRZ_RELU)>;
+
+def CVT_tf32_f32 :
+   NVPTXInst<(outs Int32Regs:$dest), (ins Float32Regs:$a),
+                   "cvt.rna.tf32.f32 \t$dest, $a;",
+       [(set Int32Regs:$dest, (int_nvvm_f2tf32_rna Float32Regs:$a))]>;
+
 def INT_NVVM_LOHI_I2D : F_MATH_2<"mov.b64 \t$dst, {{$src0, $src1}};",
   Float64Regs, Int32Regs, Int32Regs, int_nvvm_lohi_i2d>;
 
Index: llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
===================================================================
--- llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -48,6 +48,7 @@
 def CvtRZ   : PatLeaf<(i32 0x6)>;
 def CvtRM   : PatLeaf<(i32 0x7)>;
 def CvtRP   : PatLeaf<(i32 0x8)>;
+def CvtRNA   : PatLeaf<(i32 0x9)>;
 
 def CvtNONE_FTZ : PatLeaf<(i32 0x10)>;
 def CvtRNI_FTZ  : PatLeaf<(i32 0x11)>;
@@ -62,6 +63,10 @@
 def CvtSAT      : PatLeaf<(i32 0x20)>;
 def CvtSAT_FTZ  : PatLeaf<(i32 0x30)>;
 
+def CvtNONE_RELU   : PatLeaf<(i32 0x40)>;
+def CvtRN_RELU   : PatLeaf<(i32 0x45)>;
+def CvtRZ_RELU   : PatLeaf<(i32 0x46)>;
+
 def CvtMode : Operand<i32> {
   let PrintMethod = "printCvtMode";
 }
@@ -526,6 +531,29 @@
                                     "cvt.s64.s16 \t$dst, $src;", []>;
   def CVT_INREG_s64_s32 : NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$src),
                                     "cvt.s64.s32 \t$dst, $src;", []>;
+
+multiclass CVT_FROM_FLOAT_SM80<string FromName, RegisterClass RC> {
+    def _f32 :
+      NVPTXInst<(outs RC:$dst),
+                (ins Float32Regs:$src, CvtMode:$mode),
+                !strconcat("cvt${mode:base}${mode:relu}.",
+                FromName, ".f32 \t$dst, $src;"), []>,
+                Requires<[hasPTX70, hasSM80]>;
+  }
+
+  defm CVT_bf16 : CVT_FROM_FLOAT_SM80<"bf16", Int16Regs>;
+
+    multiclass CVT_FROM_FLOAT_V2_SM80<string FromName, RegisterClass RC> {
+    def _f32 :
+      NVPTXInst<(outs RC:$dst),
+                (ins Float32Regs:$src1, Float32Regs:$src2,  CvtMode:$mode),
+                !strconcat("cvt${mode:base}${mode:relu}.",
+                FromName, ".f32 \t$dst, $src1, $src2;"), []>,
+    Requires<[hasPTX70, hasSM80]>;
+  }
+
+  defm CVT_f16x2 : CVT_FROM_FLOAT_V2_SM80<"f16x2", Float16x2Regs>;
+  defm CVT_bf16x2 : CVT_FROM_FLOAT_V2_SM80<"bf16x2", Int32Regs>;
 }
 
 //-----------------------------------
Index: llvm/lib/Target/NVPTX/NVPTX.h
===================================================================
--- llvm/lib/Target/NVPTX/NVPTX.h
+++ llvm/lib/Target/NVPTX/NVPTX.h
@@ -137,10 +137,12 @@
   RZ,
   RM,
   RP,
+  RNA,
 
   BASE_MASK = 0x0F,
   FTZ_FLAG = 0x10,
-  SAT_FLAG = 0x20
+  SAT_FLAG = 0x20,
+  RELU_FLAG = 0x40
 };
 }
 
Index: llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
===================================================================
--- llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
+++ llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
@@ -108,6 +108,10 @@
     // SAT flag
     if (Imm & NVPTX::PTXCvtMode::SAT_FLAG)
       O << ".sat";
+  } else if (strcmp(Modifier, "relu") == 0) {
+    // RELU flag
+    if (Imm & NVPTX::PTXCvtMode::RELU_FLAG)
+      O << ".relu";
   } else if (strcmp(Modifier, "base") == 0) {
     // Default operand
     switch (Imm & NVPTX::PTXCvtMode::BASE_MASK) {
@@ -139,6 +143,9 @@
     case NVPTX::PTXCvtMode::RP:
       O << ".rp";
       break;
+    case NVPTX::PTXCvtMode::RNA:
+      O << ".rna";
+      break;
     }
   } else {
     llvm_unreachable("Invalid conversion modifier");
Index: llvm/include/llvm/IR/IntrinsicsNVVM.td
===================================================================
--- llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -1185,6 +1185,36 @@
   def int_nvvm_f2h_rn : GCCBuiltin<"__nvvm_f2h_rn">,
       DefaultAttrsIntrinsic<[llvm_i16_ty], [llvm_float_ty], [IntrNoMem, IntrSpeculatable]>;
 
+  def int_nvvm_ff2bf16x2_rn : GCCBuiltin<"__nvvm_ff2bf16x2_rn">,
+       Intrinsic<[llvm_i32_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem]>;
+  def int_nvvm_ff2bf16x2_rn_relu : GCCBuiltin<"__nvvm_ff2bf16x2_rn_relu">,
+      Intrinsic<[llvm_i32_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem]>;
+  def int_nvvm_ff2bf16x2_rz : GCCBuiltin<"__nvvm_ff2bf16x2_rz">,
+      Intrinsic<[llvm_i32_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem]>;
+  def int_nvvm_ff2bf16x2_rz_relu : GCCBuiltin<"__nvvm_ff2bf16x2_rz_relu">,
+      Intrinsic<[llvm_i32_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem]>;
+
+  def int_nvvm_ff2f16x2_rn : GCCBuiltin<"__nvvm_ff2f16x2_rn">,
+      Intrinsic<[llvm_v2f16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem]>;
+  def int_nvvm_ff2f16x2_rn_relu : GCCBuiltin<"__nvvm_ff2f16x2_rn_relu">,
+      Intrinsic<[llvm_v2f16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem]>;
+  def int_nvvm_ff2f16x2_rz : GCCBuiltin<"__nvvm_ff2f16x2_rz">,
+      Intrinsic<[llvm_v2f16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem]>;
+  def int_nvvm_ff2f16x2_rz_relu : GCCBuiltin<"__nvvm_ff2f16x2_rz_relu">,
+      Intrinsic<[llvm_v2f16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem]>;
+
+  def int_nvvm_f2bf16_rn : GCCBuiltin<"__nvvm_f2bf16_rn">,
+      Intrinsic<[llvm_i16_ty], [llvm_float_ty], [IntrNoMem]>;
+  def int_nvvm_f2bf16_rn_relu : GCCBuiltin<"__nvvm_f2bf16_rn_relu">,
+      Intrinsic<[llvm_i16_ty], [llvm_float_ty], [IntrNoMem]>;
+  def int_nvvm_f2bf16_rz : GCCBuiltin<"__nvvm_f2bf16_rz">,
+      Intrinsic<[llvm_i16_ty], [llvm_float_ty], [IntrNoMem]>;
+  def int_nvvm_f2bf16_rz_relu : GCCBuiltin<"__nvvm_f2bf16_rz_relu">,
+       Intrinsic<[llvm_i16_ty], [llvm_float_ty], [IntrNoMem]>;
+
+  def int_nvvm_f2tf32_rna : GCCBuiltin<"__nvvm_f2tf32_rna">,
+      Intrinsic<[llvm_i32_ty], [llvm_float_ty], [IntrNoMem]>;
+
 //
 // Bitcast
 //
Index: clang/test/CodeGen/builtins-nvptx.c
===================================================================
--- clang/test/CodeGen/builtins-nvptx.c
+++ clang/test/CodeGen/builtins-nvptx.c
@@ -754,4 +754,40 @@
   __nvvm_cp_async_wait_all();
   #endif
   // CHECK: ret void
-}
\ No newline at end of file
+}
+
+// CHECK-LABEL: nvvm_cvt_sm80
+__device__ void nvvm_cvt_sm80() {
+#if __CUDA_ARCH__ >= 800
+  // CHECK_PTX70_SM80: call i32 @llvm.nvvm.ff2bf16x2.rn(float 1.000000e+00, float 1.000000e+00)
+  __nvvm_ff2bf16x2_rn(1, 1);
+  // CHECK_PTX70_SM80: call i32 @llvm.nvvm.ff2bf16x2.rn.relu(float 1.000000e+00, float 1.000000e+00)
+  __nvvm_ff2bf16x2_rn_relu(1, 1);
+  // CHECK_PTX70_SM80: call i32 @llvm.nvvm.ff2bf16x2.rz(float 1.000000e+00, float 1.000000e+00)
+  __nvvm_ff2bf16x2_rz(1, 1);
+  // CHECK_PTX70_SM80: call i32 @llvm.nvvm.ff2bf16x2.rz.relu(float 1.000000e+00, float 1.000000e+00)
+  __nvvm_ff2bf16x2_rz_relu(1, 1);
+
+  // CHECK_PTX70_SM80: call <2 x half> @llvm.nvvm.ff2f16x2.rn(float 1.000000e+00, float 1.000000e+00)
+  __nvvm_ff2f16x2_rn(1, 1);
+  // CHECK_PTX70_SM80: call <2 x half> @llvm.nvvm.ff2f16x2.rn.relu(float 1.000000e+00, float 1.000000e+00)
+  __nvvm_ff2f16x2_rn_relu(1, 1);
+  // CHECK_PTX70_SM80: call <2 x half> @llvm.nvvm.ff2f16x2.rz(float 1.000000e+00, float 1.000000e+00)
+  __nvvm_ff2f16x2_rz(1, 1);
+  // CHECK_PTX70_SM80: call <2 x half> @llvm.nvvm.ff2f16x2.rz.relu(float 1.000000e+00, float 1.000000e+00)
+  __nvvm_ff2f16x2_rz_relu(1, 1);
+
+  // CHECK_PTX70_SM80: call i16 @llvm.nvvm.f2bf16.rn(float 1.000000e+00)
+  __nvvm_f2bf16_rn(1);
+  // CHECK_PTX70_SM80: call i16 @llvm.nvvm.f2bf16.rn.relu(float 1.000000e+00)
+  __nvvm_f2bf16_rn_relu(1);
+  // CHECK_PTX70_SM80: call i16 @llvm.nvvm.f2bf16.rz(float 1.000000e+00)
+  __nvvm_f2bf16_rz(1);
+  // CHECK_PTX70_SM80: call i16 @llvm.nvvm.f2bf16.rz.relu(float 1.000000e+00)
+  __nvvm_f2bf16_rz_relu(1);
+
+  // CHECK_PTX70_SM80: call i32 @llvm.nvvm.f2tf32.rna(float 1.000000e+00)
+  __nvvm_f2tf32_rna(1);
+#endif
+  // CHECK: ret void
+}
Index: clang/include/clang/Basic/BuiltinsNVPTX.def
===================================================================
--- clang/include/clang/Basic/BuiltinsNVPTX.def
+++ clang/include/clang/Basic/BuiltinsNVPTX.def
@@ -402,6 +402,23 @@
 BUILTIN(__nvvm_f2h_rn_ftz, "Usf", "")
 BUILTIN(__nvvm_f2h_rn, "Usf", "")
 
+TARGET_BUILTIN(__nvvm_ff2bf16x2_rn, "ZUiff", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__nvvm_ff2bf16x2_rn_relu, "ZUiff", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__nvvm_ff2bf16x2_rz, "ZUiff", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__nvvm_ff2bf16x2_rz_relu, "ZUiff", "", AND(SM_80,PTX70))
+
+TARGET_BUILTIN(__nvvm_ff2f16x2_rn, "V2hff", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__nvvm_ff2f16x2_rn_relu, "V2hff", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__nvvm_ff2f16x2_rz, "V2hff", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__nvvm_ff2f16x2_rz_relu, "V2hff", "", AND(SM_80,PTX70))
+
+TARGET_BUILTIN(__nvvm_f2bf16_rn, "ZUsf", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__nvvm_f2bf16_rn_relu, "ZUsf", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__nvvm_f2bf16_rz, "ZUsf", "", AND(SM_80,PTX70))
+TARGET_BUILTIN(__nvvm_f2bf16_rz_relu, "ZUsf", "", AND(SM_80,PTX70))
+
+TARGET_BUILTIN(__nvvm_f2tf32_rna, "ZUif", "", AND(SM_80,PTX70))
+
 // Bitcast
 
 BUILTIN(__nvvm_bitcast_f2i, "if", "")
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to