[GitHub] [incubator-tvm] Menooker commented on a change in pull request #5601: [DataType] Add bfloat16

2020-06-16 Thread GitBox


Menooker commented on a change in pull request #5601:
URL: https://github.com/apache/incubator-tvm/pull/5601#discussion_r441250362



##
File path: src/tir/transforms/bf16_legalize.cc
##
@@ -0,0 +1,387 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file bf16_legalize.cc
+ * \brief legalize bf16 type by adding cast_to_fp32
+ */
+
+#include 
+#include 
+#include 
+
+#include 
+#include 
+
+#include "../../arith/ir_mutator_with_analyzer.h"
+#include "../../arith/ir_visitor_with_analyzer.h"
+
+namespace tvm {
+namespace tir {
+
+using arith::Analyzer;
+using arith::IRMutatorWithAnalyzer;
+
+class BF16PromoteRewriter : public StmtExprMutator {
+ public:
+  BF16PromoteRewriter() {}
+
+  Stmt operator()(Stmt s) { return VisitStmt(s); }
+
+  std::tuple DoCast(PrimExpr orig_a, PrimExpr orig_b, 
bool* is_bfloat16) {
+auto a = this->VisitExpr(orig_a);
+auto b = this->VisitExpr(orig_b);
+*is_bfloat16 = false;
+if (a->dtype.is_bfloat16()) {
+  CHECK(b->dtype.is_bfloat16());
+  *is_bfloat16 = true;
+} else if (b->dtype.is_bfloat16()) {
+  CHECK(a->dtype.is_bfloat16());
+  *is_bfloat16 = true;
+}
+
+if (*is_bfloat16) {
+  DataType fp32ty(kDLFloat, 32, 1);
+  a = CastNode::make(fp32ty, a);
+  b = CastNode::make(fp32ty, b);
+}
+return std::make_tuple(a, b);
+  }
+
+  PrimExpr VisitExpr_(const AddNode* op) final;
+  PrimExpr VisitExpr_(const SubNode* op) final;
+  PrimExpr VisitExpr_(const MulNode* op) final;
+  PrimExpr VisitExpr_(const DivNode* op) final;
+  PrimExpr VisitExpr_(const MinNode* op) final;
+  PrimExpr VisitExpr_(const MaxNode* op) final;
+  PrimExpr VisitExpr_(const LTNode* op) final;
+  PrimExpr VisitExpr_(const LENode* op) final;
+  PrimExpr VisitExpr_(const GTNode* op) final;
+  PrimExpr VisitExpr_(const GENode* op) final;
+};
+
+#define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC)   \
+  PrimExpr BF16PromoteRewriter::VisitExpr_(const OP* op) {  \
+PrimExpr a, b;  \
+bool is_bfloat16;   \
+std::tie(a, b) = DoCast(op->a, op->b, _bfloat16);\
+if (a.same_as(op->a) && b.same_as(op->b)) { \
+  return GetRef(op);  \
+} else {\
+  auto ret = FUNC(a, b);\
+  if (!is_bfloat16) \
+return ret; \
+  else  \
+return CastNode::make(DataType(kDLBfloat, 16, 1), ret); \
+}   \
+  }
+
+#define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(OP, FUNC) \
+  PrimExpr BF16PromoteRewriter::VisitExpr_(const OP* op) {\
+PrimExpr a, b;\
+bool is_bfloat16; \
+std::tie(a, b) = DoCast(op->a, op->b, _bfloat16);  \
+if (a.same_as(op->a) && b.same_as(op->b)) {   \
+  return GetRef(op);\
+} else {  \
+  auto ret = FUNC(a, b);  \
+  return ret; \
+} \
+  }
+
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(AddNode, operator+)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(SubNode, operator-)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MulNode, operator*)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(DivNode, div)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MinNode, min)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MaxNode, max)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(LTNode, operator<)   // 
NOLINT(*)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(LENode, operator<=)  // 
NOLINT(*)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(GTNode, operator>)   // 
NOLINT(*)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(GENode, operator>=)  // 

[GitHub] [incubator-tvm] Menooker commented on a change in pull request #5601: [DataType] Add bfloat16

2020-06-16 Thread GitBox


Menooker commented on a change in pull request #5601:
URL: https://github.com/apache/incubator-tvm/pull/5601#discussion_r441250264



##
File path: python/tvm/_ffi/_cython/base.pxi
##
@@ -27,7 +27,7 @@ cdef enum TVMTypeCode:
 kUInt = 1
 kFloat = 2
 kTVMOpaqueHandle = 3
