kushanam updated this revision to Diff 522290.
kushanam added a comment.

Adressing review changes and removing bf16 registers


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D149976/new/

https://reviews.llvm.org/D149976

Files:
  clang/include/clang/Basic/BuiltinsNVPTX.def
  llvm/include/llvm/IR/IntrinsicsNVVM.td
  llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
  llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
  llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
  llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
  llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
  llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
  llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
  llvm/lib/Target/NVPTX/NVPTXMCExpr.cpp
  llvm/lib/Target/NVPTX/NVPTXMCExpr.h
  llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td
  llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp
  llvm/lib/Target/NVPTX/NVPTXSubtarget.h
  llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
  llvm/test/CodeGen/NVPTX/bf16-instructions.ll

Index: llvm/test/CodeGen/NVPTX/bf16-instructions.ll
===================================================================
--- /dev/null
+++ llvm/test/CodeGen/NVPTX/bf16-instructions.ll
@@ -0,0 +1,88 @@
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_80 -mattr=+ptx70 | FileCheck %s
+; RUN: %if ptxas-11.0 %{ llc < %s -march=nvptx64 -mcpu=sm_80 -mattr=+ptx70 | %ptxas-verify -arch=sm_80 %}
+
+
+; CHECK-LABEL: test_fadd(
+; CHECK-DAG:  ld.param.b16    [[A:%h[0-9]+]], [test_fadd_param_0];
+; CHECK-DAG:  ld.param.b16    [[B:%h[0-9]+]], [test_fadd_param_1];
+; CHECK-NEXT: add.rn.bf16     [[R:%f[0-9]+]], [[A]], [[B]];
+; CHECK-NEXT: st.param.b16    [func_retval0+0], [[R]];
+; CHECK-NEXT: ret;
+
+define bfloat @test_fadd(bfloat %0, bfloat %1) {
+  %3 = fadd bfloat %0, %1                                         
+  ret bfloat %3
+}
+
+; CHECK-LABEL: test_fsub(
+; CHECK-DAG:  ld.param.b16    [[A:%h[0-9]+]], [test_fsub_param_0];
+; CHECK-DAG:  ld.param.b16    [[B:%h[0-9]+]], [test_fsub_param_1];
+; CHECK-NEXT: sub.rn.bf16     [[R:%f[0-9]+]], [[A]], [[B]];
+; CHECK-NEXT: st.param.b16    [func_retval0+0], [[R]];
+; CHECK-NEXT: ret;
+
+define bfloat @test_fsub(bfloat %0, bfloat %1) {
+  %3 = fsub bfloat %0, %1                                         
+  ret bfloat %3
+}
+
+; CHECK-LABEL: test_faddx2(
+; CHECK-DAG:  ld.param.b32    [[A:%hh[0-9]+]], [test_faddx2_param_0];
+; CHECK-DAG:  ld.param.b32    [[B:%hh[0-9]+]], [test_faddx2_param_1];
+; CHECK-NEXT: add.rn.bf16x2   [[R:%f[0-9]+]], [[A]], [[B]];
+
+; CHECK:      st.param.b32    [func_retval0+0], [[R]];
+; CHECK:      ret;
+
+define <2 x bfloat> @test_faddx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
+  %r = fadd <2 x bfloat> %a, %b
+  ret <2 x bfloat> %r
+}
+
+; CHECK-LABEL: test_fsubx2(
+; CHECK-DAG:  ld.param.b32    [[A:%hh[0-9]+]], [test_fsubx2_param_0];
+; CHECK-DAG:  ld.param.b32    [[B:%hh[0-9]+]], [test_fsubx2_param_1];
+; CHECK-NEXT: sub.rn.bf16x2   [[R:%f[0-9]+]], [[A]], [[B]];
+
+; CHECK:      st.param.b32    [func_retval0+0], [[R]];
+; CHECK:      ret;
+
+define <2 x bfloat> @test_fsubx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
+  %r = fsub <2 x bfloat> %a, %b
+  ret <2 x bfloat> %r
+}
+
+; CHECK-LABEL: test_fmulx2(
+; CHECK-DAG:  ld.param.b32    [[A:%hh[0-9]+]], [test_fmulx2_param_0];
+; CHECK-DAG:  ld.param.b32    [[B:%hh[0-9]+]], [test_fmulx2_param_1];
+; CHECK-NEXT: mul.rn.bf16x2   [[R:%f[0-9]+]], [[A]], [[B]];
+
+; CHECK:      st.param.b32    [func_retval0+0], [[R]];
+; CHECK:      ret;
+
+define <2 x bfloat> @test_fmul(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
+  %r = fmul <2 x bfloat> %a, %b
+  ret <2 x bfloat> %r
+}
+
+; CHECK-LABEL: test_fdiv(
+; CHECK-DAG:  ld.param.b32    [[A:%hh[0-9]+]], [test_fdiv_param_0];
+; CHECK-DAG:  ld.param.b32    [[B:%hh[0-9]+]], [test_fdiv_param_1];
+; CHECK-DAG:  mov.b32         {[[A0:%h[0-9]+]], [[A1:%h[0-9]+]]}, [[A]]
+; CHECK-DAG:  mov.b32         {[[B0:%h[0-9]+]], [[B1:%h[0-9]+]]}, [[B]]
+; CHECK-DAG:  cvt.f32.bf16     [[FA0:%f[0-9]+]], [[A0]];
+; CHECK-DAG:  cvt.f32.bf16     [[FA1:%f[0-9]+]], [[A1]];
+; CHECK-DAG:  cvt.f32.bf16     [[FB0:%f[0-9]+]], [[B0]];
+; CHECK-DAG:  cvt.f32.bf16     [[FB1:%f[0-9]+]], [[B1]];
+; CHECK-DAG:  div.rn.f32      [[FR0:%f[0-9]+]], [[FA0]], [[FB0]];
+; CHECK-DAG:  div.rn.f32      [[FR1:%f[0-9]+]], [[FA1]], [[FB1]];
+; CHECK-DAG:  cvt.rn.bf16.f32  [[R0:%h[0-9]+]], [[FR0]];
+; CHECK-DAG:  cvt.rn.bf16.f32  [[R1:%h[0-9]+]], [[FR1]];
+; CHECK-NEXT: mov.b32         [[R:%hh[0-9]+]], {[[R0]], [[R1]]}
+; CHECK-NEXT: st.param.b32    [func_retval0+0], [[R]];
+; CHECK-NEXT: ret;
+
+define <2 x bfloat> @test_fdiv(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
+  %r = fdiv <2 x bfloat> %a, %b
+  ret <2 x bfloat> %r
+}
Index: llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
===================================================================
--- llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
+++ llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
@@ -204,6 +204,14 @@
       return {Intrinsic::fma, FTZ_MustBeOff, true};
     case Intrinsic::nvvm_fma_rn_ftz_f16x2:
       return {Intrinsic::fma, FTZ_MustBeOn, true};
+    case Intrinsic::nvvm_fma_rn_bf16:
+      return {Intrinsic::fma, FTZ_MustBeOff, true};
+    case Intrinsic::nvvm_fma_rn_ftz_bf16:
+      return {Intrinsic::fma, FTZ_MustBeOn, true};
+    case Intrinsic::nvvm_fma_rn_bf16x2:
+      return {Intrinsic::fma, FTZ_MustBeOff, true};
+    case Intrinsic::nvvm_fma_rn_ftz_bf16x2:
+      return {Intrinsic::fma, FTZ_MustBeOn, true};
     case Intrinsic::nvvm_fmax_d:
       return {Intrinsic::maxnum, FTZ_Any};
     case Intrinsic::nvvm_fmax_f:
Index: llvm/lib/Target/NVPTX/NVPTXSubtarget.h
===================================================================
--- llvm/lib/Target/NVPTX/NVPTXSubtarget.h
+++ llvm/lib/Target/NVPTX/NVPTXSubtarget.h
@@ -76,7 +76,9 @@
   inline bool hasHWROT32() const { return SmVersion >= 32; }
   bool hasImageHandles() const;
   bool hasFP16Math() const { return SmVersion >= 53; }
+  bool hasBF16Math() const { return SmVersion >= 80; }
   bool allowFP16Math() const;
+  bool allowBF16Math() const;
   bool hasMaskOperator() const { return PTXVersion >= 71; }
   bool hasNoReturn() const { return SmVersion >= 30 && PTXVersion >= 64; }
   unsigned int getSmVersion() const { return SmVersion; }
Index: llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp
===================================================================
--- llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp
+++ llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp
@@ -26,7 +26,10 @@
     NoF16Math("nvptx-no-f16-math", cl::Hidden,
               cl::desc("NVPTX Specific: Disable generation of f16 math ops."),
               cl::init(false));
-
+static cl::opt<bool>
+    NoBF16Math("nvptx-no-bf16-math", cl::Hidden,
+               cl::desc("NVPTX Specific: Disable generation of bf16 math ops."),
+               cl::init(false));
 // Pin the vtable to this file.
 void NVPTXSubtarget::anchor() {}
 
@@ -65,3 +68,7 @@
 bool NVPTXSubtarget::allowFP16Math() const {
   return hasFP16Math() && NoF16Math == false;
 }
+
+bool NVPTXSubtarget::allowBF16Math() const {
+  return hasBF16Math() && NoBF16Math == false;
+}
\ No newline at end of file
Index: llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td
===================================================================
--- llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td
+++ llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td
@@ -60,8 +60,10 @@
 def Int16Regs : NVPTXRegClass<[i16], 16, (add (sequence "RS%u", 0, 4))>;
 def Int32Regs : NVPTXRegClass<[i32], 32, (add (sequence "R%u", 0, 4), VRFrame32, VRFrameLocal32)>;
 def Int64Regs : NVPTXRegClass<[i64], 64, (add (sequence "RL%u", 0, 4), VRFrame64, VRFrameLocal64)>;
-def Float16Regs : NVPTXRegClass<[f16,bf16], 16, (add (sequence "H%u", 0, 4))>;
-def Float16x2Regs : NVPTXRegClass<[v2f16,v2bf16], 32, (add (sequence "HH%u", 0, 4))>;
+def Float16Regs : NVPTXRegClass<[f16], 16, (add (sequence "H%u", 0, 4))>;
+def Float16x2Regs : NVPTXRegClass<[v2f16], 32, (add (sequence "HH%u", 0, 4))>;
+def BFloat16Regs : NVPTXRegClass<[bf16], 16, (add (sequence "H%u", 0, 4))>;
+def BFloat16x2Regs : NVPTXRegClass<[v2bf16], 32, (add (sequence "HH%u", 0, 4))>;
 def Float32Regs : NVPTXRegClass<[f32], 32, (add (sequence "F%u", 0, 4))>;
 def Float64Regs : NVPTXRegClass<[f64], 64, (add (sequence "FL%u", 0, 4))>;
 def Int32ArgRegs : NVPTXRegClass<[i32], 32, (add (sequence "ia%u", 0, 4))>;
Index: llvm/lib/Target/NVPTX/NVPTXMCExpr.h
===================================================================
--- llvm/lib/Target/NVPTX/NVPTXMCExpr.h
+++ llvm/lib/Target/NVPTX/NVPTXMCExpr.h
@@ -21,6 +21,7 @@
 public:
   enum VariantKind {
     VK_NVPTX_None,
+    VK_NVPTX_BFLOAT_PREC_FLOAT, // FP constant in bfloat-precision
     VK_NVPTX_HALF_PREC_FLOAT,   // FP constant in half-precision
     VK_NVPTX_SINGLE_PREC_FLOAT, // FP constant in single-precision
     VK_NVPTX_DOUBLE_PREC_FLOAT  // FP constant in double-precision
@@ -40,6 +41,11 @@
   static const NVPTXFloatMCExpr *create(VariantKind Kind, const APFloat &Flt,
                                         MCContext &Ctx);
 
+  static const NVPTXFloatMCExpr *createConstantBFPHalf(const APFloat &Flt,
+                                                       MCContext &Ctx) {
+    return create(VK_NVPTX_BFLOAT_PREC_FLOAT, Flt, Ctx);
+  }
+
   static const NVPTXFloatMCExpr *createConstantFPHalf(const APFloat &Flt,
                                                         MCContext &Ctx) {
     return create(VK_NVPTX_HALF_PREC_FLOAT, Flt, Ctx);
Index: llvm/lib/Target/NVPTX/NVPTXMCExpr.cpp
===================================================================
--- llvm/lib/Target/NVPTX/NVPTXMCExpr.cpp
+++ llvm/lib/Target/NVPTX/NVPTXMCExpr.cpp
@@ -34,6 +34,11 @@
     NumHex = 4;
     APF.convert(APFloat::IEEEhalf(), APFloat::rmNearestTiesToEven, &Ignored);
     break;
+  case VK_NVPTX_BFLOAT_PREC_FLOAT:
+    OS << "0x";
+    NumHex = 4;
+    APF.convert(APFloat::BFloat(), APFloat::rmNearestTiesToEven, &Ignored);
+    break;
   case VK_NVPTX_SINGLE_PREC_FLOAT:
     OS << "0f";
     NumHex = 8;
Index: llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
===================================================================
--- llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -973,6 +973,18 @@
     FMA_TUPLE<"_rn_ftz_relu_f16", int_nvvm_fma_rn_ftz_relu_f16, Float16Regs,
       [hasPTX70, hasSM80]>,
 
+    FMA_TUPLE<"_rn_bf16", int_nvvm_fma_rn_bf16, BFloat16Regs, [hasPTX70, hasSM80]>,
+    FMA_TUPLE<"_rn_ftz_bf16", int_nvvm_fma_rn_ftz_bf16, BFloat16Regs,
+      [hasPTX70, hasSM80]>,
+    FMA_TUPLE<"_rn_sat_bf16", int_nvvm_fma_rn_sat_bf16, BFloat16Regs,
+      [hasPTX70, hasSM80]>,
+    FMA_TUPLE<"_rn_ftz_sat_bf16", int_nvvm_fma_rn_ftz_sat_bf16, BFloat16Regs,
+      [hasPTX70, hasSM80]>,
+    FMA_TUPLE<"_rn_relu_bf16", int_nvvm_fma_rn_relu_bf16, BFloat16Regs,
+      [hasPTX70, hasSM80]>,
+    FMA_TUPLE<"_rn_ftz_relu_bf16", int_nvvm_fma_rn_ftz_relu_bf16, BFloat16Regs,
+      [hasPTX70, hasSM80]>,
+
     FMA_TUPLE<"_rn_f16x2", int_nvvm_fma_rn_f16x2, Float16x2Regs,
       [hasPTX42, hasSM53]>,
     FMA_TUPLE<"_rn_ftz_f16x2", int_nvvm_fma_rn_ftz_f16x2, Float16x2Regs,
@@ -986,13 +998,9 @@
     FMA_TUPLE<"_rn_ftz_relu_f16x2", int_nvvm_fma_rn_ftz_relu_f16x2,
       Float16x2Regs, [hasPTX70, hasSM80]>,
 
-    FMA_TUPLE<"_rn_bf16", int_nvvm_fma_rn_bf16, Int16Regs, [hasPTX70, hasSM80]>,
-    FMA_TUPLE<"_rn_relu_bf16", int_nvvm_fma_rn_relu_bf16, Int16Regs,
-      [hasPTX70, hasSM80]>,
-
-    FMA_TUPLE<"_rn_bf16x2", int_nvvm_fma_rn_bf16x2, Int32Regs,
+    FMA_TUPLE<"_rn_bf16x2", int_nvvm_fma_rn_bf16x2, BFloat16x2Regs,
       [hasPTX70, hasSM80]>,
-    FMA_TUPLE<"_rn_relu_bf16x2", int_nvvm_fma_rn_relu_bf16x2, Int32Regs,
+    FMA_TUPLE<"_rn_relu_bf16x2", int_nvvm_fma_rn_relu_bf16x2, BFloat16x2Regs,
       [hasPTX70, hasSM80]>
   ] in {
     def P.Variant :
@@ -1243,24 +1251,6 @@
 def : Pat<(int_nvvm_ff2bf16x2_rz_relu Float32Regs:$a, Float32Regs:$b),
           (CVT_bf16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRZ_RELU)>;
 
-def : Pat<(int_nvvm_ff2f16x2_rn Float32Regs:$a, Float32Regs:$b),
-          (CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN)>;
-def : Pat<(int_nvvm_ff2f16x2_rn_relu Float32Regs:$a, Float32Regs:$b),
-          (CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN_RELU)>;
-def : Pat<(int_nvvm_ff2f16x2_rz Float32Regs:$a, Float32Regs:$b),
-          (CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRZ)>;
-def : Pat<(int_nvvm_ff2f16x2_rz_relu Float32Regs:$a, Float32Regs:$b),
-          (CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRZ_RELU)>;
-
-def : Pat<(int_nvvm_f2bf16_rn Float32Regs:$a),
-          (CVT_bf16_f32 Float32Regs:$a, CvtRN)>;
-def : Pat<(int_nvvm_f2bf16_rn_relu Float32Regs:$a),
-          (CVT_bf16_f32 Float32Regs:$a, CvtRN_RELU)>;
-def : Pat<(int_nvvm_f2bf16_rz Float32Regs:$a),
-          (CVT_bf16_f32 Float32Regs:$a, CvtRZ)>;
-def : Pat<(int_nvvm_f2bf16_rz_relu Float32Regs:$a),
-          (CVT_bf16_f32 Float32Regs:$a, CvtRZ_RELU)>;
-
 def CVT_tf32_f32 :
    NVPTXInst<(outs Int32Regs:$dest), (ins Float32Regs:$a),
                    "cvt.rna.tf32.f32 \t$dest, $a;",
@@ -2136,6 +2126,8 @@
 defm INT_PTX_LDU_GLOBAL_i64 : LDU_G<"u64 \t$result, [$src];", Int64Regs>;
 defm INT_PTX_LDU_GLOBAL_f16 : LDU_G<"b16 \t$result, [$src];", Float16Regs>;
 defm INT_PTX_LDU_GLOBAL_f16x2 : LDU_G<"b32 \t$result, [$src];", Float16x2Regs>;
+defm INT_PTX_LDU_GLOBAL_bf16 : LDU_G<"b16 \t$result, [$src];", BFloat16Regs>;
+defm INT_PTX_LDU_GLOBAL_bf16x2 : LDU_G<"b32 \t$result, [$src];", BFloat16x2Regs>;
 defm INT_PTX_LDU_GLOBAL_f32 : LDU_G<"f32 \t$result, [$src];", Float32Regs>;
 defm INT_PTX_LDU_GLOBAL_f64 : LDU_G<"f64 \t$result, [$src];", Float64Regs>;
 defm INT_PTX_LDU_GLOBAL_p32 : LDU_G<"u32 \t$result, [$src];", Int32Regs>;
@@ -2190,6 +2182,10 @@
   : VLDU_G_ELE_V2<"v2.b16 \t{{$dst1, $dst2}}, [$src];", Float16Regs>;
 defm INT_PTX_LDU_G_v2f16x2_ELE
   : VLDU_G_ELE_V2<"v2.b32 \t{{$dst1, $dst2}}, [$src];", Float16x2Regs>;
+defm INT_PTX_LDU_G_v2bf16_ELE
+  : VLDU_G_ELE_V2<"v2.b16 \t{{$dst1, $dst2}}, [$src];", BFloat16Regs>;
+defm INT_PTX_LDU_G_v2bf16x2_ELE
+  : VLDU_G_ELE_V2<"v2.b32 \t{{$dst1, $dst2}}, [$src];", BFloat16x2Regs>;
 defm INT_PTX_LDU_G_v2f32_ELE
   : VLDU_G_ELE_V2<"v2.f32 \t{{$dst1, $dst2}}, [$src];", Float32Regs>;
 defm INT_PTX_LDU_G_v2i64_ELE
@@ -2253,6 +2249,10 @@
   : LDG_G<"b16 \t$result, [$src];", Float16Regs>;
 defm INT_PTX_LDG_GLOBAL_f16x2
   : LDG_G<"b32 \t$result, [$src];", Float16x2Regs>;
+defm INT_PTX_LDG_GLOBAL_bf16
+  : LDG_G<"b16 \t$result, [$src];", BFloat16Regs>;
+defm INT_PTX_LDG_GLOBAL_bf16x2
+  : LDG_G<"b32 \t$result, [$src];", BFloat16x2Regs>;
 defm INT_PTX_LDG_GLOBAL_f32
   : LDG_G<"f32 \t$result, [$src];", Float32Regs>;
 defm INT_PTX_LDG_GLOBAL_f64
Index: llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
===================================================================
--- llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -19,6 +19,8 @@
 
 let OperandType = "OPERAND_IMMEDIATE" in {
   def f16imm : Operand<f16>;
+  def bf16imm : Operand<bf16>;
+
 }
 
 // List of vector specific properties
@@ -172,6 +174,7 @@
 
 def useShortPtr : Predicate<"useShortPointers()">;
 def useFP16Math: Predicate<"Subtarget->allowFP16Math()">;
+def useBFP16Math: Predicate<"Subtarget->allowBF16Math()">;
 
 // Helper class to aid conversion between ValueType and a matching RegisterClass.
 
@@ -184,8 +187,8 @@
      !eq(name, "i64"): Int64Regs,
      !eq(name, "f16"): Float16Regs,
      !eq(name, "v2f16"): Float16x2Regs,
-     !eq(name, "bf16"): Float16Regs,
-     !eq(name, "v2bf16"): Float16x2Regs,
+     !eq(name, "bf16"): BFloat16Regs,
+     !eq(name, "v2bf16"): BFloat16x2Regs,
      !eq(name, "f32"): Float32Regs,
      !eq(name, "f64"): Float64Regs,
      !eq(name, "ai32"): Int32ArgRegs,
@@ -322,6 +325,31 @@
                !strconcat(OpcStr, ".f16x2 \t$dst, $a, $b;"),
                [(set Float16x2Regs:$dst, (OpNode (v2f16 Float16x2Regs:$a), (v2f16 Float16x2Regs:$b)))]>,
                Requires<[useFP16Math]>;
+   def bf16rr_ftz :
+     NVPTXInst<(outs BFloat16Regs:$dst),
+               (ins BFloat16Regs:$a, BFloat16Regs:$b),
+               !strconcat(OpcStr, ".ftz.bf16 \t$dst, $a, $b;"),
+               [(set BFloat16Regs:$dst, (OpNode (bf16 BFloat16Regs:$a), (bf16 BFloat16Regs:$b)))]>,
+               Requires<[useBFP16Math, doF32FTZ]>;
+   def bf16rr :
+     NVPTXInst<(outs BFloat16Regs:$dst),
+               (ins BFloat16Regs:$a, BFloat16Regs:$b),
+               !strconcat(OpcStr, ".bf16 \t$dst, $a, $b;"),
+               [(set BFloat16Regs:$dst, (OpNode (bf16 BFloat16Regs:$a), (bf16 BFloat16Regs:$b)))]>,
+               Requires<[useBFP16Math]>;
+
+   def bf16x2rr_ftz :
+     NVPTXInst<(outs BFloat16x2Regs:$dst),
+               (ins BFloat16x2Regs:$a, BFloat16x2Regs:$b),
+               !strconcat(OpcStr, ".ftz.bf16x2 \t$dst, $a, $b;"),
+               [(set BFloat16x2Regs:$dst, (OpNode (v2bf16 BFloat16x2Regs:$a), (v2bf16 BFloat16x2Regs:$b)))]>,
+               Requires<[useBFP16Math, doF32FTZ]>;
+   def bf16x2rr :
+     NVPTXInst<(outs BFloat16x2Regs:$dst),
+               (ins BFloat16x2Regs:$a, BFloat16x2Regs:$b),
+               !strconcat(OpcStr, ".bf16x2 \t$dst, $a, $b;"),
+               [(set BFloat16x2Regs:$dst, (OpNode (v2bf16 BFloat16x2Regs:$a), (v2bf16 BFloat16x2Regs:$b)))]>,
+               Requires<[useBFP16Math]>;  
 }
 
 // Template for instructions which take three FP args.  The
@@ -396,7 +424,31 @@
                !strconcat(OpcStr, ".f16x2 \t$dst, $a, $b;"),
                [(set Float16x2Regs:$dst, (OpNode (v2f16 Float16x2Regs:$a), (v2f16 Float16x2Regs:$b)))]>,
                Requires<[useFP16Math, allowFMA]>;
-
+   def bf16rr_ftz :
+     NVPTXInst<(outs BFloat16Regs:$dst),
+               (ins BFloat16Regs:$a, BFloat16Regs:$b),
+               !strconcat(OpcStr, ".ftz.bf16 \t$dst, $a, $b;"),
+               [(set BFloat16Regs:$dst, (OpNode (bf16 BFloat16Regs:$a), (bf16 BFloat16Regs:$b)))]>,
+               Requires<[useBFP16Math, allowFMA, doF32FTZ]>;
+   def bf16rr :
+     NVPTXInst<(outs BFloat16Regs:$dst),
+               (ins BFloat16Regs:$a, BFloat16Regs:$b),
+               !strconcat(OpcStr, ".bf16 \t$dst, $a, $b;"),
+               [(set BFloat16Regs:$dst, (OpNode (bf16 BFloat16Regs:$a), (bf16 BFloat16Regs:$b)))]>,
+               Requires<[useBFP16Math, allowFMA]>;
+
+   def bf16x2rr_ftz :
+     NVPTXInst<(outs BFloat16x2Regs:$dst),
+               (ins BFloat16x2Regs:$a, BFloat16x2Regs:$b),
+               !strconcat(OpcStr, ".ftz.bf16x2 \t$dst, $a, $b;"),
+               [(set (v2bf16 BFloat16x2Regs:$dst), (OpNode (v2bf16 BFloat16x2Regs:$a), (v2bf16 BFloat16x2Regs:$b)))]>,
+               Requires<[useBFP16Math, allowFMA, doF32FTZ]>;
+   def bf16x2rr :
+     NVPTXInst<(outs BFloat16x2Regs:$dst),
+               (ins BFloat16x2Regs:$a, BFloat16x2Regs:$b),
+               !strconcat(OpcStr, ".bf16x2 \t$dst, $a, $b;"),
+               [(set BFloat16x2Regs:$dst, (OpNode (v2bf16 BFloat16x2Regs:$a), (v2bf16 BFloat16x2Regs:$b)))]>,
+               Requires<[useBFP16Math, allowFMA]>;
    // These have strange names so we don't perturb existing mir tests.
    def _rnf64rr :
      NVPTXInst<(outs Float64Regs:$dst),
@@ -458,6 +510,30 @@
                !strconcat(OpcStr, ".rn.f16x2 \t$dst, $a, $b;"),
                [(set Float16x2Regs:$dst, (OpNode (v2f16 Float16x2Regs:$a), (v2f16 Float16x2Regs:$b)))]>,
                Requires<[useFP16Math, noFMA]>;
+  def _rnbf16rr_ftz :
+     NVPTXInst<(outs BFloat16Regs:$dst),
+               (ins BFloat16Regs:$a, BFloat16Regs:$b),
+               !strconcat(OpcStr, ".rn.ftz.bf16 \t$dst, $a, $b;"),
+               [(set BFloat16Regs:$dst, (OpNode (bf16 BFloat16Regs:$a), (bf16 BFloat16Regs:$b)))]>,
+               Requires<[useBFP16Math, noFMA, doF32FTZ]>;
+   def _rnbf16rr :
+     NVPTXInst<(outs BFloat16Regs:$dst),
+               (ins BFloat16Regs:$a, BFloat16Regs:$b),
+               !strconcat(OpcStr, ".rn.bf16 \t$dst, $a, $b;"),
+               [(set BFloat16Regs:$dst, (OpNode (bf16 BFloat16Regs:$a), (bf16 BFloat16Regs:$b)))]>,
+               Requires<[useBFP16Math, noFMA]>;
+   def _rnbf16x2rr_ftz :
+     NVPTXInst<(outs BFloat16x2Regs:$dst),
+               (ins BFloat16x2Regs:$a, BFloat16x2Regs:$b),
+               !strconcat(OpcStr, ".rn.ftz.bf16x2 \t$dst, $a, $b;"),
+               [(set BFloat16x2Regs:$dst, (OpNode (v2bf16 BFloat16x2Regs:$a), (v2bf16 BFloat16x2Regs:$b)))]>,
+               Requires<[useBFP16Math, noFMA, doF32FTZ]>;
+   def _rnbf16x2rr :
+     NVPTXInst<(outs BFloat16x2Regs:$dst),
+               (ins BFloat16x2Regs:$a, BFloat16x2Regs:$b),
+               !strconcat(OpcStr, ".rn.bf16x2 \t$dst, $a, $b;"),
+               [(set BFloat16x2Regs:$dst, (OpNode (v2bf16 BFloat16x2Regs:$a), (v2bf16 BFloat16x2Regs:$b)))]>,
+               Requires<[useBFP16Math, noFMA]>;
 }
 
 // Template for operations which take two f32 or f64 operands.  Provides three
@@ -534,6 +610,11 @@
                 (ins Float16Regs:$src, CvtMode:$mode),
                 !strconcat("cvt${mode:base}${mode:ftz}${mode:sat}.",
                 FromName, ".f16 \t$dst, $src;"), []>;
+    def _bf16 :
+      NVPTXInst<(outs RC:$dst),
+                (ins BFloat16Regs:$src, CvtMode:$mode),
+                !strconcat("cvt${mode:base}${mode:ftz}${mode:sat}.",
+                FromName, ".bf16 \t$dst, $src;"), []>;
     def _f32 :
       NVPTXInst<(outs RC:$dst),
                 (ins Float32Regs:$src, CvtMode:$mode),
@@ -556,6 +637,7 @@
   defm CVT_s64 : CVT_FROM_ALL<"s64", Int64Regs>;
   defm CVT_u64 : CVT_FROM_ALL<"u64", Int64Regs>;
   defm CVT_f16 : CVT_FROM_ALL<"f16", Float16Regs>;
+  defm CVT_bf16 : CVT_FROM_ALL<"bf16", BFloat16Regs>;
   defm CVT_f32 : CVT_FROM_ALL<"f32", Float32Regs>;
   defm CVT_f64 : CVT_FROM_ALL<"f64", Float64Regs>;
 
@@ -574,18 +656,7 @@
   def CVT_INREG_s64_s32 : NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$src),
                                     "cvt.s64.s32 \t$dst, $src;", []>;
 
-multiclass CVT_FROM_FLOAT_SM80<string FromName, RegisterClass RC> {
-    def _f32 :
-      NVPTXInst<(outs RC:$dst),
-                (ins Float32Regs:$src, CvtMode:$mode),
-                !strconcat("cvt${mode:base}${mode:relu}.",
-                FromName, ".f32 \t$dst, $src;"), []>,
-                Requires<[hasPTX70, hasSM80]>;
-  }
-
-  defm CVT_bf16 : CVT_FROM_FLOAT_SM80<"bf16", Int16Regs>;
-
-    multiclass CVT_FROM_FLOAT_V2_SM80<string FromName, RegisterClass RC> {
+  multiclass CVT_FROM_FLOAT_V2_SM80<string FromName, RegisterClass RC> {
     def _f32 :
       NVPTXInst<(outs RC:$dst),
                 (ins Float32Regs:$src1, Float32Regs:$src2,  CvtMode:$mode),
@@ -594,7 +665,7 @@
     Requires<[hasPTX70, hasSM80]>;
   }
 
-  defm CVT_f16x2 : CVT_FROM_FLOAT_V2_SM80<"f16x2", Float16x2Regs>;
+  defm CVT_f16x2 : CVT_FROM_FLOAT_V2_SM80<"f16x2", BFloat16x2Regs>;
   defm CVT_bf16x2 : CVT_FROM_FLOAT_V2_SM80<"bf16x2", Int32Regs>;
 }
 
@@ -659,7 +730,7 @@
 defm SELP_s64 : SELP<"s64", Int64Regs, i64imm>;
 defm SELP_u64 : SELP<"u64", Int64Regs, i64imm>;
 defm SELP_f16 : SELP_PATTERN<"b16", f16, Float16Regs, f16imm, fpimm>;
-
+defm SELP_bf16 : SELP_PATTERN<"b16", bf16, BFloat16Regs, bf16imm, fpimm>;
 defm SELP_f32 : SELP_PATTERN<"f32", f32, Float32Regs, f32imm, fpimm>;
 defm SELP_f64 : SELP_PATTERN<"f64", f64, Float64Regs, f64imm, fpimm>;
 
@@ -1023,7 +1094,9 @@
 def LOAD_CONST_F16 :
   NVPTXInst<(outs Float16Regs:$dst), (ins f16imm:$a),
             "mov.b16 \t$dst, $a;", []>;
-
+def LOAD_CONST_BF16 :
+  NVPTXInst<(outs BFloat16Regs:$dst), (ins bf16imm:$a),
+            "mov.b16 \t$dst, $a;", []>;
 defm FADD : F3_fma_component<"add", fadd>;
 defm FSUB : F3_fma_component<"sub", fsub>;
 defm FMUL : F3_fma_component<"mul", fmul>;
@@ -1051,6 +1124,20 @@
 def FNEG16x2_ftz : FNEG_F16_F16X2<"neg.ftz.f16x2", v2f16, Float16x2Regs, doF32FTZ>;
 def FNEG16x2     : FNEG_F16_F16X2<"neg.f16x2", v2f16, Float16x2Regs, True>;
 
+//
+// BF16 NEG
+//
+
+class FNEG_BF16_F16X2<string OpcStr, ValueType T, RegisterClass RC, Predicate Pred> :
+      NVPTXInst<(outs RC:$dst), (ins RC:$src),
+                !strconcat(OpcStr, " \t$dst, $src;"),
+                [(set RC:$dst, (fneg (T RC:$src)))]>,
+                Requires<[useFP16Math, hasPTX70, hasSM80, Pred]>;
+def BFNEG16_ftz   : FNEG_BF16_F16X2<"neg.ftz.bf16", bf16, BFloat16Regs, doF32FTZ>;
+def BFNEG16       : FNEG_BF16_F16X2<"neg.bf16", bf16, BFloat16Regs, True>;
+def BFNEG16x2_ftz : FNEG_BF16_F16X2<"neg.ftz.bf16x2", v2bf16, BFloat16x2Regs, doF32FTZ>;
+def BFNEG16x2     : FNEG_BF16_F16X2<"neg.bf16x2", v2bf16, BFloat16x2Regs, True>;
+
 //
 // F64 division
 //
@@ -1229,10 +1316,21 @@
                        Requires<[useFP16Math, Pred]>;
 }
 
+multiclass FMA_BF16<string OpcStr, ValueType T, RegisterClass RC, Predicate Pred> {
+   def rrr : NVPTXInst<(outs RC:$dst), (ins RC:$a, RC:$b, RC:$c),
+                       !strconcat(OpcStr, " \t$dst, $a, $b, $c;"),
+                       [(set RC:$dst, (fma (T RC:$a), (T RC:$b), (T RC:$c)))]>,
+                       Requires<[useBFP16Math, Pred]>;
+}
+
 defm FMA16_ftz : FMA_F16<"fma.rn.ftz.f16", f16, Float16Regs, doF32FTZ>;
 defm FMA16     : FMA_F16<"fma.rn.f16", f16, Float16Regs, True>;
 defm FMA16x2_ftz : FMA_F16<"fma.rn.ftz.f16x2", v2f16, Float16x2Regs, doF32FTZ>;
 defm FMA16x2     : FMA_F16<"fma.rn.f16x2", v2f16, Float16x2Regs, True>;
+defm BFMA16_ftz : FMA_BF16<"fma.rn.ftz.bf16", bf16, BFloat16Regs, doF32FTZ>;
+defm BFMA16     : FMA_BF16<"fma.rn.bf16", bf16, BFloat16Regs, True>;
+defm BFMA16x2_ftz : FMA_BF16<"fma.rn.ftz.bf16x2", v2bf16, BFloat16x2Regs, doF32FTZ>;
+defm BFMA16x2     : FMA_BF16<"fma.rn.bf16x2", v2bf16, BFloat16x2Regs, True>;
 defm FMA32_ftz : FMA<"fma.rn.ftz.f32", Float32Regs, f32imm, doF32FTZ>;
 defm FMA32     : FMA<"fma.rn.f32", Float32Regs, f32imm, True>;
 defm FMA64     : FMA<"fma.rn.f64", Float64Regs, f64imm, True>;
@@ -1679,6 +1777,18 @@
                 "setp${cmp:base}${cmp:ftz}.f16x2 \t$p|$q, $a, $b;",
                 []>,
                 Requires<[useFP16Math]>;
+def SETP_bf16rr :
+      NVPTXInst<(outs Int1Regs:$dst),
+                (ins BFloat16Regs:$a, BFloat16Regs:$b, CmpMode:$cmp),
+                "setp${cmp:base}${cmp:ftz}.bf16 \t$dst, $a, $b;",
+                []>, Requires<[useBFP16Math]>;
+
+def SETP_bf16x2rr :
+      NVPTXInst<(outs Int1Regs:$p, Int1Regs:$q),
+                (ins BFloat16x2Regs:$a, BFloat16x2Regs:$b, CmpMode:$cmp),
+                "setp${cmp:base}${cmp:ftz}.bf16x2 \t$p|$q, $a, $b;",
+                []>,
+                Requires<[useBFP16Math]>;
 
 
 // FIXME: This doesn't appear to be correct.  The "set" mnemonic has the form
@@ -1709,6 +1819,7 @@
 defm SET_s64 : SET<"s64", Int64Regs, i64imm>;
 defm SET_u64 : SET<"u64", Int64Regs, i64imm>;
 defm SET_f16 : SET<"f16", Float16Regs, f16imm>;
+defm SET_bf16 : SET<"bf16", BFloat16Regs, bf16imm>;
 defm SET_f32 : SET<"f32", Float32Regs, f32imm>;
 defm SET_f64 : SET<"f64", Float64Regs, f64imm>;
 
@@ -1781,6 +1892,8 @@
   def FMOV16rr : NVPTXInst<(outs Float16Regs:$dst), (ins Float16Regs:$src),
                            // We have to use .b16 here as there's no mov.f16.
                            "mov.b16 \t$dst, $src;", []>;
+  def BFMOV16rr : NVPTXInst<(outs BFloat16Regs:$dst), (ins BFloat16Regs:$src),
+                           "mov.b16 \t$dst, $src;", []>;
   def FMOV32rr : NVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$src),
                            "mov.f32 \t$dst, $src;", []>;
   def FMOV64rr : NVPTXInst<(outs Float64Regs:$dst), (ins Float64Regs:$src),
@@ -1963,7 +2076,27 @@
             (SETP_f16rr (LOAD_CONST_F16 fpimm:$a), Float16Regs:$b, Mode)>,
         Requires<[useFP16Math]>;
 
-  // f32 -> pred
+  // bf16 -> pred
+  def : Pat<(i1 (OpNode (bf16 BFloat16Regs:$a), (bf16 BFloat16Regs:$b))),
+            (SETP_bf16rr BFloat16Regs:$a, BFloat16Regs:$b, ModeFTZ)>,
+        Requires<[useBFP16Math,doF32FTZ]>;
+  def : Pat<(i1 (OpNode (bf16 BFloat16Regs:$a), (bf16 BFloat16Regs:$b))),
+            (SETP_bf16rr BFloat16Regs:$a, BFloat16Regs:$b, Mode)>,
+        Requires<[useBFP16Math]>;
+  def : Pat<(i1 (OpNode (bf16 BFloat16Regs:$a), fpimm:$b)),
+            (SETP_bf16rr BFloat16Regs:$a, (LOAD_CONST_BF16 fpimm:$b), ModeFTZ)>,
+        Requires<[useBFP16Math,doF32FTZ]>;
+  def : Pat<(i1 (OpNode (bf16 BFloat16Regs:$a), fpimm:$b)),
+            (SETP_bf16rr BFloat16Regs:$a, (LOAD_CONST_BF16 fpimm:$b), Mode)>,
+        Requires<[useBFP16Math]>;
+  def : Pat<(i1 (OpNode fpimm:$a, (bf16 BFloat16Regs:$b))),
+            (SETP_bf16rr (LOAD_CONST_BF16 fpimm:$a), BFloat16Regs:$b, ModeFTZ)>,
+        Requires<[useBFP16Math,doF32FTZ]>;
+  def : Pat<(i1 (OpNode fpimm:$a, (bf16 BFloat16Regs:$b))),
+            (SETP_bf16rr (LOAD_CONST_BF16 fpimm:$a), BFloat16Regs:$b, Mode)>,
+        Requires<[useBFP16Math]>;
+  
+  //f32 -> pred
   def : Pat<(i1 (OpNode Float32Regs:$a, Float32Regs:$b)),
             (SETP_f32rr Float32Regs:$a, Float32Regs:$b, ModeFTZ)>,
         Requires<[doF32FTZ]>;
@@ -2007,6 +2140,26 @@
   def : Pat<(i32 (OpNode fpimm:$a, (f16 Float16Regs:$b))),
             (SET_f16ir (LOAD_CONST_F16 fpimm:$a), Float16Regs:$b, Mode)>,
         Requires<[useFP16Math]>;
+  
+  // bf16 -> i32
+  def : Pat<(i32 (OpNode (bf16 BFloat16Regs:$a), (bf16 BFloat16Regs:$b))),
+            (SET_bf16rr BFloat16Regs:$a, BFloat16Regs:$b, ModeFTZ)>,
+        Requires<[useBFP16Math, doF32FTZ]>;
+  def : Pat<(i32 (OpNode (bf16 BFloat16Regs:$a), (bf16 BFloat16Regs:$b))),
+            (SET_bf16rr BFloat16Regs:$a, BFloat16Regs:$b, Mode)>,
+        Requires<[useBFP16Math]>;
+  def : Pat<(i32 (OpNode (bf16 BFloat16Regs:$a), fpimm:$b)),
+            (SET_bf16rr BFloat16Regs:$a, (LOAD_CONST_BF16 fpimm:$b), ModeFTZ)>,
+        Requires<[useBFP16Math, doF32FTZ]>;
+  def : Pat<(i32 (OpNode (bf16 BFloat16Regs:$a), fpimm:$b)),
+            (SET_bf16rr BFloat16Regs:$a, (LOAD_CONST_BF16 fpimm:$b), Mode)>,
+        Requires<[useBFP16Math]>;
+  def : Pat<(i32 (OpNode fpimm:$a, (bf16 BFloat16Regs:$b))),
+            (SET_bf16ir (LOAD_CONST_BF16 fpimm:$a), BFloat16Regs:$b, ModeFTZ)>,
+        Requires<[useBFP16Math, doF32FTZ]>;
+  def : Pat<(i32 (OpNode fpimm:$a, (bf16 BFloat16Regs:$b))),
+            (SET_bf16ir (LOAD_CONST_BF16 fpimm:$a), BFloat16Regs:$b, Mode)>,
+        Requires<[useBFP16Math]>;
 
   // f32 -> i32
   def : Pat<(i32 (OpNode Float32Regs:$a, Float32Regs:$b)),
@@ -2296,10 +2449,14 @@
 def LoadParamMemV4I8   : LoadParamV4MemInst<Int16Regs, ".b8">;
 def LoadParamMemF16    : LoadParamMemInst<Float16Regs, ".b16">;
 def LoadParamMemF16x2  : LoadParamMemInst<Float16x2Regs, ".b32">;
+def LoadParamMemBF16    : LoadParamMemInst<BFloat16Regs, ".b16">;
+def LoadParamMemBF16x2  : LoadParamMemInst<BFloat16x2Regs, ".b32">;
 def LoadParamMemF32    : LoadParamMemInst<Float32Regs, ".f32">;
 def LoadParamMemF64    : LoadParamMemInst<Float64Regs, ".f64">;
 def LoadParamMemV2F16  : LoadParamV2MemInst<Float16Regs, ".b16">;
 def LoadParamMemV2F16x2: LoadParamV2MemInst<Float16x2Regs, ".b32">;
+def LoadParamMemV2BF16  : LoadParamV2MemInst<BFloat16Regs, ".b16">;
+def LoadParamMemV2BF16x2: LoadParamV2MemInst<BFloat16x2Regs, ".b32">;
 def LoadParamMemV2F32  : LoadParamV2MemInst<Float32Regs, ".f32">;
 def LoadParamMemV2F64  : LoadParamV2MemInst<Float64Regs, ".f64">;
 def LoadParamMemV4F16  : LoadParamV4MemInst<Float16Regs, ".b16">;
@@ -2322,6 +2479,10 @@
 
 def StoreParamF16      : StoreParamInst<Float16Regs, ".b16">;
 def StoreParamF16x2    : StoreParamInst<Float16x2Regs, ".b32">;
+
+def StoreParamBF16      : StoreParamInst<BFloat16Regs, ".b16">;
+def StoreParamBF16x2    : StoreParamInst<BFloat16x2Regs, ".b32">;
+
 def StoreParamF32      : StoreParamInst<Float32Regs, ".f32">;
 def StoreParamF64      : StoreParamInst<Float64Regs, ".f64">;
 def StoreParamV2F16    : StoreParamV2Inst<Float16Regs, ".b16">;
@@ -2348,6 +2509,8 @@
 def StoreRetvalF32    : StoreRetvalInst<Float32Regs, ".f32">;
 def StoreRetvalF16    : StoreRetvalInst<Float16Regs, ".b16">;
 def StoreRetvalF16x2  : StoreRetvalInst<Float16x2Regs, ".b32">;
+def StoreRetvalBF16    : StoreRetvalInst<BFloat16Regs, ".b16">;
+def StoreRetvalBF16x2  : StoreRetvalInst<BFloat16x2Regs, ".b32">;
 def StoreRetvalV2F64  : StoreRetvalV2Inst<Float64Regs, ".f64">;
 def StoreRetvalV2F32  : StoreRetvalV2Inst<Float32Regs, ".f32">;
 def StoreRetvalV2F16  : StoreRetvalV2Inst<Float16Regs, ".b16">;
@@ -2450,6 +2613,7 @@
 def MoveParamF64 : MoveParamInst<f64, Float64Regs, ".f64">;
 def MoveParamF32 : MoveParamInst<f32, Float32Regs, ".f32">;
 def MoveParamF16 : MoveParamInst<f16, Float16Regs, ".f16">;
+def MoveParamBF16 : MoveParamInst<bf16, BFloat16Regs, ".bf16">;
 
 class PseudoUseParamInst<NVPTXRegClass regclass> :
   NVPTXInst<(outs), (ins regclass:$src),
@@ -2473,11 +2637,11 @@
   def ProxyRegI32   : ProxyRegInst<"b32",  i32, Int32Regs>;
   def ProxyRegI64   : ProxyRegInst<"b64",  i64, Int64Regs>;
   def ProxyRegF16   : ProxyRegInst<"b16",  f16, Float16Regs>;
-  def ProxyRegBF16  : ProxyRegInst<"b16",  bf16, Float16Regs>;
+  def ProxyRegBF16  : ProxyRegInst<"b16",  bf16, BFloat16Regs>;
   def ProxyRegF32   : ProxyRegInst<"f32",  f32, Float32Regs>;
   def ProxyRegF64   : ProxyRegInst<"f64",  f64, Float64Regs>;
   def ProxyRegF16x2 : ProxyRegInst<"b32",  v2f16, Float16x2Regs>;
-  def ProxyRegBF16x2 : ProxyRegInst<"b32",  v2bf16, Float16x2Regs>;
+  def ProxyRegBF16x2 : ProxyRegInst<"b32",  v2bf16, BFloat16x2Regs>;
 }
 
 //
