This is an automated email from the ASF dual-hosted git repository.

lunderberg pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 6ea16b98e7 [Unity] Delegate DataflowVar visitor to Var by default 
(#15688)
6ea16b98e7 is described below

commit 6ea16b98e72e30e6a73a831e14ad0e0a42a2fd04
Author: Eric Lunderberg <[email protected]>
AuthorDate: Tue Sep 26 15:15:02 2023 -0500

    [Unity] Delegate DataflowVar visitor to Var by default (#15688)
    
    * [Unity] Delegate DataflowVar visitor to Var by default
    
    Prior to this commit, writing a subclass of `relax::ExprVisitor` or
    `relax::ExprMutator` required separate overrides for visiting a
    `relax::DataflowVar` and a `relax::Var`.  In the majority of cases,
    these two types should be treated identically, and failure to handle a
    `DataflowVar` equivalently would be a bug.
    
    This commit updates the `relax::ExprVisitor` and `relax::ExprMutator`
    base classes to visit `DataflowVar` by delegate to the visitor of
    `relax::Var`.  As a result, any derived class that overrides the
    `relax::Var` visitor will also update the behavior for
    `relax::DataflowVar`.  A pass that requires different behavior for
    `relax::Var` and `relax::DataflowVar` can still explicitly override
    both methods in order to provide different behavior.
    
    * Updates to Normalizer, VisitVarDef
    
    * Removed DataflowVar visitor in VarVisitor
    
    It was introduced in https://github.com/apache/tvm/pull/15698, before
    the more general fix implemented in the current PR.
---
 src/relax/analysis/analysis.cc               |  2 --
 src/relax/analysis/struct_info_analysis.cc   |  4 ---
 src/relax/analysis/udchain.cc                |  4 ---
 src/relax/analysis/well_formed.cc            | 10 -------
 src/relax/ir/binding_rewrite.cc              |  5 ----
 src/relax/ir/dataflow_matcher.cc             |  4 ---
 src/relax/ir/expr_functor.cc                 | 39 ++++++++++------------------
 src/relax/transform/canonicalize_bindings.cc |  8 ------
 src/relax/transform/convert_layout.cc        |  2 --
 src/relax/transform/fold_constant.cc         |  9 -------
 src/relax/transform/gradient.cc              |  3 ---
 src/relax/transform/lift_transform_params.cc |  4 ---
 src/relax/transform/to_mixed_precision.cc    |  9 -------
 src/relax/transform/utils.h                  | 14 ----------
 14 files changed, 14 insertions(+), 103 deletions(-)

diff --git a/src/relax/analysis/analysis.cc b/src/relax/analysis/analysis.cc
index 7875a517a1..108fe69372 100644
--- a/src/relax/analysis/analysis.cc
+++ b/src/relax/analysis/analysis.cc
@@ -94,8 +94,6 @@ class VarVisitor : protected ExprVisitor {
 
   void VisitExpr_(const VarNode* var) final { vars_.Insert(GetRef<Var>(var)); }
 
-  void VisitExpr_(const DataflowVarNode* var) final { 
vars_.Insert(GetRef<Var>(var)); }
-
   void VisitExpr_(const FunctionNode* op) final {
     for (const auto& param : op->params) {
       MarkBounded(param);
diff --git a/src/relax/analysis/struct_info_analysis.cc 
b/src/relax/analysis/struct_info_analysis.cc
index ddb3fdb5c1..96e51eede8 100644
--- a/src/relax/analysis/struct_info_analysis.cc
+++ b/src/relax/analysis/struct_info_analysis.cc
@@ -189,10 +189,6 @@ class WellDefinedEraser : public StructInfoMutator,
     }
   }
 
-  Expr VisitExpr_(const DataflowVarNode* var) final {
-    return VisitExpr_(static_cast<const VarNode*>(var));
-  }
-
   Expr VisitExpr_(const VarNode* var) final {
     Optional<Expr> ret;
     if (f_var_map_ != nullptr) {
diff --git a/src/relax/analysis/udchain.cc b/src/relax/analysis/udchain.cc
index 1c49fd581f..54c7307b96 100644
--- a/src/relax/analysis/udchain.cc
+++ b/src/relax/analysis/udchain.cc
@@ -56,10 +56,6 @@ class UDChain : public relax::ExprVisitor {
     cur_user_ = nullptr;
     ExprVisitor::VisitExpr_(op);
   }
-
-  void VisitExpr_(const DataflowVarNode* op) override {
-    VisitExpr_(static_cast<const VarNode*>(op));
-  }
 };
 
 std::pair<runtime::Map<Var, runtime::Array<Var>>, runtime::Array<Var>> 
FunctionUseDef(
diff --git a/src/relax/analysis/well_formed.cc 
b/src/relax/analysis/well_formed.cc
index b37662af85..79135b943a 100644
--- a/src/relax/analysis/well_formed.cc
+++ b/src/relax/analysis/well_formed.cc
@@ -431,16 +431,6 @@ class WellFormedChecker : public relax::ExprVisitor,
     CheckStructInfo(var);
   }
 
-  void VisitVarDef(const Var& var) final {
-    if (const DataflowVarNode* lv_node = var.as<DataflowVarNode>()) {
-      VisitVarDef_(lv_node);
-    } else if (const VarNode* gv_node = var.as<VarNode>()) {
-      VisitVarDef_(gv_node);
-    } else {
-      LOG(FATAL) << "TypeError: Invalid type: " << var->GetTypeKey();
-    }
-  }
-
   void VisitExpr_(const tir::VarNode* op) final {
     tir::Var var = GetRef<tir::Var>(op);
     // default mode, check defined.
diff --git a/src/relax/ir/binding_rewrite.cc b/src/relax/ir/binding_rewrite.cc
index 4eec0310fb..ae48b6bd69 100644
--- a/src/relax/ir/binding_rewrite.cc
+++ b/src/relax/ir/binding_rewrite.cc
@@ -71,10 +71,6 @@ void DataflowBlockRewriteNode::ReplaceAllUses(Var old_var, 
Var new_var) {
       return (op == old_var.get()) ? new_var : GetRef<Expr>(op);
     }
 
-    Expr VisitExpr_(const DataflowVarNode* op) override {
-      return (op == old_var.get()) ? new_var : GetRef<Expr>(op);
-    }
-
     BindingBlock VisitBindingBlock_(const DataflowBlockNode* op) override {
       BindingBlock res = ExprMutator::VisitBindingBlock_(op);
       if (op == to_catch) caught = Downcast<DataflowBlock>(res);
@@ -136,7 +132,6 @@ std::set<const VarNode*> GetUsedVars(Expr val) {
    public:
     std::set<const VarNode*> used_vars;
     void VisitExpr_(const VarNode* op) override { used_vars.insert(op); }
-    void VisitExpr_(const DataflowVarNode* op) override { 
used_vars.insert(op); }
   } uvar{};
   uvar.VisitExpr(val);
   return std::move(uvar.used_vars);
diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc
index ab2ad4fa36..3b17f7bd1d 100644
--- a/src/relax/ir/dataflow_matcher.cc
+++ b/src/relax/ir/dataflow_matcher.cc
@@ -619,10 +619,6 @@ class MatcherUseDefAnalysis : public relax::ExprVisitor {
 
     caller2callees[cur_user_].push_back(op);
   }
-
-  void VisitExpr_(const DataflowVarNode* op) override {
-    VisitExpr_(static_cast<const VarNode*>(op));
-  }
 };
 
 struct PNode {
diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc
index 0174308802..14a704d729 100644
--- a/src/relax/ir/expr_functor.cc
+++ b/src/relax/ir/expr_functor.cc
@@ -140,10 +140,7 @@ void ExprVisitor::VisitExpr_(const VarNode* op) {
 
 // Visit the use-site of a defined DataflowVar
 void ExprVisitor::VisitExpr_(const DataflowVarNode* op) {
-  this->VisitSpan(op->span);
-  if (auto* sinfo = op->struct_info_.as<StructInfoNode>()) {
-    this->VisitExprDepStructInfoField(GetRef<StructInfo>(sinfo));
-  }
+  VisitExpr_(static_cast<const VarNode*>(op));
 }
 
 void ExprVisitor::VisitExpr_(const FunctionNode* op) {
@@ -275,7 +272,9 @@ void ExprVisitor::VisitBindingBlock_(const 
DataflowBlockNode* block) {
   }
 }
 
-void ExprVisitor::VisitVarDef_(const DataflowVarNode* var) { 
this->VisitSpan(var->span); }
+void ExprVisitor::VisitVarDef_(const DataflowVarNode* var) {
+  VisitVarDef_(static_cast<const VarNode*>(var));
+}
 
 void ExprVisitor::VisitVarDef_(const VarNode* var) { 
this->VisitSpan(var->span); }
 
@@ -400,9 +399,7 @@ Expr ExprMutatorBase::VisitExpr_(const VarNode* op) {
 
 // Visit the use-site of a defined DataflowVar
 Expr ExprMutatorBase::VisitExpr_(const DataflowVarNode* op) {
-  // struct info of var-use should remain stable
-  // or the var itself will get replaced
-  return GetRef<Expr>(op);
+  return VisitExpr_(static_cast<const VarNode*>(op));
 }
 
 Expr ExprMutatorBase::VisitExpr_(const FunctionNode* op) {
@@ -565,13 +562,7 @@ Expr ExprMutator::VisitExpr_(const VarNode* op) {
 
 // Visit the use-site of a defined DataflowVar
 Expr ExprMutator::VisitExpr_(const DataflowVarNode* op) {
-  auto it = var_remap_.find(op->vid);
-  if (it != var_remap_.end()) {
-    return it->second;
-  }
-
-  // default case return self.
-  return GetRef<Expr>(op);
+  return VisitExpr_(static_cast<const VarNode*>(op));
 }
 
 Expr ExprMutator::VisitExpr_(const FunctionNode* op) {
@@ -718,16 +709,14 @@ BindingBlock ExprMutator::VisitBindingBlock_(const 
DataflowBlockNode* block) {
 }
 
 Var ExprMutator::VisitVarDef_(const DataflowVarNode* var) {
-  if (auto* sinfo = var->struct_info_.as<StructInfoNode>()) {
-    StructInfo struct_info = 
this->VisitExprDepStructInfoField(GetRef<StructInfo>(sinfo));
-    if (struct_info.same_as(var->struct_info_)) {
-      return GetRef<DataflowVar>(var);
-    } else {
-      return DataflowVar(var->vid, struct_info, var->span);
-    }
-  } else {
-    return GetRef<DataflowVar>(var);
-  }
+  Var output = VisitVarDef_(static_cast<const VarNode*>(var));
+  // Because we delegate from DataflowVar visitor to Var visitor to
+  // provide default behavior in subclasses, we may produce a Var
+  // where we should produce a DataflowVar.
+  if (!output->IsInstance<DataflowVarNode>()) {
+    output = DataflowVar(output->vid, GetStructInfo(output), output->span);
+  }
+  return output;
 }
 
 Var ExprMutator::VisitVarDef_(const VarNode* var) {
diff --git a/src/relax/transform/canonicalize_bindings.cc 
b/src/relax/transform/canonicalize_bindings.cc
index d355c09786..ea5a612e1a 100644
--- a/src/relax/transform/canonicalize_bindings.cc
+++ b/src/relax/transform/canonicalize_bindings.cc
@@ -48,14 +48,6 @@ class BindingCanonicalizer : public ExprMutator {
     return ExprMutator::VisitExpr_(LookupBinding(v).as<VarNode>());
   }
 
-  Expr VisitExpr_(const DataflowVarNode* op) override {
-    Var v = Downcast<Var>(ExprMutator::VisitExpr_(op));
-    if (!CanCanonicalizeVar(v)) {
-      return Downcast<Expr>(v);
-    }
-    return ExprMutator::VisitExpr_(LookupBinding(v).as<VarNode>());
-  }
-
   Expr VisitExpr_(const TupleGetItemNode* tuple_get_item) override {
     if (auto tuple_var = tuple_get_item->tuple.as<Var>()) {
       if (auto tuple_value = LookupBinding(tuple_var.value())) {
diff --git a/src/relax/transform/convert_layout.cc 
b/src/relax/transform/convert_layout.cc
index f91d221b40..6530d0d2cf 100644
--- a/src/relax/transform/convert_layout.cc
+++ b/src/relax/transform/convert_layout.cc
@@ -131,8 +131,6 @@ class LayoutConvertMutator : public ExprMutator {
 
   Expr VisitExpr_(const VarNode* op) final { return 
VisitVars_(GetRef<Var>(op)); }
 
-  Expr VisitExpr_(const DataflowVarNode* op) final { return 
VisitVars_(GetRef<Var>(op)); }
-
   bool HasUnknownDimTensor(const NLayout& nlayout) {
     bool find = false;
     auto fvisit = [&](const LayoutDecision& layout) {
diff --git a/src/relax/transform/fold_constant.cc 
b/src/relax/transform/fold_constant.cc
index a13b7f3d93..8a78c98144 100644
--- a/src/relax/transform/fold_constant.cc
+++ b/src/relax/transform/fold_constant.cc
@@ -305,15 +305,6 @@ class ConstantFolder : public ExprMutator {
     return std::move(post_call);
   }
 
-  Expr VisitExpr_(const DataflowVarNode* op) final {
-    Optional<Expr> opt = LookupBinding(GetRef<Var>(op));
-    // `as` check checks if opt is not null and is instance of constant
-    if (opt.as<relax::ConstantNode>()) {
-      return opt.value();
-    }
-    return ExprMutator::VisitExpr_(op);
-  }
-
   Expr VisitExpr_(const VarNode* op) final {
     Optional<Expr> opt = LookupBinding(GetRef<Var>(op));
     // `as` check checks if opt is not null and is instance of constant
diff --git a/src/relax/transform/gradient.cc b/src/relax/transform/gradient.cc
index 4fd1838670..e09a6d6232 100644
--- a/src/relax/transform/gradient.cc
+++ b/src/relax/transform/gradient.cc
@@ -233,9 +233,6 @@ class CheckpointGenerator : private ExprMutator {
   // Visit the use-site of a defined Var
   Expr VisitExpr_(const VarNode* op) final { return VisitVar(GetRef<Var>(op)); 
}
 
-  // Visit the use-site of a defined DataflowVar
-  Expr VisitExpr_(const DataflowVarNode* op) final { return 
VisitVar(GetRef<Var>(op)); }
-
   Expr VisitVar(const Var& var) {
     auto it = checkpoint_map_.find(var);
     if (it != checkpoint_map_.end()) {
diff --git a/src/relax/transform/lift_transform_params.cc 
b/src/relax/transform/lift_transform_params.cc
index 7cdf03b10d..7201786c37 100644
--- a/src/relax/transform/lift_transform_params.cc
+++ b/src/relax/transform/lift_transform_params.cc
@@ -296,10 +296,6 @@ class TransformParamsLifter : ExprMutator {
     return ExprMutator::VisitExpr_(var);
   }
 
-  Expr VisitExpr_(const DataflowVarNode* var) final {
-    return VisitExpr_(static_cast<const VarNode*>(var));
-  }
-
   // Remap the original parameters to TupleGetItem from the packed tuple of 
transformed parameters.
   std::unordered_map<Var, Expr, ObjectPtrHash, ObjectPtrEqual> param_remap_;
   // The plan of lifting the transform params
diff --git a/src/relax/transform/to_mixed_precision.cc 
b/src/relax/transform/to_mixed_precision.cc
index d12d1080b9..c844d59356 100644
--- a/src/relax/transform/to_mixed_precision.cc
+++ b/src/relax/transform/to_mixed_precision.cc
@@ -190,8 +190,6 @@ class DTypeDecisionCollector : public ExprVisitor {
 
   void VisitExpr_(const VarNode* op) final { VisitVars_(op); }
 
-  void VisitExpr_(const DataflowVarNode* op) final { VisitVars_(op); }
-
   void VisitBinding_(const VarBindingNode* binding, const CallNode* call_node) 
final {
     auto policy = GetMixedPrecisionInfo(call_node);
     if (policy == -1) {
@@ -451,13 +449,6 @@ class ToMixedPrecisionRewriter : public ExprMutator {
 
   Var VisitVarDef(const Var& var) { return GetRemapped(var); }
 
-  Expr VisitExpr_(const DataflowVarNode* op) final {
-    if (!builder_->CurrentBlockIsDataFlow()) {
-      return ExprMutator::VisitExpr_(op);
-    }
-    return VisitVar_(GetRef<Var>(op));
-  }
-
   void VisitBinding(const Binding& binding) {
     ExprMutator::VisitBinding(binding);
     if (!builder_->CurrentBlockIsDataFlow()) return;
diff --git a/src/relax/transform/utils.h b/src/relax/transform/utils.h
index 3d40a565bd..a51c71d788 100644
--- a/src/relax/transform/utils.h
+++ b/src/relax/transform/utils.h
@@ -214,12 +214,6 @@ class VarReplacer : public ExprMutator {
     return it == var_remap_.end() ? var : it->second;
   }
 
-  Expr VisitExpr_(const DataflowVarNode* op) final {
-    Var var = GetRef<Var>(op);
-    auto it = var_remap_.find(var->vid);
-    return it == var_remap_.end() ? var : it->second;
-  }
-
   const VarMap& var_remap_;
 };
 
@@ -296,14 +290,6 @@ class FunctionCopier : public ExprMutator {
     return SymbolicVarRenewMutator::Renew(new_func);
   }
 
-  Var VisitVarDef_(const DataflowVarNode* var) override {
-    Var new_var = ExprMutator::VisitVarDef_(var);
-    Var copied_var = DataflowVar(new_var->name_hint(), GetStructInfo(new_var), 
new_var->span);
-    var_remap_[var->vid] = copied_var;
-    var_map.Set(GetRef<Var>(var), copied_var);
-    return copied_var;
-  }
-
   Var VisitVarDef_(const VarNode* var) override {
     Var new_var = ExprMutator::VisitVarDef_(var);
     Var copied_var = Var(new_var->name_hint(), GetStructInfo(new_var), 
new_var->span);

Reply via email to