This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e973770932 [DataType] BF16 Support (#17670)
e973770932 is described below
commit e973770932bb41ad7136a0834c8cf3c4aedbae29
Author: Joshua Hong <[email protected]>
AuthorDate: Sun Mar 9 21:06:26 2025 -0700
[DataType] BF16 Support (#17670)
Adds general BF16 support to TVM
- Addresses missing legalization in `comm_reducer`
- Adds BF16 legalization pass skipping if the target supports BF16
- Unit Tests for `comm_reducer` changes as well as legalization skipping
- Modifications to TVM datatypes to allow for `T.bfloat16` in the test file
- Fixes for BFloat related cuda codegen
Related PR in MLC-LLM adds BF16 support with quantization
https://github.com/mlc-ai/mlc-llm/pull/3158
Tested with the original problematic model Gemma 2 27b with
both added quantization configurations `q4bf16_0` and `q4bf16_1`.
While compilation is successful and the first few rounds of prompting
have expected performance, we observe that for long contexts
generation quality degrades. The same behavior isn't observed on
Gemma 2 9B, quantized or unquantized
---------
Co-authored-by: Joshua Hong <[email protected]>
---
include/tvm/runtime/data_type.h | 2 +
include/tvm/script/ir_builder/tir/ir.h | 2 +
python/tvm/_ffi/runtime_ctypes.py | 1 +
python/tvm/script/ir_builder/tir/ir.py | 4 +-
.../postproc/disallow_async_strided_mem_copy.cc | 10 +-
src/relax/op/nn/nn.cc | 3 +-
src/relax/op/op_common.h | 3 +-
src/script/ir_builder/tir/ir.cc | 2 +
src/target/source/codegen_cuda.cc | 2 +-
src/target/source/intrin_rule_cuda.cc | 8 +-
src/tir/op/op.cc | 2 +-
src/tir/transforms/unsupported_dtype_legalize.cc | 48 ++-
.../test_tir_transform_bf16_legalize.py | 364 ++++++++++++++++-----
13 files changed, 361 insertions(+), 90 deletions(-)
diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h
index 65fd0c98fd..40664f0c40 100644
--- a/include/tvm/runtime/data_type.h
+++ b/include/tvm/runtime/data_type.h
@@ -124,6 +124,8 @@ class DataType {
bool is_bool() const { return code() == DataType::kUInt && bits() == 1; }
/*! \return whether type is a float type. */
bool is_float() const { return code() == DataType::kFloat; }
+ /*! \return whether type is a bfloat type. */
+ bool is_bfloat() const { return code() == DataType::kBFloat; }
/*! \return whether type is a float8 type. */
bool is_float8() const {
return (code() == DataType::kFloat || code() == DataType::kFloat8_e4m3fn ||
diff --git a/include/tvm/script/ir_builder/tir/ir.h
b/include/tvm/script/ir_builder/tir/ir.h
index e60a3859ac..b01cb84222 100644
--- a/include/tvm/script/ir_builder/tir/ir.h
+++ b/include/tvm/script/ir_builder/tir/ir.h
@@ -473,6 +473,7 @@ inline Var Handle(runtime::DataType dtype =
runtime::DataType::Void(),
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##32, FDType(32)); \
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##64, FDType(64));
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(BFloat, DataType::BFloat);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(Float, DataType::Float);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(UInt, DataType::UInt);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(Int, DataType::Int);
@@ -490,6 +491,7 @@ TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(Int, DataType::Int);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##32, FDType, 32); \
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##64, FDType, 64);
+TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(BFloat, DataType::BFloat);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Float, DataType::Float);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(UInt, DataType::UInt);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Int, DataType::Int);
diff --git a/python/tvm/_ffi/runtime_ctypes.py
b/python/tvm/_ffi/runtime_ctypes.py
index 317bd6bead..8a9c231617 100644
--- a/python/tvm/_ffi/runtime_ctypes.py
+++ b/python/tvm/_ffi/runtime_ctypes.py
@@ -118,6 +118,7 @@ class DataType(ctypes.Structure):
"float16": {"type_code": DataTypeCode.FLOAT, "bits": 16, "lanes": 1},
"float32": {"type_code": DataTypeCode.FLOAT, "bits": 32, "lanes": 1},
"float64": {"type_code": DataTypeCode.FLOAT, "bits": 64, "lanes": 1},
+ "bfloat16": {"type_code": DataTypeCode.BFLOAT, "bits": 16, "lanes": 1},
}
def __init__(self, type_str):
diff --git a/python/tvm/script/ir_builder/tir/ir.py
b/python/tvm/script/ir_builder/tir/ir.py
index 2fce022da3..3e835e8d9d 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -1464,7 +1464,7 @@ float4_e2m1fnx16 = func_gen(("Float4E2M1FNx16"))
float4_e2m1fnx32 = func_gen(("Float4E2M1FNx32"))
float4_e2m1fnx64 = func_gen(("Float4E2M1FNx64"))
-
+bfloat16 = func_gen(("BFloat16"))
# pylint: enable=invalid-name
@@ -1961,7 +1961,6 @@ tvm_call_cpacked_lowered = call_cpacked_lowered
# pylint: enable=invalid-name
-
__all__ = [
"int8",
"int16",
@@ -2048,6 +2047,7 @@ __all__ = [
"float16x64",
"float32x64",
"float64x64",
+ "bfloat16",
"buffer",
"buffer_decl",
"prim_func",
diff --git a/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc
b/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc
index d654e467f1..a6a71202ae 100644
--- a/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc
+++ b/src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc
@@ -119,7 +119,11 @@ namespace meta_schedule {
class DisallowAsyncStridedMemCopyNode : public PostprocNode {
public:
// Inherited from PostprocNode
- void InitializeWithTuneContext(const TuneContext& context) final {}
+ void InitializeWithTuneContext(const TuneContext& context) final {
+ /* Null check */
+ ICHECK(context->target) << "Context must contain a target";
+ this->target = context->target.value();
+ }
// Inherited from PostprocNode
bool Apply(const tir::Schedule& sch) final {
IRModule mod = sch->mod();
@@ -130,6 +134,7 @@ class DisallowAsyncStridedMemCopyNode : public PostprocNode
{
IRModule lowered{nullptr};
try {
auto pass_list = Array<tvm::transform::Pass>();
+ pass_list.push_back(tir::transform::BindTarget(this->target));
pass_list.push_back(tir::transform::LowerInitBlock());
pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
@@ -168,6 +173,9 @@ class DisallowAsyncStridedMemCopyNode : public PostprocNode
{
static constexpr const char* _type_key =
"meta_schedule.DisallowAsyncStridedMemCopy";
TVM_DECLARE_FINAL_OBJECT_INFO(DisallowAsyncStridedMemCopyNode, PostprocNode);
+
+ private:
+ tvm::Target target;
};
Postproc Postproc::DisallowAsyncStridedMemCopy() {
diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc
index 526b816d09..b4668d65d3 100644
--- a/src/relax/op/nn/nn.cc
+++ b/src/relax/op/nn/nn.cc
@@ -191,7 +191,8 @@ bool NormCheckDtypeAndShape(const Call& call, const
BlockBuilder& ctx,
axes_non_neg = NormalizeAxes(call, ctx, data_sinfo->ndim, axes);
}
int n_axis = axes.size();
- if (!data_sinfo->IsUnknownDtype() && !data_sinfo->dtype.is_float()) {
+ if (!data_sinfo->IsUnknownDtype() &&
+ (!data_sinfo->dtype.is_float() && !data_sinfo->dtype.is_bfloat())) {
ctx->ReportFatal(
Diagnostic::Error(call)
<< op << " requires the input data to have float dtype. However, the
given data dtype is "
diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h
index 6e2ef6bd2b..eea6db22fd 100644
--- a/src/relax/op/op_common.h
+++ b/src/relax/op/op_common.h
@@ -199,7 +199,8 @@ template <bool require_float_dtype, typename FType>
inline StructInfo InferStructInfoUnary(const Call& call, const BlockBuilder&
ctx,
FType f_compute_out_dtype) {
TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
- if (require_float_dtype && !input_sinfo->IsUnknownDtype() &&
!input_sinfo->dtype.is_float()) {
+ if (require_float_dtype && !input_sinfo->IsUnknownDtype() &&
+ (!input_sinfo->dtype.is_float() && !input_sinfo->dtype.is_bfloat())) {
ctx->ReportFatal(
Diagnostic::Error(call)
<< call->op
diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc
index a75a357810..83e32f5af8 100644
--- a/src/script/ir_builder/tir/ir.cc
+++ b/src/script/ir_builder/tir/ir.cc
@@ -752,8 +752,10 @@
TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.Float", Float);
TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.UInt", UInt);
TVM_REGISTER_GLOBAL_SIZES_LANES("script.ir_builder.tir.Int", Int);
+TVM_REGISTER_GLOBAL("script.ir_builder.tir.BFloat16").set_body_typed(BFloat16);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Float8E4M3FN").set_body_typed(Float8E4M3FN);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Float8E5M2").set_body_typed(Float8E5M2);
+TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.BFloat16", BFloat16);
TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float8E4M3FN", Float8E4M3FN);
TVM_REGISTER_GLOBAL_LANES("script.ir_builder.tir.Float8E5M2", Float8E5M2);
diff --git a/src/target/source/codegen_cuda.cc
b/src/target/source/codegen_cuda.cc
index 35973776c8..34023e0bb7 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -1661,7 +1661,7 @@ void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i,
const std::string& val
os << '(';
}
if (i % 2 == 0) {
- os << "__pack_bfloat162(" << value;
+ os << "__pack_nv_bfloat162(" << value;
} else {
os << "," << value << ")";
if (i != t.lanes() - 1) {
diff --git a/src/target/source/intrin_rule_cuda.cc
b/src/target/source/intrin_rule_cuda.cc
index 79ea7a458f..e762bde69f 100644
--- a/src/target/source/intrin_rule_cuda.cc
+++ b/src/target/source/intrin_rule_cuda.cc
@@ -53,7 +53,13 @@ struct CUDAMath {
return "";
}
} else if (t.is_bfloat16()) {
- return 'h' + name;
+ if (name == "fabs") {
+ return "__habs";
+ } else if (name == "round") {
+ return "hrint";
+ } else {
+ return "h" + name;
+ }
} else if (t.is_int() || t.is_uint()) {
switch (t.bits()) {
case 32:
diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index 63c82d1d6c..46c15cb3df 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -801,7 +801,7 @@ PrimExpr abs(PrimExpr x, Span span) {
return IntImm(x.dtype(), std::abs(px->value), px->span);
}
return tir::Select(x >= make_zero(x.dtype()), x, -x, span);
- } else if (x.dtype().is_float()) {
+ } else if (x.dtype().is_float() || x.dtype().is_bfloat()) {
using tir::FloatImmNode;
const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) {
diff --git a/src/tir/transforms/unsupported_dtype_legalize.cc
b/src/tir/transforms/unsupported_dtype_legalize.cc
index c75ecf77e7..e20ffcff0b 100644
--- a/src/tir/transforms/unsupported_dtype_legalize.cc
+++ b/src/tir/transforms/unsupported_dtype_legalize.cc
@@ -339,7 +339,6 @@ class ComputeLegalizer : public StmtExprMutator {
Stmt VisitStmt_(const AttrStmtNode* op) final {
Stmt ret = StmtExprMutator::VisitStmt_(op);
op = ret.as<AttrStmtNode>();
-
if (auto buffer = op->node.as<Buffer>()) {
auto it = buffer_remap_.find(buffer.value());
if (it != buffer_remap_.end()) {
@@ -350,6 +349,43 @@ class ComputeLegalizer : public StmtExprMutator {
if (it != var_remap_.end()) {
return AttrStmt(it->second, op->attr_key, op->value, op->body);
}
+ } else if (auto reducer = op->node.as<CommReducerNode>()) {
+ auto legalized_identity_elements =
+ reducer->identity_element.Map([this](PrimExpr expr) { return
this->VisitExpr(expr); });
+
+ // Remap input variables
+ for (size_t i = 0; i < legalized_identity_elements.size(); i++) {
+ Var lhs_var = reducer->lhs[i];
+ if (lhs_var.dtype() != legalized_identity_elements[i].dtype()) {
+ var_remap_[lhs_var] =
lhs_var.copy_with_dtype(legalized_identity_elements[i].dtype());
+ }
+ Var rhs_var = reducer->rhs[i];
+ if (rhs_var.dtype() != legalized_identity_elements[i].dtype()) {
+ var_remap_[rhs_var] =
rhs_var.copy_with_dtype(legalized_identity_elements[i].dtype());
+ }
+ }
+
+ auto legalized_results =
+ reducer->result.Map([this](PrimExpr expr) { return
this->VisitExpr(expr); });
+
+ auto legalized_lhs = reducer->lhs.Map([this](Var var) {
+ auto it = var_remap_.find(var);
+ if (it != var_remap_.end()) {
+ return it->second;
+ }
+ return var;
+ });
+
+ auto legalized_rhs = reducer->rhs.Map([this](Var var) {
+ auto it = var_remap_.find(var);
+ if (it != var_remap_.end()) {
+ return it->second;
+ }
+ return var;
+ });
+ return AttrStmt(CommReducer(legalized_lhs, legalized_rhs,
legalized_results,
+ legalized_identity_elements, reducer->span),
+ op->attr_key, op->value, op->body);
}
return ret;
}
@@ -714,7 +750,10 @@ bool CheckDataTypeSupport(const Target& target, const
std::string& support_func_
Pass BF16ComputeLegalize() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
- // TODO(tvm-team): skip if the target supports bf16
+ auto target = f->GetAttr<Target>(tvm::attr::kTarget).value();
+ if (CheckDataTypeSupport(target, "tvm.contrib.nvcc.supports_bf16")) {
+ return f;
+ }
return BF16ComputeLegalizer().Legalize(f);
};
return CreatePrimFuncPass(pass_func, 0, "tir.BF16ComputeLegalize", {});
@@ -724,7 +763,10 @@
TVM_REGISTER_GLOBAL("tir.transform.BF16ComputeLegalize").set_body_typed(BF16Comp
Pass BF16StorageLegalize() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
- // TODO(tvm-team): skip if the target supports bf16
+ auto target = f->GetAttr<Target>(tvm::attr::kTarget).value();
+ if (CheckDataTypeSupport(target, "tvm.contrib.nvcc.supports_bf16")) {
+ return f;
+ }
return BF16StorageLegalizer().Legalize(f);
};
return CreatePrimFuncPass(pass_func, 0, "tir.BF16StorageLegalize", {});
diff --git a/tests/python/tir-transform/test_tir_transform_bf16_legalize.py
b/tests/python/tir-transform/test_tir_transform_bf16_legalize.py
index e2752e8bbb..fa1aa558b6 100644
--- a/tests/python/tir-transform/test_tir_transform_bf16_legalize.py
+++ b/tests/python/tir-transform/test_tir_transform_bf16_legalize.py
@@ -16,26 +16,10 @@
# under the License.
import tvm
import tvm.script
+from tvm.target import Target
from tvm.script import tir as T
-
-
-def get_before():
- @tvm.script.ir_module
- class Before:
- @T.prim_func
- def main(
- Aptr: T.handle("bfloat16"), Bptr: T.handle("bfloat16"), Dptr:
T.handle("bfloat16")
- ):
- T.func_attr({"global_symbol": "main"})
- A = T.decl_buffer((100,), "bfloat16", data=Aptr)
- B = T.decl_buffer((100,), "bfloat16", data=Bptr)
- D = T.decl_buffer((100,), "bfloat16", data=Dptr)
- C = T.decl_buffer((100,), "bfloat16")
- for i in T.grid(100):
- C[i] = A[i] + B[i]
- D[i] = T.exp(C[i])
-
- return Before
+from tvm.target import Target
+from tvm.tir.transform.transform import BindTarget
def u16tof32(v):
@@ -60,61 +44,7 @@ def f32tobf16(v):
return T.reinterpret("bfloat16", f32tou16(v))
-def get_after_compute_legalize():
- @tvm.script.ir_module
- class After:
- @T.prim_func
- def main(
- Aptr: T.handle("bfloat16"), Bptr: T.handle("bfloat16"), Dptr:
T.handle("bfloat16")
- ):
- T.func_attr({"global_symbol": "main"})
- A = T.decl_buffer((100,), "bfloat16", data=Aptr)
- B = T.decl_buffer((100,), "bfloat16", data=Bptr)
- D = T.decl_buffer((100,), "bfloat16", data=Dptr)
- C = T.decl_buffer((100,), "float32")
- for i in T.grid(100):
- C[i] = bf16tof32(A[i]) + bf16tof32(B[i])
- D[i] = f32tobf16(T.exp(C[i]))
-
- return After
-
-
-def get_after_storage_legalize():
- @tvm.script.ir_module
- class After:
- @T.prim_func
- def main(Aptr: T.handle("uint16"), Bptr: T.handle("uint16"), Dptr:
T.handle("uint16")):
- T.func_attr({"global_symbol": "main"})
- A = T.decl_buffer((100,), "uint16", data=Aptr)
- B = T.decl_buffer((100,), "uint16", data=Bptr)
- D = T.decl_buffer((100,), "uint16", data=Dptr)
- C = T.decl_buffer((100,), "float32")
- for i in T.grid(100):
- C[i] = u16tof32(A[i]) + u16tof32(B[i])
- D[i] = f32tou16(T.exp(C[i]))
-
- return After
-
-
-def test_bf16_compute_legalize():
- before = get_before()
- expected = get_after_compute_legalize()
- # run the transform twice to ensure we can afford to deal
- # with this repeative optimizations
- after = tvm.tir.transform.BF16ComputeLegalize()(before)
- after = tvm.tir.transform.BF16ComputeLegalize()(after)
-
- tvm.ir.assert_structural_equal(after, expected)
-
-
-def test_bf16_storage_legalize():
- before = get_after_compute_legalize()
- after = tvm.tir.transform.BF16StorageLegalize()(before)
- expected = get_after_storage_legalize()
- tvm.ir.assert_structural_equal(after, expected)
-
-
-def test_bf16_storage_scope():
+def test_bf16_storage_compute_scope_will_legalize():
def get_before():
@tvm.script.ir_module
class Before:
@@ -175,13 +105,289 @@ def test_bf16_storage_scope():
return After
- before = get_before()
+ target = Target("nvidia/geforce-rtx-2080-ti")
+ before = BindTarget(target)(get_before())
+ after_compute = tvm.tir.transform.BF16ComputeLegalize()(before)
+ after_storage = tvm.tir.transform.BF16StorageLegalize()(after_compute)
+ tvm.ir.assert_structural_equal(after_compute,
BindTarget(target)(after_compute_legalize()))
+ tvm.ir.assert_structural_equal(after_storage,
BindTarget(target)(after_storage_legalize()))
+
+
+def test_bf16_storage_compute_scope_wont_legalize():
+ def get_before():
+ @tvm.script.ir_module
+ class Before:
+ @T.prim_func
+ def main(
+ Aptr: T.handle("bfloat16", storage_scope="shared"),
+ Bptr: T.handle("bfloat16", storage_scope="local"),
+ Dptr: T.handle("bfloat16"),
+ ):
+ T.func_attr({"global_symbol": "main"})
+ A = T.decl_buffer((100,), "bfloat16", data=Aptr)
+ B = T.decl_buffer((100,), "bfloat16", data=Bptr)
+ D = T.decl_buffer((100,), "bfloat16", data=Dptr)
+ C = T.decl_buffer((100,), "bfloat16")
+ for i in T.grid(100):
+ C[i] = A[i] + B[i]
+ D[i] = T.exp(C[i])
+
+ return Before
+
+ def after_compute_legalize():
+ @tvm.script.ir_module
+ class After:
+ @T.prim_func
+ def main(
+ Aptr: T.handle("bfloat16", storage_scope="shared"),
+ Bptr: T.handle("bfloat16", storage_scope="local"),
+ Dptr: T.handle("bfloat16"),
+ ):
+ T.func_attr({"global_symbol": "main"})
+ A = T.decl_buffer((100,), "bfloat16", data=Aptr)
+ B = T.decl_buffer((100,), "bfloat16", data=Bptr)
+ D = T.decl_buffer((100,), "bfloat16", data=Dptr)
+ C = T.decl_buffer((100,), "bfloat16")
+ for i in T.grid(100):
+ C[i] = A[i] + B[i]
+ D[i] = T.exp(C[i])
+
+ return After
+
+ def after_storage_legalize():
+ @tvm.script.ir_module
+ class After:
+ @T.prim_func
+ def main(
+ Aptr: T.handle("bfloat16", storage_scope="shared"),
+ Bptr: T.handle("bfloat16", storage_scope="local"),
+ Dptr: T.handle("bfloat16"),
+ ):
+ T.func_attr({"global_symbol": "main"})
+ A = T.decl_buffer((100,), "bfloat16", data=Aptr)
+ B = T.decl_buffer((100,), "bfloat16", data=Bptr)
+ D = T.decl_buffer((100,), "bfloat16", data=Dptr)
+ C = T.decl_buffer((100,), "bfloat16")
+ for i in T.grid(100):
+ C[i] = A[i] + B[i]
+ D[i] = T.exp(C[i])
+
+ return After
+
+ target = Target("nvidia/geforce-rtx-3090-ti")
+ before = BindTarget(target)(get_before())
+ after_compute = tvm.tir.transform.BF16ComputeLegalize()(before)
+ after_storage = tvm.tir.transform.BF16StorageLegalize()(after_compute)
+ tvm.ir.assert_structural_equal(after_compute,
BindTarget(target)(after_compute_legalize()))
+ tvm.ir.assert_structural_equal(after_storage,
BindTarget(target)(after_storage_legalize()))
+
+
+def test_bf16_reduce_will_legalize():
+ def get_before():
+ @tvm.script.ir_module
+ class Before:
+ @T.prim_func(private=True)
+ def main(
+ Aptr: T.handle("bfloat16", storage_scope="shared"),
+ ):
+ A_flat = T.decl_buffer(4096, "bfloat16", data=Aptr)
+
+ for i in range(128):
+ threadIdx_x = T.launch_thread("threadIdx.x", 32)
+
+ reduce = T.decl_buffer(1, dtype="bfloat16", scope="local")
+
+ with T.attr(
+ T.comm_reducer(lambda x, y: x + y, [T.bfloat16(0)]),
+ "reduce_scope",
+ T.reinterpret("handle", T.uint64(0)),
+ ):
+ T.tvm_thread_allreduce(
+ T.uint32(1),
+ A_flat[0],
+ T.bool(True),
+ reduce[0],
+ threadIdx_x,
+ )
+
+ return Before
+
+ def after_compute_legalize():
+ @tvm.script.ir_module
+ class After:
+ @T.prim_func(private=True)
+ def main(
+ Aptr: T.handle("bfloat16", storage_scope="shared"),
+ ):
+ A_flat_1 = T.decl_buffer(4096, "bfloat16", data=Aptr)
+
+ for i in range(128):
+ threadIdx_x = T.launch_thread("threadIdx.x", 32)
+
+ reduce = T.decl_buffer(1, dtype="float32", scope="local")
+
+ with T.attr(
+ T.comm_reducer(lambda x, y: x + y, [T.float32(0)]),
+ "reduce_scope",
+ T.reinterpret("handle", T.uint64(0)),
+ ):
+ T.tvm_thread_allreduce(
+ T.uint32(1),
+ T.reinterpret(
+ "float32",
+ T.shift_left(
+ T.Cast("uint32", T.reinterpret("uint16",
A_flat_1[0])),
+ T.uint32(16),
+ ),
+ ),
+ T.bool(True),
+ reduce[0],
+ threadIdx_x,
+ )
+
+ return After
+
+ def after_storage_legalize():
+ @tvm.script.ir_module
+ class After:
+ @T.prim_func(private=True)
+ def main(
+ Aptr: T.handle("uint16", storage_scope="shared"),
+ ):
+ A_flat_1 = T.decl_buffer(4096, "uint16", data=Aptr)
+
+ for i in range(128):
+ threadIdx_x = T.launch_thread("threadIdx.x", 32)
+
+ reduce = T.decl_buffer(1, dtype="float32", scope="local")
+
+ with T.attr(
+ T.comm_reducer(lambda x, y: x + y, [T.float32(0)]),
+ "reduce_scope",
+ T.reinterpret("handle", T.uint64(0)),
+ ):
+ T.tvm_thread_allreduce(
+ T.uint32(1),
+ T.reinterpret(
+ "float32",
+ T.shift_left(
+ T.Cast("uint32", T.reinterpret("uint16",
A_flat_1[0])),
+ T.uint32(16),
+ ),
+ ),
+ T.bool(True),
+ reduce[0],
+ threadIdx_x,
+ )
+
+ return After
+
+ target = Target("nvidia/geforce-rtx-2080-ti")
+ before = BindTarget(target)(get_before())
+ after_compute = tvm.tir.transform.BF16ComputeLegalize()(before)
+ after_storage = tvm.tir.transform.BF16StorageLegalize()(after_compute)
+ tvm.ir.assert_structural_equal(after_compute,
BindTarget(target)(after_compute_legalize()))
+ tvm.ir.assert_structural_equal(after_storage,
BindTarget(target)(after_storage_legalize()))
+
+
+def test_bf16_reduce_wont_legalize():
+ def get_before():
+ @tvm.script.ir_module
+ class Before:
+ @T.prim_func(private=True)
+ def main(
+ Aptr: T.handle("bfloat16", storage_scope="shared"),
+ ):
+ A_flat = T.decl_buffer(4096, "bfloat16", data=Aptr)
+
+ for i in range(128):
+ threadIdx_x = T.launch_thread("threadIdx.x", 32)
+
+ reduce = T.decl_buffer(1, dtype="bfloat16", scope="local")
+
+ with T.attr(
+ T.comm_reducer(lambda x, y: x + y, [T.bfloat16(0)]),
+ "reduce_scope",
+ T.reinterpret("handle", T.uint64(0)),
+ ):
+ T.tvm_thread_allreduce(
+ T.uint32(1),
+ A_flat[0],
+ T.bool(True),
+ reduce[0],
+ threadIdx_x,
+ )
+
+ return Before
+
+ def after_compute_legalize():
+ @tvm.script.ir_module
+ class After:
+ @T.prim_func(private=True)
+ def main(
+ Aptr: T.handle("bfloat16", storage_scope="shared"),
+ ):
+ A_flat = T.decl_buffer(4096, "bfloat16", data=Aptr)
+
+ for i in range(128):
+ threadIdx_x = T.launch_thread("threadIdx.x", 32)
+
+ reduce = T.decl_buffer(1, dtype="bfloat16", scope="local")
+
+ with T.attr(
+ T.comm_reducer(lambda x, y: x + y, [T.bfloat16(0)]),
+ "reduce_scope",
+ T.reinterpret("handle", T.uint64(0)),
+ ):
+ T.tvm_thread_allreduce(
+ T.uint32(1),
+ A_flat[0],
+ T.bool(True),
+ reduce[0],
+ threadIdx_x,
+ )
+
+ return After
+
+ def after_storage_legalize():
+ @tvm.script.ir_module
+ class After:
+ @T.prim_func(private=True)
+ def main(
+ Aptr: T.handle("bfloat16", storage_scope="shared"),
+ ):
+ A_flat = T.decl_buffer(4096, "bfloat16", data=Aptr)
+
+ for i in range(128):
+ threadIdx_x = T.launch_thread("threadIdx.x", 32)
+
+ reduce = T.decl_buffer(1, dtype="bfloat16", scope="local")
+
+ with T.attr(
+ T.comm_reducer(lambda x, y: x + y, [T.bfloat16(0)]),
+ "reduce_scope",
+ T.reinterpret("handle", T.uint64(0)),
+ ):
+ T.tvm_thread_allreduce(
+ T.uint32(1),
+ A_flat[0],
+ T.bool(True),
+ reduce[0],
+ threadIdx_x,
+ )
+
+ return After
+
+ target = Target("nvidia/geforce-rtx-3090-ti")
+ before = BindTarget(target)(get_before())
after_compute = tvm.tir.transform.BF16ComputeLegalize()(before)
after_storage = tvm.tir.transform.BF16StorageLegalize()(after_compute)
- tvm.ir.assert_structural_equal(after_compute, after_compute_legalize())
- tvm.ir.assert_structural_equal(after_storage, after_storage_legalize())
+ tvm.ir.assert_structural_equal(after_compute,
BindTarget(target)(after_compute_legalize()))
+ tvm.ir.assert_structural_equal(after_storage,
BindTarget(target)(after_storage_legalize()))
if __name__ == "__main__":
- test_bf16_storage_legalize()
- test_bf16_storage_scope()
+ test_bf16_storage_compute_scope_will_legalize()
+ test_bf16_storage_compute_scope_wont_legalize()
+ test_bf16_reduce_will_legalize()
+ test_bf16_reduce_wont_legalize()