@@ -2578,7 +2742,9 @@
   defm ST_i32 : ST<Int32Regs>;
   defm ST_i64 : ST<Int64Regs>;
   defm ST_f16 : ST<Float16Regs>;
+  defm ST_bf16 : ST<BFloat16Regs>;
   defm ST_f16x2 : ST<Float16x2Regs>;
+  defm ST_bf16x2 : ST<BFloat16x2Regs>;
   defm ST_f32 : ST<Float32Regs>;
   defm ST_f64 : ST<Float64Regs>;
 }
@@ -2667,6 +2833,8 @@
   defm LDV_i64 : LD_VEC<Int64Regs>;
   defm LDV_f16 : LD_VEC<Float16Regs>;
   defm LDV_f16x2 : LD_VEC<Float16x2Regs>;
+  defm LDV_bf16 : LD_VEC<BFloat16Regs>;
+  defm LDV_bf16x2 : LD_VEC<BFloat16x2Regs>;
   defm LDV_f32 : LD_VEC<Float32Regs>;
   defm LDV_f64 : LD_VEC<Float64Regs>;
 }
@@ -2762,6 +2930,8 @@
   defm STV_i64 : ST_VEC<Int64Regs>;
   defm STV_f16 : ST_VEC<Float16Regs>;
   defm STV_f16x2 : ST_VEC<Float16x2Regs>;
