This is an automated email from the ASF dual-hosted git repository.

jroesch pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 3c48cad  WithFields for Tuples (#9533)
3c48cad is described below

commit 3c48caddd6af8381e8544ca9ab3cb182b3d7396f
Author: Lily Orth-Smith <lilyorthsm...@gmail.com>
AuthorDate: Wed Nov 24 10:22:41 2021 -0800

    WithFields for Tuples (#9533)
---
 include/tvm/relay/expr.h                     | 12 ++++++++++++
 src/relay/ir/expr.cc                         | 21 +++++++++++++++++++++
 src/relay/ir/expr_functor.cc                 | 15 +++++----------
 src/relay/transforms/annotate_target.cc      | 18 ++++++++++--------
 src/relay/transforms/device_planner.cc       |  3 +--
 src/relay/transforms/first_order_gradient.cc | 12 +++++++-----
 src/relay/transforms/forward_rewrite.cc      | 19 +++++++------------
 src/relay/transforms/fuse_ops.cc             | 12 ++++++------
 src/relay/transforms/memory_alloc.cc         |  8 +++++---
 src/relay/transforms/partial_eval.cc         |  2 ++
 src/relay/transforms/partition_graph.cc      |  8 ++++----
 src/relay/transforms/split_args.cc           | 18 ++++++++++--------
 src/relay/transforms/to_a_normal_form.cc     | 11 ++++++-----
 src/relay/transforms/to_cps.cc               |  9 +++++----
 src/relay/transforms/transform_layout.h      | 11 +++++++----
 15 files changed, 108 insertions(+), 71 deletions(-)

diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h
index aa34194..8077bbf 100644
--- a/include/tvm/relay/expr.h
+++ b/include/tvm/relay/expr.h
@@ -142,9 +142,21 @@ class Tuple : public Expr {
   TVM_DLL explicit Tuple(tvm::Array<relay::Expr> fields, Span span = Span());
 
   TVM_DEFINE_OBJECT_REF_METHODS(Tuple, RelayExpr, TupleNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(TupleNode);
 };
 
 /*!
+ * \brief Returns the tuple with given properties. A null property denotes 'no 
change'.
+ * Returns this if all properties are unchanged. Otherwise, returns a copy 
with the new fields.
+ * \param tuple The tuple to copy
+ * \param opt_fields The (optional) fields for the copied tuple. If none, 
ret_tuple->fields =
+ * tuple->fields.
+ * \param opt_span The (optional) span for the copied tuple. If none, 
ret_tuple->span = tuple->span.
+ */
+Tuple WithFields(Tuple tuple, Optional<Array<Expr>> opt_fields = 
Optional<Array<Expr>>(),
+                 Optional<Span> opt_span = Optional<Span>(nullptr));
+
+/*!
  * \brief Local variables used in the let expression.
  *
  * Its semantics are similar to tvm.Var node used in TVM's low level
diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc
index c7a81f9..59e8c9e 100644
--- a/src/relay/ir/expr.cc
+++ b/src/relay/ir/expr.cc
@@ -76,6 +76,27 @@ TVM_REGISTER_NODE_TYPE(TupleNode);
 
TVM_REGISTER_GLOBAL("relay.ir.Tuple").set_body_typed([](tvm::Array<relay::Expr> 
fields, Span span) {
   return Tuple(fields, span);
 });
+Tuple WithFields(Tuple tuple, Optional<Array<Expr>> opt_fields, Optional<Span> 
opt_span) {
+  Array<Expr> fields = opt_fields.value_or(tuple->fields);
+  Span span = opt_span.value_or(tuple->span);
+
+  bool all_fields_unchanged = true;
+  if (fields.size() == tuple->fields.size()) {
+    for (size_t i = 0; i < fields.size(); i++) {
+      all_fields_unchanged &= fields[i].same_as(tuple->fields[i]);
+    }
+  } else {
+    all_fields_unchanged = false;
+  }
+
+  all_fields_unchanged = all_fields_unchanged && span.same_as(tuple->span);
+  if (!all_fields_unchanged) {
+    TupleNode* cow_tuple_node = tuple.CopyOnWrite();
+    cow_tuple_node->fields = fields;
+    cow_tuple_node->span = span;
+  }
+  return std::move(tuple);
+}
 
 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
     .set_dispatch<TupleNode>([](const ObjectRef& ref, ReprPrinter* p) {
diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc
index e9441f1..08c9b96 100644
--- a/src/relay/ir/expr_functor.cc
+++ b/src/relay/ir/expr_functor.cc
@@ -177,20 +177,15 @@ Expr ExprMutator::VisitExpr_(const GlobalVarNode* op) { 
return GetRef<Expr>(op);
 
 Expr ExprMutator::VisitExpr_(const OpNode* op) { return GetRef<Expr>(op); }
 
-Expr ExprMutator::VisitExpr_(const TupleNode* op) {
+Expr ExprMutator::VisitExpr_(const TupleNode* tuple_node) {
   tvm::Array<Expr> fields;
-  bool all_fields_unchanged = true;
-  for (auto field : op->fields) {
+  fields.reserve(tuple_node->fields.size());
+
+  for (auto field : tuple_node->fields) {
     auto new_field = this->Mutate(field);
     fields.push_back(new_field);
-    all_fields_unchanged &= new_field.same_as(field);
-  }
-
-  if (all_fields_unchanged) {
-    return GetRef<Expr>(op);
-  } else {
-    return Tuple(fields, op->span);
   }
+  return WithFields(GetRef<Tuple>(tuple_node), std::move(fields));
 }
 
 Expr ExprMutator::VisitExpr_(const FunctionNode* op) {
diff --git a/src/relay/transforms/annotate_target.cc 
b/src/relay/transforms/annotate_target.cc
index b12e25a..df1a858 100644
--- a/src/relay/transforms/annotate_target.cc
+++ b/src/relay/transforms/annotate_target.cc
@@ -266,11 +266,11 @@ class AnnotateTargetRewriter : public ExprRewriter {
 
   virtual std::unique_ptr<Call> RewriteVarCall(const Call& post_call) { return 
nullptr; }
 
-  Expr Rewrite_(const TupleNode* op, const Expr& post) override {
-    auto expr = Downcast<Tuple>(post);
+  Expr Rewrite_(const TupleNode* tuple_node, const Expr& post) override {
+    auto tuple = Downcast<Tuple>(post);
 
-    auto target_n_args = AnnotateArgs(expr->fields);
-    auto new_expr = Tuple(std::get<1>(target_n_args));
+    auto target_n_args = AnnotateArgs(tuple->fields);
+    auto new_expr = WithFields(std::move(tuple), 
std::move(std::get<1>(target_n_args)));
     op_expr_to_target_[new_expr] = std::get<0>(target_n_args);
     return std::move(new_expr);
   }
@@ -370,13 +370,15 @@ class CallOpsTargetRewriter : public 
AnnotateTargetRewriter {
     return new_call;
   }
 
-  Expr Rewrite_(const TupleNode* op, const Expr& post) override {
-    auto expr = Downcast<Tuple>(post);
+  Expr Rewrite_(const TupleNode* tuple_node, const Expr& post) override {
+    auto tuple = Downcast<Tuple>(post);
     Array<Expr> new_fields;
-    for (auto f : expr->fields) {
+    new_fields.reserve(tuple->fields.size());
+
+    for (auto f : tuple->fields) {
       new_fields.push_back(InsertCompilerEndAndPropogateTarget(f));
     }
-    return std::move(Tuple(new_fields));
+    return WithFields(std::move(tuple), std::move(new_fields));
   }
 
   Expr Rewrite_(const TupleGetItemNode* op, const Expr& post) override {
diff --git a/src/relay/transforms/device_planner.cc 
b/src/relay/transforms/device_planner.cc
index d6ab566..afa598b 100644
--- a/src/relay/transforms/device_planner.cc
+++ b/src/relay/transforms/device_planner.cc
@@ -786,8 +786,7 @@ class DeviceCapturer : public ExprMutator {
     for (const auto& field : tuple_node->fields) {
       fields.push_back(VisitChild(tuple, field));
     }
-    // TODO(mbs): Avoid copy
-    return Tuple(std::move(fields), tuple_node->span);
+    return WithFields(std::move(tuple), std::move(fields));
   }
 
   Expr VisitExpr_(const FunctionNode* function_node) final {
diff --git a/src/relay/transforms/first_order_gradient.cc 
b/src/relay/transforms/first_order_gradient.cc
index 3419cb6..9408d16 100644
--- a/src/relay/transforms/first_order_gradient.cc
+++ b/src/relay/transforms/first_order_gradient.cc
@@ -195,11 +195,13 @@ struct FirstOrderReverseAD : ExprFunctor<ADValue(const 
Expr&)> {
     return ret;
   }
 
-  ADValue VisitExpr_(const TupleNode* op) final {
-    auto tt = Downcast<TupleType>(op->checked_type());
+  ADValue VisitExpr_(const TupleNode* tuple_node) final {
+    auto tt = Downcast<TupleType>(tuple_node->checked_type());
     std::vector<ADValue> ad_fields;
-    std::vector<Expr> field_bindings;
-    for (const auto& f : op->fields) {
+    Array<Expr> field_bindings;
+    field_bindings.reserve(tuple_node->fields.size());
+
+    for (const auto& f : tuple_node->fields) {
       ADValue f_ad = VisitExpr(f);
       if (!dynamic_cast<ADTensor*>(f_ad.get())) {
         diag_ctx.EmitFatal(Diagnostic::Error(f->span)
@@ -209,7 +211,7 @@ struct FirstOrderReverseAD : ExprFunctor<ADValue(const 
Expr&)> {
       field_bindings.push_back(f_ad->get<ADTensor>().forward);
     }
     // reconstruct tuple using let-bound variables to avoid duplication
-    auto orig = Tuple(field_bindings);
+    auto orig = WithFields(GetRef<Tuple>(tuple_node), 
std::move(field_bindings));
     orig->checked_type_ = tt;
     auto ret = std::make_shared<ADTensor>(ll, orig, diag_ctx);
     // for orig = tuple(x1, ..., xn), tuple_grad(x1, ..., xn, G) = [pi(G, 1), 
..., pi(G, n)]
diff --git a/src/relay/transforms/forward_rewrite.cc 
b/src/relay/transforms/forward_rewrite.cc
index 1212ad7..23c45a9 100644
--- a/src/relay/transforms/forward_rewrite.cc
+++ b/src/relay/transforms/forward_rewrite.cc
@@ -113,21 +113,16 @@ class ForwardRewriter : private MixedModeMutator {
     }
   }
 
-  Expr Rewrite_(const TupleNode* op, const Expr& post) final {
+  Expr Rewrite_(const TupleNode* tuple_node, const Expr& post) final {
     tvm::Array<Expr> fields;
-    bool all_fields_unchanged = true;
-    const auto* post_node = post.as<TupleNode>();
-    for (size_t i = 0; i < op->fields.size(); ++i) {
-      auto new_field = this->GetTempExpr(op->fields[i], post_node->fields[i]);
-      fields.push_back(new_field);
-      all_fields_unchanged &= new_field.same_as(op->fields[i]);
-    }
+    fields.reserve(tuple_node->fields.size());
 
-    if (all_fields_unchanged) {
-      return GetRef<Expr>(op);
-    } else {
-      return Tuple(fields);
+    const auto* post_tuple_node = post.as<TupleNode>();
+    for (size_t i = 0; i < tuple_node->fields.size(); ++i) {
+      fields.push_back(this->GetTempExpr(tuple_node->fields[i], 
post_tuple_node->fields[i]));
     }
+
+    return WithFields(GetRef<Tuple>(tuple_node), std::move(fields));
   }
 
   Expr Rewrite_(const CallNode* call_node, const Expr& post) final {
diff --git a/src/relay/transforms/fuse_ops.cc b/src/relay/transforms/fuse_ops.cc
index 960f569..247ae33 100644
--- a/src/relay/transforms/fuse_ops.cc
+++ b/src/relay/transforms/fuse_ops.cc
@@ -898,14 +898,14 @@ class FuseMutator : private MixedModeMutator {
     }
   }
 
-  Expr Rewrite_(const TupleNode* tuple, const Expr& post) {
-    auto* ret_group = gmap_.at(tuple)->FindRoot();
-    if (ret_group->root_ref == tuple) {
-      return ExprMutator::VisitExpr_(tuple);
+  Expr Rewrite_(const TupleNode* tuple_node, const Expr& post) {
+    auto* ret_group = gmap_.at(tuple_node)->FindRoot();
+    if (ret_group->root_ref == tuple_node) {
+      return ExprMutator::VisitExpr_(tuple_node);
     }
     // This tuple is an intermediate node in the group
-    Array<Expr> new_fields = GetNewArguments(tuple->fields, ret_group);
-    return Tuple(new_fields);
+    Array<Expr> new_fields = GetNewArguments(tuple_node->fields, ret_group);
+    return WithFields(GetRef<Tuple>(tuple_node), std::move(new_fields));
   }
 
   Expr Rewrite_(const TupleGetItemNode* tuple_get, const Expr& post) {
diff --git a/src/relay/transforms/memory_alloc.cc 
b/src/relay/transforms/memory_alloc.cc
index acea12f..ddbc606 100644
--- a/src/relay/transforms/memory_alloc.cc
+++ b/src/relay/transforms/memory_alloc.cc
@@ -84,10 +84,12 @@ class DialectRewriter : public 
transform::DeviceAwareExprMutator {
   Function Rewrite(const Function& expr) { return 
Downcast<Function>(Mutate(expr)); }
 
  private:
-  Expr VisitExpr_(const TupleNode* tn) final {
+  Expr VisitExpr_(const TupleNode* tuple_node) final {
     LetList& scope = scopes_.back();
     Array<Expr> new_fields;
-    for (auto field : tn->fields) {
+    new_fields.reserve(tuple_node->fields.size());
+
+    for (auto field : tuple_node->fields) {
       auto new_field = Mutate(field);
       if (new_field->IsInstance<ConstantNode>()) {
         Var const_var("const", Type(nullptr));
@@ -95,7 +97,7 @@ class DialectRewriter : public 
transform::DeviceAwareExprMutator {
       }
       new_fields.push_back(new_field);
     }
-    return Tuple(new_fields);
+    return WithFields(GetRef<Tuple>(tuple_node), std::move(new_fields));
   }
 
   void PreVisitLetBlock_(const LetNode* let_node) final { 
scopes_.emplace_back(); }
diff --git a/src/relay/transforms/partial_eval.cc 
b/src/relay/transforms/partial_eval.cc
index ccdd9c9..7388d9f 100644
--- a/src/relay/transforms/partial_eval.cc
+++ b/src/relay/transforms/partial_eval.cc
@@ -615,6 +615,8 @@ class PartialEvaluator : public ExprFunctor<PStatic(const 
Expr& e, LetList* ll)>
       value.push_back(ps);
       expr.push_back(ps->dynamic);
     }
+    // Note(@electriclilies): The partial evaluator seems to do some weird 
stuff with sharing.
+    // Changing Tuple(expr) to WithFields(op, expr) causes some strange 
failures.
     return HasStatic(MkSTuple(value), ll->Push(Tuple(expr)));
   }
 
diff --git a/src/relay/transforms/partition_graph.cc 
b/src/relay/transforms/partition_graph.cc
index 99799fd..4a21bc8 100644
--- a/src/relay/transforms/partition_graph.cc
+++ b/src/relay/transforms/partition_graph.cc
@@ -455,18 +455,18 @@ IRModule FlattenTupleOutputs(IRModule module) {
         // Arguments of annotation ops should be 1
         ICHECK_EQ(call->args.size(), 1U);
         auto annotated_op = Downcast<Call>(post)->args[0];
-        if (const auto* tn = annotated_op.as<TupleNode>()) {
+        if (const auto* tuple_node = annotated_op.as<TupleNode>()) {
           Array<Expr> new_fields;
+          new_fields.reserve(tuple_node->fields.size());
 
           // Here each input of the tuple will be annotated with compiler_ends
-          for (auto& tn_arg : tn->fields) {
+          for (auto& tn_arg : tuple_node->fields) {
             new_fields.push_back((*make_end_op)(tn_arg, target));
           }
 
           // Return a tuple of compiler_ends in the place of the tuple that was
           // annotated with a compiler_end.
-          auto out = Tuple(new_fields);
-          return std::move(out);
+          return WithFields(GetRef<Tuple>(tuple_node), std::move(new_fields));
         }
       }
       return post;
diff --git a/src/relay/transforms/split_args.cc 
b/src/relay/transforms/split_args.cc
index eb647ce..fbb2d73 100644
--- a/src/relay/transforms/split_args.cc
+++ b/src/relay/transforms/split_args.cc
@@ -37,14 +37,14 @@ class ArgumentSplitter : public ExprRewriter {
   Expr Rewrite_(const CallNode* call, const Expr& post) final {
     if (max_function_args_ < 0) return post;
     if (call->op == concat_op_) {
-      auto op = call->args[0].as<TupleNode>();
+      auto tuple_node = call->args[0].as<TupleNode>();
       const auto param = call->attrs.as<ConcatenateAttrs>();
       int outputsNum = 1;
       if (const auto* tuple_type = call->checked_type().as<TupleTypeNode>()) {
         outputsNum = tuple_type->fields.size();
       }
       const int limit = max_function_args_ - outputsNum;
-      int argsNum = op->fields.size();
+      int argsNum = tuple_node->fields.size();
       if (argsNum < limit) return post;
       int splitNum = argsNum / limit;
       splitNum = (argsNum % limit) ? splitNum + 1 : splitNum;
@@ -54,16 +54,18 @@ class ArgumentSplitter : public ExprRewriter {
         int startIdx = i * limit;
         int argsCount = std::min(limit, argsNum - startIdx);
         tvm::Array<Expr> args;
+        args.reserve(argsCount);
+
         for (int j = 0; j < argsCount; ++j) {
-          args.push_back(op->fields[j + startIdx]);
+          args.push_back(tuple_node->fields[j + startIdx]);
         }
-        Tuple tuple(args);
-        Expr body = MakeConcatenate(tuple, param->axis);
+        Tuple new_tuple = WithFields(GetRef<Tuple>(tuple_node), 
std::move(args));
+        Expr body = MakeConcatenate(new_tuple, param->axis);
         splitted[i] = StopFusion(body);
       }
-      tvm::Array<Expr> tupleArgs(splitted);
-      Tuple tuple(tupleArgs);
-      return MakeConcatenate(tuple, param->axis);
+      tvm::Array<Expr> tuple_args(splitted);
+      Tuple new_tuple = WithFields(GetRef<Tuple>(tuple_node), 
std::move(tuple_args));
+      return MakeConcatenate(new_tuple, param->axis);
     }
     return post;
   }
diff --git a/src/relay/transforms/to_a_normal_form.cc 
b/src/relay/transforms/to_a_normal_form.cc
index 0814e73..f958a60 100644
--- a/src/relay/transforms/to_a_normal_form.cc
+++ b/src/relay/transforms/to_a_normal_form.cc
@@ -248,13 +248,14 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)>, 
private transform::Lexi
     return Compound(e, Call(VisitExpr(c->op), args, c->attrs, c->type_args), 
v);
   }
 
-  Expr VisitExpr_(const TupleNode* t, const Var& v) final {
-    Expr e = GetRef<Expr>(t);
-    std::vector<Expr> fields;
-    for (const auto& a : t->fields) {
+  Expr VisitExpr_(const TupleNode* tuple_node, const Var& v) final {
+    Expr e = GetRef<Expr>(tuple_node);
+    Array<Expr> fields;
+    fields.reserve(tuple_node->fields.size());
+    for (const auto& a : tuple_node->fields) {
       fields.push_back(VisitExpr(a));
     }
-    return Compound(e, Tuple(fields), v);
+    return Compound(e, WithFields(GetRef<Tuple>(tuple_node), 
std::move(fields)), v);
   }
 
   Expr VisitExpr_(const TupleGetItemNode* t, const Var& v) final {
diff --git a/src/relay/transforms/to_cps.cc b/src/relay/transforms/to_cps.cc
index b7f9caf..0f889cd 100644
--- a/src/relay/transforms/to_cps.cc
+++ b/src/relay/transforms/to_cps.cc
@@ -210,13 +210,14 @@ Function ToCPS(const Function& f, const IRModule& m, 
CPSMap* cm, VarMap* vm,
       });
     }
 
-    Expr VisitExpr_(const TupleNode* op, const MCont& k) final {
+    Expr VisitExpr_(const TupleNode* tuple_node, const MCont& k) final {
       tvm::Array<Expr> fields;
+      fields.reserve(tuple_node->fields.size());
       std::function<Expr()> next;
       next = [&]() {
-        return (fields.size() == op->fields.size())
-                   ? k(Tuple(fields))
-                   : VisitExpr(op->fields[fields.size()], [&](const Expr& v) {
+        return (fields.size() == tuple_node->fields.size())
+                   ? k(WithFields(GetRef<Tuple>(tuple_node), 
std::move(fields)))
+                   : VisitExpr(tuple_node->fields[fields.size()], [&](const 
Expr& v) {
                        fields.push_back(v);
                        return next();
                      });
diff --git a/src/relay/transforms/transform_layout.h 
b/src/relay/transforms/transform_layout.h
index 7bfb31a..56affb5 100644
--- a/src/relay/transforms/transform_layout.h
+++ b/src/relay/transforms/transform_layout.h
@@ -32,6 +32,7 @@
 #include <string>
 #include <tuple>
 #include <unordered_map>
+#include <utility>
 #include <vector>
 
 #include "infer_layout_utils.h"
@@ -293,12 +294,13 @@ Expr LayoutRewriter(const Call& ref_call, const 
Array<Expr>& new_args, const Obj
     // NOTE: do not support nested tuple
     if (new_arg->IsInstance<TupleNode>()) {
       Tuple tuple_new_arg = Downcast<Tuple>(new_arg);
-      std::vector<Expr> fields;
+      Array<Expr> fields;
+      fields.reserve(tuple_new_arg->fields.size());
       for (auto x : tuple_new_arg->fields) {
         Expr tmp = push_back_one_arg(x);
         fields.push_back(tmp);
       }
-      normal_new_args.push_back(Tuple(fields));
+      normal_new_args.push_back(WithFields(tuple_new_arg, std::move(fields)));
     } else {
       Expr tmp = push_back_one_arg(new_arg);
       normal_new_args.push_back(tmp);
@@ -375,12 +377,13 @@ Expr LayoutRewriter(const Call& ref_call, const 
Array<Expr>& new_args, const Obj
   for (auto arg : new_call->args) {
     if (arg->IsInstance<TupleNode>()) {  // unflatten tuple
       Tuple tuple_arg = Downcast<Tuple>(arg);
-      std::vector<Expr> transformed_tuple_arg;
+      Array<Expr> transformed_tuple_arg;
+      transformed_tuple_arg.reserve(tuple_arg->fields.size());
       for (auto arg_item : tuple_arg->fields) {
         transformed_tuple_arg.push_back(memorizer.Transform(arg_item, 
new_in[pt], new_in2[pt]));
         pt++;
       }
-      transformed_args.push_back(Tuple(transformed_tuple_arg));
+      transformed_args.push_back(WithFields(tuple_arg, 
std::move(transformed_tuple_arg)));
     } else {
       transformed_args.push_back(memorizer.Transform(arg, new_in[pt], 
new_in2[pt]));
       pt++;

Reply via email to