This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 9ddd39be56 [Unity][UX] Symbolic Variables Used in Multiple Functions
(#14606)
9ddd39be56 is described below
commit 9ddd39be56bcd62905269d6386b44e9989639550
Author: Chaofan Lin <[email protected]>
AuthorDate: Fri Apr 14 01:37:41 2023 +0800
[Unity][UX] Symbolic Variables Used in Multiple Functions (#14606)
Prior to this PR, there is no constraint to prevent user defining multiple
functions which may use the same symbolic TIR var. For example, user may write
the following script:
```
batch_size = tir.Var("batch_size", "int64")
@I.ir_module
class Test:
@R.function
def main(
x: R.Tensor((batch_size, 10), "float32"),
):
with R.dataflow():
lv = R.sum(x, axis=1)
lv1 = R.mean(x, axis=1)
out = R.add(lv, lv1)
R.output(out)
return out
@R.function
def main1(
x: R.Tensor((batch_size, 10), "float32"),
y: R.Tensor((batch_size, 10), "float32"),
):
with R.dataflow():
out = R.subtract(x, y)
R.output(out)
return out
lowered_mod = LegalizeOps()(Test)
ex = relax.build(lowered_mod, "llvm")
vm = relax.VirtualMachine(ex, tvm.cpu(0))
```
But this script can not be built because `VMShapeLower` can not handle this
case well.
Since it is a illegal behaviour, we need to prevent user from doing this.
Specifically, this PR contains two parts of work:
- Let well form checker to check this case.
- Let `CopyWithNewVars` util copies the symbolic vars in the struct info
inside the function.
---
src/relax/analysis/well_formed.cc | 57 ++++++++++++++++++-----
src/relax/transform/fuse_ops.cc | 40 +---------------
src/relax/transform/utils.h | 62 +++++++++++++++++++++++++
src/relax/utils.cc | 11 ++++-
tests/python/relax/test_analysis_well_formed.py | 14 ++++++
tests/python/relax/test_utils.py | 15 ++++++
6 files changed, 147 insertions(+), 52 deletions(-)
diff --git a/src/relax/analysis/well_formed.cc
b/src/relax/analysis/well_formed.cc
index 3eeefd0be5..aeae975bf5 100644
--- a/src/relax/analysis/well_formed.cc
+++ b/src/relax/analysis/well_formed.cc
@@ -30,16 +30,17 @@
* 3. When a Function has a corresponding GlobalVar and a `global_symbol`
* attribute, the name of the GlobalVar must equal the value of the
* `global_symbol` attribute value.
- * 4. Any variable cannot used as different function parameters in the same
IRModule
- * 5. Vars are defined before use.
- * 6. Vars are defined exactly once.
- * 7. Symbolic Vars are defined before use.
- * 8. DataflowVars cannot be defined inside BindingBlock.
- * 9. Vars defined in IfNode, except the return Var, are invisible
+ * 4. Any variable cannot used as different function parameters in the same
IRModule.
+ * 5. Any symbolic var cannot present across different functions in the
same IRModule.
+ * 6. Vars are defined before use.
+ * 7. Vars are defined exactly once.
+ * 8. Symbolic Vars are defined before use.
+ * 9. DataflowVars cannot be defined inside BindingBlock.
+ * 10. Vars defined in IfNode, except the return Var, are invisible
* out of the If body.(May change for new AST designs)
- * 10. SeqExpr only serves as function body, or in the true and
+ * 11. SeqExpr only serves as function body, or in the true and
* false branches in IfNode.
- * 11. The IR is in ANF:
+ * 12. The IR is in ANF:
* (a) Expressions cannot contain nested complex expressions.
* Here are the expressions that may be nested inside other
expressions:
* Var, DataflowVar, GlobalVar, Constant, ShapeExpr,
@@ -54,7 +55,7 @@
* * The cond field of If nodes
* * The op or args fields of Call nodes
* * Inside the fields of Tuple nodes
- * 12. Expr always has checked_type_ (with the exception of Op).
+ * 13. Expr always has checked_type_ (with the exception of Op).
*/
#include <tvm/relax/analysis.h>
#include <tvm/relax/expr.h>
@@ -92,7 +93,7 @@ class WellFormedChecker : public relax::ExprVisitor,
private:
explicit WellFormedChecker(IRModule mod, bool check_struct_info)
- : mod_(std::move(mod)), check_struct_info_(check_struct_info) {}
+ : mod_(std::move(mod)), check_struct_info_(check_struct_info),
cur_visited_func_(nullptr) {}
using relax::ExprVisitor::VisitExpr_;
using tir::ExprVisitor::VisitExpr;
@@ -196,6 +197,12 @@ class WellFormedChecker : public relax::ExprVisitor,
}
void VisitExpr_(const FunctionNode* op) final {
+ // set current visited function.
+ // for nested functions, we only set the outermost function.
+ if (cur_visited_func_ == nullptr) {
+ cur_visited_func_ = op;
+ }
+
// save the var_set_ for local function
auto prev_var_set = var_set_;
auto prev_dataflow_var_set = dataflow_var_set_;
@@ -223,7 +230,7 @@ class WellFormedChecker : public relax::ExprVisitor,
<< "Relax variable " << param->name_hint()
<< " is repeatedly used as parameters in function.");
}
- param_var_func_map_.insert({param, GetRef<Function>(op)});
+ param_var_func_map_.insert({param, cur_visited_func_});
}
// check function ret_struct_info
if (op->ret_struct_info.defined()) {
@@ -242,11 +249,18 @@ class WellFormedChecker : public relax::ExprVisitor,
dataflow_var_set_ = prev_dataflow_var_set;
var_set_ = prev_var_set;
symbolic_var_set_ = prev_symbolic_var_set;
+
+ if (cur_visited_func_ == op) {
+ cur_visited_func_ = nullptr;
+ }
}
void VisitExpr_(const CallNode* op) final {
if (IsLeafOrTuple(op->op)) {
+ const FunctionNode* prev_visited_func = cur_visited_func_;
+ cur_visited_func_ = nullptr; // close the symbolic var dup check
this->VisitExpr(op->op);
+ cur_visited_func_ = prev_visited_func;
} else {
Malformed(Diagnostic::Error(op) << "The called expression must be a leaf
expression");
}
@@ -400,6 +414,21 @@ class WellFormedChecker : public relax::ExprVisitor,
this->Malformed(Diagnostic::Error(var)
<< "Symbolic Var " << var->name_hint << " is not
defined.");
}
+
+ // don't perform the check
+ if (cur_visited_func_ == nullptr) {
+ return;
+ }
+
+ // check across functions presence
+ auto it = symbolic_var_func_map_.find(var);
+ if (it != symbolic_var_func_map_.end() && it->second != cur_visited_func_)
{
+ // TODO(relax-team): Complete this error info after we integrate printer
+ Malformed(Diagnostic::Error(var->span)
+ << "Symbolic Var " << var->name_hint
+ << " presents in different functions in the same Module.");
+ }
+ symbolic_var_func_map_.insert({var, cur_visited_func_});
}
void VisitStructInfo_(const FuncStructInfoNode* op) final {
@@ -473,6 +502,8 @@ class WellFormedChecker : public relax::ExprVisitor,
const bool check_struct_info_;
bool well_formed_ = true;
bool is_dataflow_;
+ // Current visited function.
+ const FunctionNode* cur_visited_func_;
// Current visit mode.
VisitMode mode_ = VisitMode::kDefault;
// set of context variables.
@@ -480,7 +511,9 @@ class WellFormedChecker : public relax::ExprVisitor,
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> recur_vars_;
std::unordered_set<DataflowVar, ObjectPtrHash, ObjectPtrEqual>
dataflow_var_set_;
std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual>
symbolic_var_set_;
- std::unordered_map<Var, Function, ObjectPtrHash, ObjectPtrEqual>
param_var_func_map_;
+ std::unordered_map<Var, const FunctionNode*, ObjectPtrHash, ObjectPtrEqual>
param_var_func_map_;
+ std::unordered_map<tir::Var, const FunctionNode*, ObjectPtrHash,
ObjectPtrEqual>
+ symbolic_var_func_map_;
};
bool WellFormed(IRModule m, bool check_struct_info) {
diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
index a49ae86267..adce61f4b8 100644
--- a/src/relax/transform/fuse_ops.cc
+++ b/src/relax/transform/fuse_ops.cc
@@ -42,6 +42,7 @@
#include "../../relay/analysis/graph_partitioner.h"
#include "../../support/arena.h"
#include "tvm/relax/expr.h"
+#include "utils.h"
namespace tvm {
namespace relax {
@@ -345,45 +346,6 @@ class GraphCreator : public ExprVisitor {
std::unordered_set<IndexedForwardGraph::Node*> initialized_nodes_;
};
-/*!
- * \brief Renew the definition of symbolic vars in Relax.
- * \details This mutator is used to prevent the same symbolic var from being
used in different
- * functions, which is malformed.
- */
-class SymbolicVarRenewMutator : public ExprMutator, tir::ExprMutator {
- public:
- static Function Renew(const Function& function) {
- SymbolicVarRenewMutator mutator;
- return Downcast<Function>(mutator.VisitExpr(function));
- }
-
- private:
- SymbolicVarRenewMutator() = default;
- using relax::ExprMutator::VisitExpr;
- using relax::ExprMutator::VisitExpr_;
- using tir::ExprMutator::VisitExpr_;
-
- PrimExpr VisitPrimExpr(const PrimExpr& expr) final { return
tir::ExprMutator::VisitExpr(expr); }
-
- // TODO(Siyuan): enhance the method to the following steps:
- // 1. Visit and replace all tir::Vars at the definition point
- // 2. Revisit the function again and update the use side.
- PrimExpr VisitExpr_(const tir::VarNode* op) final {
- auto it = var_map_.find(GetRef<tir::Var>(op));
- if (it != var_map_.end()) {
- return (*it).second;
- } else {
- auto n = make_object<tir::VarNode>(*op);
- tir::Var v(n);
- var_map_.Set(GetRef<tir::Var>(op), v);
- return v;
- }
- }
-
- private:
- Map<tir::Var, tir::Var> var_map_;
-};
-
/*!
* \brief The ExprMutator used to create a new grouped function
* \details The workflow of this ExprMutator is:
diff --git a/src/relax/transform/utils.h b/src/relax/transform/utils.h
index 9334fd8347..5363bca68b 100644
--- a/src/relax/transform/utils.h
+++ b/src/relax/transform/utils.h
@@ -28,6 +28,7 @@
#include <tvm/ir/module.h>
#include <tvm/relax/expr.h>
#include <tvm/relax/expr_functor.h>
+#include <tvm/tir/expr_functor.h>
#include <algorithm>
#include <string>
@@ -222,6 +223,67 @@ class VarReplacer : public ExprMutator {
const VarMap& var_remap_;
};
+/*!
+ * \brief Renew the definition of symbolic vars in Relax.
+ * \details This mutator is used to prevent the same symbolic var from being
used in different
+ * functions, which is malformed.
+ */
+class SymbolicVarRenewMutator : public ExprMutator, tir::ExprMutator {
+ public:
+ static Function Renew(const Function& function) {
+ SymbolicVarRenewMutator mutator;
+ return Downcast<Function>(mutator.VisitExpr(function));
+ }
+
+ private:
+ SymbolicVarRenewMutator() = default;
+ using relax::ExprMutator::VisitExpr;
+ using relax::ExprMutator::VisitExpr_;
+ using tir::ExprMutator::VisitExpr_;
+
+ PrimExpr VisitPrimExpr(const PrimExpr& expr) final { return
tir::ExprMutator::VisitExpr(expr); }
+
+ // TODO(Siyuan): enhance the method to the following steps:
+ // 1. Visit and replace all tir::Vars at the definition point
+ // 2. Revisit the function again and update the use side.
+ PrimExpr VisitExpr_(const tir::VarNode* op) final {
+ auto it = var_map_.find(GetRef<tir::Var>(op));
+ if (it != var_map_.end()) {
+ return (*it).second;
+ } else {
+ auto n = make_object<tir::VarNode>(*op);
+ tir::Var v(n);
+ var_map_.Set(GetRef<tir::Var>(op), v);
+ return v;
+ }
+ }
+
+ Expr VisitExpr_(const FunctionNode* op) {
+ tvm::Array<Var> params;
+ bool all_params_unchanged = true;
+ for (Var param : op->params) {
+ Var new_param = this->VisitVarDef(param);
+ params.push_back(new_param);
+ if (!param.same_as(new_param)) {
+ var_remap_[param->vid] = new_param;
+ all_params_unchanged = false;
+ }
+ }
+
+ Expr body = this->VisitWithNewScope(op->body, params);
+
+ if (all_params_unchanged && body.same_as(op->body)) {
+ return GetRef<Expr>(op);
+ } else {
+ auto new_ret_sinfo =
this->VisitExprDepStructInfoField(op->ret_struct_info);
+ return Function(params, body, new_ret_sinfo, op->attrs);
+ }
+ }
+
+ private:
+ Map<tir::Var, tir::Var> var_map_;
+};
+
/*!
* \brief Create a Constant with a scalar
*
diff --git a/src/relax/utils.cc b/src/relax/utils.cc
index cf1d9bed98..131ed6c7d0 100644
--- a/src/relax/utils.cc
+++ b/src/relax/utils.cc
@@ -17,6 +17,8 @@
* under the License.
*/
+#include "transform/utils.h"
+
#include <tvm/relax/expr_functor.h>
namespace tvm {
@@ -109,6 +111,7 @@ bool IsLeafOrTuple(const Expr& expr) {
expr.as<OpNode>() || expr.as<TupleNode>();
}
+/*! \brief Helper to implement CopyWithNewVars.*/
class FunctionCopier : public ExprMutator {
public:
static Function Transform(Function func) {
@@ -116,7 +119,8 @@ class FunctionCopier : public ExprMutator {
// All variables that are bound inside the original function would be
copied
// to satisfy the restriction in the well-formed check: Variables in Relax
// must be bound exactly once.
- return Downcast<Function>(copier.VisitExpr(func));
+ auto new_func = Downcast<Function>(copier.VisitExpr(func));
+ return SymbolicVarRenewMutator::Renew(new_func);
}
Var VisitVarDef_(const DataflowVarNode* var) override {
@@ -134,6 +138,11 @@ class FunctionCopier : public ExprMutator {
}
};
+/*!
+ * \brief Copy a new Relax function with new remapped vars and symbolic vars.
+ * \param func The Relax function we want to copy.
+ * \return The copied function.
+ */
Function CopyWithNewVars(Function func) { return
FunctionCopier::Transform(func); }
TVM_REGISTER_GLOBAL("relax.CopyWithNewVars").set_body_typed(CopyWithNewVars);
diff --git a/tests/python/relax/test_analysis_well_formed.py
b/tests/python/relax/test_analysis_well_formed.py
index b4b68504a4..97f076dc6c 100644
--- a/tests/python/relax/test_analysis_well_formed.py
+++ b/tests/python/relax/test_analysis_well_formed.py
@@ -142,6 +142,20 @@ def test_symbolic_var():
assert not rx.analysis.well_formed(mod, check_struct_info=False)
+def test_symbolic_var_across_functions():
+ # Error: Symbolic Var s presents across different functions
+ s = tir.Var("s", "int64")
+ v0 = rx.Var("v0", R.Tensor([5, s], "float32"))
+ v1 = rx.Var("v1", R.Tensor([s, 7], "float32"))
+ bb = rx.BlockBuilder()
+ with bb.function("func1", [v0]):
+ bb.emit_func_output(v0)
+ with bb.function("func2", [v1]):
+ bb.emit_func_output(v1)
+ mod = bb.get()
+ assert not rx.analysis.well_formed(mod, check_struct_info=False)
+
+
def test_symbolic_var_invalid_type():
with pytest.raises(
tvm.TVMError, match="the value in ShapeStructInfo can only have dtype
of int64"
diff --git a/tests/python/relax/test_utils.py b/tests/python/relax/test_utils.py
index 15122dab37..c55876a3ba 100644
--- a/tests/python/relax/test_utils.py
+++ b/tests/python/relax/test_utils.py
@@ -36,6 +36,21 @@ def test_copy_with_new_vars():
assert before_var != after_var
+def test_copy_with_new_vars_copied_symbolic_vars():
+ @R.function
+ def before(x: R.Tensor(("m",), "float32"), y: R.Tensor(("m",), "float32")):
+ gv = R.add(x, y)
+ return gv
+
+ after = relax.utils.copy_with_new_vars(before)
+ assert_structural_equal(after, before)
+
+ assert len(after.params) == len(before.params)
+ for before_var, after_var in zip(before.params, after.params):
+ assert before_var != after_var
+ assert before_var.struct_info.shape[0] !=
after_var.struct_info.shape[0]
+
+
def test_copy_with_new_vars_on_ir_module():
@tvm.script.ir_module
class Actual: