Lunderberg commented on code in PR #12692:
URL: https://github.com/apache/tvm/pull/12692#discussion_r974661862


##########
include/tvm/runtime/container/array.h:
##########
@@ -706,6 +710,98 @@ class Array : public ObjectRef {
     }
     return static_cast<ArrayNode*>(data_.get());
   }
+
+  /*! \brief Helper method for mutate/map
+   *
+   * A helper function used internally by both `Array::Map` and
+   * `Array::MutateInPlace`.  Given an array of data, apply the
+   * mapping function to each element, returning the collected array.
+   * Applies both mutate-in-place and copy-on-write optimizations, if
+   * possible.
+   *
+   * \param data A pointer to the ArrayNode containing input data.
+   * Passed by value to allow for mutate-in-place optimizations.
+   *
+   * \param fmap The mapping function
+   *
+   * \tparam F The type of the mutation function.
+   *
+   * \tparam U The output type of the mutation function.  Inferred
+   * from the callable type given.  Must inherit from ObjectRef.
+   *
+   * \return The mapped array.  Depending on whether mutate-in-place
+   * or copy-on-write optimizations were applicable, may be the same
+   * underlying array as the `data` parameter.
+   */
+  template <typename F, typename U = std::invoke_result_t<F, T>>
+  static ObjectPtr<Object> MapHelper(ObjectPtr<Object> data, F fmap) {
+    if (data == nullptr) {
+      return nullptr;
+    }
+
+    ICHECK(data->IsInstance<ArrayNode>());
+
+    constexpr bool is_same_output_type = std::is_same_v<T, U>;
+
+    if constexpr (is_same_output_type) {
+      if (data.unique()) {
+        // Mutate-in-place path.  Only allowed if the output type U is
+        // the same as type T, we have a mutable this*, and there are
+        // no other shared copies of the array.
+        auto arr = static_cast<ArrayNode*>(data.get());
+        for (auto it = arr->MutableBegin(); it != arr->MutableEnd(); it++) {
+          T mapped = fmap(DowncastNoCheck<T>(std::move(*it)));
+          *it = std::move(mapped);
+        }
+        return data;
+      }
+    }
+
+    constexpr bool compatible_types = is_valid_iterator_v<T, U*> || 
is_valid_iterator_v<U, T*>;
+
+    ObjectPtr<ArrayNode> output = nullptr;
+    auto arr = static_cast<ArrayNode*>(data.get());
+
+    auto it = arr->begin();
+    if constexpr (compatible_types) {
+      // Copy-on-write path, if the output Array<U> might be
+      // represented by the same underlying array as the existing
+      // Array<T>.  Typically, this is for functions that map `T` to
+      // `T`, but can also apply to functions that map `T` to
+      // `Optional<T>`, or that map `T` to a subclass or superclass of
+      // `T`.
+      bool all_identical = true;
+      for (; it != arr->end(); it++) {
+        U mapped = fmap(DowncastNoCheck<T>(*it));
+        if (!mapped.same_as(*it)) {
+          all_identical = false;
+          output = ArrayNode::CreateRepeated(arr->size(), U());
+          output->InitRange(0, arr->begin(), it);
+          output->SetItem(it - arr->begin(), std::move(mapped));
+          it++;
+          break;
+        }
+      }
+      if (all_identical) {
+        return data;
+      }
+    } else {
+      // Path for incompatible types.  The constexpr check for
+      // compatible types isn't strictly necessary, as the first
+      // mapped.same_as(*it) would return false, but we might as well
+      // avoid it altogether.
+      output = ArrayNode::CreateRepeated(arr->size(), U());
+    }
+
+    // Normal path for incompatible types, or post-copy path for
+    // copy-on-write instances.

Review Comment:
   > What will be left over on the copy-on-write instance? 
   
   If we have compatible types, and we've reached this point, we've found at 
least one element for which the `mapped.same_as(*it)` check on line 776 has 
failed.  In that case, `output` will contain everything in the range 
`[arr->begin(), it)`.  That is, `output` contains all elements that are 
identical, and the first non-identical element.  `it` will point to the next 
element that should be transformed, and so the next loop over `it` can continue 
where the first loop left off.
   
   > Will there be some items that are incompatible?
   
   It's entirely possible, either at compile-time or at runtime.  For example, 
I could have an `Array<PrimExpr> buffer_shape` and map it to allowed ranges 
`buffer_shape.Map([](PrimExpr expr) { return Range::FromMinExtent(0, 
expr);});`, which would be incompatible and identified as such at compile-time. 
 In that case, the `if constexpr` could identify that they cannot be 
represented by the same underlying array, and can skip the attempts to do so 
altogether.
   
   If a type is incompatible at runtime, then it will also fail the 
`mapped.same_as(*it)` check on line 776.  So if I have an `Array<Var>` being 
mapped to `Array<PrimExpr>` with `var_array.Map([&](Var var) { return 
var.same_as(to_replace) ? replace_with : var;});`, it may or may not be 
compatible, depending on whether `to_replace` shows up in the array.
   
   > How are those guaranteed to be at the end?
   
   Incompatible items may occur at any point in the mapped output, even at the 
very first iteration.  In that case, the commands executed in the conditional 
on `!mapped.same_as(*it)` are the same as would be executed up through the 
first iteration of the mapping loop.
   
   ```c++
   // Same as the else branch on `compatible_types`
   output = ArrayNode::CreateRepeated(arr->size(), U());
   
   // For the first iteration, it is `arr->begin()`, so this would be an
   // empty range [begin, begin), nothing is initialized, and this
   // statement has no effect.
   output->InitRange(0, arr->begin(), it);
   
   // The newly mapped item is stored to the first location of the output.
   output->SetItem(it - arr->begin(), std::move(mapped));
   
   // The loop increment that would have happened
   it++;
   
   // `it` now points to the second element of the input, and we have one
   // mapped element in the output.  We're now ready to start the second
   // loop, just at the second iteration instead of the first.
   ```
   
   Essentially, we only need to check for identical return values up until we 
find a single non-identical element, at which point we know that we can't avoid 
the copy anyways.  But once we reach the first non-identical value, we don't 
need to repeat the function calls up to that point, because we know that 
everything is either identical (and can therefore be copied from the input) or 
is non-identical (is which case it is the first such non-identical value).



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to