-kTVMNullptr = 4
+kBFloat = 4

Review comment:
   changed





This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [incubator-tvm] Menooker commented on a change in pull request #5601: [DataType] Add bfloat16

2020-06-16 Thread GitBox


Menooker commented on a change in pull request #5601:
URL: https://github.com/apache/incubator-tvm/pull/5601#discussion_r441250295



##
File path: include/tvm/runtime/data_type.h
##
@@ -372,7 +372,7 @@ inline DLDataType String2DLDataType(std::string s) {
 t.lanes = 1;
 return t;
   } else if (s.substr(0, 6) == "bfloat") {
-t.code = kTVMBFloat;
+t.code = kDLBfloat;

Review comment:
   changed





This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [incubator-tvm] Menooker commented on a change in pull request #5601: [DataType] Add bfloat16

2020-06-16 Thread GitBox


Menooker commented on a change in pull request #5601:
URL: https://github.com/apache/incubator-tvm/pull/5601#discussion_r441249227



##
File path: python/tvm/_ffi/runtime_ctypes.py
##
@@ -96,6 +98,9 @@ def __init__(self, type_str):
 self.type_code = DataTypeCode.HANDLE
 bits = 64
 head = ""
+elif head.startswith("bfloat"):
+self.type_code = 4

Review comment:
   > not sure if it is good to hard code here
   
   Change to DataTypeCode. TVM refactors a lot (which is good). And when this 
PR was raised, all the type code here used hard codes.
   
   The other two issues you raised were also changed as required.





This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [incubator-tvm] Menooker commented on a change in pull request #5601: [DataType] Add bfloat16

2020-06-16 Thread GitBox


Menooker commented on a change in pull request #5601:
URL: https://github.com/apache/incubator-tvm/pull/5601#discussion_r441249227



##
File path: python/tvm/_ffi/runtime_ctypes.py
##
@@ -96,6 +98,9 @@ def __init__(self, type_str):
 self.type_code = DataTypeCode.HANDLE
 bits = 64
 head = ""
+elif head.startswith("bfloat"):
+self.type_code = 4

Review comment:
   > not sure if it is good to hard code here
   
   Change to DataTypeCode. TVM refactors a lot (which is good). And when this 
PR is raised, all the type code here uses hard codes.
   
   The other two issues you raised are also changed as required.





This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [incubator-tvm] Menooker commented on a change in pull request #5601: [DataType] Add bfloat16

2020-06-16 Thread GitBox


Menooker commented on a change in pull request #5601:
URL: https://github.com/apache/incubator-tvm/pull/5601#discussion_r441248116



##
File path: include/tvm/runtime/data_type.h
##
@@ -72,6 +73,9 @@ class DataType {
 data_.code = static_cast(code);
 data_.bits = static_cast(bits);
 data_.lanes = static_cast(lanes);
+if (code == kBFloat) {
+  CHECK_EQ(bits, 16);

Review comment:
   I understand your concern. Any suggestions for the location where we put 
this check? Thanks.





This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [incubator-tvm] Menooker commented on a change in pull request #5601: [DataType] Add bfloat16

2020-06-12 Thread GitBox


Menooker commented on a change in pull request #5601:
URL: https://github.com/apache/incubator-tvm/pull/5601#discussion_r439704765



##
File path: tests/python/unittest/test_tir_transform_bf16_legalize.py
##
@@ -0,0 +1,152 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+import topi
+from tvm import te
+from tvm.tir import const
+
+
+def lower_stmt(sche, params, passfunc):
+func = tvm.driver.build_module.form_irmodule(sche, params, "main", 
None)["main"]
+func = passfunc()(
+tvm.IRModule.from_expr(func))["main"]
+stmt = func.body
+return stmt
+
+
+def test_promote():
+def runpass(op, passfunc):
+a = te.placeholder((100,), dtype='bfloat16')
+b = te.placeholder((100,), dtype='bfloat16')
+c = te.compute((100,), lambda i: op(a[i], b[i]))
+s = te.create_schedule(c.op)
+return lower_stmt(s, [a, b, c], passfunc)
+
+def get_promoted(op):
+a = te.placeholder((100,), dtype='bfloat16')
+b = te.placeholder((100,), dtype='bfloat16')
+c = te.compute((100,), lambda i:
+topi.cast(op(topi.cast(a[i],'float'),
+topi.cast(b[i],'float')), 'bfloat16')
+)
+s = te.create_schedule(c.op)
+func = tvm.driver.build_module.form_irmodule(s, [a,b,c], "main", 
None)["main"]
+return func.body
+
+def test_promoted(op):
+stmt = runpass(op, tvm.tir.transform.BF16Promote)
+tvm.ir.assert_structural_equal(stmt, get_promoted(op))
+test_promoted(topi.add)
+test_promoted(topi.subtract)
+test_promoted(topi.multiply)
+test_promoted(topi.divide)
+
+def test_eliminate():
+def to32(v):
+return topi.cast(v, 'float')
+def to16(v):
+return topi.cast(v, 'bfloat16')
+def get_eliminated():
+a = te.placeholder((100,), dtype='bfloat16')
+b = te.placeholder((100,), dtype='bfloat16')
+c = te.compute((100,), lambda i: to16(
+topi.add(
+to32(
+to16(
+topi.add(
+to32(a[i]),
+to32(b[i]),
+)
+)
+),
+to32(
+to16(
+topi.add(
+to32(a[i]),
+to32(b[i]),
+)
+)
+)
+)
+))
+s = te.create_schedule(c.op)
+stmt = lower_stmt(s, [a, b, c], tvm.tir.transform.BF16CastElimination)
+return stmt
+
+def get_target():
+a = te.placeholder((100,), dtype='bfloat16')
+b = te.placeholder((100,), dtype='bfloat16')
+c = te.compute((100,), lambda i: to16(
+topi.add(topi.add(
+to32(a[i]),
+to32(b[i]),
+),
+topi.add(
+to32(a[i]),
+to32(b[i]),
+)
+)
+))
+s = te.create_schedule(c.op)
+func = tvm.driver.build_module.form_irmodule(s, [a,b,c], "main", 
None)["main"]
+return func.body
+tvm.ir.assert_structural_equal(get_eliminated(), get_target())
+
+def test_legalize():
+def to32(v):
+uint32_v = topi.cast(v, "uint32")
+uint32_v = tvm.tir.call_pure_intrin("uint32", "shift_left", uint32_v, 
tvm.tir.const(16, "uint32"))
+return tvm.tir.call_pure_intrin("float32", "reinterpret", uint32_v)
+def to16(v):
+uint32_v = tvm.tir.call_pure_intrin("uint32", "reinterpret", v)
+rounding_bias = tvm.tir.call_pure_intrin("uint32", "shift_right", 
uint32_v, tvm.tir.const(16, "uint32"))
+rounding_bias = tvm.tir.call_pure_intrin("uint32", "bitwise_and", 
rounding_bias, tvm.tir.const(1, "uint32"))
+rounding_bias = rounding_bias + tvm.tir.const(0x7FFF, "uint16")
+uint32_v = uint32_v + rounding_bias
+uint32_v = tvm.tir.call_pure_intrin("uint32", "shift_right", uint32_v, 
tvm.tir.const(16, "uint32"))
+return topi.cast(uint32_v, 'uint16')
+
+def 

[GitHub] [incubator-tvm] Menooker commented on a change in pull request #5601: [DataType] Add bfloat16

2020-06-03 Thread GitBox


Menooker commented on a change in pull request #5601:
URL: https://github.com/apache/incubator-tvm/pull/5601#discussion_r434955796



##
File path: python/tvm/_ffi/runtime_ctypes.py
##
@@ -58,7 +58,8 @@ class DataType(ctypes.Structure):
 0 : 'int',
 1 : 'uint',
 2 : 'float',
-4 : 'handle'
+4 : 'handle',
+65: 'bfloat'

Review comment:
   4 is the type code of bfloat in dlpack





This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [incubator-tvm] Menooker commented on a change in pull request #5601: [DataType] Add bfloat16

2020-06-03 Thread GitBox


Menooker commented on a change in pull request #5601:
URL: https://github.com/apache/incubator-tvm/pull/5601#discussion_r434955653



##
File path: python/tvm/_ffi/runtime_ctypes.py
##
@@ -58,7 +58,8 @@ class DataType(ctypes.Structure):
 0 : 'int',
 1 : 'uint',
 2 : 'float',
-4 : 'handle'
+4 : 'handle',
+65: 'bfloat'

Review comment:
   May I ask what I am supposed to do here? In the newest commit, I changed 
`4 : 'handle'` here to 3, and let 4 mapped to bfloat. I also changed `class 
TypeCode(object):` and relavent C++ class to let `NULL` mapped to 21, instead 
of 4. I have noticed that in `class TypeCode(object):`, we have to define both 
`BFLOAT` and `NULL`, so I thought we needed to resolve the conflict as they 
were all mapped to 4.





This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [incubator-tvm] Menooker commented on a change in pull request #5601: [DataType] Add bfloat16

2020-05-31 Thread GitBox


Menooker commented on a change in pull request #5601:
URL: https://github.com/apache/incubator-tvm/pull/5601#discussion_r433016729



##
File path: python/tvm/_ffi/runtime_ctypes.py
##
@@ -58,7 +58,8 @@ class DataType(ctypes.Structure):
 0 : 'int',
 1 : 'uint',
 2 : 'float',
-4 : 'handle'
+4 : 'handle',
+65: 'bfloat'

Review comment:
   @tqchen Hi, I found an inconsistency here: in Python binding, 4 is 
mapped to 'handle', while in C++ code, 4 is for Nullptr and 3 is for handle. Is 
this expected or a bug?





This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org




[GitHub] [incubator-tvm] Menooker commented on a change in pull request #5601: [DataType] Add bfloat16

2020-05-29 Thread GitBox


Menooker commented on a change in pull request #5601:
URL: https://github.com/apache/incubator-tvm/pull/5601#discussion_r432449575



##
File path: src/tir/transforms/bf16_legalize.cc
##
@@ -0,0 +1,384 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file bf16_legalize.cc
+ * \brief legalize bf16 type by adding cast_to_fp32
+ */
+
+#include 
+#include 
+#include 
+
+#include 
+#include 
+
+#include "../../arith/ir_mutator_with_analyzer.h"
+#include "../../arith/ir_visitor_with_analyzer.h"
+
+namespace tvm {
+namespace tir {
+
+using arith::Analyzer;
+using arith::IRMutatorWithAnalyzer;
+
+class BF16PromoteRewriter : public StmtExprMutator {
+ public:
+  BF16PromoteRewriter() {}
+
+  Stmt operator()(Stmt s) { return VisitStmt(s); }
+
+  std::tuple DoCast(PrimExpr orig_a, PrimExpr orig_b, 
bool* is_bfloat16) {
+auto a = this->VisitExpr(orig_a);
+auto b = this->VisitExpr(orig_b);
+*is_bfloat16 = false;
+if (a->dtype.is_bfloat16()) {
+  CHECK(b->dtype.is_bfloat16());
+  *is_bfloat16 = true;
+} else if (b->dtype.is_bfloat16()) {
+  CHECK(a->dtype.is_bfloat16());
+  *is_bfloat16 = true;
+}
+
+if (*is_bfloat16) {
+  DataType fp32ty(kDLFloat, 32, 1);
+  a = CastNode::make(fp32ty, a);
+  b = CastNode::make(fp32ty, b);
+}
+return std::make_tuple(a, b);
+  }
+
+  PrimExpr VisitExpr_(const AddNode* op) final;
+  PrimExpr VisitExpr_(const SubNode* op) final;
+  PrimExpr VisitExpr_(const MulNode* op) final;
+  PrimExpr VisitExpr_(const DivNode* op) final;
+  PrimExpr VisitExpr_(const MinNode* op) final;
+  PrimExpr VisitExpr_(const MaxNode* op) final;
+  PrimExpr VisitExpr_(const LTNode* op) final;
+  PrimExpr VisitExpr_(const LENode* op) final;
+  PrimExpr VisitExpr_(const GTNode* op) final;
+  PrimExpr VisitExpr_(const GENode* op) final;
+};
+
+#define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC)\
+  PrimExpr BF16PromoteRewriter::VisitExpr_(const OP* op) {   \
+PrimExpr a, b;   \
+bool is_bfloat16;\
+std::tie(a, b) = DoCast(op->a, op->b, _bfloat16); \
+if (a.same_as(op->a) && b.same_as(op->b)) {  \
+  return GetRef(op);   \
+} else { \
+  auto ret = FUNC(a, b); \
+  if (!is_bfloat16)  \
+return ret;  \
+  else   \
+return CastNode::make(DataType(kTVMBFloat, 16, 1), ret); \
+}\
+  }
+
+#define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(OP, FUNC) \
+  PrimExpr BF16PromoteRewriter::VisitExpr_(const OP* op) {\
+PrimExpr a, b;\
+bool is_bfloat16; \
+std::tie(a, b) = DoCast(op->a, op->b, _bfloat16);  \
+if (a.same_as(op->a) && b.same_as(op->b)) {   \
+  return GetRef(op);\
+} else {  \
+  auto ret = FUNC(a, b);  \
+  return ret; \
+} \
+  }
+
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(AddNode, operator+)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(SubNode, operator-)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MulNode, operator*)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(DivNode, div)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MinNode, min)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MaxNode, max)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(LTNode, operator<)   // 
NOLINT(*)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(LENode, operator<=)  // 
NOLINT(*)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(GTNode, operator>)   // 
NOLINT(*)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(GENode, 

[GitHub] [incubator-tvm] Menooker commented on a change in pull request #5601: [DataType] Add bfloat16

2020-05-29 Thread GitBox


Menooker commented on a change in pull request #5601:
URL: https://github.com/apache/incubator-tvm/pull/5601#discussion_r432401610



##
File path: src/tir/transforms/bf16_legalize.cc
##
@@ -0,0 +1,384 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file bf16_legalize.cc
+ * \brief legalize bf16 type by adding cast_to_fp32
+ */
+
+#include 
+#include 
+#include 
+
+#include 
+#include 
+
+#include "../../arith/ir_mutator_with_analyzer.h"
+#include "../../arith/ir_visitor_with_analyzer.h"
+
+namespace tvm {
+namespace tir {
+
+using arith::Analyzer;
+using arith::IRMutatorWithAnalyzer;
+
+class BF16PromoteRewriter : public StmtExprMutator {
+ public:
+  BF16PromoteRewriter() {}
+
+  Stmt operator()(Stmt s) { return VisitStmt(s); }
+
+  std::tuple DoCast(PrimExpr orig_a, PrimExpr orig_b, 
bool* is_bfloat16) {
+auto a = this->VisitExpr(orig_a);
+auto b = this->VisitExpr(orig_b);
+*is_bfloat16 = false;
+if (a->dtype.is_bfloat16()) {
+  CHECK(b->dtype.is_bfloat16());
+  *is_bfloat16 = true;
+} else if (b->dtype.is_bfloat16()) {
+  CHECK(a->dtype.is_bfloat16());
+  *is_bfloat16 = true;
+}
+
+if (*is_bfloat16) {
+  DataType fp32ty(kDLFloat, 32, 1);
+  a = CastNode::make(fp32ty, a);
+  b = CastNode::make(fp32ty, b);
+}
+return std::make_tuple(a, b);
+  }
+
+  PrimExpr VisitExpr_(const AddNode* op) final;
+  PrimExpr VisitExpr_(const SubNode* op) final;
+  PrimExpr VisitExpr_(const MulNode* op) final;
+  PrimExpr VisitExpr_(const DivNode* op) final;
+  PrimExpr VisitExpr_(const MinNode* op) final;
+  PrimExpr VisitExpr_(const MaxNode* op) final;
+  PrimExpr VisitExpr_(const LTNode* op) final;
+  PrimExpr VisitExpr_(const LENode* op) final;
+  PrimExpr VisitExpr_(const GTNode* op) final;
+  PrimExpr VisitExpr_(const GENode* op) final;
+};
+
+#define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC)\
+  PrimExpr BF16PromoteRewriter::VisitExpr_(const OP* op) {   \
+PrimExpr a, b;   \
+bool is_bfloat16;\
+std::tie(a, b) = DoCast(op->a, op->b, _bfloat16); \
+if (a.same_as(op->a) && b.same_as(op->b)) {  \
+  return GetRef(op);   \
+} else { \
+  auto ret = FUNC(a, b); \
+  if (!is_bfloat16)  \
+return ret;  \
+  else   \
+return CastNode::make(DataType(kTVMBFloat, 16, 1), ret); \
+}\
+  }
+
+#define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(OP, FUNC) \
+  PrimExpr BF16PromoteRewriter::VisitExpr_(const OP* op) {\
+PrimExpr a, b;\
+bool is_bfloat16; \
+std::tie(a, b) = DoCast(op->a, op->b, _bfloat16);  \
+if (a.same_as(op->a) && b.same_as(op->b)) {   \
+  return GetRef(op);\
+} else {  \
+  auto ret = FUNC(a, b);  \
+  return ret; \
+} \
+  }
+
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(AddNode, operator+)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(SubNode, operator-)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MulNode, operator*)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(DivNode, div)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MinNode, min)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MaxNode, max)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(LTNode, operator<)   // 
NOLINT(*)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(LENode, operator<=)  // 
NOLINT(*)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(GTNode, operator>)   // 
NOLINT(*)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(GENode,