This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 674167805c Revert "[Unity] Fix IndexDataTypeNormalizer so that it
correctly handles corner case" (#16241)
674167805c is described below
commit 674167805cb16b4c69e3ce947edd476825ee80ad
Author: Tianqi Chen <[email protected]>
AuthorDate: Thu Dec 14 10:14:14 2023 -0500
Revert "[Unity] Fix IndexDataTypeNormalizer so that it correctly handles
corner case" (#16241)
Revert "[Unity] Fix IndexDataTypeNormalizer so that it correctly handles
corner case (#16235)"
This reverts commit f7b0193f9d80e1b977ee6aa091697b77a8169718.
---
include/tvm/tir/data_type_rewriter.h | 3 ---
src/tir/ir/data_type_rewriter.cc | 21 +++------------------
2 files changed, 3 insertions(+), 21 deletions(-)
diff --git a/include/tvm/tir/data_type_rewriter.h
b/include/tvm/tir/data_type_rewriter.h
index 7ee3a9ba6b..8bdcc097a2 100644
--- a/include/tvm/tir/data_type_rewriter.h
+++ b/include/tvm/tir/data_type_rewriter.h
@@ -84,9 +84,6 @@ class DataTypeLegalizer : public StmtExprMutator {
std::unordered_map<const IterVarNode*, IterVar> ivmap_;
// a map from original vars to ones with new dtype
std::unordered_map<const VarNode*, Var> var_remap_;
- // number of iterations. The first iteration collects var_remap_,
- // and the second iteration performs rewrite
- int iter_ = 0;
};
/*!
diff --git a/src/tir/ir/data_type_rewriter.cc b/src/tir/ir/data_type_rewriter.cc
index 470f4645ac..aa8f2f3f6f 100644
--- a/src/tir/ir/data_type_rewriter.cc
+++ b/src/tir/ir/data_type_rewriter.cc
@@ -89,9 +89,6 @@ Stmt DataTypeLegalizer::VisitStmt_(const AttrStmtNode* op) {
ICHECK(iv != nullptr) << "Expected type to be IterVarNode"
<< ", but get " << op->node->GetTypeKey();
PrimExpr e = VisitExpr(iv->var);
- if (iter_ == 0) {
- return GetRef<AttrStmt>(op);
- }
Var var = Downcast<Var>(e);
if (ivmap_.find(iv) == ivmap_.end()) {
Range dom = iv->dom;
@@ -396,9 +393,6 @@ IterVar IndexDataTypeRewriter::VisitIterVar(const IterVar&
iter_var) {
}
Buffer IndexDataTypeRewriter::VisitBuffer(const Buffer& buffer) {
- if (iter_ == 0) {
- return buffer;
- }
bool is_enabled = is_enabled_;
is_enabled_ = true;
@@ -588,10 +582,6 @@ IndexDataTypeNormalizer::IndexDataTypeNormalizer(DataType
target_data_type)
: target_data_type_(std::move(target_data_type)) {}
PrimFunc IndexDataTypeNormalizer::Rewrite(PrimFunc func) {
- // collect var remap
- VisitStmt(std::move(func->body));
- iter_++;
- // start rewrite
Map<Var, Buffer> new_buffer_map = func->buffer_map;
for (const auto& [var, buffer] : func->buffer_map) {
new_buffer_map.Set(var, VisitBuffer(buffer));
@@ -628,15 +618,10 @@ PrimExpr IndexDataTypeNormalizer::VisitExpr_(const
IntImmNode* op) {
}
PrimExpr IndexDataTypeNormalizer::VisitExpr_(const VarNode* op) {
- // In the first iteration, collect var_remap_
- if (iter_ == 0) {
- if (is_enabled_ && CanRewriteDType(op->dtype) && op->dtype !=
target_data_type_ &&
- !var_remap_.count(op)) {
- var_remap_[op] = GetRef<Var>(op).copy_with_dtype(target_data_type_);
- }
- return GetRef<Var>(op);
+ if (is_enabled_ && CanRewriteDType(op->dtype) && op->dtype !=
target_data_type_ &&
+ !var_remap_.count(op)) {
+ var_remap_[op] = GetRef<Var>(op).copy_with_dtype(target_data_type_);
}
- // In the second iteration, rewrite the var
return DataTypeLegalizer::VisitExpr_(op);
}