This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 674167805c Revert "[Unity] Fix IndexDataTypeNormalizer so that it 
correctly handles corner case" (#16241)
674167805c is described below

commit 674167805cb16b4c69e3ce947edd476825ee80ad
Author: Tianqi Chen <[email protected]>
AuthorDate: Thu Dec 14 10:14:14 2023 -0500

    Revert "[Unity] Fix IndexDataTypeNormalizer so that it correctly handles 
corner case" (#16241)
    
    Revert "[Unity] Fix IndexDataTypeNormalizer so that it correctly handles 
corner case (#16235)"
    
    This reverts commit f7b0193f9d80e1b977ee6aa091697b77a8169718.
---
 include/tvm/tir/data_type_rewriter.h |  3 ---
 src/tir/ir/data_type_rewriter.cc     | 21 +++------------------
 2 files changed, 3 insertions(+), 21 deletions(-)

diff --git a/include/tvm/tir/data_type_rewriter.h 
b/include/tvm/tir/data_type_rewriter.h
index 7ee3a9ba6b..8bdcc097a2 100644
--- a/include/tvm/tir/data_type_rewriter.h
+++ b/include/tvm/tir/data_type_rewriter.h
@@ -84,9 +84,6 @@ class DataTypeLegalizer : public StmtExprMutator {
   std::unordered_map<const IterVarNode*, IterVar> ivmap_;
   // a map from original vars to ones with new dtype
   std::unordered_map<const VarNode*, Var> var_remap_;
-  // number of iterations. The first iteration collects var_remap_,
-  // and the second iteration performs rewrite
-  int iter_ = 0;
 };
 
 /*!
diff --git a/src/tir/ir/data_type_rewriter.cc b/src/tir/ir/data_type_rewriter.cc
index 470f4645ac..aa8f2f3f6f 100644
--- a/src/tir/ir/data_type_rewriter.cc
+++ b/src/tir/ir/data_type_rewriter.cc
@@ -89,9 +89,6 @@ Stmt DataTypeLegalizer::VisitStmt_(const AttrStmtNode* op) {
     ICHECK(iv != nullptr) << "Expected type to be IterVarNode"
                           << ", but get " << op->node->GetTypeKey();
     PrimExpr e = VisitExpr(iv->var);
-    if (iter_ == 0) {
-      return GetRef<AttrStmt>(op);
-    }
     Var var = Downcast<Var>(e);
     if (ivmap_.find(iv) == ivmap_.end()) {
       Range dom = iv->dom;
@@ -396,9 +393,6 @@ IterVar IndexDataTypeRewriter::VisitIterVar(const IterVar& 
iter_var) {
 }
 
 Buffer IndexDataTypeRewriter::VisitBuffer(const Buffer& buffer) {
-  if (iter_ == 0) {
-    return buffer;
-  }
   bool is_enabled = is_enabled_;
 
   is_enabled_ = true;
@@ -588,10 +582,6 @@ IndexDataTypeNormalizer::IndexDataTypeNormalizer(DataType 
target_data_type)
     : target_data_type_(std::move(target_data_type)) {}
 
 PrimFunc IndexDataTypeNormalizer::Rewrite(PrimFunc func) {
-  // collect var remap
-  VisitStmt(std::move(func->body));
-  iter_++;
-  // start rewrite
   Map<Var, Buffer> new_buffer_map = func->buffer_map;
   for (const auto& [var, buffer] : func->buffer_map) {
     new_buffer_map.Set(var, VisitBuffer(buffer));
@@ -628,15 +618,10 @@ PrimExpr IndexDataTypeNormalizer::VisitExpr_(const 
IntImmNode* op) {
 }
 
 PrimExpr IndexDataTypeNormalizer::VisitExpr_(const VarNode* op) {
-  // In the first iteration, collect var_remap_
-  if (iter_ == 0) {
-    if (is_enabled_ && CanRewriteDType(op->dtype) && op->dtype != 
target_data_type_ &&
-        !var_remap_.count(op)) {
-      var_remap_[op] = GetRef<Var>(op).copy_with_dtype(target_data_type_);
-    }
-    return GetRef<Var>(op);
+  if (is_enabled_ && CanRewriteDType(op->dtype) && op->dtype != 
target_data_type_ &&
+      !var_remap_.count(op)) {
+    var_remap_[op] = GetRef<Var>(op).copy_with_dtype(target_data_type_);
   }
-  // In the second iteration, rewrite the var
   return DataTypeLegalizer::VisitExpr_(op);
 }
 

Reply via email to