Lunderberg commented on code in PR #12692:
URL: https://github.com/apache/tvm/pull/12692#discussion_r974662892


##########
include/tvm/runtime/container/array.h:
##########
@@ -706,6 +710,98 @@ class Array : public ObjectRef {
     }
     return static_cast<ArrayNode*>(data_.get());
   }
+
+  /*! \brief Helper method for mutate/map
+   *
+   * A helper function used internally by both `Array::Map` and
+   * `Array::MutateInPlace`.  Given an array of data, apply the
+   * mapping function to each element, returning the collected array.
+   * Applies both mutate-in-place and copy-on-write optimizations, if
+   * possible.
+   *
+   * \param data A pointer to the ArrayNode containing input data.
+   * Passed by value to allow for mutate-in-place optimizations.
+   *
+   * \param fmap The mapping function
+   *
+   * \tparam F The type of the mutation function.
+   *
+   * \tparam U The output type of the mutation function.  Inferred
+   * from the callable type given.  Must inherit from ObjectRef.
+   *
+   * \return The mapped array.  Depending on whether mutate-in-place
+   * or copy-on-write optimizations were applicable, may be the same
+   * underlying array as the `data` parameter.
+   */
+  template <typename F, typename U = std::invoke_result_t<F, T>>
+  static ObjectPtr<Object> MapHelper(ObjectPtr<Object> data, F fmap) {
+    if (data == nullptr) {
+      return nullptr;
+    }
+
+    ICHECK(data->IsInstance<ArrayNode>());
+
+    constexpr bool is_same_output_type = std::is_same_v<T, U>;
+
+    if constexpr (is_same_output_type) {
+      if (data.unique()) {
+        // Mutate-in-place path.  Only allowed if the output type U is
+        // the same as type T, we have a mutable this*, and there are
+        // no other shared copies of the array.
+        auto arr = static_cast<ArrayNode*>(data.get());
+        for (auto it = arr->MutableBegin(); it != arr->MutableEnd(); it++) {
+          T mapped = fmap(DowncastNoCheck<T>(std::move(*it)));
+          *it = std::move(mapped);
+        }
+        return data;
+      }
+    }
+
+    constexpr bool compatible_types = is_valid_iterator_v<T, U*> || 
is_valid_iterator_v<U, T*>;
+
+    ObjectPtr<ArrayNode> output = nullptr;
+    auto arr = static_cast<ArrayNode*>(data.get());
+
+    auto it = arr->begin();
+    if constexpr (compatible_types) {
+      // Copy-on-write path, if the output Array<U> might be
+      // represented by the same underlying array as the existing
+      // Array<T>.  Typically, this is for functions that map `T` to
+      // `T`, but can also apply to functions that map `T` to
+      // `Optional<T>`, or that map `T` to a subclass or superclass of
+      // `T`.
+      bool all_identical = true;
+      for (; it != arr->end(); it++) {
+        U mapped = fmap(DowncastNoCheck<T>(*it));
+        if (!mapped.same_as(*it)) {
+          all_identical = false;
+          output = ArrayNode::CreateRepeated(arr->size(), U());
+          output->InitRange(0, arr->begin(), it);
+          output->SetItem(it - arr->begin(), std::move(mapped));
+          it++;
+          break;
+        }
+      }
+      if (all_identical) {
+        return data;
+      }
+    } else {
+      // Path for incompatible types.  The constexpr check for
+      // compatible types isn't strictly necessary, as the first
+      // mapped.same_as(*it) would return false, but we might as well
+      // avoid it altogether.
+      output = ArrayNode::CreateRepeated(arr->size(), U());
+    }
+
+    // Normal path for incompatible types, or post-copy path for
+    // copy-on-write instances.
+    for (; it != arr->end(); it++) {
+      U mapped = fmap(DowncastNoCheck<T>(*it));
+      output->SetItem(it - arr->begin(), std::move(mapped));
+    }
+
+    return output;
+  }

Review Comment:
   Certainly, and thank you for pointing that out!  There are some existing 
tests in 
[`container_test.cc`](https://github.com/apache/tvm/blob/main/tests/cpp/container_test.cc#L165),
 along with a large amount of usage when lowering TIR, but no tests that would 
specifically point to these edge cases.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to