This is an automated email from the ASF dual-hosted git repository.
lunderberg pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 6ea16b98e7 [Unity] Delegate DataflowVar visitor to Var by default
(#15688)
6ea16b98e7 is described below
commit 6ea16b98e72e30e6a73a831e14ad0e0a42a2fd04
Author: Eric Lunderberg <[email protected]>
AuthorDate: Tue Sep 26 15:15:02 2023 -0500
[Unity] Delegate DataflowVar visitor to Var by default (#15688)
* [Unity] Delegate DataflowVar visitor to Var by default
Prior to this commit, writing a subclass of `relax::ExprVisitor` or
`relax::ExprMutator` required separate overrides for visiting a
`relax::DataflowVar` and a `relax::Var`. In the majority of cases,
these two types should be treated identically, and failure to handle a
`DataflowVar` equivalently would be a bug.
This commit updates the `relax::ExprVisitor` and `relax::ExprMutator`
base classes to visit `DataflowVar` by delegate to the visitor of
`relax::Var`. As a result, any derived class that overrides the
`relax::Var` visitor will also update the behavior for
`relax::DataflowVar`. A pass that requires different behavior for
`relax::Var` and `relax::DataflowVar` can still explicitly override
both methods in order to provide different behavior.
* Updates to Normalizer, VisitVarDef
* Removed DataflowVar visitor in VarVisitor
It was introduced in https://github.com/apache/tvm/pull/15698, before
the more general fix implemented in the current PR.
---
src/relax/analysis/analysis.cc | 2 --
src/relax/analysis/struct_info_analysis.cc | 4 ---
src/relax/analysis/udchain.cc | 4 ---
src/relax/analysis/well_formed.cc | 10 -------
src/relax/ir/binding_rewrite.cc | 5 ----
src/relax/ir/dataflow_matcher.cc | 4 ---
src/relax/ir/expr_functor.cc | 39 ++++++++++------------------
src/relax/transform/canonicalize_bindings.cc | 8 ------
src/relax/transform/convert_layout.cc | 2 --
src/relax/transform/fold_constant.cc | 9 -------
src/relax/transform/gradient.cc | 3 ---
src/relax/transform/lift_transform_params.cc | 4 ---
src/relax/transform/to_mixed_precision.cc | 9 -------
src/relax/transform/utils.h | 14 ----------
14 files changed, 14 insertions(+), 103 deletions(-)
diff --git a/src/relax/analysis/analysis.cc b/src/relax/analysis/analysis.cc
index 7875a517a1..108fe69372 100644
--- a/src/relax/analysis/analysis.cc
+++ b/src/relax/analysis/analysis.cc
@@ -94,8 +94,6 @@ class VarVisitor : protected ExprVisitor {
void VisitExpr_(const VarNode* var) final { vars_.Insert(GetRef<Var>(var)); }
- void VisitExpr_(const DataflowVarNode* var) final {
vars_.Insert(GetRef<Var>(var)); }
-
void VisitExpr_(const FunctionNode* op) final {
for (const auto& param : op->params) {
MarkBounded(param);
diff --git a/src/relax/analysis/struct_info_analysis.cc
b/src/relax/analysis/struct_info_analysis.cc
index ddb3fdb5c1..96e51eede8 100644
--- a/src/relax/analysis/struct_info_analysis.cc
+++ b/src/relax/analysis/struct_info_analysis.cc
@@ -189,10 +189,6 @@ class WellDefinedEraser : public StructInfoMutator,
}
}
- Expr VisitExpr_(const DataflowVarNode* var) final {
- return VisitExpr_(static_cast<const VarNode*>(var));
- }
-
Expr VisitExpr_(const VarNode* var) final {
Optional<Expr> ret;
if (f_var_map_ != nullptr) {
diff --git a/src/relax/analysis/udchain.cc b/src/relax/analysis/udchain.cc
index 1c49fd581f..54c7307b96 100644
--- a/src/relax/analysis/udchain.cc
+++ b/src/relax/analysis/udchain.cc
@@ -56,10 +56,6 @@ class UDChain : public relax::ExprVisitor {
cur_user_ = nullptr;
ExprVisitor::VisitExpr_(op);
}
-
- void VisitExpr_(const DataflowVarNode* op) override {
- VisitExpr_(static_cast<const VarNode*>(op));
- }
};
std::pair<runtime::Map<Var, runtime::Array<Var>>, runtime::Array<Var>>
FunctionUseDef(
diff --git a/src/relax/analysis/well_formed.cc
b/src/relax/analysis/well_formed.cc
index b37662af85..79135b943a 100644
--- a/src/relax/analysis/well_formed.cc
+++ b/src/relax/analysis/well_formed.cc
@@ -431,16 +431,6 @@ class WellFormedChecker : public relax::ExprVisitor,
CheckStructInfo(var);
}
- void VisitVarDef(const Var& var) final {
- if (const DataflowVarNode* lv_node = var.as<DataflowVarNode>()) {
- VisitVarDef_(lv_node);
- } else if (const VarNode* gv_node = var.as<VarNode>()) {
- VisitVarDef_(gv_node);
- } else {
- LOG(FATAL) << "TypeError: Invalid type: " << var->GetTypeKey();
- }
- }
-
void VisitExpr_(const tir::VarNode* op) final {
tir::Var var = GetRef<tir::Var>(op);
// default mode, check defined.
diff --git a/src/relax/ir/binding_rewrite.cc b/src/relax/ir/binding_rewrite.cc
index 4eec0310fb..ae48b6bd69 100644
--- a/src/relax/ir/binding_rewrite.cc
+++ b/src/relax/ir/binding_rewrite.cc
@@ -71,10 +71,6 @@ void DataflowBlockRewriteNode::ReplaceAllUses(Var old_var,
Var new_var) {
return (op == old_var.get()) ? new_var : GetRef<Expr>(op);
}
- Expr VisitExpr_(const DataflowVarNode* op) override {
- return (op == old_var.get()) ? new_var : GetRef<Expr>(op);
- }
-
BindingBlock VisitBindingBlock_(const DataflowBlockNode* op) override {
BindingBlock res = ExprMutator::VisitBindingBlock_(op);
if (op == to_catch) caught = Downcast<DataflowBlock>(res);
@@ -136,7 +132,6 @@ std::set<const VarNode*> GetUsedVars(Expr val) {
public:
std::set<const VarNode*> used_vars;
void VisitExpr_(const VarNode* op) override { used_vars.insert(op); }
- void VisitExpr_(const DataflowVarNode* op) override {
used_vars.insert(op); }
} uvar{};
uvar.VisitExpr(val);
return std::move(uvar.used_vars);
diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc
index ab2ad4fa36..3b17f7bd1d 100644
--- a/src/relax/ir/dataflow_matcher.cc
+++ b/src/relax/ir/dataflow_matcher.cc
@@ -619,10 +619,6 @@ class MatcherUseDefAnalysis : public relax::ExprVisitor {
caller2callees[cur_user_].push_back(op);
}
-
- void VisitExpr_(const DataflowVarNode* op) override {
- VisitExpr_(static_cast<const VarNode*>(op));
- }
};
struct PNode {
diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc
index 0174308802..14a704d729 100644
--- a/src/relax/ir/expr_functor.cc
+++ b/src/relax/ir/expr_functor.cc
@@ -140,10 +140,7 @@ void ExprVisitor::VisitExpr_(const VarNode* op) {
// Visit the use-site of a defined DataflowVar
void ExprVisitor::VisitExpr_(const DataflowVarNode* op) {
- this->VisitSpan(op->span);
- if (auto* sinfo = op->struct_info_.as<StructInfoNode>()) {
- this->VisitExprDepStructInfoField(GetRef<StructInfo>(sinfo));
- }
+ VisitExpr_(static_cast<const VarNode*>(op));
}
void ExprVisitor::VisitExpr_(const FunctionNode* op) {
@@ -275,7 +272,9 @@ void ExprVisitor::VisitBindingBlock_(const
DataflowBlockNode* block) {
}
}
-void ExprVisitor::VisitVarDef_(const DataflowVarNode* var) {
this->VisitSpan(var->span); }
+void ExprVisitor::VisitVarDef_(const DataflowVarNode* var) {
+ VisitVarDef_(static_cast<const VarNode*>(var));
+}
void ExprVisitor::VisitVarDef_(const VarNode* var) {
this->VisitSpan(var->span); }
@@ -400,9 +399,7 @@ Expr ExprMutatorBase::VisitExpr_(const VarNode* op) {
// Visit the use-site of a defined DataflowVar
Expr ExprMutatorBase::VisitExpr_(const DataflowVarNode* op) {
- // struct info of var-use should remain stable
- // or the var itself will get replaced
- return GetRef<Expr>(op);
+ return VisitExpr_(static_cast<const VarNode*>(op));
}
Expr ExprMutatorBase::VisitExpr_(const FunctionNode* op) {
@@ -565,13 +562,7 @@ Expr ExprMutator::VisitExpr_(const VarNode* op) {
// Visit the use-site of a defined DataflowVar
Expr ExprMutator::VisitExpr_(const DataflowVarNode* op) {
- auto it = var_remap_.find(op->vid);
- if (it != var_remap_.end()) {
- return it->second;
- }
-
- // default case return self.
- return GetRef<Expr>(op);
+ return VisitExpr_(static_cast<const VarNode*>(op));
}
Expr ExprMutator::VisitExpr_(const FunctionNode* op) {
@@ -718,16 +709,14 @@ BindingBlock ExprMutator::VisitBindingBlock_(const
DataflowBlockNode* block) {
}
Var ExprMutator::VisitVarDef_(const DataflowVarNode* var) {
- if (auto* sinfo = var->struct_info_.as<StructInfoNode>()) {
- StructInfo struct_info =
this->VisitExprDepStructInfoField(GetRef<StructInfo>(sinfo));
- if (struct_info.same_as(var->struct_info_)) {
- return GetRef<DataflowVar>(var);
- } else {
- return DataflowVar(var->vid, struct_info, var->span);
- }
- } else {
- return GetRef<DataflowVar>(var);
- }
+ Var output = VisitVarDef_(static_cast<const VarNode*>(var));
+ // Because we delegate from DataflowVar visitor to Var visitor to
+ // provide default behavior in subclasses, we may produce a Var
+ // where we should produce a DataflowVar.
+ if (!output->IsInstance<DataflowVarNode>()) {
+ output = DataflowVar(output->vid, GetStructInfo(output), output->span);
+ }
+ return output;
}
Var ExprMutator::VisitVarDef_(const VarNode* var) {
diff --git a/src/relax/transform/canonicalize_bindings.cc
b/src/relax/transform/canonicalize_bindings.cc
index d355c09786..ea5a612e1a 100644
--- a/src/relax/transform/canonicalize_bindings.cc
+++ b/src/relax/transform/canonicalize_bindings.cc
@@ -48,14 +48,6 @@ class BindingCanonicalizer : public ExprMutator {
return ExprMutator::VisitExpr_(LookupBinding(v).as<VarNode>());
}
- Expr VisitExpr_(const DataflowVarNode* op) override {
- Var v = Downcast<Var>(ExprMutator::VisitExpr_(op));
- if (!CanCanonicalizeVar(v)) {
- return Downcast<Expr>(v);
- }
- return ExprMutator::VisitExpr_(LookupBinding(v).as<VarNode>());
- }
-
Expr VisitExpr_(const TupleGetItemNode* tuple_get_item) override {
if (auto tuple_var = tuple_get_item->tuple.as<Var>()) {
if (auto tuple_value = LookupBinding(tuple_var.value())) {
diff --git a/src/relax/transform/convert_layout.cc
b/src/relax/transform/convert_layout.cc
index f91d221b40..6530d0d2cf 100644
--- a/src/relax/transform/convert_layout.cc
+++ b/src/relax/transform/convert_layout.cc
@@ -131,8 +131,6 @@ class LayoutConvertMutator : public ExprMutator {
Expr VisitExpr_(const VarNode* op) final { return
VisitVars_(GetRef<Var>(op)); }
- Expr VisitExpr_(const DataflowVarNode* op) final { return
VisitVars_(GetRef<Var>(op)); }
-
bool HasUnknownDimTensor(const NLayout& nlayout) {
bool find = false;
auto fvisit = [&](const LayoutDecision& layout) {
diff --git a/src/relax/transform/fold_constant.cc
b/src/relax/transform/fold_constant.cc
index a13b7f3d93..8a78c98144 100644
--- a/src/relax/transform/fold_constant.cc
+++ b/src/relax/transform/fold_constant.cc
@@ -305,15 +305,6 @@ class ConstantFolder : public ExprMutator {
return std::move(post_call);
}
- Expr VisitExpr_(const DataflowVarNode* op) final {
- Optional<Expr> opt = LookupBinding(GetRef<Var>(op));
- // `as` check checks if opt is not null and is instance of constant
- if (opt.as<relax::ConstantNode>()) {
- return opt.value();
- }
- return ExprMutator::VisitExpr_(op);
- }
-
Expr VisitExpr_(const VarNode* op) final {
Optional<Expr> opt = LookupBinding(GetRef<Var>(op));
// `as` check checks if opt is not null and is instance of constant
diff --git a/src/relax/transform/gradient.cc b/src/relax/transform/gradient.cc
index 4fd1838670..e09a6d6232 100644
--- a/src/relax/transform/gradient.cc
+++ b/src/relax/transform/gradient.cc
@@ -233,9 +233,6 @@ class CheckpointGenerator : private ExprMutator {
// Visit the use-site of a defined Var
Expr VisitExpr_(const VarNode* op) final { return VisitVar(GetRef<Var>(op));
}
- // Visit the use-site of a defined DataflowVar
- Expr VisitExpr_(const DataflowVarNode* op) final { return
VisitVar(GetRef<Var>(op)); }
-
Expr VisitVar(const Var& var) {
auto it = checkpoint_map_.find(var);
if (it != checkpoint_map_.end()) {
diff --git a/src/relax/transform/lift_transform_params.cc
b/src/relax/transform/lift_transform_params.cc
index 7cdf03b10d..7201786c37 100644
--- a/src/relax/transform/lift_transform_params.cc
+++ b/src/relax/transform/lift_transform_params.cc
@@ -296,10 +296,6 @@ class TransformParamsLifter : ExprMutator {
return ExprMutator::VisitExpr_(var);
}
- Expr VisitExpr_(const DataflowVarNode* var) final {
- return VisitExpr_(static_cast<const VarNode*>(var));
- }
-
// Remap the original parameters to TupleGetItem from the packed tuple of
transformed parameters.
std::unordered_map<Var, Expr, ObjectPtrHash, ObjectPtrEqual> param_remap_;
// The plan of lifting the transform params
diff --git a/src/relax/transform/to_mixed_precision.cc
b/src/relax/transform/to_mixed_precision.cc
index d12d1080b9..c844d59356 100644
--- a/src/relax/transform/to_mixed_precision.cc
+++ b/src/relax/transform/to_mixed_precision.cc
@@ -190,8 +190,6 @@ class DTypeDecisionCollector : public ExprVisitor {
void VisitExpr_(const VarNode* op) final { VisitVars_(op); }
- void VisitExpr_(const DataflowVarNode* op) final { VisitVars_(op); }
-
void VisitBinding_(const VarBindingNode* binding, const CallNode* call_node)
final {
auto policy = GetMixedPrecisionInfo(call_node);
if (policy == -1) {
@@ -451,13 +449,6 @@ class ToMixedPrecisionRewriter : public ExprMutator {
Var VisitVarDef(const Var& var) { return GetRemapped(var); }
- Expr VisitExpr_(const DataflowVarNode* op) final {
- if (!builder_->CurrentBlockIsDataFlow()) {
- return ExprMutator::VisitExpr_(op);
- }
- return VisitVar_(GetRef<Var>(op));
- }
-
void VisitBinding(const Binding& binding) {
ExprMutator::VisitBinding(binding);
if (!builder_->CurrentBlockIsDataFlow()) return;
diff --git a/src/relax/transform/utils.h b/src/relax/transform/utils.h
index 3d40a565bd..a51c71d788 100644
--- a/src/relax/transform/utils.h
+++ b/src/relax/transform/utils.h
@@ -214,12 +214,6 @@ class VarReplacer : public ExprMutator {
return it == var_remap_.end() ? var : it->second;
}
- Expr VisitExpr_(const DataflowVarNode* op) final {
- Var var = GetRef<Var>(op);
- auto it = var_remap_.find(var->vid);
- return it == var_remap_.end() ? var : it->second;
- }
-
const VarMap& var_remap_;
};
@@ -296,14 +290,6 @@ class FunctionCopier : public ExprMutator {
return SymbolicVarRenewMutator::Renew(new_func);
}
- Var VisitVarDef_(const DataflowVarNode* var) override {
- Var new_var = ExprMutator::VisitVarDef_(var);
- Var copied_var = DataflowVar(new_var->name_hint(), GetStructInfo(new_var),
new_var->span);
- var_remap_[var->vid] = copied_var;
- var_map.Set(GetRef<Var>(var), copied_var);
- return copied_var;
- }
-
Var VisitVarDef_(const VarNode* var) override {
Var new_var = ExprMutator::VisitVarDef_(var);
Var copied_var = Var(new_var->name_hint(), GetStructInfo(new_var),
new_var->span);