This is an automated email from the ASF dual-hosted git repository.
jroesch pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 3c48cad WithFields for Tuples (#9533)
3c48cad is described below
commit 3c48caddd6af8381e8544ca9ab3cb182b3d7396f
Author: Lily Orth-Smith <[email protected]>
AuthorDate: Wed Nov 24 10:22:41 2021 -0800
WithFields for Tuples (#9533)
---
include/tvm/relay/expr.h | 12 ++++++++++++
src/relay/ir/expr.cc | 21 +++++++++++++++++++++
src/relay/ir/expr_functor.cc | 15 +++++----------
src/relay/transforms/annotate_target.cc | 18 ++++++++++--------
src/relay/transforms/device_planner.cc | 3 +--
src/relay/transforms/first_order_gradient.cc | 12 +++++++-----
src/relay/transforms/forward_rewrite.cc | 19 +++++++------------
src/relay/transforms/fuse_ops.cc | 12 ++++++------
src/relay/transforms/memory_alloc.cc | 8 +++++---
src/relay/transforms/partial_eval.cc | 2 ++
src/relay/transforms/partition_graph.cc | 8 ++++----
src/relay/transforms/split_args.cc | 18 ++++++++++--------
src/relay/transforms/to_a_normal_form.cc | 11 ++++++-----
src/relay/transforms/to_cps.cc | 9 +++++----
src/relay/transforms/transform_layout.h | 11 +++++++----
15 files changed, 108 insertions(+), 71 deletions(-)
diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h
index aa34194..8077bbf 100644
--- a/include/tvm/relay/expr.h
+++ b/include/tvm/relay/expr.h
@@ -142,9 +142,21 @@ class Tuple : public Expr {
TVM_DLL explicit Tuple(tvm::Array<relay::Expr> fields, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Tuple, RelayExpr, TupleNode);
+ TVM_DEFINE_OBJECT_REF_COW_METHOD(TupleNode);
};
/*!
+ * \brief Returns the tuple with given properties. A null property denotes 'no
change'.
+ * Returns this if all properties are unchanged. Otherwise, returns a copy
with the new fields.
+ * \param tuple The tuple to copy
+ * \param opt_fields The (optional) fields for the copied tuple. If none,
ret_tuple->fields =
+ * tuple->fields.
+ * \param opt_span The (optional) span for the copied tuple. If none,
ret_tuple->span = tuple->span.
+ */
+Tuple WithFields(Tuple tuple, Optional<Array<Expr>> opt_fields =
Optional<Array<Expr>>(),
+ Optional<Span> opt_span = Optional<Span>(nullptr));
+
+/*!
* \brief Local variables used in the let expression.
*
* Its semantics are similar to tvm.Var node used in TVM's low level
diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc
index c7a81f9..59e8c9e 100644
--- a/src/relay/ir/expr.cc
+++ b/src/relay/ir/expr.cc
@@ -76,6 +76,27 @@ TVM_REGISTER_NODE_TYPE(TupleNode);
TVM_REGISTER_GLOBAL("relay.ir.Tuple").set_body_typed([](tvm::Array<relay::Expr>
fields, Span span) {
return Tuple(fields, span);
});
+Tuple WithFields(Tuple tuple, Optional<Array<Expr>> opt_fields, Optional<Span>
opt_span) {
+ Array<Expr> fields = opt_fields.value_or(tuple->fields);
+ Span span = opt_span.value_or(tuple->span);
+
+ bool all_fields_unchanged = true;
+ if (fields.size() == tuple->fields.size()) {
+ for (size_t i = 0; i < fields.size(); i++) {
+ all_fields_unchanged &= fields[i].same_as(tuple->fields[i]);
+ }
+ } else {
+ all_fields_unchanged = false;
+ }
+
+ all_fields_unchanged = all_fields_unchanged && span.same_as(tuple->span);
+ if (!all_fields_unchanged) {
+ TupleNode* cow_tuple_node = tuple.CopyOnWrite();
+ cow_tuple_node->fields = fields;
+ cow_tuple_node->span = span;
+ }
+ return std::move(tuple);
+}
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TupleNode>([](const ObjectRef& ref, ReprPrinter* p) {
diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc
index e9441f1..08c9b96 100644
--- a/src/relay/ir/expr_functor.cc
+++ b/src/relay/ir/expr_functor.cc
@@ -177,20 +177,15 @@ Expr ExprMutator::VisitExpr_(const GlobalVarNode* op) {
return GetRef<Expr>(op);
Expr ExprMutator::VisitExpr_(const OpNode* op) { return GetRef<Expr>(op); }
-Expr ExprMutator::VisitExpr_(const TupleNode* op) {
+Expr ExprMutator::VisitExpr_(const TupleNode* tuple_node) {
tvm::Array<Expr> fields;
- bool all_fields_unchanged = true;
- for (auto field : op->fields) {
+ fields.reserve(tuple_node->fields.size());
+
+ for (auto field : tuple_node->fields) {
auto new_field = this->Mutate(field);
fields.push_back(new_field);
- all_fields_unchanged &= new_field.same_as(field);
- }
-
- if (all_fields_unchanged) {
- return GetRef<Expr>(op);
- } else {
- return Tuple(fields, op->span);
}
+ return WithFields(GetRef<Tuple>(tuple_node), std::move(fields));
}
Expr ExprMutator::VisitExpr_(const FunctionNode* op) {
diff --git a/src/relay/transforms/annotate_target.cc
b/src/relay/transforms/annotate_target.cc
index b12e25a..df1a858 100644
--- a/src/relay/transforms/annotate_target.cc
+++ b/src/relay/transforms/annotate_target.cc
@@ -266,11 +266,11 @@ class AnnotateTargetRewriter : public ExprRewriter {
virtual std::unique_ptr<Call> RewriteVarCall(const Call& post_call) { return
nullptr; }
- Expr Rewrite_(const TupleNode* op, const Expr& post) override {
- auto expr = Downcast<Tuple>(post);
+ Expr Rewrite_(const TupleNode* tuple_node, const Expr& post) override {
+ auto tuple = Downcast<Tuple>(post);
- auto target_n_args = AnnotateArgs(expr->fields);
- auto new_expr = Tuple(std::get<1>(target_n_args));
+ auto target_n_args = AnnotateArgs(tuple->fields);
+ auto new_expr = WithFields(std::move(tuple),
std::move(std::get<1>(target_n_args)));
op_expr_to_target_[new_expr] = std::get<0>(target_n_args);
return std::move(new_expr);
}
@@ -370,13 +370,15 @@ class CallOpsTargetRewriter : public
AnnotateTargetRewriter {
return new_call;
}
- Expr Rewrite_(const TupleNode* op, const Expr& post) override {
- auto expr = Downcast<Tuple>(post);
+ Expr Rewrite_(const TupleNode* tuple_node, const Expr& post) override {
+ auto tuple = Downcast<Tuple>(post);
Array<Expr> new_fields;
- for (auto f : expr->fields) {
+ new_fields.reserve(tuple->fields.size());
+
+ for (auto f : tuple->fields) {
new_fields.push_back(InsertCompilerEndAndPropogateTarget(f));
}
- return std::move(Tuple(new_fields));
+ return WithFields(std::move(tuple), std::move(new_fields));
}
Expr Rewrite_(const TupleGetItemNode* op, const Expr& post) override {
diff --git a/src/relay/transforms/device_planner.cc
b/src/relay/transforms/device_planner.cc
index d6ab566..afa598b 100644
--- a/src/relay/transforms/device_planner.cc
+++ b/src/relay/transforms/device_planner.cc
@@ -786,8 +786,7 @@ class DeviceCapturer : public ExprMutator {
for (const auto& field : tuple_node->fields) {
fields.push_back(VisitChild(tuple, field));
}
- // TODO(mbs): Avoid copy
- return Tuple(std::move(fields), tuple_node->span);
+ return WithFields(std::move(tuple), std::move(fields));
}
Expr VisitExpr_(const FunctionNode* function_node) final {
diff --git a/src/relay/transforms/first_order_gradient.cc
b/src/relay/transforms/first_order_gradient.cc
index 3419cb6..9408d16 100644
--- a/src/relay/transforms/first_order_gradient.cc
+++ b/src/relay/transforms/first_order_gradient.cc
@@ -195,11 +195,13 @@ struct FirstOrderReverseAD : ExprFunctor<ADValue(const
Expr&)> {
return ret;
}
- ADValue VisitExpr_(const TupleNode* op) final {
- auto tt = Downcast<TupleType>(op->checked_type());
+ ADValue VisitExpr_(const TupleNode* tuple_node) final {
+ auto tt = Downcast<TupleType>(tuple_node->checked_type());
std::vector<ADValue> ad_fields;
- std::vector<Expr> field_bindings;
- for (const auto& f : op->fields) {
+ Array<Expr> field_bindings;
+ field_bindings.reserve(tuple_node->fields.size());
+
+ for (const auto& f : tuple_node->fields) {
ADValue f_ad = VisitExpr(f);
if (!dynamic_cast<ADTensor*>(f_ad.get())) {
diag_ctx.EmitFatal(Diagnostic::Error(f->span)
@@ -209,7 +211,7 @@ struct FirstOrderReverseAD : ExprFunctor<ADValue(const
Expr&)> {
field_bindings.push_back(f_ad->get<ADTensor>().forward);
}
// reconstruct tuple using let-bound variables to avoid duplication
- auto orig = Tuple(field_bindings);
+ auto orig = WithFields(GetRef<Tuple>(tuple_node),
std::move(field_bindings));
orig->checked_type_ = tt;
auto ret = std::make_shared<ADTensor>(ll, orig, diag_ctx);
// for orig = tuple(x1, ..., xn), tuple_grad(x1, ..., xn, G) = [pi(G, 1),
..., pi(G, n)]
diff --git a/src/relay/transforms/forward_rewrite.cc
b/src/relay/transforms/forward_rewrite.cc
index 1212ad7..23c45a9 100644
--- a/src/relay/transforms/forward_rewrite.cc
+++ b/src/relay/transforms/forward_rewrite.cc
@@ -113,21 +113,16 @@ class ForwardRewriter : private MixedModeMutator {
}
}
- Expr Rewrite_(const TupleNode* op, const Expr& post) final {
+ Expr Rewrite_(const TupleNode* tuple_node, const Expr& post) final {
tvm::Array<Expr> fields;
- bool all_fields_unchanged = true;
- const auto* post_node = post.as<TupleNode>();
- for (size_t i = 0; i < op->fields.size(); ++i) {
- auto new_field = this->GetTempExpr(op->fields[i], post_node->fields[i]);
- fields.push_back(new_field);
- all_fields_unchanged &= new_field.same_as(op->fields[i]);
- }
+ fields.reserve(tuple_node->fields.size());
- if (all_fields_unchanged) {
- return GetRef<Expr>(op);
- } else {
- return Tuple(fields);
+ const auto* post_tuple_node = post.as<TupleNode>();
+ for (size_t i = 0; i < tuple_node->fields.size(); ++i) {
+ fields.push_back(this->GetTempExpr(tuple_node->fields[i],
post_tuple_node->fields[i]));
}
+
+ return WithFields(GetRef<Tuple>(tuple_node), std::move(fields));
}
Expr Rewrite_(const CallNode* call_node, const Expr& post) final {
diff --git a/src/relay/transforms/fuse_ops.cc b/src/relay/transforms/fuse_ops.cc
index 960f569..247ae33 100644
--- a/src/relay/transforms/fuse_ops.cc
+++ b/src/relay/transforms/fuse_ops.cc
@@ -898,14 +898,14 @@ class FuseMutator : private MixedModeMutator {
}
}
- Expr Rewrite_(const TupleNode* tuple, const Expr& post) {
- auto* ret_group = gmap_.at(tuple)->FindRoot();
- if (ret_group->root_ref == tuple) {
- return ExprMutator::VisitExpr_(tuple);
+ Expr Rewrite_(const TupleNode* tuple_node, const Expr& post) {
+ auto* ret_group = gmap_.at(tuple_node)->FindRoot();
+ if (ret_group->root_ref == tuple_node) {
+ return ExprMutator::VisitExpr_(tuple_node);
}
// This tuple is an intermediate node in the group
- Array<Expr> new_fields = GetNewArguments(tuple->fields, ret_group);
- return Tuple(new_fields);
+ Array<Expr> new_fields = GetNewArguments(tuple_node->fields, ret_group);
+ return WithFields(GetRef<Tuple>(tuple_node), std::move(new_fields));
}
Expr Rewrite_(const TupleGetItemNode* tuple_get, const Expr& post) {
diff --git a/src/relay/transforms/memory_alloc.cc
b/src/relay/transforms/memory_alloc.cc
index acea12f..ddbc606 100644
--- a/src/relay/transforms/memory_alloc.cc
+++ b/src/relay/transforms/memory_alloc.cc
@@ -84,10 +84,12 @@ class DialectRewriter : public
transform::DeviceAwareExprMutator {
Function Rewrite(const Function& expr) { return
Downcast<Function>(Mutate(expr)); }
private:
- Expr VisitExpr_(const TupleNode* tn) final {
+ Expr VisitExpr_(const TupleNode* tuple_node) final {
LetList& scope = scopes_.back();
Array<Expr> new_fields;
- for (auto field : tn->fields) {
+ new_fields.reserve(tuple_node->fields.size());
+
+ for (auto field : tuple_node->fields) {
auto new_field = Mutate(field);
if (new_field->IsInstance<ConstantNode>()) {
Var const_var("const", Type(nullptr));
@@ -95,7 +97,7 @@ class DialectRewriter : public
transform::DeviceAwareExprMutator {
}
new_fields.push_back(new_field);
}
- return Tuple(new_fields);
+ return WithFields(GetRef<Tuple>(tuple_node), std::move(new_fields));
}
void PreVisitLetBlock_(const LetNode* let_node) final {
scopes_.emplace_back(); }
diff --git a/src/relay/transforms/partial_eval.cc
b/src/relay/transforms/partial_eval.cc
index ccdd9c9..7388d9f 100644
--- a/src/relay/transforms/partial_eval.cc
+++ b/src/relay/transforms/partial_eval.cc
@@ -615,6 +615,8 @@ class PartialEvaluator : public ExprFunctor<PStatic(const
Expr& e, LetList* ll)>
value.push_back(ps);
expr.push_back(ps->dynamic);
}
+ // Note(@electriclilies): The partial evaluator seems to do some weird
stuff with sharing.
+ // Changing Tuple(expr) to WithFields(op, expr) causes some strange
failures.
return HasStatic(MkSTuple(value), ll->Push(Tuple(expr)));
}
diff --git a/src/relay/transforms/partition_graph.cc
b/src/relay/transforms/partition_graph.cc
index 99799fd..4a21bc8 100644
--- a/src/relay/transforms/partition_graph.cc
+++ b/src/relay/transforms/partition_graph.cc
@@ -455,18 +455,18 @@ IRModule FlattenTupleOutputs(IRModule module) {
// Arguments of annotation ops should be 1
ICHECK_EQ(call->args.size(), 1U);
auto annotated_op = Downcast<Call>(post)->args[0];
- if (const auto* tn = annotated_op.as<TupleNode>()) {
+ if (const auto* tuple_node = annotated_op.as<TupleNode>()) {
Array<Expr> new_fields;
+ new_fields.reserve(tuple_node->fields.size());
// Here each input of the tuple will be annotated with compiler_ends
- for (auto& tn_arg : tn->fields) {
+ for (auto& tn_arg : tuple_node->fields) {
new_fields.push_back((*make_end_op)(tn_arg, target));
}
// Return a tuple of compiler_ends in the place of the tuple that was
// annotated with a compiler_end.
- auto out = Tuple(new_fields);
- return std::move(out);
+ return WithFields(GetRef<Tuple>(tuple_node), std::move(new_fields));
}
}
return post;
diff --git a/src/relay/transforms/split_args.cc
b/src/relay/transforms/split_args.cc
index eb647ce..fbb2d73 100644
--- a/src/relay/transforms/split_args.cc
+++ b/src/relay/transforms/split_args.cc
@@ -37,14 +37,14 @@ class ArgumentSplitter : public ExprRewriter {
Expr Rewrite_(const CallNode* call, const Expr& post) final {
if (max_function_args_ < 0) return post;
if (call->op == concat_op_) {
- auto op = call->args[0].as<TupleNode>();
+ auto tuple_node = call->args[0].as<TupleNode>();
const auto param = call->attrs.as<ConcatenateAttrs>();
int outputsNum = 1;
if (const auto* tuple_type = call->checked_type().as<TupleTypeNode>()) {
outputsNum = tuple_type->fields.size();
}
const int limit = max_function_args_ - outputsNum;
- int argsNum = op->fields.size();
+ int argsNum = tuple_node->fields.size();
if (argsNum < limit) return post;
int splitNum = argsNum / limit;
splitNum = (argsNum % limit) ? splitNum + 1 : splitNum;
@@ -54,16 +54,18 @@ class ArgumentSplitter : public ExprRewriter {
int startIdx = i * limit;
int argsCount = std::min(limit, argsNum - startIdx);
tvm::Array<Expr> args;
+ args.reserve(argsCount);
+
for (int j = 0; j < argsCount; ++j) {
- args.push_back(op->fields[j + startIdx]);
+ args.push_back(tuple_node->fields[j + startIdx]);
}
- Tuple tuple(args);
- Expr body = MakeConcatenate(tuple, param->axis);
+ Tuple new_tuple = WithFields(GetRef<Tuple>(tuple_node),
std::move(args));
+ Expr body = MakeConcatenate(new_tuple, param->axis);
splitted[i] = StopFusion(body);
}
- tvm::Array<Expr> tupleArgs(splitted);
- Tuple tuple(tupleArgs);
- return MakeConcatenate(tuple, param->axis);
+ tvm::Array<Expr> tuple_args(splitted);
+ Tuple new_tuple = WithFields(GetRef<Tuple>(tuple_node),
std::move(tuple_args));
+ return MakeConcatenate(new_tuple, param->axis);
}
return post;
}
diff --git a/src/relay/transforms/to_a_normal_form.cc
b/src/relay/transforms/to_a_normal_form.cc
index 0814e73..f958a60 100644
--- a/src/relay/transforms/to_a_normal_form.cc
+++ b/src/relay/transforms/to_a_normal_form.cc
@@ -248,13 +248,14 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)>,
private transform::Lexi
return Compound(e, Call(VisitExpr(c->op), args, c->attrs, c->type_args),
v);
}
- Expr VisitExpr_(const TupleNode* t, const Var& v) final {
- Expr e = GetRef<Expr>(t);
- std::vector<Expr> fields;
- for (const auto& a : t->fields) {
+ Expr VisitExpr_(const TupleNode* tuple_node, const Var& v) final {
+ Expr e = GetRef<Expr>(tuple_node);
+ Array<Expr> fields;
+ fields.reserve(tuple_node->fields.size());
+ for (const auto& a : tuple_node->fields) {
fields.push_back(VisitExpr(a));
}
- return Compound(e, Tuple(fields), v);
+ return Compound(e, WithFields(GetRef<Tuple>(tuple_node),
std::move(fields)), v);
}
Expr VisitExpr_(const TupleGetItemNode* t, const Var& v) final {
diff --git a/src/relay/transforms/to_cps.cc b/src/relay/transforms/to_cps.cc
index b7f9caf..0f889cd 100644
--- a/src/relay/transforms/to_cps.cc
+++ b/src/relay/transforms/to_cps.cc
@@ -210,13 +210,14 @@ Function ToCPS(const Function& f, const IRModule& m,
CPSMap* cm, VarMap* vm,
});
}
- Expr VisitExpr_(const TupleNode* op, const MCont& k) final {
+ Expr VisitExpr_(const TupleNode* tuple_node, const MCont& k) final {
tvm::Array<Expr> fields;
+ fields.reserve(tuple_node->fields.size());
std::function<Expr()> next;
next = [&]() {
- return (fields.size() == op->fields.size())
- ? k(Tuple(fields))
- : VisitExpr(op->fields[fields.size()], [&](const Expr& v) {
+ return (fields.size() == tuple_node->fields.size())
+ ? k(WithFields(GetRef<Tuple>(tuple_node),
std::move(fields)))
+ : VisitExpr(tuple_node->fields[fields.size()], [&](const
Expr& v) {
fields.push_back(v);
return next();
});
diff --git a/src/relay/transforms/transform_layout.h
b/src/relay/transforms/transform_layout.h
index 7bfb31a..56affb5 100644
--- a/src/relay/transforms/transform_layout.h
+++ b/src/relay/transforms/transform_layout.h
@@ -32,6 +32,7 @@
#include <string>
#include <tuple>
#include <unordered_map>
+#include <utility>
#include <vector>
#include "infer_layout_utils.h"
@@ -293,12 +294,13 @@ Expr LayoutRewriter(const Call& ref_call, const
Array<Expr>& new_args, const Obj
// NOTE: do not support nested tuple
if (new_arg->IsInstance<TupleNode>()) {
Tuple tuple_new_arg = Downcast<Tuple>(new_arg);
- std::vector<Expr> fields;
+ Array<Expr> fields;
+ fields.reserve(tuple_new_arg->fields.size());
for (auto x : tuple_new_arg->fields) {
Expr tmp = push_back_one_arg(x);
fields.push_back(tmp);
}
- normal_new_args.push_back(Tuple(fields));
+ normal_new_args.push_back(WithFields(tuple_new_arg, std::move(fields)));
} else {
Expr tmp = push_back_one_arg(new_arg);
normal_new_args.push_back(tmp);
@@ -375,12 +377,13 @@ Expr LayoutRewriter(const Call& ref_call, const
Array<Expr>& new_args, const Obj
for (auto arg : new_call->args) {
if (arg->IsInstance<TupleNode>()) { // unflatten tuple
Tuple tuple_arg = Downcast<Tuple>(arg);
- std::vector<Expr> transformed_tuple_arg;
+ Array<Expr> transformed_tuple_arg;
+ transformed_tuple_arg.reserve(tuple_arg->fields.size());
for (auto arg_item : tuple_arg->fields) {
transformed_tuple_arg.push_back(memorizer.Transform(arg_item,
new_in[pt], new_in2[pt]));
pt++;
}
- transformed_args.push_back(Tuple(transformed_tuple_arg));
+ transformed_args.push_back(WithFields(tuple_arg,
std::move(transformed_tuple_arg)));
} else {
transformed_args.push_back(memorizer.Transform(arg, new_in[pt],
new_in2[pt]));
pt++;