+  defm STV_bf16 : ST_VEC<BFloat16Regs>;
+  defm STV_bf16x2 : ST_VEC<BFloat16x2Regs>;
   defm STV_f32 : ST_VEC<Float32Regs>;
   defm STV_f64 : ST_VEC<Float64Regs>;
 }
@@ -2816,6 +2986,26 @@
 def : Pat<(f16 (uint_to_fp Int64Regs:$a)),
           (CVT_f16_u64 Int64Regs:$a, CvtRN)>;
 
+// sint -> bf16
+def : Pat<(bf16 (sint_to_fp Int1Regs:$a)),
+          (CVT_bf16_s32 (SELP_u32ii 1, 0, Int1Regs:$a), CvtRN)>;
+def : Pat<(bf16 (sint_to_fp Int16Regs:$a)),
+          (CVT_bf16_s16 Int16Regs:$a, CvtRN)>;
+def : Pat<(bf16 (sint_to_fp Int32Regs:$a)),
+          (CVT_bf16_s32 Int32Regs:$a, CvtRN)>;
+def : Pat<(bf16 (sint_to_fp Int64Regs:$a)),
+          (CVT_bf16_s64 Int64Regs:$a, CvtRN)>;
+
+// uint -> bf16
+def : Pat<(bf16 (uint_to_fp Int1Regs:$a)),
+          (CVT_bf16_u32 (SELP_u32ii 1, 0, Int1Regs:$a), CvtRN)>;
+def : Pat<(bf16 (uint_to_fp Int16Regs:$a)),
+          (CVT_bf16_u16 Int16Regs:$a, CvtRN)>;
+def : Pat<(bf16 (uint_to_fp Int32Regs:$a)),
+          (CVT_bf16_u32 Int32Regs:$a, CvtRN)>;
+def : Pat<(bf16 (uint_to_fp Int64Regs:$a)),
+          (CVT_bf16_u64 Int64Regs:$a, CvtRN)>;
+
 // sint -> f32
 def : Pat<(f32 (sint_to_fp Int1Regs:$a)),
           (CVT_f32_s32 (SELP_u32ii 1, 0, Int1Regs:$a), CvtRN)>;
