junrushao commented on code in PR #13327:
URL: https://github.com/apache/tvm/pull/13327#discussion_r1024384023


##########
src/tir/ir/data_type_rewriter.cc:
##########
@@ -191,5 +191,352 @@ PrimExpr DataTypeLegalizer::VisitExpr_(const CallNode* 
op) {
   return e;
 }
 
+Stmt IndexDataTypeRewriter::VisitStmt_(const AllocateNode* op) {
+  bool is_enabled = is_enabled_;
+  is_enabled_ = true;
+  auto new_extents = op->extents.Map([this](const PrimExpr& e) { return 
this->VisitExpr(e); });
+  auto new_cond = VisitExpr(op->condition);
+  is_enabled_ = is_enabled;
+  auto new_body = this->VisitStmt(op->body);
+  if (!new_extents.same_as(op->extents) || !new_cond.same_as(op->condition) ||
+      !new_body.same_as(op->body)) {
+    Allocate new_allocate = GetRef<Allocate>(op);
+    auto* n = new_allocate.CopyOnWrite();
+    n->extents = std::move(new_extents);
+    n->condition = std::move(new_cond);
+    n->body = std::move(new_body);
+    return std::move(new_allocate);
+  } else {
+    return GetRef<Stmt>(op);
+  }
+}
+
+Stmt IndexDataTypeRewriter::VisitStmt_(const DeclBufferNode* op) {
+  Buffer new_buffer = VisitBuffer(op->buffer);
+  DeclBuffer decl_buffer = 
Downcast<DeclBuffer>(StmtExprMutator::VisitStmt_(op));
+  if (!new_buffer.same_as(op->buffer)) {
+    decl_buffer.CopyOnWrite()->buffer = new_buffer;
+  }
+  return std::move(decl_buffer);
+}
+
+Stmt IndexDataTypeRewriter::VisitStmt_(const BlockRealizeNode* op) {
+  bool is_condition = is_condition_;
+  is_condition_ = true;
+  auto new_predicate = VisitExpr(op->predicate);
+  is_condition_ = is_condition;
+
+  bool is_enabled = is_enabled_;
+  is_enabled_ = true;
+  auto new_iter_values =
+      op->iter_values.Map([this](const PrimExpr& e) { return 
this->VisitExpr(e); });
+  is_enabled_ = is_enabled;
+  Block new_body = Downcast<Block>(this->VisitStmt(op->block));
+  if (!new_predicate.same_as(op->predicate) || 
!new_iter_values.same_as(op->iter_values) ||
+      !new_body.same_as(op->block)) {
+    BlockRealize new_block_realize = GetRef<BlockRealize>(op);
+    auto* n = new_block_realize.CopyOnWrite();
+    n->predicate = std::move(new_predicate);
+    n->iter_values = std::move(new_iter_values);
+    n->block = std::move(new_body);
+    return std::move(new_block_realize);
+  } else {
+    return GetRef<Stmt>(op);
+  }
+}
+
+Stmt IndexDataTypeRewriter::VisitStmt_(const BlockNode* op) {
+  Array<Buffer> new_alloc_buffers =
+      op->alloc_buffers.Map([this](const Buffer& buffer) { return 
this->VisitBuffer(buffer); });
+  Array<MatchBufferRegion> new_match_buffers =
+      op->match_buffers.Map([this](const MatchBufferRegion& 
match_buffer_region) {
+        Buffer new_buffer = this->VisitBuffer(match_buffer_region->buffer);
+        BufferRegion new_buffer_region = 
this->VisitBufferRegion(match_buffer_region->source);
+        if (!new_buffer.same_as(match_buffer_region->buffer) ||
+            !new_buffer_region.same_as(match_buffer_region->source)) {
+          return MatchBufferRegion(new_buffer, new_buffer_region);
+        } else {
+          return match_buffer_region;
+        }
+      });
+  Array<BufferRegion> new_reads = op->reads.Map(
+      [this](const BufferRegion& buffer_region) { return 
this->VisitBufferRegion(buffer_region); });
+  Array<BufferRegion> new_writes = op->writes.Map(
+      [this](const BufferRegion& buffer_region) { return 
this->VisitBufferRegion(buffer_region); });
+  Array<IterVar> new_iter_vars =
+      op->iter_vars.Map([this](const IterVar& iter_var) { return 
this->VisitIterVar(iter_var); });
+  Optional<Stmt> new_init = NullOpt;
+  if (op->init.defined()) {
+    new_init = this->VisitStmt(op->init.value());
+  }
+  Stmt new_body = this->VisitStmt(op->body);
+
+  if (!new_init.same_as(op->init) || !new_body.same_as(op->body) ||
+      !new_alloc_buffers.same_as(op->alloc_buffers) ||
+      !new_match_buffers.same_as(op->match_buffers) || 
!new_reads.same_as(op->reads) ||
+      !new_writes.same_as(op->writes) || new_iter_vars.same_as(op->iter_vars)) 
{
+    Block new_block = GetRef<Block>(op);
+    BlockNode* n = new_block.CopyOnWrite();
+    n->alloc_buffers = std::move(new_alloc_buffers);
+    n->match_buffers = std::move(new_match_buffers);
+    n->reads = std::move(new_reads);
+    n->writes = std::move(new_writes);
+    n->iter_vars = std::move(new_iter_vars);
+    n->init = std::move(new_init);
+    n->body = std::move(new_body);
+    return std::move(new_block);
+  }
+  return GetRef<Stmt>(op);
+}
+
+Map<String, ObjectRef> IndexDataTypeRewriter::VisitBlockAnnotations(
+    const Map<String, ObjectRef>& annotations) {
+  auto new_annotations = annotations;
+
+  std::function<ObjectRef(const ObjectRef&)> f_mutate_obj =
+      [this, &f_mutate_obj](const ObjectRef& obj) -> ObjectRef {
+    if (!obj.defined()) {
+      return obj;
+    }
+    if (obj->IsInstance<BufferNode>()) {
+      Buffer buffer = Downcast<Buffer>(obj);
+      if (Buffer new_buffer = GetRemappedBuffer(buffer); 
!new_buffer.same_as(buffer)) {
+        return new_buffer;
+      }
+    } else if (obj->IsInstance<ArrayNode>()) {
+      return Downcast<Array<ObjectRef>>(obj).Map(f_mutate_obj);
+    }
+    return obj;
+  };
+  for (const auto& [key, value] : annotations) {
+    auto new_value = f_mutate_obj(value);
+    if (!new_value.same_as(value)) {
+      new_annotations.Set(key, new_value);
+    }
+  }
+  return new_annotations;
+}
+
+Buffer IndexDataTypeRewriter::GetRemappedBuffer(const Buffer& buffer) {
+  if (auto it = buffer_remap_.find(buffer); it != buffer_remap_.end()) {
+    return (*it).second;
+  }
+  return buffer;
+}
+
+IterVar IndexDataTypeRewriter::VisitIterVar(const IterVar& iter_var) {
+  bool is_enabled = is_enabled_;
+  is_enabled_ = true;
+  Var new_var = Downcast<Var>(VisitExpr(iter_var->var));
+  PrimExpr min = VisitExpr(iter_var->dom->min);
+  PrimExpr extent = VisitExpr(iter_var->dom->extent);
+  is_enabled_ = is_enabled;
+  if (!new_var.same_as(iter_var->var) || !min.same_as(iter_var->dom->min) ||
+      !extent.same_as(iter_var->dom->extent)) {
+    IterVar new_iter_var = iter_var;
+    IterVarNode* n = new_iter_var.CopyOnWrite();
+    n->var = std::move(new_var);
+    n->dom = Range(min, extent);
+    return new_iter_var;
+  }
+  return iter_var;
+}
+
+Buffer IndexDataTypeRewriter::VisitBuffer(const Buffer& buffer) {
+  bool is_enabled = is_enabled_;
+
+  is_enabled_ = true;
+  Array<PrimExpr> new_shape =
+      buffer->shape.Map([&](const PrimExpr& e) { return this->VisitExpr(e); });
+  Array<PrimExpr> new_strides =
+      buffer->strides.Map([&](const PrimExpr& e) { return this->VisitExpr(e); 
});
+  auto new_elem_offset = VisitExpr(buffer->elem_offset);
+  is_enabled_ = is_enabled;
+
+  if (!buffer->shape.same_as(new_shape) || 
!buffer->strides.same_as(new_strides) ||
+      !buffer->elem_offset.same_as(new_elem_offset)) {
+    Buffer new_buffer = buffer;
+    BufferNode* new_buffer_node = new_buffer.CopyOnWrite();
+    new_buffer_node->shape = std::move(new_shape);
+    new_buffer_node->strides = std::move(new_strides);
+    new_buffer_node->elem_offset = std::move(new_elem_offset);
+    buffer_remap_.Set(buffer, new_buffer);
+    return new_buffer;
+  } else {
+    return buffer;
+  }
+}
+
+BufferRegion IndexDataTypeRewriter::VisitBufferRegion(const BufferRegion& 
buffer_region) {
+  Buffer remapped_buffer = GetRemappedBuffer(buffer_region->buffer);
+
+  bool is_enabled = is_enabled_;
+  is_enabled_ = true;
+  auto new_region = buffer_region->region.Map([&](const Range& range) {
+    return Range::FromMinExtent(this->VisitExpr(range->min), 
this->VisitExpr(range->extent));
+  });
+  is_enabled_ = is_enabled;
+
+  if (!remapped_buffer.same_as(buffer_region->buffer) ||
+      !new_region.same_as(buffer_region->region)) {
+    return BufferRegion(remapped_buffer, new_region);
+  } else {
+    return buffer_region;
+  }
+}
+
+Stmt IndexDataTypeRewriter::VisitStmt_(const BufferStoreNode* op) {
+  BufferStore store = GetRef<BufferStore>(op);
+
+  Buffer new_buffer = GetRemappedBuffer(op->buffer);
+  auto value = this->VisitExpr(op->value);
+  auto indices = VisitIndices(op->indices);
+
+  if (!new_buffer.same_as(op->buffer) || !value.same_as(op->value) ||
+      !indices.same_as(op->indices)) {
+    auto writer = store.CopyOnWrite();
+    writer->buffer = new_buffer;
+    writer->value = value;
+    writer->indices = indices;
+  }
+
+  return std::move(store);
+}
+
+PrimExpr IndexDataTypeRewriter::VisitExpr_(const BufferLoadNode* op) {
+  BufferLoad load = GetRef<BufferLoad>(op);
+
+  Buffer new_buffer = GetRemappedBuffer(op->buffer);
+  auto indices = VisitIndices(op->indices);
+
+  if (!new_buffer.same_as(op->buffer) || !indices.same_as(op->indices)) {
+    auto writer = load.CopyOnWrite();
+    writer->indices = indices;
+    writer->buffer = new_buffer;
+  }
+
+  return std::move(load);
+}
+
+Array<PrimExpr> IndexDataTypeRewriter::VisitIndices(Array<PrimExpr> indices) {
+  bool is_enabled = is_enabled_;
+  is_enabled_ = true;
+
+  auto fmutate = [this](const PrimExpr& index) { return 
this->VisitExpr(index); };
+  indices.MutateByApply(fmutate);
+
+  is_enabled_ = is_enabled;
+
+  return indices;
+}
+
+Stmt IndexDataTypeRewriter::VisitStmt_(const IfThenElseNode* op) {
+  bool is_condition = is_condition_;
+  is_condition_ = true;
+  PrimExpr cond = VisitExpr(op->condition);
+  is_condition_ = is_condition;
+
+  Stmt then_case = VisitStmt(op->then_case);
+  Optional<Stmt> else_case =
+      op->else_case.defined() ? 
Optional<Stmt>{VisitStmt(op->else_case.value())} : NullOpt;
+  if (!cond.same_as(op->condition) || !then_case.same_as(op->then_case) ||
+      !else_case.same_as(op->else_case)) {
+    IfThenElse new_stmt = GetRef<IfThenElse>(op);
+    auto* n = new_stmt.CopyOnWrite();
+    n->condition = std::move(cond);
+    n->then_case = std::move(then_case);
+    n->else_case = std::move(else_case);
+    return std::move(new_stmt);
+  }
+  return GetRef<Stmt>(op);
+}
+
+Stmt IndexDataTypeRewriter::VisitStmt_(const ForNode* op) {
+  bool is_enabled = is_enabled_;
+  is_enabled_ = true;
+  Var new_loop_var = Downcast<Var>(VisitExpr(op->loop_var));
+  PrimExpr min = VisitExpr(op->min);
+  PrimExpr extent = VisitExpr(op->extent);
+  is_enabled_ = is_enabled;
+
+  Stmt new_body = VisitStmt(op->body);
+
+  if (!new_loop_var.same_as(op->loop_var) || !min.same_as(op->min) || 
!extent.same_as(op->extent) ||
+      !new_body.same_as(op->body)) {
+    For new_for = GetRef<For>(op);
+    auto* n = new_for.CopyOnWrite();
+    n->loop_var = new_loop_var;
+    n->min = cast(new_loop_var.dtype(), min);
+    n->extent = cast(new_loop_var.dtype(), extent);
+    n->body = new_body;
+    return std::move(new_for);
+  } else {
+    return GetRef<Stmt>(op);
+  }
+}
+
+#define DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC)                     
    \
+  PrimExpr IndexDataTypeRewriter::VisitExpr_(const OP* op) {                   
    \
+    bool is_enabled = is_enabled_;                                             
    \
+    is_enabled_ = is_condition_ && op->a->dtype.is_int() && 
op->b->dtype.is_int(); \
+    auto result = Parent::VisitExpr_(op);                                      
    \
+    is_enabled_ = is_enabled;                                                  
    \
+    return std::move(result);                                                  
    \
+  }
+
+DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(EQNode, operator==);

Review Comment:
   Prefix it with TVM_



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to