This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new efd7c45462 [TIR] [bfloat16] add bfloat16 promotion for CallNode
(#12370)
efd7c45462 is described below
commit efd7c45462a4f666919fccb82189a0818e2d722f
Author: Youlei Yang <[email protected]>
AuthorDate: Fri Aug 19 01:44:07 2022 +0800
[TIR] [bfloat16] add bfloat16 promotion for CallNode (#12370)
* add bfloat16 promotion for CallNode
* add softmax to bfloat16 build test
---
src/tir/transforms/bf16_legalize.cc | 21 +++++++++++++++++++++
tests/python/relay/test_cpp_build_module.py | 4 +++-
2 files changed, 24 insertions(+), 1 deletion(-)
diff --git a/src/tir/transforms/bf16_legalize.cc
b/src/tir/transforms/bf16_legalize.cc
index 193584f84b..5dc08f31c2 100644
--- a/src/tir/transforms/bf16_legalize.cc
+++ b/src/tir/transforms/bf16_legalize.cc
@@ -55,6 +55,7 @@ class BF16PromoteRewriter : public StmtExprMutator {
PrimExpr VisitExpr_(const LENode* op) final;
PrimExpr VisitExpr_(const GTNode* op) final;
PrimExpr VisitExpr_(const GENode* op) final;
+ PrimExpr VisitExpr_(const CallNode* op) final;
};
#define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC, NEEDCAST)
\
@@ -88,6 +89,26 @@ DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LENode, operator<=,
false)
DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GTNode, operator>, false)
DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GENode, operator>=, false)
+PrimExpr BF16PromoteRewriter::VisitExpr_(const CallNode* op) {
+ Array<PrimExpr> args;
+ for (auto& arg : op->args) {
+ PrimExpr x = this->VisitExpr(arg);
+ if (x.dtype().is_bfloat16()) {
+ DataType fp32_dtype(kDLFloat, 32, x.dtype().lanes());
+ args.push_back(Cast(fp32_dtype, {x}, op->span));
+ } else {
+ args.push_back(x);
+ }
+ }
+ if (op->dtype.is_bfloat16()) {
+ DataType fp32_dtype(kDLFloat, 32, op->dtype.lanes());
+ PrimExpr result_fp32 = Call(fp32_dtype, op->op, args, op->span);
+ return Cast(op->dtype, {result_fp32}, op->span);
+ } else {
+ return Call(op->dtype, op->op, args, op->span);
+ }
+}
+
/*
* Eliminate verbose casting between fp32 and bf16
* Checks if the AST has the pattern:
diff --git a/tests/python/relay/test_cpp_build_module.py
b/tests/python/relay/test_cpp_build_module.py
index ccf961fbe4..e8e6676863 100644
--- a/tests/python/relay/test_cpp_build_module.py
+++ b/tests/python/relay/test_cpp_build_module.py
@@ -127,7 +127,9 @@ def test_bf16_build():
relu_bf16 = relay.nn.relu(bn_bf16[0])
maxpool_bf16 = relay.nn.max_pool2d(relu_bf16, pool_size=(2, 2),
strides=(2, 2))
avgpool_bf16 = relay.nn.avg_pool2d(maxpool_bf16, pool_size=(2, 2),
strides=(2, 2))
- mod_bf16 = tvm.IRModule.from_expr(avgpool_bf16)
+ flattened_bf16 = relay.nn.batch_flatten(avgpool_bf16)
+ softmax_bf16 = relay.nn.softmax(flattened_bf16)
+ mod_bf16 = tvm.IRModule.from_expr(softmax_bf16)
with tvm.transform.PassContext(opt_level=3):
relay.build(mod_bf16, target="llvm", params=params)