@@ -2877,6 +3067,25 @@
 def : Pat<(i64 (fp_to_uint (f16 Float16Regs:$a))),
           (CVT_u64_f16 Float16Regs:$a, CvtRZI)>;
 
+// bf16 -> sint
+def : Pat<(i1 (fp_to_sint (bf16 BFloat16Regs:$a))),
+          (SETP_b16ri (BITCONVERT_16_BF2I BFloat16Regs:$a), 0, CmpEQ)>;
+def : Pat<(i16 (fp_to_sint (bf16 BFloat16Regs:$a))),
+          (CVT_s16_bf16 (bf16 BFloat16Regs:$a), CvtRZI)>;
+def : Pat<(i32 (fp_to_sint (bf16 BFloat16Regs:$a))),
+          (CVT_s32_bf16 (bf16 BFloat16Regs:$a), CvtRZI)>;
+def : Pat<(i64 (fp_to_sint (bf16 BFloat16Regs:$a))),
+          (CVT_s64_bf16 BFloat16Regs:$a, CvtRZI)>;
+
+// bf16 -> uint
+def : Pat<(i1 (fp_to_uint (bf16 BFloat16Regs:$a))),
+          (SETP_b16ri (BITCONVERT_16_BF2I BFloat16Regs:$a), 0, CmpEQ)>;
+def : Pat<(i16 (fp_to_uint (bf16 BFloat16Regs:$a))),
+          (CVT_u16_bf16 BFloat16Regs:$a, CvtRZI)>;
+def : Pat<(i32 (fp_to_uint (bf16 BFloat16Regs:$a))),
+          (CVT_u32_bf16 BFloat16Regs:$a, CvtRZI)>;
+def : Pat<(i64 (fp_to_uint (bf16 BFloat16Regs:$a))),
+          (CVT_u64_bf16 BFloat16Regs:$a, CvtRZI)>;
 // f32 -> sint
 def : Pat<(i1 (fp_to_sint Float32Regs:$a)),
           (SETP_b32ri (BITCONVERT_32_F2I Float32Regs:$a), 0, CmpEQ)>;
