You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/09/19 17:51:23 UTC

[GitHub] [tvm] janetsc commented on a diff in pull request #12692: [Containers] Add Array::Map

janetsc commented on code in PR #12692:
URL: https://github.com/apache/tvm/pull/12692#discussion_r974520581


##########
include/tvm/runtime/container/array.h:
##########
@@ -706,6 +710,98 @@ class Array : public ObjectRef {
     }
     return static_cast<ArrayNode*>(data_.get());
   }
+
+  /*! \brief Helper method for mutate/map
+   *
+   * A helper function used internally by both `Array::Map` and
+   * `Array::MutateInPlace`.  Given an array of data, apply the
+   * mapping function to each element, returning the collected array.
+   * Applies both mutate-in-place and copy-on-write optimizations, if
+   * possible.
+   *
+   * \param data A pointer to the ArrayNode containing input data.
+   * Passed by value to allow for mutate-in-place optimizations.
+   *
+   * \param fmap The mapping function
+   *
+   * \tparam F The type of the mutation function.
+   *
+   * \tparam U The output type of the mutation function.  Inferred
+   * from the callable type given.  Must inherit from ObjectRef.
+   *
+   * \return The mapped array.  Depending on whether mutate-in-place
+   * or copy-on-write optimizations were applicable, may be the same
+   * underlying array as the `data` parameter.
+   */
+  template <typename F, typename U = std::invoke_result_t<F, T>>
+  static ObjectPtr<Object> MapHelper(ObjectPtr<Object> data, F fmap) {
+    if (data == nullptr) {
+      return nullptr;
+    }
+
+    ICHECK(data->IsInstance<ArrayNode>());
+
+    constexpr bool is_same_output_type = std::is_same_v<T, U>;
+
+    if constexpr (is_same_output_type) {
+      if (data.unique()) {
+        // Mutate-in-place path.  Only allowed if the output type U is
+        // the same as type T, we have a mutable this*, and there are
+        // no other shared copies of the array.
+        auto arr = static_cast<ArrayNode*>(data.get());
+        for (auto it = arr->MutableBegin(); it != arr->MutableEnd(); it++) {
+          T mapped = fmap(DowncastNoCheck<T>(std::move(*it)));
+          *it = std::move(mapped);
+        }
+        return data;
+      }
+    }
+
+    constexpr bool compatible_types = is_valid_iterator_v<T, U*> || is_valid_iterator_v<U, T*>;
+
+    ObjectPtr<ArrayNode> output = nullptr;
+    auto arr = static_cast<ArrayNode*>(data.get());
+
+    auto it = arr->begin();
+    if constexpr (compatible_types) {
+      // Copy-on-write path, if the output Array<U> might be
+      // represented by the same underlying array as the existing
+      // Array<T>.  Typically, this is for functions that map `T` to
+      // `T`, but can also apply to functions that map `T` to
+      // `Optional<T>`, or that map `T` to a subclass or superclass of
+      // `T`.
+      bool all_identical = true;
+      for (; it != arr->end(); it++) {
+        U mapped = fmap(DowncastNoCheck<T>(*it));
+        if (!mapped.same_as(*it)) {
+          all_identical = false;
+          output = ArrayNode::CreateRepeated(arr->size(), U());
+          output->InitRange(0, arr->begin(), it);
+          output->SetItem(it - arr->begin(), std::move(mapped));
+          it++;
+          break;
+        }
+      }
+      if (all_identical) {
+        return data;
+      }
+    } else {
+      // Path for incompatible types.  The constexpr check for
+      // compatible types isn't strictly necessary, as the first
+      // mapped.same_as(*it) would return false, but we might as well
+      // avoid it altogether.
+      output = ArrayNode::CreateRepeated(arr->size(), U());
+    }
+
+    // Normal path for incompatible types, or post-copy path for
+    // copy-on-write instances.

Review Comment:
   What will be left over on the copy-on-write instance?  Will there be some items that are incompatible?  How are those guaranteed to be at the end?



##########
include/tvm/runtime/container/array.h:
##########
@@ -706,6 +710,98 @@ class Array : public ObjectRef {
     }
     return static_cast<ArrayNode*>(data_.get());
   }
