junrushao commented on code in PR #13327:
URL: https://github.com/apache/tvm/pull/13327#discussion_r1024384378


##########
src/tir/transforms/narrow_datatype.cc:
##########
@@ -315,65 +265,25 @@ class DataTypeRewriter : public DataTypeLegalizer {
     return Parent::VisitExpr_(op);
   }
 
-  PrimExpr VisitExpr_(const EQNode* op) final;
-  PrimExpr VisitExpr_(const NENode* op) final;
-  PrimExpr VisitExpr_(const LTNode* op) final;
-  PrimExpr VisitExpr_(const LENode* op) final;
-  PrimExpr VisitExpr_(const GTNode* op) final;
-  PrimExpr VisitExpr_(const GENode* op) final;
-  PrimExpr VisitExpr_(const CallNode* op) final;
-
  private:
   // the internal visitor to deduce the narrowed dtype
   DataTypeVisitor visitor_;
   // a map from Var before rewrite to that after rewrite,
   // ensures one old Var maps to exactly one new Var
   std::unordered_map<const VarNode*, Var> vmap_;
-  // indicator of index expr to rewrite
-  bool is_index_{false};
-  // indicator of condition
-  bool is_condition_{false};
 };
 
-#define DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC)                     
     \
-  PrimExpr DataTypeRewriter::VisitExpr_(const OP* op) {                        
     \
-    bool is_index = is_index_;                                                 
     \
-    bool rewrite = is_condition_ && op->a->dtype.is_int() && 
op->b->dtype.is_int(); \
-    if (rewrite) {                                                             
     \
-      is_index_ = true;                                                        
     \
-    }                                                                          
     \
-    auto result = Parent::VisitExpr_(op);                                      
     \
-    is_index_ = is_index;                                                      
     \
-    return std::move(result);                                                  
     \
-  }
-
-DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(EQNode, operator==);
-DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(NENode, operator!=);
-DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(LENode, operator<=);
-DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(LTNode, operator<);  // NOLINT(*)
-DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(GTNode, operator>);  // NOLINT(*)
-DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(GENode, operator>=);
-
-PrimExpr DataTypeRewriter::VisitExpr_(const CallNode* op) {
-  // handle if_then_else condition
-  if (op->op.same_as(builtin::if_then_else())) {
-    bool is_condition = is_condition_;
-    is_condition_ = true;
-    PrimExpr cond = VisitExpr(op->args[0]);
-    is_condition_ = is_condition;
-    return if_then_else(cond, VisitExpr(op->args[1]), VisitExpr(op->args[2]));
-  }
-  return Parent::VisitExpr_(op);
+Stmt NarrowDataType(Stmt stmt, int target_bits) {
+  return NarrowDataTypeRewriter(target_bits)(stmt);
 }
 
-Stmt NarrowDataType(Stmt stmt, int target_bits) { return 
DataTypeRewriter(target_bits)(stmt); }
-
 namespace transform {
 
 Pass NarrowDataType(int target_bits) {
   auto pass_func = [target_bits](PrimFunc f, IRModule m, PassContext ctx) {
     auto* n = f.CopyOnWrite();
-    n->body = DataTypeRewriter(target_bits)(std::move(n->body));
+    n->body = NarrowDataTypeRewriter(target_bits)(std::move(n->body));
+    // LOG(INFO) << "AfterNarrow: " << tir::AsTVMScript(f);

Review Comment:
   Remove this line?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to