This is an automated email from the ASF dual-hosted git repository.
csullivan pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new dfc77eb129 [Unity] Extend RemoveAllUnused to support relax::Expr
(#15807)
dfc77eb129 is described below
commit dfc77eb1299b2120467f5fcb8f5ee06d370d3ff5
Author: Eric Lunderberg <[email protected]>
AuthorDate: Tue Sep 26 18:09:33 2023 -0500
[Unity] Extend RemoveAllUnused to support relax::Expr (#15807)
* [Unity] Extend RemoveAllUnused to support relax::Expr
Prior to this commit, the `RemoveAllUnused` utility could only
apply to `relax::Function` instances. This commit extends the allowed
usage to apply to any `relax::Expr` that contains variable bindings,
such as `relax::SeqExpr`.
* Update unit tests for new behavior
Because `RemoveAllUnused` now handles unused variables in a
non-dataflow binding block, passes that use it as a utility function
needed a few tests updated.
---
include/tvm/relax/analysis.h | 19 +++-
src/relax/analysis/udchain.cc | 7 +-
src/relax/ir/binding_rewrite.cc | 51 ++++-----
src/relax/ir/dataflow_matcher.cc | 2 +-
src/relax/transform/dead_code_elimination.cc | 2 +-
src/relax/transform/fold_constant.cc | 2 +-
src/relax/transform/gradient.cc | 2 +-
tests/python/relax/test_analysis.py | 126 ++++++++++++++++++---
tests/python/relax/test_transform_fold_constant.py | 15 +--
tests/python/relax/test_tuning_api.py | 1 -
10 files changed, 157 insertions(+), 70 deletions(-)
diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h
index f515ba6201..82fb73b1bd 100644
--- a/include/tvm/relax/analysis.h
+++ b/include/tvm/relax/analysis.h
@@ -399,18 +399,25 @@ TVM_DLL Map<Var, Array<Var>> DataflowBlockUseDef(const
DataflowBlock& dfb);
/*!
* \brief Get the use-def chain of variables inside a function.
*
- * \param fn The function to be analyzed.
- * \return A map from variable definitions to a set of uses and variables
needed by return value.
+ * \param expr The expression to be analyzed.
+ *
+ * \return A tuple of variable usage and variable outputs. The first
+ * element is a map from variable definitions to the set of downstream
+ * users of that definition. The second element is a list of
+ * variables whose usage occurs outside of any variable binding,
+ * typically the output body of a relax::Function or a relax::SeqExpr.
*/
-std::pair<Map<Var, Array<Var>>, Array<Var>> FunctionUseDef(const Function& fn);
+std::pair<Map<Var, Array<Var>>, Array<Var>> FunctionUseDef(const Expr& expr);
/*!
* \brief Remove unused statements inside DataflowBlocks.
*
- * \param fn The function to remove unused statements.
- * \return The function that contains no unused statements in DataflowBlock.
+ * \param expr The expression (typically a relax::Function) from which
+ * to remove unused statements.
+ *
+ * \return The updated function with no unused statements in DataflowBlock.
*/
-TVM_DLL Function RemoveAllUnused(const Function fn);
+TVM_DLL Expr RemoveAllUnused(Expr expr);
/*!
* \brief Annotate Op Pattern Kind for PrimFunc, which is used in relax
FuseOps.
diff --git a/src/relax/analysis/udchain.cc b/src/relax/analysis/udchain.cc
index 54c7307b96..fc0ffd3292 100644
--- a/src/relax/analysis/udchain.cc
+++ b/src/relax/analysis/udchain.cc
@@ -40,14 +40,15 @@ class UDChain : public relax::ExprVisitor {
// nullptr users means it is the output of the function.
std::map<const VarNode*, std::set<const VarNode*>> to_users;
- const VarNode* cur_user_;
+ const VarNode* cur_user_{nullptr};
void VisitBinding_(const VarBindingNode* binding) override {
// init
+ auto cache = cur_user_;
cur_user_ = binding->var.get();
this->VisitVarDef(binding->var);
this->VisitExpr(binding->value);
- cur_user_ = nullptr;
+ cur_user_ = cache;
}
void VisitExpr_(const VarNode* op) override {
to_users[op].insert(cur_user_); }
@@ -59,7 +60,7 @@ class UDChain : public relax::ExprVisitor {
};
std::pair<runtime::Map<Var, runtime::Array<Var>>, runtime::Array<Var>>
FunctionUseDef(
- const Function& fn) {
+ const Expr& fn) {
UDChain udchain;
udchain.VisitExpr(fn);
diff --git a/src/relax/ir/binding_rewrite.cc b/src/relax/ir/binding_rewrite.cc
index ae48b6bd69..f5820eca80 100644
--- a/src/relax/ir/binding_rewrite.cc
+++ b/src/relax/ir/binding_rewrite.cc
@@ -240,38 +240,31 @@ class RemoveUnusedVars : public ExprMutator {
RemoveUnusedVars(Map<Var, Array<Var>> users, Array<Var> fn_outputs)
: RemoveUnusedVars(GetUnusedVars(users, fn_outputs)) {}
- BindingBlock VisitBindingBlock_(const BindingBlockNode* block) override {
- builder_->BeginBindingBlock();
- for (Binding binding : block->bindings) {
- bool can_remove = [&]() -> bool {
- if (!unused_vars.count(binding->var)) {
- return false;
- }
- auto var_binding = binding.as<VarBindingNode>();
- if (!var_binding) {
- return false;
- }
- return var_binding->value->IsInstance<FunctionNode>();
- }();
- if (!can_remove) {
- VisitBinding(binding);
- }
+ void VisitBinding_(const VarBindingNode* binding) override {
+ bool can_remove = unused_vars.count(binding->var) &&
+ (in_dataflow_block_ ||
!ContainsImpureCall(binding->value));
+ if (!can_remove) {
+ ExprMutator::VisitBinding_(binding);
}
- return builder_->EndBlock();
}
BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) override {
- auto prev_dfb = GetRef<DataflowBlock>(block);
- builder_->BeginDataflowBlock();
- for (Binding binding : block->bindings) {
- if (!unused_vars.count(binding->var) || binding.as<MatchCastNode>()) {
- VisitBinding(binding);
- }
+ bool capture_output = (block == caught_rewrite.get());
+
+ bool cache = in_dataflow_block_;
+ in_dataflow_block_ = true;
+ BindingBlock output = ExprMutator::VisitBindingBlock_(block);
+ in_dataflow_block_ = cache;
+
+ if (capture_output) {
+ caught_rewrite = Downcast<DataflowBlock>(output);
}
- auto new_dfb = builder_->EndBlock();
- if (caught_rewrite == prev_dfb) caught_rewrite =
Downcast<DataflowBlock>(new_dfb);
- return std::move(new_dfb);
+
+ return std::move(output);
}
+
+ private:
+ bool in_dataflow_block_{false};
};
void DataflowBlockRewriteNode::RemoveUnused(Var unused, bool allow_undef) {
@@ -322,10 +315,10 @@ void DataflowBlockRewriteNode::RemoveAllUnused() {
TVM_REGISTER_GLOBAL("relax.dfb_rewrite_remove_all_unused")
.set_body_typed([](DataflowBlockRewrite rwt) { rwt->RemoveAllUnused(); });
-Function RemoveAllUnused(Function fn) {
- auto [users, outputs] = FunctionUseDef(fn);
+Expr RemoveAllUnused(Expr expr) {
+ auto [users, outputs] = FunctionUseDef(expr);
RemoveUnusedVars remover(users, outputs);
- return Downcast<Function>(remover.VisitExpr_(fn.get()));
+ return remover.VisitExpr(std::move(expr));
}
TVM_REGISTER_GLOBAL("relax.analysis.remove_all_unused").set_body_typed(RemoveAllUnused);
diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc
index 3b17f7bd1d..e85e3c4d51 100644
--- a/src/relax/ir/dataflow_matcher.cc
+++ b/src/relax/ir/dataflow_matcher.cc
@@ -934,7 +934,7 @@ class PatternRewriter : ExprMutator {
params.insert(p.get());
}
PatternRewriter rewriter(pat, rewriter_func, params);
- return RemoveAllUnused(Downcast<Function>(rewriter.VisitExpr(f)));
+ return Downcast<Function>(RemoveAllUnused(rewriter.VisitExpr(f)));
}
void VisitBinding_(const VarBindingNode* binding) final {
diff --git a/src/relax/transform/dead_code_elimination.cc
b/src/relax/transform/dead_code_elimination.cc
index 494665ec71..6d9f25296a 100644
--- a/src/relax/transform/dead_code_elimination.cc
+++ b/src/relax/transform/dead_code_elimination.cc
@@ -137,7 +137,7 @@ IRModule DeadCodeElimination(const IRModule& arg_mod,
Array<runtime::String> ent
IRModule updates;
for (const auto& [gvar, base_func] : mod->functions) {
if (auto opt = base_func.as<Function>()) {
- auto new_func = RemoveAllUnused(opt.value());
+ auto new_func = Downcast<Function>(RemoveAllUnused(opt.value()));
if (!new_func.same_as(base_func)) {
updates->Add(gvar, new_func);
}
diff --git a/src/relax/transform/fold_constant.cc
b/src/relax/transform/fold_constant.cc
index 8a78c98144..d6da79c484 100644
--- a/src/relax/transform/fold_constant.cc
+++ b/src/relax/transform/fold_constant.cc
@@ -34,7 +34,7 @@ class ConstantFolder : public ExprMutator {
public:
static Function Fold(Function func, IRModule ctx_module) {
ConstantFolder folder(std::move(ctx_module));
- func = RemoveAllUnused(Downcast<Function>(folder(func)));
+ func = Downcast<Function>(RemoveAllUnused(folder(func)));
return func;
}
diff --git a/src/relax/transform/gradient.cc b/src/relax/transform/gradient.cc
index e09a6d6232..8560db6a04 100644
--- a/src/relax/transform/gradient.cc
+++ b/src/relax/transform/gradient.cc
@@ -632,7 +632,7 @@ class GradientMutator : private ExprMutator {
new_func = CallTIRWithGradEliminator::Transform(new_func);
if (remove_all_unused) {
- new_func = RemoveAllUnused(new_func);
+ new_func = Downcast<Function>(RemoveAllUnused(new_func));
}
// Step 5.3 mark the transformed function as public
diff --git a/tests/python/relax/test_analysis.py
b/tests/python/relax/test_analysis.py
index d5545a0a56..dfee026206 100644
--- a/tests/python/relax/test_analysis.py
+++ b/tests/python/relax/test_analysis.py
@@ -88,6 +88,13 @@ def test_chained_remove_all_unused():
def test_binding_block_remove_all_unused():
+ """Remove unused dataflow bindings
+
+ Removal of unused bindings may not remove side effects. Since
+ bindings within a dataflow block are guaranteed not to have side
+ effects, they may be removed if unused.
+ """
+
@tvm.script.ir_module
class IdentityUnused:
@R.function
@@ -117,24 +124,49 @@ def test_binding_block_remove_all_unused():
tvm.ir.assert_structural_equal(optimized, GroundTruth["main"])
-def test_binding_block_remove_all_unused_without_dataflow():
- @tvm.script.ir_module
- class IdentityUnused:
- @R.function
- def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
- lv0 = x
- unused0 = R.call_dps_packed("my_sigmoid", (x,), R.Tensor((32, 32),
dtype="float32"))
- unused1 = R.call_dps_packed(
- "my_dps_func", (unused0,), R.Tensor((32, 32), dtype="float32")
- )
- z = R.call_packed("vm.builtin.copy", lv0,
sinfo_args=(R.Tensor((32, 32), "float32")))
- return z
+def test_binding_block_remove_unused_pure_without_dataflow():
+ """Remove unused dataflow bindings
- optimized = remove_all_unused(IdentityUnused["main"])
+ Removal of unused bindings may not remove side effects. Unused
+ bindings whose value is a pure operation
+ (e.g. `R.call_dps_packed`) may be removed, even if outside of a
+ dataflow block.
+ """
- GroundTruth = IdentityUnused
+ @R.function(private=True)
+ def before(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
+ lv0 = x
+ unused0 = R.call_dps_packed("my_sigmoid", (x,), R.Tensor((32, 32),
dtype="float32"))
+ unused1 = R.call_dps_packed("my_dps_func", (unused0,), R.Tensor((32,
32), dtype="float32"))
+ return x
- tvm.ir.assert_structural_equal(optimized, GroundTruth["main"])
+ @R.function(private=True)
+ def expected(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
+ return x
+
+ after = remove_all_unused(before)
+ tvm.ir.assert_structural_equal(expected, after)
+
+
+def test_binding_block_keep_impure_without_dataflow():
+ """Remove unused dataflow bindings
+
+ Removal of unused bindings may not remove side effects. Unused
+ bindings whose value is an impure operation (e.g. `R.call_packed`)
+ may not be removed, as outside of a dataflow block they may
+ contain side effects.
+ """
+
+ @R.function(private=True)
+ def before(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
+ lv0 = x
+ y = R.call_packed("vm.builtin.copy", lv0, sinfo_args=(R.Tensor((32,
32), "float32")))
+ return y
+
+ expected = before
+
+ after = remove_all_unused(before)
+ tvm.ir.assert_structural_equal(expected, after)
def test_binding_block_remove_all_unused_func_without_dataflow():
@@ -226,6 +258,70 @@ def
test_edge_binding_block_fake_unused_remove_all_unused2():
tvm.ir.assert_structural_equal(optimized, IdentityUnused["main"])
+def test_remove_all_unused_from_dataflow_block():
+ """Like test_chained_remove_all_unused, but on a SeqExpr"""
+
+ @R.function
+ def before(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
+ with R.dataflow():
+ lv0 = x
+ unused0 = R.call_dps_packed("my_sigmoid", (x,), R.Tensor((32, 32),
dtype="float32"))
+ unused1 = R.call_dps_packed(
+ "my_dps_func", (unused0,), R.Tensor((32, 32), dtype="float32")
+ )
+ R.output(lv0)
+ return lv0
+
+ @R.function
+ def expected(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
+ with R.dataflow():
+ lv0 = x
+ R.output(lv0)
+ return lv0
+
+ after = remove_all_unused(before.body)
+ tvm.ir.assert_structural_equal(expected.body, after, map_free_vars=True)
+
+
+def test_remove_all_unused_from_binding_block():
+ """Like test_chained_remove_all_unused, but on a SeqExpr"""
+
+ @R.function
+ def before(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
+ lv0 = x
+ unused0 = R.call_dps_packed("my_sigmoid", (x,), R.Tensor((32, 32),
dtype="float32"))
+ unused1 = R.call_dps_packed("my_dps_func", (unused0,), R.Tensor((32,
32), dtype="float32"))
+ return lv0
+
+ @R.function
+ def expected(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
+ lv0 = x
+ return lv0
+
+ after = remove_all_unused(before.body)
+ tvm.ir.assert_structural_equal(expected.body, after, map_free_vars=True)
+
+
+def test_retain_impure_calls_unused_in_binding_block():
+ """An impure call may have side effects, and must be kept"""
+
+ @R.function
+ def before(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
+ lv0 = x
+ unused0 = R.call_packed("my_impure_call", x, sinfo_args=R.Tensor((32,
32), dtype="float32"))
+ unused1 = R.call_dps_packed("my_unused_call", (lv0,), R.Tensor((32,
32), dtype="float32"))
+ return lv0
+
+ @R.function
+ def expected(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
+ lv0 = x
+ unused0 = R.call_packed("my_impure_call", x, sinfo_args=R.Tensor((32,
32), dtype="float32"))
+ return lv0
+
+ after = remove_all_unused(before.body)
+ tvm.ir.assert_structural_equal(expected.body, after, map_free_vars=True)
+
+
def test_name_to_binding_var_shadowing():
@R.function
def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
diff --git a/tests/python/relax/test_transform_fold_constant.py
b/tests/python/relax/test_transform_fold_constant.py
index c2a3bd5092..a4dffba114 100644
--- a/tests/python/relax/test_transform_fold_constant.py
+++ b/tests/python/relax/test_transform_fold_constant.py
@@ -73,7 +73,6 @@ def test_one_fold_addone():
@R.function
def expected(c1: R.Tensor((16, 16), "float32")):
- lv0 = c1
return c1
c0_np = np.arange((16 * 16)).astype("float32").reshape(16, 16)
@@ -104,7 +103,6 @@ def test_one_fold_transpose():
@R.function
def expected(c1: R.Tensor((3, 2), "float32")):
- lv0 = c1
return c1
c0_np = np.arange(2 * 3).astype("float32").reshape(2, 3)
@@ -135,8 +133,6 @@ def test_two_hop_addone():
@R.function
def expected(c1: R.Tensor((2, 2), "float32"), c2: R.Tensor((2, 2),
"float32")):
- lv0 = c1
- lv1 = c2
return c2
c0_np = np.arange((2 * 2)).astype("float32").reshape(2, 2)
@@ -218,7 +214,7 @@ def test_fold_mixed_case():
lv2 = relax.call_tir(cls.sub, (c0, lv1), R.Tensor((16, 16),
dtype="float32"))
# this line can not be folded because x's shape is unknown
lv3 = relax.call_tir(cls.sub, (lv2, x), R.Tensor((16, 16),
dtype="float32"))
- return lv3
+ return (lv0, lv3)
@R.function
def expected(
@@ -226,19 +222,15 @@ def test_fold_mixed_case():
c1: R.Tensor((16, 16), "float32"),
c2: R.Tensor((16, 16), "float32"),
x: R.Tensor("float32", ndim=2),
- ) -> R.Tensor:
+ ):
n, m = T.int64(), T.int64()
cls = Module
x0 = R.match_cast(x, R.Tensor((n, m), "float32"))
# this line cannot be folded because n is unknown
lv0 = relax.call_tir(cls.addone, (c0,), R.Tensor((n, 16),
dtype="float32"))
- # this line can be folded
- lv1 = c1
- # this line can be folded because all inputs are const
- lv2 = c2
# this line can not be folded because x's shape is unknown
lv3 = relax.call_tir(cls.sub, (c2, x), R.Tensor((16, 16),
dtype="float32"))
- return lv3
+ return (lv0, lv3)
c0_np = np.arange((16 * 16)).astype("float32").reshape(16, 16)
c1_np = c0_np + 1
@@ -268,7 +260,6 @@ def test_int32_fold():
@R.function
def expected(c1: R.Tensor((16, 16), "int32")):
- lv0 = c1
return c1
c0_np = np.arange((16 * 16)).astype("int32").reshape(16, 16)
diff --git a/tests/python/relax/test_tuning_api.py
b/tests/python/relax/test_tuning_api.py
index 5c2f165dc3..082c9ce16a 100644
--- a/tests/python/relax/test_tuning_api.py
+++ b/tests/python/relax/test_tuning_api.py
@@ -64,7 +64,6 @@ class TestModule:
# Expected IRModule after transformation.
@R.function
def expected(c1: R.Tensor((16, 16), "int32")):
- lv0 = c1
return c1