@@ -3024,6 +3233,9 @@
 def : Pat<(select Int32Regs:$pred, (f16 Float16Regs:$a), (f16 Float16Regs:$b)),
           (SELP_f16rr Float16Regs:$a, Float16Regs:$b,
           (SETP_b32ri (ANDb32ri Int32Regs:$pred, 1), 1, CmpEQ))>;
+def : Pat<(select Int32Regs:$pred, (bf16 BFloat16Regs:$a), (bf16 BFloat16Regs:$b)),
+          (SELP_bf16rr BFloat16Regs:$a, BFloat16Regs:$b,
+          (SETP_b32ri (ANDb32ri Int32Regs:$pred, 1), 1, CmpEQ))>;
 def : Pat<(select Int32Regs:$pred, Float32Regs:$a, Float32Regs:$b),
           (SELP_f32rr Float32Regs:$a, Float32Regs:$b,
           (SETP_b32ri (ANDb32ri Int32Regs:$pred, 1), 1, CmpEQ))>;
@@ -3124,6 +3336,42 @@
                                    (ins Int32Regs:$src),
                                    "mov.b32 \t{{$lo, $hi}}, $src;",
                                    []>;
+  def BF16x2toBF16_0 : NVPTXInst<(outs BFloat16Regs:$dst),
+                               (ins BFloat16x2Regs:$src),
+                               "{{ .reg .b16 \t%tmp_hi;\n\t"
+                               "  mov.b32 \t{$dst, %tmp_hi}, $src; }}",
+                               [(set BFloat16Regs:$dst,
+                                 (extractelt (v2bf16 BFloat16x2Regs:$src), 0))]>;
+  def BF16x2toBF16_1 : NVPTXInst<(outs BFloat16Regs:$dst),
+                               (ins BFloat16x2Regs:$src),
+                               "{{ .reg .b16 \t%tmp_lo;\n\t"
+                               "  mov.b32 \t{%tmp_lo, $dst}, $src; }}",
+                               [(set BFloat16Regs:$dst,
+                                 (extractelt (v2bf16 BFloat16x2Regs:$src), 1))]>;
+
+  // // Coalesce two bf16 registers into bf16x2
+  // def BuildBF16x2 : NVPTXInst<(outs BFloat16x2Regs:$dst),
+  //                            (ins BFloat16Regs:$a, BFloat16Regs:$b),
+  //                            "mov.b32 \t$dst, {{$a, $b}};",
+  //                            [(set (v2bf16 BFloat16x2Regs:$dst),
+  //                              (build_vector (bf16 BFloat16Regs:$a), (bf16 BFloat16Regs:$b)))]>;
+
+  // // Directly initializing underlying the b32 register is one less SASS
+  // // instruction than than vector-packing move.
+  // def BuildBF16x2i : NVPTXInst<(outs BFloat16x2Regs:$dst), (ins i32imm:$src),
+  //                             "mov.b32 \t$dst, $src;",
+  //                             []>;
+
+  // // Split f16x2 into two f16 registers.
+  // def SplitBF16x2  : NVPTXInst<(outs BFloat16Regs:$lo, BFloat16Regs:$hi),
+  //                             (ins BFloat16x2Regs:$src),
+  //                             "mov.b32 \t{{$lo, $hi}}, $src;",
+  //                             []>;
+  // // Split an i32 into two f16
+  // def SplitI32toBF16x2  : NVPTXInst<(outs BFloat16Regs:$lo, BFloat16Regs:$hi),
+  //                                  (ins Int32Regs:$src),
+  //                                  "mov.b32 \t{{$lo, $hi}}, $src;",
+  //                                  []>;
 }
 
 // Count leading zeros
