You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ap...@apache.org on 2023/01/16 11:23:58 UTC

[arrow] branch master updated: GH-33607: [C++] Support optional additional arguments for inline visit functions (#33608)

This is an automated email from the ASF dual-hosted git repository.

apitrou pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 4b48eccce2 GH-33607: [C++] Support optional additional arguments for inline visit functions (#33608)
4b48eccce2 is described below

commit 4b48eccce25cbc5cfbd5c9c1b6d3e83ccd780917
Author: Jin Shang <sh...@gmail.com>
AuthorDate: Mon Jan 16 19:23:43 2023 +0800

    GH-33607: [C++] Support optional additional arguments for inline visit functions (#33608)
    
    # Which issue does this PR close?
    
    Closes #33607
    
    # Rationale for this change
    
    Sometimes we need extra arguments for the `Visit` function, the most common of which is output parameters to save the result. For now, we need to have a member variable and a member function to get the result:
    ```cpp
    class ExampleVisitor {
     public:
      template <typename T>
      Status Visit(const T& arr) {
        /// Do stuff and save result to output_;
        return Status::OK();
      }
    
      std::shared_ptr<Array> GetOutput() { return output_; }
    
     private:
      std::shared_ptr<Array> output_;
    };
    
    ExampleVisitor visitor;
    RETURN_NOT_OK(VisitArrayInline(*arr, &visitor));
    *output = visitor.GetOutput();
    ```
    
    It will be more convenient to write a Visitor if the VisitArrayInline function supports additional args for the `Visit` method:
    ```cpp
    class ExampleVisitorWithArg {
     public:
      template <typename T>
      Status Visit(const T& arr, std::shared_ptr<Array>* output) {
        /// Do stuff and save result to output directly;
        return Status::OK();
      }
    };
    
    ExampleVisitorWithArg visitor;
    RETURN_NOT_OK(VisitArrayInline(*arr, &visitor, &output));
    ```
    
    # Are these changes tested?
    
    Not sure if we need to explicitly test this feature, since no runtime behavior is changed. The existing tests all pass.
    
    # Are there any user-facing changes?
    
    There is a small chance, like in [diff.cc:402](https://github.com/apache/arrow/pull/33608/files#diff-ef9d394f251ee719f9896b864246bd97a81165c08195d4846d5d99e589e5c1c0L402), `VisitTypeInline` is declared as a friend function. The declaration needs to be changed. But I am not sure if any outside user would do this.
    
    Authored-by: Jin Shang <sh...@gmail.com>
    Signed-off-by: Antoine Pitrou <an...@python.org>
---
 cpp/src/arrow/array/diff.cc         |  4 ++--
 cpp/src/arrow/testing/builder.h     | 24 ++++++++-----------
 cpp/src/arrow/visit_array_inline.h  |  9 +++++---
 cpp/src/arrow/visit_scalar_inline.h | 17 ++++++++------
 cpp/src/arrow/visit_type_inline.h   | 46 ++++++++++++++++++++++---------------
 cpp/src/parquet/column_writer.cc    | 28 +++++++++++-----------
 6 files changed, 68 insertions(+), 60 deletions(-)

diff --git a/cpp/src/arrow/array/diff.cc b/cpp/src/arrow/array/diff.cc
index 10802939a7..9fbb5df2c0 100644
--- a/cpp/src/arrow/array/diff.cc
+++ b/cpp/src/arrow/array/diff.cc
@@ -399,8 +399,8 @@ class MakeFormatterImpl {
   }
 
  private:
-  template <typename VISITOR>
-  friend Status VisitTypeInline(const DataType&, VISITOR*);
+  template <typename VISITOR, typename... ARGS>
+  friend Status VisitTypeInline(const DataType&, VISITOR*, ARGS&&... args);
 
   // factory implementation
   Status Visit(const BooleanType&) {
diff --git a/cpp/src/arrow/testing/builder.h b/cpp/src/arrow/testing/builder.h
index f8a375589c..09e8f49dea 100644
--- a/cpp/src/arrow/testing/builder.h
+++ b/cpp/src/arrow/testing/builder.h
@@ -19,6 +19,7 @@
 
 #include <cstdint>
 #include <memory>
+#include <type_traits>
 #include <vector>
 
 #include "arrow/array.h"
@@ -27,6 +28,7 @@
 #include "arrow/array/builder_time.h"
 #include "arrow/buffer.h"
 #include "arrow/testing/gtest_util.h"
+#include "arrow/type_fwd.h"
 #include "arrow/util/bit_util.h"
 #include "arrow/visit_type_inline.h"
 
@@ -159,31 +161,21 @@ Status MakeArray(const std::vector<uint8_t>& valid_bytes, const std::vector<T>&
 }
 
 template <typename Fn>
-struct VisitBuilderImpl {
+struct VisitBuilder {
   template <typename T, typename BuilderType = typename TypeTraits<T>::BuilderType,
             // need to let SFINAE drop this Visit when it would result in
             // [](NullBuilder*){}(double_builder)
             typename = decltype(std::declval<Fn>()(std::declval<BuilderType*>()))>
-  Status Visit(const T&) {
-    fn_(internal::checked_cast<BuilderType*>(builder_));
+  Status Visit(const T&, ArrayBuilder* builder, Fn&& fn) {
+    fn(internal::checked_cast<BuilderType*>(builder));
     return Status::OK();
   }
 
-  Status Visit(const DataType& t) {
+  Status Visit(const DataType& t, ArrayBuilder* builder, Fn&& fn) {
     return Status::NotImplemented("visiting builders of type ", t);
   }
-
-  Status Visit() { return VisitTypeInline(*builder_->type(), this); }
-
-  ArrayBuilder* builder_;
-  Fn fn_;
 };
 
-template <typename Fn>
-Status VisitBuilder(ArrayBuilder* builder, Fn&& fn) {
-  return VisitBuilderImpl<Fn>{builder, std::forward<Fn>(fn)}.Visit();
-}
-
 template <typename Fn>
 Result<std::shared_ptr<Array>> ArrayFromBuilderVisitor(
     const std::shared_ptr<DataType>& type, int64_t initial_capacity,
@@ -195,8 +187,10 @@ Result<std::shared_ptr<Array>> ArrayFromBuilderVisitor(
     RETURN_NOT_OK(builder->Resize(initial_capacity));
   }
 
+  VisitBuilder<Fn> visitor;
   for (int64_t i = 0; i < visitor_repetitions; ++i) {
-    RETURN_NOT_OK(VisitBuilder(builder.get(), std::forward<Fn>(fn)));
+    RETURN_NOT_OK(
+        VisitTypeInline(*builder->type(), &visitor, builder.get(), std::forward<Fn>(fn)));
   }
 
   std::shared_ptr<Array> out;
diff --git a/cpp/src/arrow/visit_array_inline.h b/cpp/src/arrow/visit_array_inline.h
index 895cc37445..cb6ff49b69 100644
--- a/cpp/src/arrow/visit_array_inline.h
+++ b/cpp/src/arrow/visit_array_inline.h
@@ -27,11 +27,14 @@ namespace arrow {
   case TYPE_CLASS##Type::type_id:                                                        \
     return visitor->Visit(                                                               \
         internal::checked_cast<const typename TypeTraits<TYPE_CLASS##Type>::ArrayType&>( \
-            array));
+            array),                                                                      \
+        std::forward<ARGS>(args)...);
 
 /// \brief Apply the visitors Visit() method specialized to the array type
 ///
 /// \tparam VISITOR Visitor type that implements Visit() for all array types.
+/// \tparam ARGS Additional arguments, if any, will be passed to the Visit function after
+/// the `arr` argument
 /// \return Status
 ///
 /// A visitor is a type that implements specialized logic for each Arrow type.
@@ -46,8 +49,8 @@ namespace arrow {
 /// ExampleVisitor visitor;
 /// VisitArrayInline(some_array, &visitor);
 /// ```
-template <typename VISITOR>
-inline Status VisitArrayInline(const Array& array, VISITOR* visitor) {
+template <typename VISITOR, typename... ARGS>
+inline Status VisitArrayInline(const Array& array, VISITOR* visitor, ARGS&&... args) {
   switch (array.type_id()) {
     ARROW_GENERATE_FOR_ALL_TYPES(ARRAY_VISIT_INLINE);
     default:
diff --git a/cpp/src/arrow/visit_scalar_inline.h b/cpp/src/arrow/visit_scalar_inline.h
index f3e8108e9c..85357f288c 100644
--- a/cpp/src/arrow/visit_scalar_inline.h
+++ b/cpp/src/arrow/visit_scalar_inline.h
@@ -28,13 +28,16 @@
 
 namespace arrow {
 
-#define SCALAR_VISIT_INLINE(TYPE_CLASS) \
-  case TYPE_CLASS##Type::type_id:       \
-    return visitor->Visit(internal::checked_cast<const TYPE_CLASS##Scalar&>(scalar));
+#define SCALAR_VISIT_INLINE(TYPE_CLASS)                                              \
+  case TYPE_CLASS##Type::type_id:                                                    \
+    return visitor->Visit(internal::checked_cast<const TYPE_CLASS##Scalar&>(scalar), \
+                          std::forward<ARGS>(args)...);
 
 /// \brief Apply the visitors Visit() method specialized to the scalar type
 ///
 /// \tparam VISITOR Visitor type that implements Visit() for all scalar types.
+/// \tparam ARGS Additional arguments, if any, will be passed to the Visit function after
+/// the `scalar` argument
 /// \return Status
 ///
 /// A visitor is a type that implements specialized logic for each Arrow type.
@@ -42,15 +45,15 @@ namespace arrow {
 ///
 /// ```
 /// class ExampleVisitor {
-///   arrow::Status Visit(arrow::Int32Scalar arr) { ... }
-///   arrow::Status Visit(arrow::Int64Scalar arr) { ... }
+///   arrow::Status Visit(arrow::Int32Scalar scalar) { ... }
+///   arrow::Status Visit(arrow::Int64Scalar scalar) { ... }
 ///   ...
 /// }
 /// ExampleVisitor visitor;
 /// VisitScalarInline(some_scalar, &visitor);
 /// ```
-template <typename VISITOR>
-inline Status VisitScalarInline(const Scalar& scalar, VISITOR* visitor) {
+template <typename VISITOR, typename... ARGS>
+inline Status VisitScalarInline(const Scalar& scalar, VISITOR* visitor, ARGS&&... args) {
   switch (scalar.type->id()) {
     ARROW_GENERATE_FOR_ALL_TYPES(SCALAR_VISIT_INLINE);
     default:
diff --git a/cpp/src/arrow/visit_type_inline.h b/cpp/src/arrow/visit_type_inline.h
index 333ceaea1b..73da58dfcc 100644
--- a/cpp/src/arrow/visit_type_inline.h
+++ b/cpp/src/arrow/visit_type_inline.h
@@ -24,13 +24,16 @@
 
 namespace arrow {
 
-#define TYPE_VISIT_INLINE(TYPE_CLASS) \
-  case TYPE_CLASS##Type::type_id:     \
-    return visitor->Visit(internal::checked_cast<const TYPE_CLASS##Type&>(type));
+#define TYPE_VISIT_INLINE(TYPE_CLASS)                                            \
+  case TYPE_CLASS##Type::type_id:                                                \
+    return visitor->Visit(internal::checked_cast<const TYPE_CLASS##Type&>(type), \
+                          std::forward<ARGS>(args)...);
 
 /// \brief Calls `visitor` with the corresponding concrete type class
 ///
 /// \tparam VISITOR Visitor type that implements Visit() for all Arrow types.
+/// \tparam ARGS Additional arguments, if any, will be passed to the Visit function after
+/// the `type` argument
 /// \return Status
 ///
 /// A visitor is a type that implements specialized logic for each Arrow type.
@@ -45,8 +48,8 @@ namespace arrow {
 /// ExampleVisitor visitor;
 /// VisitTypeInline(some_type, &visitor);
 /// ```
-template <typename VISITOR>
-inline Status VisitTypeInline(const DataType& type, VISITOR* visitor) {
+template <typename VISITOR, typename... ARGS>
+inline Status VisitTypeInline(const DataType& type, VISITOR* visitor, ARGS&&... args) {
   switch (type.id()) {
     ARROW_GENERATE_FOR_ALL_TYPES(TYPE_VISIT_INLINE);
     default:
@@ -57,12 +60,15 @@ inline Status VisitTypeInline(const DataType& type, VISITOR* visitor) {
 
 #undef TYPE_VISIT_INLINE
 
-#define TYPE_VISIT_INLINE(TYPE_CLASS)      \
-  case TYPE_CLASS##Type::type_id:          \
-    return std::forward<VISITOR>(visitor)( \
-        internal::checked_cast<const TYPE_CLASS##Type&>(type));
+#define TYPE_VISIT_INLINE(TYPE_CLASS)                          \
+  case TYPE_CLASS##Type::type_id:                              \
+    return std::forward<VISITOR>(visitor)(                     \
+        internal::checked_cast<const TYPE_CLASS##Type&>(type), \
+        std::forward<ARGS>(args)...);
 
 /// \brief Call `visitor` with the corresponding concrete type class
+/// \tparam ARGS Additional arguments, if any, will be passed to the Visit function after
+/// the `type` argument
 ///
 /// Unlike VisitTypeInline which calls `visitor.Visit`, here `visitor`
 /// itself is called.
@@ -71,31 +77,33 @@ inline Status VisitTypeInline(const DataType& type, VISITOR* visitor) {
 ///
 /// The intent is for this to be called on a generic lambda
 /// that may internally use `if constexpr` or similar constructs.
-template <typename VISITOR>
-inline auto VisitType(const DataType& type, VISITOR&& visitor)
-    -> decltype(std::forward<VISITOR>(visitor)(type)) {
+template <typename VISITOR, typename... ARGS>
+inline auto VisitType(const DataType& type, VISITOR&& visitor, ARGS&&... args)
+    -> decltype(std::forward<VISITOR>(visitor)(type, args...)) {
   switch (type.id()) {
     ARROW_GENERATE_FOR_ALL_TYPES(TYPE_VISIT_INLINE);
     default:
       break;
   }
-  return std::forward<VISITOR>(visitor)(type);
+  return std::forward<VISITOR>(visitor)(type, std::forward<ARGS>(args)...);
 }
 
 #undef TYPE_VISIT_INLINE
 
-#define TYPE_ID_VISIT_INLINE(TYPE_CLASS)            \
-  case TYPE_CLASS##Type::type_id: {                 \
-    const TYPE_CLASS##Type* concrete_ptr = NULLPTR; \
-    return visitor->Visit(concrete_ptr);            \
+#define TYPE_ID_VISIT_INLINE(TYPE_CLASS)                              \
+  case TYPE_CLASS##Type::type_id: {                                   \
+    const TYPE_CLASS##Type* concrete_ptr = NULLPTR;                   \
+    return visitor->Visit(concrete_ptr, std::forward<ARGS>(args)...); \
   }
 
 /// \brief Calls `visitor` with a nullptr of the corresponding concrete type class
 ///
 /// \tparam VISITOR Visitor type that implements Visit() for all Arrow types.
+/// \tparam ARGS Additional arguments, if any, will be passed to the Visit function after
+/// the `type` argument
 /// \return Status
-template <typename VISITOR>
-inline Status VisitTypeIdInline(Type::type id, VISITOR* visitor) {
+template <typename VISITOR, typename... ARGS>
+inline Status VisitTypeIdInline(Type::type id, VISITOR* visitor, ARGS&&... args) {
   switch (id) {
     ARROW_GENERATE_FOR_ALL_TYPES(TYPE_ID_VISIT_INLINE);
     default:
diff --git a/cpp/src/parquet/column_writer.cc b/cpp/src/parquet/column_writer.cc
index 1fc13aa3c1..27ec640cec 100644
--- a/cpp/src/parquet/column_writer.cc
+++ b/cpp/src/parquet/column_writer.cc
@@ -74,9 +74,10 @@ namespace {
 // Visitor that exracts the value buffer from a FlatArray at a given offset.
 struct ValueBufferSlicer {
   template <typename T>
-  ::arrow::enable_if_base_binary<typename T::TypeClass, Status> Visit(const T& array) {
+  ::arrow::enable_if_base_binary<typename T::TypeClass, Status> Visit(
+      const T& array, std::shared_ptr<Buffer>* buffer) {
     auto data = array.data();
-    buffer_ =
+    *buffer =
         SliceBuffer(data->buffers[1], data->offset * sizeof(typename T::offset_type),
                     data->length * sizeof(typename T::offset_type));
     return Status::OK();
@@ -84,9 +85,9 @@ struct ValueBufferSlicer {
 
   template <typename T>
   ::arrow::enable_if_fixed_size_binary<typename T::TypeClass, Status> Visit(
-      const T& array) {
+      const T& array, std::shared_ptr<Buffer>* buffer) {
     auto data = array.data();
-    buffer_ = SliceBuffer(data->buffers[1], data->offset * array.byte_width(),
+    *buffer = SliceBuffer(data->buffers[1], data->offset * array.byte_width(),
                           data->length * array.byte_width());
     return Status::OK();
   }
@@ -95,29 +96,30 @@ struct ValueBufferSlicer {
   ::arrow::enable_if_t<::arrow::has_c_type<typename T::TypeClass>::value &&
                            !std::is_same<BooleanType, typename T::TypeClass>::value,
                        Status>
-  Visit(const T& array) {
+  Visit(const T& array, std::shared_ptr<Buffer>* buffer) {
     auto data = array.data();
-    buffer_ = SliceBuffer(
+    *buffer = SliceBuffer(
         data->buffers[1],
         ::arrow::TypeTraits<typename T::TypeClass>::bytes_required(data->offset),
         ::arrow::TypeTraits<typename T::TypeClass>::bytes_required(data->length));
     return Status::OK();
   }
 
-  Status Visit(const ::arrow::BooleanArray& array) {
+  Status Visit(const ::arrow::BooleanArray& array, std::shared_ptr<Buffer>* buffer) {
     auto data = array.data();
     if (bit_util::IsMultipleOf8(data->offset)) {
-      buffer_ = SliceBuffer(data->buffers[1], bit_util::BytesForBits(data->offset),
+      *buffer = SliceBuffer(data->buffers[1], bit_util::BytesForBits(data->offset),
                             bit_util::BytesForBits(data->length));
       return Status::OK();
     }
-    PARQUET_ASSIGN_OR_THROW(buffer_,
+    PARQUET_ASSIGN_OR_THROW(*buffer,
                             ::arrow::internal::CopyBitmap(pool_, data->buffers[1]->data(),
                                                           data->offset, data->length));
     return Status::OK();
   }
 #define NOT_IMPLEMENTED_VISIT(ArrowTypePrefix)                                      \
-  Status Visit(const ::arrow::ArrowTypePrefix##Array& array) {                      \
+  Status Visit(const ::arrow::ArrowTypePrefix##Array& array,                        \
+               std::shared_ptr<Buffer>* buffer) {                                   \
     return Status::NotImplemented("Slicing not implemented for " #ArrowTypePrefix); \
   }
 
@@ -133,7 +135,6 @@ struct ValueBufferSlicer {
 #undef NOT_IMPLEMENTED_VISIT
 
   MemoryPool* pool_;
-  std::shared_ptr<Buffer> buffer_;
 };
 
 internal::LevelInfo ComputeLevelInfo(const ColumnDescriptor* descr) {
@@ -1316,10 +1317,9 @@ class TypedColumnWriterImpl : public ColumnWriterImpl, public TypedColumnWriter<
     buffers[0] = bits_buffer_;
     // Should be a leaf array.
     DCHECK_GT(buffers.size(), 1);
-    ValueBufferSlicer slicer{memory_pool, /*buffer=*/nullptr};
+    ValueBufferSlicer slicer{memory_pool};
     if (array->data()->offset > 0) {
-      RETURN_NOT_OK(::arrow::VisitArrayInline(*array, &slicer));
-      buffers[1] = slicer.buffer_;
+      RETURN_NOT_OK(::arrow::VisitArrayInline(*array, &slicer, &buffers[1]));
     }
     return ::arrow::MakeArray(std::make_shared<ArrayData>(
         array->type(), array->length(), std::move(buffers), new_null_count));