This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e973770932 [DataType] BF16 Support (#17670)
e973770932 is described below

commit e973770932bb41ad7136a0834c8cf3c4aedbae29
Author: Joshua Hong <[email protected]>
AuthorDate: Sun Mar 9 21:06:26 2025 -0700

    [DataType] BF16 Support (#17670)
    
    Adds general BF16 support to TVM
    
    - Addresses missing legalization in `comm_reducer`
    - Adds BF16 legalization pass skipping if the target supports BF16
    - Unit Tests for `comm_reducer` changes as well as legalization skipping
    - Modifications to TVM datatypes to allow for `T.bfloat16` in the test file
    - Fixes for BFloat related cuda codegen
    
    Related PR in MLC-LLM adds BF16 support with quantization
    https://github.com/mlc-ai/mlc-llm/pull/3158
    
    Tested with the original problematic model Gemma 2 27b with
    both added quantization configurations `q4bf16_0` and `q4bf16_1`.
    While compilation is successful and the first few rounds of prompting
    have expected performance, we observe that for long contexts
    generation quality degrades. The same behavior isn't observed on
    Gemma 2 9B, quantized or unquantized
    
    
    
    ---------
    
    Co-authored-by: Joshua Hong <[email protected]>
---
 include/tvm/runtime/data_type.h                    |   2 +
 include/tvm/script/ir_builder/tir/ir.h             |   2 +
 python/tvm/_ffi/runtime_ctypes.py                  |   1 +
 python/tvm/script/ir_builder/tir/ir.py             |   4 +-
 .../postproc/disallow_async_strided_mem_copy.cc    |  10 +-
 src/relax/op/nn/nn.cc                              |   3 +-
 src/relax/op/op_common.h                           |   3 +-
 src/script/ir_builder/tir/ir.cc                    |   2 +
 src/target/source/codegen_cuda.cc                  |   2 +-
 src/target/source/intrin_rule_cuda.cc              |   8 +-
 src/tir/op/op.cc                                   |   2 +-
 src/tir/transforms/unsupported_dtype_legalize.cc   |  48 ++-
 .../test_tir_transform_bf16_legalize.py            | 364 ++++++++++++++++-----
 13 files changed, 361 insertions(+), 90 deletions(-)

diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h
index 65fd0c98fd..40664f0c40 100644
--- a/include/tvm/runtime/data_type.h
+++ b/include/tvm/runtime/data_type.h
@@ -124,6 +124,8 @@ class DataType {
   bool is_bool() const { return code() == DataType::kUInt && bits() == 1; }
   /*! \return whether type is a float type. */
   bool is_float() const { return code() == DataType::kFloat; }
+  /*! \return whether type is a bfloat type. */
+  bool is_bfloat() const { return code() == DataType::kBFloat; }
   /*! \return whether type is a float8 type. */
   bool is_float8() const {
     return (code() == DataType::kFloat || code() == DataType::kFloat8_e4m3fn ||
diff --git a/include/tvm/script/ir_builder/tir/ir.h 
b/include/tvm/script/ir_builder/tir/ir.h
index e60a3859ac..b01cb84222 100644
--- a/include/tvm/script/ir_builder/tir/ir.h
+++ b/include/tvm/script/ir_builder/tir/ir.h
@@ -473,6 +473,7 @@ inline Var Handle(runtime::DataType dtype = 
runtime::DataType::Void(),
   TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##32, FDType(32));    \
   TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##64, FDType(64));
 
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(BFloat, DataType::BFloat);
 TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(Float, DataType::Float);
 TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(UInt, DataType::UInt);
 TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(Int, DataType::Int);
@@ -490,6 +491,7 @@ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(Int, DataType::Int);
   TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##32, FDType, 32);    \
   TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##64, FDType, 64);
 
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(BFloat, DataType::BFloat);
 TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Float, DataType::Float);
 TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(UInt, DataType::UInt);
 TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Int, DataType::Int);
diff --git a/python/tvm/_ffi/runtime_ctypes.py 
b/python/tvm/_ffi/runtime_ctypes.py
index 317bd6bead..8a9c231617 100644
--- a/python/tvm/_ffi/runtime_ctypes.py
+++ b/python/tvm/_ffi/runtime_ctypes.py
@@ -118,6 +118,7 @@ class DataType(ctypes.Structure):
         "float16": {"type_code": DataTypeCode.FLOAT, "bits": 16, "lanes": 1},
         "float32": {"type_code": DataTypeCode.FLOAT, "bits": 32, "lanes": 1},
         "float64": {"type_code": DataTypeCode.FLOAT, "bits": 64, "lanes": 1},
+        "bfloat16": {"type_code": DataTypeCode.BFLOAT, "bits": 16, "lanes": 1},
     }
 
     def __init__(self, type_str):
diff --git a/python/tvm/script/ir_builder/tir/ir.py 
b/python/tvm/script/ir_builder/tir/ir.py
index 2fce022da3..3e835e8d9d 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -1464,7 +1464,7 @@ float4_e2m1fnx16 = func_gen(("Float4E2M1FNx16"))
 float4_e2m1fnx32 = func_gen(("Float4E2M1FNx32"))
 float4_e2m1fnx64 = func_gen(("Float4E2M1FNx64"))
 
-
+bfloat16 = func_gen(("BFloat16"))
 # pylint: enable=invalid-name
 
 
@@ -1961,7 +1961,6 @@ tvm_call_cpacked_lowered = call_cpacked_lowered
 
 # pylint: enable=invalid-name
 
-
 __all__ = [
     "int8",
     "int16",
@@ -2048,6 +2047,7 @@ __all__ = [
     "float16x64",
     "float32x64",
     "float64x64",
+    "bfloat16",
     "buffer",
     "buffer_decl",
     "prim_func",
diff --git a/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc 
b/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc
index d654e467f1..a6a71202ae 100644
--- a/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc
+++ b/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc
@@ -119,7 +119,11 @@ namespace meta_schedule {
 class DisallowAsyncStridedMemCopyNode : public PostprocNode {
  public:
   // Inherited from PostprocNode
-  void InitializeWithTuneContext(const TuneContext& context) final {}
+  void InitializeWithTuneContext(const TuneContext& context) final {
+    /* Null check */
+    ICHECK(context->target) << "Context must contain a target";
+    this->target = context->target.value();
+  }
   // Inherited from PostprocNode
   bool Apply(const tir::Schedule& sch) final {
     IRModule mod = sch->mod();
@@ -130,6 +134,7 @@ class DisallowAsyncStridedMemCopyNode : public PostprocNode 
{
         IRModule lowered{nullptr};
         try {
           auto pass_list = Array<tvm::transform::Pass>();
+          pass_list.push_back(tir::transform::BindTarget(this->target));
           pass_list.push_back(tir::transform::LowerInitBlock());
           
pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
           pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
@@ -168,6 +173,9 @@ class DisallowAsyncStridedMemCopyNode : public PostprocNode 
{
 
   static constexpr const char* _type_key = 
"meta_schedule.DisallowAsyncStridedMemCopy";
   TVM_DECLARE_FINAL_OBJECT_INFO(DisallowAsyncStridedMemCopyNode, PostprocNode);
+
+ private:
+  tvm::Target target;
 };
 
 Postproc Postproc::DisallowAsyncStridedMemCopy() {
diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc
index 526b816d09..b4668d65d3 100644
--- a/src/relax/op/nn/nn.cc
+++ b/src/relax/op/nn/nn.cc
@@ -191,7 +191,8 @@ bool NormCheckDtypeAndShape(const Call& call, const 
BlockBuilder& ctx,
     axes_non_neg = NormalizeAxes(call, ctx, data_sinfo->ndim, axes);
   }
   int n_axis = axes.size();
-  if (!data_sinfo->IsUnknownDtype() && !data_sinfo->dtype.is_float()) {
+  if (!data_sinfo->IsUnknownDtype() &&
+      (!data_sinfo->dtype.is_float() && !data_sinfo->dtype.is_bfloat())) {
     ctx->ReportFatal(
         Diagnostic::Error(call)
         << op << " requires the input data to have float dtype. However, the 
given data dtype is "
diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h
index 6e2ef6bd2b..eea6db22fd 100644
--- a/src/relax/op/op_common.h
+++ b/src/relax/op/op_common.h
@@ -199,7 +199,8 @@ template <bool require_float_dtype, typename FType>
 inline StructInfo InferStructInfoUnary(const Call& call, const BlockBuilder& 
ctx,
                                        FType f_compute_out_dtype) {
   TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
-  if (require_float_dtype && !input_sinfo->IsUnknownDtype() && 
!input_sinfo->dtype.is_float()) {
+  if (require_float_dtype && !input_sinfo->IsUnknownDtype() &&
+      (!input_sinfo->dtype.is_float() && !input_sinfo->dtype.is_bfloat())) {
     ctx->ReportFatal(
         Diagnostic::Error(call)
         << call->op
diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc
index a75a357810..83e32f5af8 100644
--- a/src/script/ir_builder/tir/ir.cc
+++ b/src/script/ir_builder/tir/ir.cc
@@ -752,8 +752,10 @@ 
TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.Float", Float);
 TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.UInt", UInt);
 TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.Int", Int);
 
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.BFloat16").set_body_typed(BFloat16);
 
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Float8E4M3FN").set_body_typed(Float8E4M3FN);
 
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Float8E5M2").set_body_typed(Float8E5M2);
+TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.BFloat16", BFloat16);
 TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float8E4M3FN", Float8E4M3FN);
 TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float8E5M2", Float8E5M2);
 
diff --git a/src/target/source/codegen_cuda.cc 
b/src/target/source/codegen_cuda.cc
index 35973776c8..34023e0bb7 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -1661,7 +1661,7 @@ void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, 
const std::string& val
       os << '(';
     }
     if (i % 2 == 0) {
-      os << "__pack_bfloat162(" << value;
+      os << "__pack_nv_bfloat162(" << value;
     } else {
       os << "," << value << ")";
       if (i != t.lanes() - 1) {
diff --git a/src/target/source/intrin_rule_cuda.cc 
b/src/target/source/intrin_rule_cuda.cc
index 79ea7a458f..e762bde69f 100644
--- a/src/target/source/intrin_rule_cuda.cc
+++ b/src/target/source/intrin_rule_cuda.cc
@@ -53,7 +53,13 @@ struct CUDAMath {
           return "";
       }
     } else if (t.is_bfloat16()) {
-      return 'h' + name;
+      if (name == "fabs") {
+        return "__habs";
+      } else if (name == "round") {
+        return "hrint";
+      } else {
+        return "h" + name;
+      }
     } else if (t.is_int() || t.is_uint()) {
       switch (t.bits()) {
         case 32:
diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index 63c82d1d6c..46c15cb3df 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -801,7 +801,7 @@ PrimExpr abs(PrimExpr x, Span span) {
       return IntImm(x.dtype(), std::abs(px->value), px->span);
     }
     return tir::Select(x >= make_zero(x.dtype()), x, -x, span);
-  } else if (x.dtype().is_float()) {
+  } else if (x.dtype().is_float() || x.dtype().is_bfloat()) {
     using tir::FloatImmNode;
     const FloatImmNode* fx = x.as<FloatImmNode>();
     if (fx) {
diff --git a/src/tir/transforms/unsupported_dtype_legalize.cc 
b/src/tir/transforms/unsupported_dtype_legalize.cc
index c75ecf77e7..e20ffcff0b 100644
--- a/src/tir/transforms/unsupported_dtype_legalize.cc
+++ b/src/tir/transforms/unsupported_dtype_legalize.cc
@@ -339,7 +339,6 @@ class ComputeLegalizer : public StmtExprMutator {
   Stmt VisitStmt_(const AttrStmtNode* op) final {
     Stmt ret = StmtExprMutator::VisitStmt_(op);
     op = ret.as<AttrStmtNode>();
-
     if (auto buffer = op->node.as<Buffer>()) {
       auto it = buffer_remap_.find(buffer.value());
       if (it != buffer_remap_.end()) {
@@ -350,6 +349,43 @@ class ComputeLegalizer : public StmtExprMutator {
       if (it != var_remap_.end()) {
         return AttrStmt(it->second, op->attr_key, op->value, op->body);
       }
+    } else if (auto reducer = op->node.as<CommReducerNode>()) {
+      auto legalized_identity_elements =
+          reducer->identity_element.Map([this](PrimExpr expr) { return 
this->VisitExpr(expr); });
+
+      // Remap input variables
+      for (size_t i = 0; i < legalized_identity_elements.size(); i++) {
+        Var lhs_var = reducer->lhs[i];
+        if (lhs_var.dtype() != legalized_identity_elements[i].dtype()) {
+          var_remap_[lhs_var] = 
lhs_var.copy_with_dtype(legalized_identity_elements[i].dtype());
+        }
+        Var rhs_var = reducer->rhs[i];
+        if (rhs_var.dtype() != legalized_identity_elements[i].dtype()) {
+          var_remap_[rhs_var] = 
rhs_var.copy_with_dtype(legalized_identity_elements[i].dtype());
+        }
+      }
+
+      auto legalized_results =
+          reducer->result.Map([this](PrimExpr expr) { return 
this->VisitExpr(expr); });
+
+      auto legalized_lhs = reducer->lhs.Map([this](Var var) {
+        auto it = var_remap_.find(var);
+        if (it != var_remap_.end()) {
+          return it->second;
+        }
+        return var;
+      });
+
+      auto legalized_rhs = reducer->rhs.Map([this](Var var) {
+        auto it = var_remap_.find(var);
+        if (it != var_remap_.end()) {
+          return it->second;
+        }
+        return var;
+      });
+      return AttrStmt(CommReducer(legalized_lhs, legalized_rhs, 
legalized_results,
+                                  legalized_identity_elements, reducer->span),
+                      op->attr_key, op->value, op->body);
     }
     return ret;
   }
@@ -714,7 +750,10 @@ bool CheckDataTypeSupport(const Target& target, const 
std::string& support_func_
 
 Pass BF16ComputeLegalize() {
   auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
-    // TODO(tvm-team): skip if the target supports bf16
+    auto target = f->GetAttr<Target>(tvm::attr::kTarget).value();
+    if (CheckDataTypeSupport(target, "tvm.contrib.nvcc.supports_bf16")) {
+      return f;
+    }
     return BF16ComputeLegalizer().Legalize(f);
   };
   return CreatePrimFuncPass(pass_func, 0, "tir.BF16ComputeLegalize", {});
@@ -724,7 +763,10 @@ 
TVM_REGISTER_GLOBAL("tir.transform.BF16ComputeLegalize").set_body_typed(BF16Comp
 
 Pass BF16StorageLegalize() {
   auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
-    // TODO(tvm-team): skip if the target supports bf16
+    auto target = f->GetAttr<Target>(tvm::attr::kTarget).value();
+    if (CheckDataTypeSupport(target, "tvm.contrib.nvcc.supports_bf16")) {
+      return f;
+    }
     return BF16StorageLegalizer().Legalize(f);
   };
   return CreatePrimFuncPass(pass_func, 0, "tir.BF16StorageLegalize", {});
diff --git a/tests/python/tir-transform/test_tir_transform_bf16_legalize.py 
b/tests/python/tir-transform/test_tir_transform_bf16_legalize.py
index e2752e8bbb..fa1aa558b6 100644
--- a/tests/python/tir-transform/test_tir_transform_bf16_legalize.py
+++ b/tests/python/tir-transform/test_tir_transform_bf16_legalize.py
@@ -16,26 +16,10 @@
 # under the License.
 import tvm
 import tvm.script
+from tvm.target import Target
 from tvm.script import tir as T
-
-
-def get_before():
-    @tvm.script.ir_module
-    class Before:
-        @T.prim_func
-        def main(
-            Aptr: T.handle("bfloat16"), Bptr: T.handle("bfloat16"), Dptr: 
T.handle("bfloat16")
-        ):
-            T.func_attr({"global_symbol": "main"})
-            A = T.decl_buffer((100,), "bfloat16", data=Aptr)
-            B = T.decl_buffer((100,), "bfloat16", data=Bptr)
-            D = T.decl_buffer((100,), "bfloat16", data=Dptr)
-            C = T.decl_buffer((100,), "bfloat16")
-            for i in T.grid(100):
-                C[i] = A[i] + B[i]
-                D[i] = T.exp(C[i])
-
-    return Before
+from tvm.target import Target
+from tvm.tir.transform.transform import BindTarget
 
 
 def u16tof32(v):
@@ -60,61 +44,7 @@ def f32tobf16(v):
     return T.reinterpret("bfloat16", f32tou16(v))
 
 
-def get_after_compute_legalize():
-    @tvm.script.ir_module
-    class After:
-        @T.prim_func
-        def main(
-            Aptr: T.handle("bfloat16"), Bptr: T.handle("bfloat16"), Dptr: 
T.handle("bfloat16")
-        ):
-            T.func_attr({"global_symbol": "main"})
-            A = T.decl_buffer((100,), "bfloat16", data=Aptr)
-            B = T.decl_buffer((100,), "bfloat16", data=Bptr)
-            D = T.decl_buffer((100,), "bfloat16", data=Dptr)
-            C = T.decl_buffer((100,), "float32")
-            for i in T.grid(100):
-                C[i] = bf16tof32(A[i]) + bf16tof32(B[i])
-                D[i] = f32tobf16(T.exp(C[i]))
-
-    return After
-
-
-def get_after_storage_legalize():
-    @tvm.script.ir_module
-    class After:
-        @T.prim_func
-        def main(Aptr: T.handle("uint16"), Bptr: T.handle("uint16"), Dptr: 
T.handle("uint16")):
-            T.func_attr({"global_symbol": "main"})
-            A = T.decl_buffer((100,), "uint16", data=Aptr)
-            B = T.decl_buffer((100,), "uint16", data=Bptr)
-            D = T.decl_buffer((100,), "uint16", data=Dptr)
-            C = T.decl_buffer((100,), "float32")
-            for i in T.grid(100):
-                C[i] = u16tof32(A[i]) + u16tof32(B[i])
-                D[i] = f32tou16(T.exp(C[i]))
-
-    return After
-
-
-def test_bf16_compute_legalize():
-    before = get_before()
-    expected = get_after_compute_legalize()
-    # run the transform twice to ensure we can afford to deal
-    # with this repeative optimizations
-    after = tvm.tir.transform.BF16ComputeLegalize()(before)
-    after = tvm.tir.transform.BF16ComputeLegalize()(after)
-
-    tvm.ir.assert_structural_equal(after, expected)
-
-
-def test_bf16_storage_legalize():
-    before = get_after_compute_legalize()
-    after = tvm.tir.transform.BF16StorageLegalize()(before)
-    expected = get_after_storage_legalize()
-    tvm.ir.assert_structural_equal(after, expected)
-
-
-def test_bf16_storage_scope():
+def test_bf16_storage_compute_scope_will_legalize():
     def get_before():
         @tvm.script.ir_module
         class Before:
@@ -175,13 +105,289 @@ def test_bf16_storage_scope():
 
         return After
 
-    before = get_before()
+    target = Target("nvidia/geforce-rtx-2080-ti")
+    before = BindTarget(target)(get_before())
+    after_compute = tvm.tir.transform.BF16ComputeLegalize()(before)
+    after_storage = tvm.tir.transform.BF16StorageLegalize()(after_compute)
+    tvm.ir.assert_structural_equal(after_compute, 
BindTarget(target)(after_compute_legalize()))
+    tvm.ir.assert_structural_equal(after_storage, 
BindTarget(target)(after_storage_legalize()))
+
+
+def test_bf16_storage_compute_scope_wont_legalize():
+    def get_before():
+        @tvm.script.ir_module
+        class Before:
+            @T.prim_func
+            def main(
+                Aptr: T.handle("bfloat16", storage_scope="shared"),
+                Bptr: T.handle("bfloat16", storage_scope="local"),
+                Dptr: T.handle("bfloat16"),
+            ):
+                T.func_attr({"global_symbol": "main"})
+                A = T.decl_buffer((100,), "bfloat16", data=Aptr)
+                B = T.decl_buffer((100,), "bfloat16", data=Bptr)
+                D = T.decl_buffer((100,), "bfloat16", data=Dptr)
+                C = T.decl_buffer((100,), "bfloat16")
+                for i in T.grid(100):
+                    C[i] = A[i] + B[i]
+                    D[i] = T.exp(C[i])
+
+        return Before
+
+    def after_compute_legalize():
+        @tvm.script.ir_module
+        class After:
+            @T.prim_func
+            def main(
+                Aptr: T.handle("bfloat16", storage_scope="shared"),
+                Bptr: T.handle("bfloat16", storage_scope="local"),
+                Dptr: T.handle("bfloat16"),
+            ):
+                T.func_attr({"global_symbol": "main"})
+                A = T.decl_buffer((100,), "bfloat16", data=Aptr)
+                B = T.decl_buffer((100,), "bfloat16", data=Bptr)
+                D = T.decl_buffer((100,), "bfloat16", data=Dptr)
+                C = T.decl_buffer((100,), "bfloat16")
+                for i in T.grid(100):
+                    C[i] = A[i] + B[i]
+                    D[i] = T.exp(C[i])
+
+        return After
+
+    def after_storage_legalize():
+        @tvm.script.ir_module
+        class After:
+            @T.prim_func
+            def main(
+                Aptr: T.handle("bfloat16", storage_scope="shared"),
+                Bptr: T.handle("bfloat16", storage_scope="local"),
+                Dptr: T.handle("bfloat16"),
+            ):
+                T.func_attr({"global_symbol": "main"})
+                A = T.decl_buffer((100,), "bfloat16", data=Aptr)
+                B = T.decl_buffer((100,), "bfloat16", data=Bptr)
+                D = T.decl_buffer((100,), "bfloat16", data=Dptr)
+                C = T.decl_buffer((100,), "bfloat16")
+                for i in T.grid(100):
+                    C[i] = A[i] + B[i]
+                    D[i] = T.exp(C[i])
+
+        return After
+
+    target = Target("nvidia/geforce-rtx-3090-ti")
+    before = BindTarget(target)(get_before())
+    after_compute = tvm.tir.transform.BF16ComputeLegalize()(before)
+    after_storage = tvm.tir.transform.BF16StorageLegalize()(after_compute)
+    tvm.ir.assert_structural_equal(after_compute, 
BindTarget(target)(after_compute_legalize()))
+    tvm.ir.assert_structural_equal(after_storage, 
BindTarget(target)(after_storage_legalize()))
+
+
+def test_bf16_reduce_will_legalize():
+    def get_before():
+        @tvm.script.ir_module
+        class Before:
+            @T.prim_func(private=True)
+            def main(
+                Aptr: T.handle("bfloat16", storage_scope="shared"),
+            ):
+                A_flat = T.decl_buffer(4096, "bfloat16", data=Aptr)
+
+                for i in range(128):
+                    threadIdx_x = T.launch_thread("threadIdx.x", 32)
+
+                    reduce = T.decl_buffer(1, dtype="bfloat16", scope="local")
+
+                    with T.attr(
+                        T.comm_reducer(lambda x, y: x + y, [T.bfloat16(0)]),
+                        "reduce_scope",
+                        T.reinterpret("handle", T.uint64(0)),
+                    ):
+                        T.tvm_thread_allreduce(
+                            T.uint32(1),
+                            A_flat[0],
+                            T.bool(True),
+                            reduce[0],
+                            threadIdx_x,
+                        )
+
+        return Before
+
+    def after_compute_legalize():
+        @tvm.script.ir_module
+        class After:
+            @T.prim_func(private=True)
+            def main(
+                Aptr: T.handle("bfloat16", storage_scope="shared"),
+            ):
+                A_flat_1 = T.decl_buffer(4096, "bfloat16", data=Aptr)
+
+                for i in range(128):
+                    threadIdx_x = T.launch_thread("threadIdx.x", 32)
+
+                    reduce = T.decl_buffer(1, dtype="float32", scope="local")
+
+                    with T.attr(
+                        T.comm_reducer(lambda x, y: x + y, [T.float32(0)]),
+                        "reduce_scope",
+                        T.reinterpret("handle", T.uint64(0)),
+                    ):
+                        T.tvm_thread_allreduce(
+                            T.uint32(1),
+                            T.reinterpret(
+                                "float32",
+                                T.shift_left(
+                                    T.Cast("uint32", T.reinterpret("uint16", 
A_flat_1[0])),
+                                    T.uint32(16),
+                                ),
+                            ),
+                            T.bool(True),
+                            reduce[0],
+                            threadIdx_x,
+                        )
+
+        return After
+
+    def after_storage_legalize():
+        @tvm.script.ir_module
+        class After:
+            @T.prim_func(private=True)
+            def main(
+                Aptr: T.handle("uint16", storage_scope="shared"),
+            ):
+                A_flat_1 = T.decl_buffer(4096, "uint16", data=Aptr)
+
+                for i in range(128):
+                    threadIdx_x = T.launch_thread("threadIdx.x", 32)
+
+                    reduce = T.decl_buffer(1, dtype="float32", scope="local")
+
+                    with T.attr(
+                        T.comm_reducer(lambda x, y: x + y, [T.float32(0)]),
+                        "reduce_scope",
+                        T.reinterpret("handle", T.uint64(0)),
+                    ):
+                        T.tvm_thread_allreduce(
+                            T.uint32(1),
+                            T.reinterpret(
+                                "float32",
+                                T.shift_left(
+                                    T.Cast("uint32", T.reinterpret("uint16", 
A_flat_1[0])),
+                                    T.uint32(16),
+                                ),
+                            ),
+                            T.bool(True),
+                            reduce[0],
+                            threadIdx_x,
+                        )
+
+        return After
+
+    target = Target("nvidia/geforce-rtx-2080-ti")
+    before = BindTarget(target)(get_before())
+    after_compute = tvm.tir.transform.BF16ComputeLegalize()(before)
+    after_storage = tvm.tir.transform.BF16StorageLegalize()(after_compute)
+    tvm.ir.assert_structural_equal(after_compute, 
BindTarget(target)(after_compute_legalize()))
+    tvm.ir.assert_structural_equal(after_storage, 
BindTarget(target)(after_storage_legalize()))
+
+
+def test_bf16_reduce_wont_legalize():
+    def get_before():
+        @tvm.script.ir_module
+        class Before:
+            @T.prim_func(private=True)
+            def main(
+                Aptr: T.handle("bfloat16", storage_scope="shared"),
+            ):
+                A_flat = T.decl_buffer(4096, "bfloat16", data=Aptr)
+
+                for i in range(128):
+                    threadIdx_x = T.launch_thread("threadIdx.x", 32)
+
+                    reduce = T.decl_buffer(1, dtype="bfloat16", scope="local")
+
+                    with T.attr(
+                        T.comm_reducer(lambda x, y: x + y, [T.bfloat16(0)]),
+                        "reduce_scope",
+                        T.reinterpret("handle", T.uint64(0)),
+                    ):
+                        T.tvm_thread_allreduce(
+                            T.uint32(1),
+                            A_flat[0],
+                            T.bool(True),
+                            reduce[0],
+                            threadIdx_x,
+                        )
+
+        return Before
+
+    def after_compute_legalize():
+        @tvm.script.ir_module
+        class After:
+            @T.prim_func(private=True)
+            def main(
+                Aptr: T.handle("bfloat16", storage_scope="shared"),
+            ):
+                A_flat = T.decl_buffer(4096, "bfloat16", data=Aptr)
+
+                for i in range(128):
+                    threadIdx_x = T.launch_thread("threadIdx.x", 32)
+
+                    reduce = T.decl_buffer(1, dtype="bfloat16", scope="local")
+
+                    with T.attr(
+                        T.comm_reducer(lambda x, y: x + y, [T.bfloat16(0)]),
+                        "reduce_scope",
+                        T.reinterpret("handle", T.uint64(0)),
+                    ):
+                        T.tvm_thread_allreduce(
+                            T.uint32(1),
+                            A_flat[0],
+                            T.bool(True),
+                            reduce[0],
+                            threadIdx_x,
+                        )
+
+        return After
+
+    def after_storage_legalize():
+        @tvm.script.ir_module
+        class After:
+            @T.prim_func(private=True)
+            def main(
+                Aptr: T.handle("bfloat16", storage_scope="shared"),
+            ):
+                A_flat = T.decl_buffer(4096, "bfloat16", data=Aptr)
+
+                for i in range(128):
+                    threadIdx_x = T.launch_thread("threadIdx.x", 32)
+
+                    reduce = T.decl_buffer(1, dtype="bfloat16", scope="local")
+
+                    with T.attr(
+                        T.comm_reducer(lambda x, y: x + y, [T.bfloat16(0)]),
+                        "reduce_scope",
+                        T.reinterpret("handle", T.uint64(0)),
+                    ):
+                        T.tvm_thread_allreduce(
+                            T.uint32(1),
+                            A_flat[0],
+                            T.bool(True),
+                            reduce[0],
+                            threadIdx_x,
+                        )
+
+        return After
+
+    target = Target("nvidia/geforce-rtx-3090-ti")
+    before = BindTarget(target)(get_before())
     after_compute = tvm.tir.transform.BF16ComputeLegalize()(before)
     after_storage = tvm.tir.transform.BF16StorageLegalize()(after_compute)
-    tvm.ir.assert_structural_equal(after_compute, after_compute_legalize())
-    tvm.ir.assert_structural_equal(after_storage, after_storage_legalize())
+    tvm.ir.assert_structural_equal(after_compute, 
BindTarget(target)(after_compute_legalize()))
+    tvm.ir.assert_structural_equal(after_storage, 
BindTarget(target)(after_storage_legalize()))
 
 
 if __name__ == "__main__":
-    test_bf16_storage_legalize()
-    test_bf16_storage_scope()
+    test_bf16_storage_compute_scope_will_legalize()
+    test_bf16_storage_compute_scope_wont_legalize()
+    test_bf16_reduce_will_legalize()
+    test_bf16_reduce_wont_legalize()

Reply via email to