@@ -3193,10 +3441,17 @@
 def : Pat<(f16 (fpround Float32Regs:$a)),
           (CVT_f16_f32 Float32Regs:$a, CvtRN)>;
 
+// fpround f32 -> bf16
+def : Pat<(bf16 (fpround Float32Regs:$a)),
+          (CVT_bf16_f32 Float32Regs:$a, CvtRN)>;
+
 // fpround f64 -> f16
 def : Pat<(f16 (fpround Float64Regs:$a)),
           (CVT_f16_f64 Float64Regs:$a, CvtRN)>;
 
+// fpround f64 -> bf16
+def : Pat<(bf16 (fpround Float64Regs:$a)),
+          (CVT_bf16_f64 Float64Regs:$a, CvtRN)>;
 // fpround f64 -> f32
 def : Pat<(f32 (fpround Float64Regs:$a)),
           (CVT_f32_f64 Float64Regs:$a, CvtRN_FTZ)>, Requires<[doF32FTZ]>;
@@ -3208,11 +3463,20 @@
           (CVT_f32_f16 Float16Regs:$a, CvtNONE_FTZ)>, Requires<[doF32FTZ]>;
 def : Pat<(f32 (fpextend (f16 Float16Regs:$a))),
           (CVT_f32_f16 Float16Regs:$a, CvtNONE)>;
+// fpextend bf16 -> f32
+def : Pat<(f32 (fpextend (bf16 BFloat16Regs:$a))),
+          (CVT_f32_bf16 BFloat16Regs:$a, CvtNONE_FTZ)>, Requires<[doF32FTZ]>;
+def : Pat<(f32 (fpextend (bf16 BFloat16Regs:$a))),
+          (CVT_f32_bf16 BFloat16Regs:$a, CvtNONE)>;
 
 // fpextend f16 -> f64
 def : Pat<(f64 (fpextend (f16 Float16Regs:$a))),
           (CVT_f64_f16 Float16Regs:$a, CvtNONE)>;
 
+// fpextend bf16 -> f64
+def : Pat<(f64 (fpextend (bf16 BFloat16Regs:$a))),
+          (CVT_f64_bf16 BFloat16Regs:$a, CvtNONE)>;
+
 // fpextend f32 -> f64
 def : Pat<(f64 (fpextend Float32Regs:$a)),
           (CVT_f64_f32 Float32Regs:$a, CvtNONE_FTZ)>, Requires<[doF32FTZ]>;
@@ -3227,6 +3491,8 @@
 multiclass CVT_ROUND<SDNode OpNode, PatLeaf Mode, PatLeaf ModeFTZ> {
   def : Pat<(OpNode (f16 Float16Regs:$a)),
             (CVT_f16_f16 Float16Regs:$a, Mode)>;
+  def : Pat<(OpNode (bf16 BFloat16Regs:$a)),
+            (CVT_bf16_bf16 BFloat16Regs:$a, Mode)>;
   def : Pat<(OpNode Float32Regs:$a),
             (CVT_f32_f32 Float32Regs:$a, ModeFTZ)>, Requires<[doF32FTZ]>;
   def : Pat<(OpNode Float32Regs:$a),
Index: llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
===================================================================
--- llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -143,6 +143,26 @@
   }
 }
 
+static bool Isv2f16Orv2bf16Type(MVT VT) {
+  switch (VT.SimpleTy) {
+  default:
+    return false;
+  case MVT::v2f16:
+  case MVT::v2bf16:
+    return true;
+  }
+}
+
+static bool Isf16Orbf16Type(MVT VT) {
+  switch (VT.SimpleTy) {
+  default:
+    return false;
+  case MVT::f16:
+  case MVT::bf16:
+    return true;
+  }
+}
+
 /// ComputePTXValueVTs - For the given Type \p Ty, returns the set of primitive
 /// EVTs that compose it.  Unlike ComputeValueVTs, this will break apart vectors
 /// into their primitive components.
@@ -193,7 +213,7 @@
       // Vectors with an even number of f16 elements will be passed to
       // us as an array of v2f16/v2bf16 elements. We must match this so we
       // stay in sync with Ins/Outs.
-      if ((EltVT == MVT::f16 || EltVT == MVT::bf16) && NumElts % 2 == 0) {
+      if ((Isf16Orbf16Type(EltVT.getSimpleVT())) && NumElts % 2 == 0) {
         EltVT = EltVT == MVT::f16 ? MVT::v2f16 : MVT::v2bf16;
         NumElts /= 2;
       }
@@ -398,6 +418,11 @@
     setOperationAction(Op, VT, STI.allowFP16Math() ? Action : NoF16Action);
   };
 
+  auto setBF16OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
+                                    LegalizeAction NoBF16Action) {
+    setOperationAction(Op, VT, STI.allowBF16Math() ? Action : NoBF16Action);
+  };
+
   addRegisterClass(MVT::i1, &NVPTX::Int1RegsRegClass);
   addRegisterClass(MVT::i16, &NVPTX::Int16RegsRegClass);
   addRegisterClass(MVT::i32, &NVPTX::Int32RegsRegClass);
@@ -406,8 +431,6 @@
   addRegisterClass(MVT::f64, &NVPTX::Float64RegsRegClass);
   addRegisterClass(MVT::f16, &NVPTX::Float16RegsRegClass);
   addRegisterClass(MVT::v2f16, &NVPTX::Float16x2RegsRegClass);
-  addRegisterClass(MVT::bf16, &NVPTX::Float16RegsRegClass);
-  addRegisterClass(MVT::v2bf16, &NVPTX::Float16x2RegsRegClass);
 
   // Conversion to/from FP16/FP16x2 is always legal.
   setOperationAction(ISD::SINT_TO_FP, MVT::f16, Legal);
@@ -420,6 +443,16 @@
   setFP16OperationAction(ISD::SETCC, MVT::f16, Legal, Promote);
   setFP16OperationAction(ISD::SETCC, MVT::v2f16, Legal, Expand);
 
+  // Conversion to/from BFP16/BFP16x2 is always legal.
+  setOperationAction(ISD::SINT_TO_FP, MVT::bf16, Legal);
+  setOperationAction(ISD::FP_TO_SINT, MVT::bf16, Legal);
+  setOperationAction(ISD::BUILD_VECTOR, MVT::v2bf16, Custom);
+  setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v2bf16, Custom);
+  setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v2bf16, Expand);
+  setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v2bf16, Expand);
+
+  setBF16OperationAction(ISD::SETCC, MVT::bf16, Legal, Promote);
+  setBF16OperationAction(ISD::SETCC, MVT::v2bf16, Legal, Expand);
   // Operations not directly supported by NVPTX.
   for (MVT VT : {MVT::f16, MVT::v2f16, MVT::f32, MVT::f64, MVT::i1, MVT::i8,
                  MVT::i16, MVT::i32, MVT::i64}) {
@@ -476,17 +509,25 @@
   // Turn FP extload into load/fpextend
   setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::f16, Expand);
   setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f16, Expand);
+  setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::bf16, Expand);
+  setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::bf16, Expand);
   setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f32, Expand);
   setLoadExtAction(ISD::EXTLOAD, MVT::v2f32, MVT::v2f16, Expand);
   setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2f16, Expand);
+  setLoadExtAction(ISD::EXTLOAD, MVT::v2f32, MVT::v2bf16, Expand);
+  setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2bf16, Expand);
   setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2f32, Expand);
   setLoadExtAction(ISD::EXTLOAD, MVT::v4f32, MVT::v4f16, Expand);
   setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4f16, Expand);
+  setLoadExtAction(ISD::EXTLOAD, MVT::v4f32, MVT::v4bf16, Expand);
+  setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4bf16, Expand);
   setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4f32, Expand);
   // Turn FP truncstore into trunc + store.
   // FIXME: vector types should also be expanded
   setTruncStoreAction(MVT::f32, MVT::f16, Expand);
   setTruncStoreAction(MVT::f64, MVT::f16, Expand);