+
+  /*! \brief Helper method for mutate/map
+   *
+   * A helper function used internally by both `Array::Map` and
+   * `Array::MutateInPlace`.  Given an array of data, apply the
+   * mapping function to each element, returning the collected array.
+   * Applies both mutate-in-place and copy-on-write optimizations, if
+   * possible.
+   *
+   * \param data A pointer to the ArrayNode containing input data.
+   * Passed by value to allow for mutate-in-place optimizations.
+   *
+   * \param fmap The mapping function
+   *
+   * \tparam F The type of the mutation function.
+   *
+   * \tparam U The output type of the mutation function.  Inferred
+   * from the callable type given.  Must inherit from ObjectRef.
+   *
+   * \return The mapped array.  Depending on whether mutate-in-place
+   * or copy-on-write optimizations were applicable, may be the same
+   * underlying array as the `data` parameter.
+   */
+  template <typename F, typename U = std::invoke_result_t<F, T>>
+  static ObjectPtr<Object> MapHelper(ObjectPtr<Object> data, F fmap) {
+    if (data == nullptr) {
+      return nullptr;
+    }
+
+    ICHECK(data->IsInstance<ArrayNode>());
+
+    constexpr bool is_same_output_type = std::is_same_v<T, U>;
+
+    if constexpr (is_same_output_type) {
+      if (data.unique()) {
+        // Mutate-in-place path.  Only allowed if the output type U is
+        // the same as type T, we have a mutable this*, and there are
+        // no other shared copies of the array.
+        auto arr = static_cast<ArrayNode*>(data.get());
+        for (auto it = arr->MutableBegin(); it != arr->MutableEnd(); it++) {
+          T mapped = fmap(DowncastNoCheck<T>(std::move(*it)));
+          *it = std::move(mapped);
+        }
+        return data;
+      }
+    }
+
+    constexpr bool compatible_types = is_valid_iterator_v<T, U*> || is_valid_iterator_v<U, T*>;
+
+    ObjectPtr<ArrayNode> output = nullptr;
+    auto arr = static_cast<ArrayNode*>(data.get());
+
+    auto it = arr->begin();
+    if constexpr (compatible_types) {
+      // Copy-on-write path, if the output Array<U> might be
+      // represented by the same underlying array as the existing
+      // Array<T>.  Typically, this is for functions that map `T` to
+      // `T`, but can also apply to functions that map `T` to
+      // `Optional<T>`, or that map `T` to a subclass or superclass of
+      // `T`.
+      bool all_identical = true;
+      for (; it != arr->end(); it++) {
+        U mapped = fmap(DowncastNoCheck<T>(*it));
+        if (!mapped.same_as(*it)) {
+          all_identical = false;
+          output = ArrayNode::CreateRepeated(arr->size(), U());
+          output->InitRange(0, arr->begin(), it);
+          output->SetItem(it - arr->begin(), std::move(mapped));
+          it++;
+          break;
+        }
+      }
+      if (all_identical) {
+        return data;
+      }
+    } else {
+      // Path for incompatible types.  The constexpr check for
+      // compatible types isn't strictly necessary, as the first
+      // mapped.same_as(*it) would return false, but we might as well
+      // avoid it altogether.
+      output = ArrayNode::CreateRepeated(arr->size(), U());
+    }
+
+    // Normal path for incompatible types, or post-copy path for
+    // copy-on-write instances.
+    for (; it != arr->end(); it++) {
+      U mapped = fmap(DowncastNoCheck<T>(*it));
+      output->SetItem(it - arr->begin(), std::move(mapped));
+    }
+
+    return output;
+  }

Review Comment:
   General comment - Can you add a unit test to exercise the edge cases in MapHelper?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org