This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new efd7c45462 [TIR] [bfloat16] add bfloat16 promotion for CallNode 
(#12370)
efd7c45462 is described below

commit efd7c45462a4f666919fccb82189a0818e2d722f
Author: Youlei Yang <[email protected]>
AuthorDate: Fri Aug 19 01:44:07 2022 +0800

    [TIR] [bfloat16] add bfloat16 promotion for CallNode (#12370)
    
    * add bfloat16 promotion for CallNode
    
    * add softmax to bfloat16 build test
---
 src/tir/transforms/bf16_legalize.cc         | 21 +++++++++++++++++++++
 tests/python/relay/test_cpp_build_module.py |  4 +++-
 2 files changed, 24 insertions(+), 1 deletion(-)

diff --git a/src/tir/transforms/bf16_legalize.cc 
b/src/tir/transforms/bf16_legalize.cc
index 193584f84b..5dc08f31c2 100644
--- a/src/tir/transforms/bf16_legalize.cc
+++ b/src/tir/transforms/bf16_legalize.cc
@@ -55,6 +55,7 @@ class BF16PromoteRewriter : public StmtExprMutator {
   PrimExpr VisitExpr_(const LENode* op) final;
   PrimExpr VisitExpr_(const GTNode* op) final;
   PrimExpr VisitExpr_(const GENode* op) final;
+  PrimExpr VisitExpr_(const CallNode* op) final;
 };
 
 #define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC, NEEDCAST)            
    \
@@ -88,6 +89,26 @@ DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LENode, operator<=, 
false)
 DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GTNode, operator>, false)
 DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GENode, operator>=, false)
 
+PrimExpr BF16PromoteRewriter::VisitExpr_(const CallNode* op) {
+  Array<PrimExpr> args;
+  for (auto& arg : op->args) {
+    PrimExpr x = this->VisitExpr(arg);
+    if (x.dtype().is_bfloat16()) {
+      DataType fp32_dtype(kDLFloat, 32, x.dtype().lanes());
+      args.push_back(Cast(fp32_dtype, {x}, op->span));
+    } else {
+      args.push_back(x);
+    }
+  }
+  if (op->dtype.is_bfloat16()) {
+    DataType fp32_dtype(kDLFloat, 32, op->dtype.lanes());
+    PrimExpr result_fp32 = Call(fp32_dtype, op->op, args, op->span);
+    return Cast(op->dtype, {result_fp32}, op->span);
+  } else {
+    return Call(op->dtype, op->op, args, op->span);
+  }
+}
+
 /*
  * Eliminate verbose casting between fp32 and bf16
  * Checks if the AST has the pattern:
diff --git a/tests/python/relay/test_cpp_build_module.py 
b/tests/python/relay/test_cpp_build_module.py
index ccf961fbe4..e8e6676863 100644
--- a/tests/python/relay/test_cpp_build_module.py
+++ b/tests/python/relay/test_cpp_build_module.py
@@ -127,7 +127,9 @@ def test_bf16_build():
     relu_bf16 = relay.nn.relu(bn_bf16[0])
     maxpool_bf16 = relay.nn.max_pool2d(relu_bf16, pool_size=(2, 2), 
strides=(2, 2))
     avgpool_bf16 = relay.nn.avg_pool2d(maxpool_bf16, pool_size=(2, 2), 
strides=(2, 2))
-    mod_bf16 = tvm.IRModule.from_expr(avgpool_bf16)
+    flattened_bf16 = relay.nn.batch_flatten(avgpool_bf16)
+    softmax_bf16 = relay.nn.softmax(flattened_bf16)
+    mod_bf16 = tvm.IRModule.from_expr(softmax_bf16)
     with tvm.transform.PassContext(opt_level=3):
         relay.build(mod_bf16, target="llvm", params=params)
 

Reply via email to