+  setTruncStoreAction(MVT::f32, MVT::bf16, Expand);
+  setTruncStoreAction(MVT::f64, MVT::bf16, Expand);
   setTruncStoreAction(MVT::f64, MVT::f32, Expand);
 
   // PTX does not support load / store predicate registers
@@ -563,9 +604,9 @@
   setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::FADD, ISD::MUL, ISD::SHL,
                        ISD::SREM, ISD::UREM});
 
-  // setcc for f16x2 needs special handling to prevent legalizer's
-  // attempt to scalarize it due to v2i1 not being legal.
-  if (STI.allowFP16Math())
+  // setcc for f16x2 and bf16x2 needs special handling to prevent
+  // legalizer's attempt to scalarize it due to v2i1 not being legal.
+  if (STI.allowFP16Math() || STI.allowBF16Math())
     setTargetDAGCombine(ISD::SETCC);
 
   // Promote fp16 arithmetic if fp16 hardware isn't available or the
@@ -579,6 +620,11 @@
     setFP16OperationAction(Op, MVT::v2f16, Legal, Expand);
   }
 
+  for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB, ISD::FMA}) {
+    setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
+    setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
+  }
+
   // f16/f16x2 neg was introduced in PTX 60, SM_53.
   const bool IsFP16FP16x2NegAvailable = STI.getSmVersion() >= 53 &&
                                         STI.getPTXVersion() >= 60 &&
@@ -587,19 +633,29 @@
     setOperationAction(ISD::FNEG, VT,
                        IsFP16FP16x2NegAvailable ? Legal : Expand);
 
+  const bool IsBFP16FP16x2NegAvailable = STI.getSmVersion() >= 80 &&
+                                         STI.getPTXVersion() >= 70 &&
+                                         STI.allowBF16Math();
+  for (const auto &VT : {MVT::bf16, MVT::v2bf16})
+    setOperationAction(ISD::FNEG, VT,
+                       IsBFP16FP16x2NegAvailable ? Legal : Expand);
   // (would be) Library functions.
 
   // These map to conversion instructions for scalar FP types.
   for (const auto &Op : {ISD::FCEIL, ISD::FFLOOR, ISD::FNEARBYINT, ISD::FRINT,
                          ISD::FROUNDEVEN, ISD::FTRUNC}) {
+    setOperationAction(Op, MVT::bf16, Legal);
     setOperationAction(Op, MVT::f16, Legal);
     setOperationAction(Op, MVT::f32, Legal);
     setOperationAction(Op, MVT::f64, Legal);
     setOperationAction(Op, MVT::v2f16, Expand);
+    setOperationAction(Op, MVT::v2bf16, Expand);
   }
 
   setOperationAction(ISD::FROUND, MVT::f16, Promote);
   setOperationAction(ISD::FROUND, MVT::v2f16, Expand);
+  setOperationAction(ISD::FROUND, MVT::bf16, Promote);
+  setOperationAction(ISD::FROUND, MVT::v2bf16, Expand);
   setOperationAction(ISD::FROUND, MVT::f32, Custom);
   setOperationAction(ISD::FROUND, MVT::f64, Custom);
 
@@ -607,6 +663,8 @@
   // 'Expand' implements FCOPYSIGN without calling an external library.
   setOperationAction(ISD::FCOPYSIGN, MVT::f16, Expand);
   setOperationAction(ISD::FCOPYSIGN, MVT::v2f16, Expand);
+  setOperationAction(ISD::FCOPYSIGN, MVT::bf16, Expand);
+  setOperationAction(ISD::FCOPYSIGN, MVT::v2bf16, Expand);
   setOperationAction(ISD::FCOPYSIGN, MVT::f32, Expand);
   setOperationAction(ISD::FCOPYSIGN, MVT::f64, Expand);
 
@@ -616,9 +674,11 @@
   for (const auto &Op :
        {ISD::FDIV, ISD::FREM, ISD::FSQRT, ISD::FSIN, ISD::FCOS, ISD::FABS}) {
     setOperationAction(Op, MVT::f16, Promote);
+    setOperationAction(Op, MVT::bf16, Promote);
     setOperationAction(Op, MVT::f32, Legal);
     setOperationAction(Op, MVT::f64, Legal);
     setOperationAction(Op, MVT::v2f16, Expand);
+    setOperationAction(Op, MVT::v2bf16, Expand);
   }
   // max.f16, max.f16x2 and max.NaN are supported on sm_80+.
   auto GetMinMaxAction = [&](LegalizeAction NotSm80Action) {
@@ -636,6 +696,12 @@
     setOperationAction(Op, MVT::f32, GetMinMaxAction(Expand));
     setFP16OperationAction(Op, MVT::v2f16, GetMinMaxAction(Expand), Expand);
   }
+  for (const auto &Op : {ISD::FMINNUM, ISD::FMAXNUM}) {
+    setBF16OperationAction(Op, MVT::bf16, GetMinMaxAction(Promote), Promote);
+    setBF16OperationAction(Op, MVT::v2bf16, GetMinMaxAction(Expand), Expand);
+    setBF16OperationAction(Op, MVT::bf16, GetMinMaxAction(Expand), Expand);
+    setBF16OperationAction(Op, MVT::v2bf16, GetMinMaxAction(Expand), Expand);
+  }
 
   // No FEXP2, FLOG2.  The PTX ex2 and log2 functions are always approximate.
   // No FPOW or FREM in PTX.
@@ -1252,7 +1318,7 @@
   if (!VT.isScalableVector() && VT.getVectorNumElements() != 1 &&
       VT.getScalarType() == MVT::i1)
     return TypeSplitVector;
-  if (VT == MVT::v2f16)
+  if (Isv2f16Orv2bf16Type(VT))
     return TypeLegal;
   return TargetLoweringBase::getPreferredVectorAction(VT);
 }
@@ -1402,7 +1468,7 @@
         sz = promoteScalarArgumentSize(sz);
       } else if (isa<PointerType>(Ty)) {
         sz = PtrVT.getSizeInBits();
-      } else if (Ty->isHalfTy())
+      } else if (Ty->isHalfTy() || Ty->isBFloatTy())
         // PTX ABI requires all scalar parameters to be at least 32
         // bits in size.  fp16 normally uses .b16 as its storage type
         // in PTX, so its size must be adjusted here, too.
@@ -2037,7 +2103,7 @@
 // generates good SASS in both cases.
 SDValue NVPTXTargetLowering::LowerBUILD_VECTOR(SDValue Op,
                                                SelectionDAG &DAG) const {
-  if (!(Op->getValueType(0) == MVT::v2f16 &&
+  if (!(Isv2f16Orv2bf16Type(Op->getOperand(0).getValueType().getSimpleVT()) &&
         isa<ConstantFPSDNode>(Op->getOperand(0)) &&
         isa<ConstantFPSDNode>(Op->getOperand(1))))
     return Op;
@@ -2048,7 +2114,7 @@
       cast<ConstantFPSDNode>(Op->getOperand(1))->getValueAPF().bitcastToAPInt();
   SDValue Const =
       DAG.getConstant(E1.zext(32).shl(16) | E0.zext(32), SDLoc(Op), MVT::i32);
-  return DAG.getNode(ISD::BITCAST, SDLoc(Op), MVT::v2f16, Const);
+  return DAG.getNode(ISD::BITCAST, SDLoc(Op), Op->getValueType(0), Const);
 }
 
 SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
@@ -2409,7 +2475,7 @@
 
   // v2f16 is legal, so we can't rely on legalizer to handle unaligned
   // loads and have to handle it here.
-  if (Op.getValueType() == MVT::v2f16) {
+  if (Isv2f16Orv2bf16Type(Op.getValueType().getSimpleVT())) {
     LoadSDNode *Load = cast<LoadSDNode>(Op);
     EVT MemVT = Load->getMemoryVT();
     if (!allowsMemoryAccessForAlignment(*DAG.getContext(), DAG.getDataLayout(),
@@ -2454,7 +2520,7 @@
 
   // v2f16 is legal, so we can't rely on legalizer to handle unaligned
   // stores and have to handle it here.
-  if (VT == MVT::v2f16 &&
+  if ((Isv2f16Orv2bf16Type(VT.getSimpleVT())) &&
       !allowsMemoryAccessForAlignment(*DAG.getContext(), DAG.getDataLayout(),
                                       VT, *Store->getMemOperand()))
     return expandUnalignedStore(Store, DAG);
@@ -2541,7 +2607,7 @@
       // v8f16 is a special case. PTX doesn't have st.v8.f16
       // instruction. Instead, we split the vector into v2f16 chunks and
       // store them with st.v4.b32.
-      assert((EltVT == MVT::f16 || EltVT == MVT::bf16) &&
+      assert((Isf16Orbf16Type(EltVT.getSimpleVT())) &&
              "Wrong type for the vector.");
       Opcode = NVPTXISD::StoreV4;
       StoreF16x2 = true;
@@ -2557,11 +2623,12 @@
       // Combine f16,f16 -> v2f16
       NumElts /= 2;
       for (unsigned i = 0; i < NumElts; ++i) {
-        SDValue E0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f16, Val,
+        SDValue E0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val,
                                  DAG.getIntPtrConstant(i * 2, DL));
-        SDValue E1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f16, Val,
+        SDValue E1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val,
                                  DAG.getIntPtrConstant(i * 2 + 1, DL));
-        SDValue V2 = DAG.getNode(ISD::BUILD_VECTOR, DL, MVT::v2f16, E0, E1);
+        EVT VecVT = EVT::getVectorVT(*DAG.getContext(), EltVT, 2);
+        SDValue V2 = DAG.getNode(ISD::BUILD_VECTOR, DL, VecVT, E0, E1);
         Ops.push_back(V2);
       }
     } else {
@@ -2733,9 +2800,9 @@
           EVT LoadVT = EltVT;
           if (EltVT == MVT::i1)
             LoadVT = MVT::i8;
-          else if (EltVT == MVT::v2f16)
+          else if (Isv2f16Orv2bf16Type(EltVT.getSimpleVT()))
             // getLoad needs a vector type, but it can't handle
-            // vectors which contain v2f16 elements. So we must load
+            // vectors which contain v2f16 or v2bf16 elements. So we must load
             // using i32 here and then bitcast back.
             LoadVT = MVT::i32;
 
@@ -5171,7 +5238,7 @@
     // v8f16 is a special case. PTX doesn't have ld.v8.f16
     // instruction. Instead, we split the vector into v2f16 chunks and
     // load them with ld.v4.b32.
-    assert((EltVT == MVT::f16 || EltVT == MVT::bf16) &&
+    assert(Isf16Orbf16Type(EltVT.getSimpleVT()) &&
            "Unsupported v8 vector type.");
     LoadF16x2 = true;
     Opcode = NVPTXISD::LoadV4;
Index: llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
===================================================================
--- llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
+++ llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
@@ -72,6 +72,7 @@
   bool trySurfaceIntrinsic(SDNode *N);
   bool tryBFE(SDNode *N);
   bool tryConstantFP16(SDNode *N);
+  bool tryConstantBF16(SDNode *N);
   bool SelectSETP_F16X2(SDNode *N);
   bool tryEXTRACT_VECTOR_ELEMENT(SDNode *N);
 
Index: llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
===================================================================
--- llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -537,6 +537,16 @@
   return true;
 }
 
+bool NVPTXDAGToDAGISel::tryConstantBF16(SDNode *N) {
+  if (N->getValueType(0) != MVT::bf16)
+    return false;
+  SDValue Val = CurDAG->getTargetConstantFP(
+      cast<ConstantFPSDNode>(N)->getValueAPF(), SDLoc(N), MVT::bf16);
+  SDNode *LoadConstBF16 =
+      CurDAG->getMachineNode(NVPTX::LOAD_CONST_BF16, SDLoc(N), MVT::bf16, Val);
+  ReplaceNode(N, LoadConstBF16);
+  return true;
+}
 // Map ISD:CONDCODE value to appropriate CmpMode expected by
 // NVPTXInstPrinter::printCmpMode()
 static unsigned getPTXCmpMode(const CondCodeSDNode &CondCode, bool FTZ) {
@@ -1288,6 +1298,10 @@
       assert(NumElts % 2 == 0 && "Vector must have even number of elements");
       EltVT = MVT::v2f16;
       NumElts /= 2;
+    } else if (EltVT == MVT::bf16 && N->getValueType(0) == MVT::v2bf16) {
+      assert(NumElts % 2 == 0 && "Vector must have even number of elements");
+      EltVT = MVT::v2bf16;
+      NumElts /= 2;
     }
   }
 
Index: llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
===================================================================
--- llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -267,6 +267,10 @@
       MCOp = MCOperand::createExpr(
         NVPTXFloatMCExpr::createConstantFPHalf(Val, OutContext));
       break;
+    case Type::BFloatTyID:
+      MCOp = MCOperand::createExpr(
+          NVPTXFloatMCExpr::createConstantBFPHalf(Val, OutContext));
+      break;
     case Type::FloatTyID:
       MCOp = MCOperand::createExpr(
         NVPTXFloatMCExpr::createConstantFPSingle(Val, OutContext));
@@ -1353,8 +1357,10 @@
     }
     break;
   }
+  case Type::BFloatTyID:
   case Type::HalfTyID:
-    // fp16 is stored as .b16 for compatibility with pre-sm_53 PTX assembly.
+    // fp16 and bf16 are stored as .b16 for compatibility with pre-sm_53
+    // PTX assembly.
     return "b16";
   case Type::FloatTyID:
     return "f32";
@@ -1588,7 +1594,7 @@
       } else if (PTy) {
         assert(PTySizeInBits && "Invalid pointer size");
         sz = PTySizeInBits;
-      } else if (Ty->isHalfTy())
+      } else if (Ty->isHalfTy() || Ty->isBFloatTy())
         // PTX ABI requires all scalar parameters to be at least 32
         // bits in size.  fp16 normally uses .b16 as its storage type
         // in PTX, so its size must be adjusted here, too.
Index: llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
===================================================================
--- llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
+++ llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
@@ -61,9 +61,11 @@
     OS << "%fd";
     break;
   case 7:
+  case 9:
     OS << "%h";
     break;
   case 8:
+  case 10:
     OS << "%hh";
     break;
   }
