Menooker commented on a change in pull request #5601:
URL: https://github.com/apache/incubator-tvm/pull/5601#discussion_r441250362



##########
File path: src/tir/transforms/bf16_legalize.cc
##########
@@ -0,0 +1,387 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file bf16_legalize.cc
+ * \brief legalize bf16 type by adding cast_to_fp32
+ */
+
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/transform.h>
+
+#include <cmath>
+#include <tuple>
+
+#include "../../arith/ir_mutator_with_analyzer.h"
+#include "../../arith/ir_visitor_with_analyzer.h"
+
+namespace tvm {
+namespace tir {
+
+using arith::Analyzer;
+using arith::IRMutatorWithAnalyzer;
+
+class BF16PromoteRewriter : public StmtExprMutator {
+ public:
+  BF16PromoteRewriter() {}
+
+  Stmt operator()(Stmt s) { return VisitStmt(s); }
+
+  std::tuple<PrimExpr, PrimExpr> DoCast(PrimExpr orig_a, PrimExpr orig_b, 
bool* is_bfloat16) {
+    auto a = this->VisitExpr(orig_a);
+    auto b = this->VisitExpr(orig_b);
+    *is_bfloat16 = false;
+    if (a->dtype.is_bfloat16()) {
+      CHECK(b->dtype.is_bfloat16());
+      *is_bfloat16 = true;
+    } else if (b->dtype.is_bfloat16()) {
+      CHECK(a->dtype.is_bfloat16());
+      *is_bfloat16 = true;
+    }
+
+    if (*is_bfloat16) {
+      DataType fp32ty(kDLFloat, 32, 1);
+      a = CastNode::make(fp32ty, a);
+      b = CastNode::make(fp32ty, b);
+    }
+    return std::make_tuple(a, b);
+  }
+
+  PrimExpr VisitExpr_(const AddNode* op) final;
+  PrimExpr VisitExpr_(const SubNode* op) final;
+  PrimExpr VisitExpr_(const MulNode* op) final;
+  PrimExpr VisitExpr_(const DivNode* op) final;
+  PrimExpr VisitExpr_(const MinNode* op) final;
+  PrimExpr VisitExpr_(const MaxNode* op) final;
+  PrimExpr VisitExpr_(const LTNode* op) final;
+  PrimExpr VisitExpr_(const LENode* op) final;
+  PrimExpr VisitExpr_(const GTNode* op) final;
+  PrimExpr VisitExpr_(const GENode* op) final;
+};
+
+#define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC)       \
+  PrimExpr BF16PromoteRewriter::VisitExpr_(const OP* op) {      \
+    PrimExpr a, b;                                              \
+    bool is_bfloat16;                                           \
+    std::tie(a, b) = DoCast(op->a, op->b, &is_bfloat16);        \
+    if (a.same_as(op->a) && b.same_as(op->b)) {                 \
+      return GetRef<PrimExpr>(op);                              \
+    } else {                                                    \
+      auto ret = FUNC(a, b);                                    \
+      if (!is_bfloat16)                                         \
+        return ret;                                             \
+      else                                                      \
+        return CastNode::make(DataType(kDLBfloat, 16, 1), ret); \
+    }                                                           \
+  }
+
+#define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(OP, FUNC) \
+  PrimExpr BF16PromoteRewriter::VisitExpr_(const OP* op) {        \
+    PrimExpr a, b;                                                \
+    bool is_bfloat16;                                             \
+    std::tie(a, b) = DoCast(op->a, op->b, &is_bfloat16);          \
+    if (a.same_as(op->a) && b.same_as(op->b)) {                   \
+      return GetRef<PrimExpr>(op);                                \
+    } else {                                                      \
+      auto ret = FUNC(a, b);                                      \
+      return ret;                                                 \
+    }                                                             \
+  }
+
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(AddNode, operator+)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(SubNode, operator-)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MulNode, operator*)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(DivNode, div)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MinNode, min)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MaxNode, max)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(LTNode, operator<)   // 
NOLINT(*)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(LENode, operator<=)  // 
NOLINT(*)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(GTNode, operator>)   // 
NOLINT(*)
+DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH_NO_CAST(GENode, operator>=)  // 
NOLINT(*)
+
+/*
+ * Eliminate verbose casting between fp32 and bf16
+ * Checks if the AST has the pattern:
+ *     castto32(castto16(some_fp32_op(...)))
+ * The verbose casting is generated by BF16Promote for multiple
+ * bf16 Ops in a row. e.g.:
+ *  X[i] + Y[i] + T[i] =>
+ *  bf16((float32(bf16((float32(X[i]) + float32(Y[i])))) + float32(T[i])))
+ * After this pass:
+ *  bf16(float32(X[i]) + float32(Y[i]) + float32(T[i]))
+ */
+class BF16CastEliminationRewriter : public StmtExprMutator {
+ public:
+  BF16CastEliminationRewriter() {}
+
+  Stmt operator()(Stmt s) { return VisitStmt(s); }
+
+  PrimExpr VisitExpr_(const CastNode* op) final {
+    auto op_val = StmtExprMutator::VisitExpr(op->value);
+    if (op->dtype.is_float() && op->dtype.bits() == 32) {
+      // if is cast_to_fp32, check if op->value is cast_to_fp16
+      // and op->value->value is a float32
+      if (auto innercast = op_val.as<CastNode>()) {
+        if (innercast->dtype.is_bfloat16() && 
innercast->value->dtype.is_float() &&
+            innercast->value->dtype.bits() == 32) {
+          return innercast->value;
+        }
+      }
+    }
+    if (op->value.same_as(op_val)) return GetRef<PrimExpr>(op);
+    return CastNode::make(op->dtype, op_val);
+  }
+};
+
+// implementation from
+// https://github.com/pytorch/pytorch/blob/master/c10/util/BFloat16.h
+inline uint16_t round_to_nearest_even(float src) {

Review comment:
       changed




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to