Index: llvm/include/llvm/IR/IntrinsicsNVVM.td
===================================================================
--- llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -597,16 +597,18 @@
           [IntrNoMem, IntrSpeculatable, Commutative]>;
     }
 
-    foreach variant = ["_bf16", "_nan_bf16", "_xorsign_abs_bf16",
-      "_nan_xorsign_abs_bf16"] in {
+    foreach variant = ["_bf16", "_ftz_bf16", "_nan_bf16", "_ftz_nan_bf16",
+      "_xorsign_abs_bf16", "_ftz_xorsign_abs_bf16", "_nan_xorsign_abs_bf16",
+      "_ftz_nan_xorsign_abs_bf16"] in {
       def int_nvvm_f # operation # variant :
         ClangBuiltin<!strconcat("__nvvm_f", operation, variant)>,
         DefaultAttrsIntrinsic<[llvm_i16_ty], [llvm_i16_ty, llvm_i16_ty],
           [IntrNoMem, IntrSpeculatable, Commutative]>;
     }
 
-    foreach variant = ["_bf16x2", "_nan_bf16x2", "_xorsign_abs_bf16x2",
-      "_nan_xorsign_abs_bf16x2"] in {
+    foreach variant = ["_bf16x2", "_ftz_bf16x2", "_nan_bf16x2",
+      "_ftz_nan_bf16x2", "_xorsign_abs_bf16x2", "_ftz_xorsign_abs_bf16x2",
+      "_nan_xorsign_abs_bf16x2", "_ftz_nan_xorsign_abs_bf16x2"]  in {
       def int_nvvm_f # operation # variant :
         ClangBuiltin<!strconcat("__nvvm_f", operation, variant)>,
         DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty],
@@ -874,17 +876,19 @@
         [IntrNoMem, IntrSpeculatable]>;
   }
 
-  foreach variant = ["_rn_bf16", "_rn_relu_bf16"] in {
+  foreach variant = ["_rn_bf16", "_rn_ftz_bf16", "_rn_sat_bf16",
+    "_rn_ftz_sat_bf16", "_rn_relu_bf16", "_rn_ftz_relu_bf16"] in {
     def int_nvvm_fma # variant : ClangBuiltin<!strconcat("__nvvm_fma", variant)>,
-      DefaultAttrsIntrinsic<[llvm_i16_ty],
-        [llvm_i16_ty, llvm_i16_ty, llvm_i16_ty],
+      DefaultAttrsIntrinsic<[llvm_bfloat_ty],
+        [llvm_bfloat_ty, llvm_bfloat_ty, llvm_bfloat_ty],
         [IntrNoMem, IntrSpeculatable]>;
   }
 
-  foreach variant = ["_rn_bf16x2", "_rn_relu_bf16x2"] in {
+  foreach variant = ["_rn_bf16x2", "_rn_ftz_bf16x2", "_rn_sat_bf16x2",
+    "_rn_ftz_sat_bf16x2", "_rn_relu_bf16x2", "_rn_ftz_relu_bf16x2"] in {
     def int_nvvm_fma # variant : ClangBuiltin<!strconcat("__nvvm_fma", variant)>,
-      DefaultAttrsIntrinsic<[llvm_i32_ty],
-        [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty],
+      DefaultAttrsIntrinsic<[llvm_v2bf16_ty],
+        [llvm_v2bf16_ty, llvm_v2bf16_ty, llvm_v2bf16_ty],
         [IntrNoMem, IntrSpeculatable]>;
   }
 
@@ -1236,6 +1240,11 @@
   def int_nvvm_f2h_rn : ClangBuiltin<"__nvvm_f2h_rn">,
       DefaultAttrsIntrinsic<[llvm_i16_ty], [llvm_float_ty], [IntrNoMem, IntrSpeculatable]>;
 
+  def int_nvvm_bf2h_rn_ftz : ClangBuiltin<"__nvvm_bf2h_rn_ftz">,
+      DefaultAttrsIntrinsic<[llvm_i16_ty], [llvm_bfloat_ty], [IntrNoMem, IntrSpeculatable]>;
+  def int_nvvm_bf2h_rn : ClangBuiltin<"__nvvm_bf2h_rn">,
+      DefaultAttrsIntrinsic<[llvm_i16_ty], [llvm_bfloat_ty], [IntrNoMem, IntrSpeculatable]>;
+
   def int_nvvm_ff2bf16x2_rn : ClangBuiltin<"__nvvm_ff2bf16x2_rn">,
        Intrinsic<[llvm_i32_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
   def int_nvvm_ff2bf16x2_rn_relu : ClangBuiltin<"__nvvm_ff2bf16x2_rn_relu">,
Index: clang/include/clang/Basic/BuiltinsNVPTX.def
===================================================================
--- clang/include/clang/Basic/BuiltinsNVPTX.def
+++ clang/include/clang/Basic/BuiltinsNVPTX.def
@@ -145,12 +145,16 @@
 TARGET_BUILTIN(__nvvm_fmin_ftz_nan_xorsign_abs_f16x2, "V2hV2hV2h", "",
                AND(SM_86, PTX72))
 TARGET_BUILTIN(__nvvm_fmin_bf16, "UsUsUs", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fmin_ftz_bf16, "UsUsUs", "", AND(SM_80, PTX70))
 TARGET_BUILTIN(__nvvm_fmin_nan_bf16, "UsUsUs", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fmin_ftz_nan_bf16, "UsUsUs", "", AND(SM_80, PTX70))
 TARGET_BUILTIN(__nvvm_fmin_xorsign_abs_bf16, "UsUsUs", "", AND(SM_86, PTX72))
 TARGET_BUILTIN(__nvvm_fmin_nan_xorsign_abs_bf16, "UsUsUs", "",
                AND(SM_86, PTX72))
 TARGET_BUILTIN(__nvvm_fmin_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fmin_ftz_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
 TARGET_BUILTIN(__nvvm_fmin_nan_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fmin_ftz_nan_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
 TARGET_BUILTIN(__nvvm_fmin_xorsign_abs_bf16x2, "ZUiZUiZUi", "",
                AND(SM_86, PTX72))
 TARGET_BUILTIN(__nvvm_fmin_nan_xorsign_abs_bf16x2, "ZUiZUiZUi", "",
@@ -187,12 +191,16 @@
 TARGET_BUILTIN(__nvvm_fmax_ftz_nan_xorsign_abs_f16x2, "V2hV2hV2h", "",
                AND(SM_86, PTX72))
 TARGET_BUILTIN(__nvvm_fmax_bf16, "UsUsUs", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fmax_ftz_bf16, "UsUsUs", "", AND(SM_80, PTX70))
 TARGET_BUILTIN(__nvvm_fmax_nan_bf16, "UsUsUs", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fmax_ftz_nan_bf16, "UsUsUs", "", AND(SM_80, PTX70))
 TARGET_BUILTIN(__nvvm_fmax_xorsign_abs_bf16, "UsUsUs", "", AND(SM_86, PTX72))
 TARGET_BUILTIN(__nvvm_fmax_nan_xorsign_abs_bf16, "UsUsUs", "",
                AND(SM_86, PTX72))
 TARGET_BUILTIN(__nvvm_fmax_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fmax_ftz_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
 TARGET_BUILTIN(__nvvm_fmax_nan_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fmax_ftz_nan_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
 TARGET_BUILTIN(__nvvm_fmax_xorsign_abs_bf16x2, "ZUiZUiZUi", "",
                AND(SM_86, PTX72))
 TARGET_BUILTIN(__nvvm_fmax_nan_xorsign_abs_bf16x2, "ZUiZUiZUi", "",
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to