You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ap...@apache.org on 2023/01/16 11:23:58 UTC
[arrow] branch master updated: GH-33607: [C++] Support optional additional arguments for inline visit functions (#33608)
This is an automated email from the ASF dual-hosted git repository.
apitrou pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 4b48eccce2 GH-33607: [C++] Support optional additional arguments for inline visit functions (#33608)
4b48eccce2 is described below
commit 4b48eccce25cbc5cfbd5c9c1b6d3e83ccd780917
Author: Jin Shang <sh...@gmail.com>
AuthorDate: Mon Jan 16 19:23:43 2023 +0800
GH-33607: [C++] Support optional additional arguments for inline visit functions (#33608)
# Which issue does this PR close?
Closes #33607
# Rationale for this change
Sometimes we need extra arguments for the `Visit` function, the most common of which is output parameters to save the result. For now, we need to have a member variable and a member function to get the result:
```cpp
class ExampleVisitor {
public:
template <typename T>
Status Visit(const T& arr) {
/// Do stuff and save result to output_;
return Status::OK();
}
std::shared_ptr<Array> GetOutput() { return output_; }
private:
std::shared_ptr<Array> output_;
};
ExampleVisitor visitor;
RETURN_NOT_OK(VisitArrayInline(*arr, &visitor));
*output = visitor.GetOutput();
```
It will be more convenient to write a Visitor if the VisitArrayInline function supports additional args for the `Visit` method:
```cpp
class ExampleVisitorWithArg {
public:
template <typename T>
Status Visit(const T& arr, std::shared_ptr<Array>* output) {
/// Do stuff and save result to output directly;
return Status::OK();
}
};
ExampleVisitorWithArg visitor;
RETURN_NOT_OK(VisitArrayInline(*arr, &visitor, &output));
```
# Are these changes tested?
Not sure if we need to explicitly test this feature, since no runtime behavior is changed. The existing tests all pass.
# Are there any user-facing changes?
There is a small chance, like in [diff.cc:402](https://github.com/apache/arrow/pull/33608/files#diff-ef9d394f251ee719f9896b864246bd97a81165c08195d4846d5d99e589e5c1c0L402), `VisitTypeInline` is declared as a friend function. The declaration needs to be changed. But I am not sure if any outside user would do this.
Authored-by: Jin Shang <sh...@gmail.com>
Signed-off-by: Antoine Pitrou <an...@python.org>
---
cpp/src/arrow/array/diff.cc | 4 ++--
cpp/src/arrow/testing/builder.h | 24 ++++++++-----------
cpp/src/arrow/visit_array_inline.h | 9 +++++---
cpp/src/arrow/visit_scalar_inline.h | 17 ++++++++------
cpp/src/arrow/visit_type_inline.h | 46 ++++++++++++++++++++++---------------
cpp/src/parquet/column_writer.cc | 28 +++++++++++-----------
6 files changed, 68 insertions(+), 60 deletions(-)
diff --git a/cpp/src/arrow/array/diff.cc b/cpp/src/arrow/array/diff.cc
index 10802939a7..9fbb5df2c0 100644
--- a/cpp/src/arrow/array/diff.cc
+++ b/cpp/src/arrow/array/diff.cc
@@ -399,8 +399,8 @@ class MakeFormatterImpl {
}
private:
- template <typename VISITOR>
- friend Status VisitTypeInline(const DataType&, VISITOR*);
+ template <typename VISITOR, typename... ARGS>
+ friend Status VisitTypeInline(const DataType&, VISITOR*, ARGS&&... args);
// factory implementation
Status Visit(const BooleanType&) {
diff --git a/cpp/src/arrow/testing/builder.h b/cpp/src/arrow/testing/builder.h
index f8a375589c..09e8f49dea 100644
--- a/cpp/src/arrow/testing/builder.h
+++ b/cpp/src/arrow/testing/builder.h
@@ -19,6 +19,7 @@
#include <cstdint>
#include <memory>
+#include <type_traits>
#include <vector>
#include "arrow/array.h"
@@ -27,6 +28,7 @@
#include "arrow/array/builder_time.h"
#include "arrow/buffer.h"
#include "arrow/testing/gtest_util.h"
+#include "arrow/type_fwd.h"
#include "arrow/util/bit_util.h"
#include "arrow/visit_type_inline.h"
@@ -159,31 +161,21 @@ Status MakeArray(const std::vector<uint8_t>& valid_bytes, const std::vector<T>&
}
template <typename Fn>
-struct VisitBuilderImpl {
+struct VisitBuilder {
template <typename T, typename BuilderType = typename TypeTraits<T>::BuilderType,
// need to let SFINAE drop this Visit when it would result in
// [](NullBuilder*){}(double_builder)
typename = decltype(std::declval<Fn>()(std::declval<BuilderType*>()))>
- Status Visit(const T&) {
- fn_(internal::checked_cast<BuilderType*>(builder_));
+ Status Visit(const T&, ArrayBuilder* builder, Fn&& fn) {
+ fn(internal::checked_cast<BuilderType*>(builder));
return Status::OK();
}
- Status Visit(const DataType& t) {
+ Status Visit(const DataType& t, ArrayBuilder* builder, Fn&& fn) {
return Status::NotImplemented("visiting builders of type ", t);
}
-
- Status Visit() { return VisitTypeInline(*builder_->type(), this); }
-
- ArrayBuilder* builder_;
- Fn fn_;
};
-template <typename Fn>
-Status VisitBuilder(ArrayBuilder* builder, Fn&& fn) {
- return VisitBuilderImpl<Fn>{builder, std::forward<Fn>(fn)}.Visit();
-}
-
template <typename Fn>
Result<std::shared_ptr<Array>> ArrayFromBuilderVisitor(
const std::shared_ptr<DataType>& type, int64_t initial_capacity,
@@ -195,8 +187,10 @@ Result<std::shared_ptr<Array>> ArrayFromBuilderVisitor(
RETURN_NOT_OK(builder->Resize(initial_capacity));
}
+ VisitBuilder<Fn> visitor;
for (int64_t i = 0; i < visitor_repetitions; ++i) {
- RETURN_NOT_OK(VisitBuilder(builder.get(), std::forward<Fn>(fn)));
+ RETURN_NOT_OK(
+ VisitTypeInline(*builder->type(), &visitor, builder.get(), std::forward<Fn>(fn)));
}
std::shared_ptr<Array> out;
diff --git a/cpp/src/arrow/visit_array_inline.h b/cpp/src/arrow/visit_array_inline.h
index 895cc37445..cb6ff49b69 100644
--- a/cpp/src/arrow/visit_array_inline.h
+++ b/cpp/src/arrow/visit_array_inline.h
@@ -27,11 +27,14 @@ namespace arrow {
case TYPE_CLASS##Type::type_id: \
return visitor->Visit( \
internal::checked_cast<const typename TypeTraits<TYPE_CLASS##Type>::ArrayType&>( \
- array));
+ array), \
+ std::forward<ARGS>(args)...);
/// \brief Apply the visitors Visit() method specialized to the array type
///
/// \tparam VISITOR Visitor type that implements Visit() for all array types.
+/// \tparam ARGS Additional arguments, if any, will be passed to the Visit function after
+/// the `arr` argument
/// \return Status
///
/// A visitor is a type that implements specialized logic for each Arrow type.
@@ -46,8 +49,8 @@ namespace arrow {
/// ExampleVisitor visitor;
/// VisitArrayInline(some_array, &visitor);
/// ```
-template <typename VISITOR>
-inline Status VisitArrayInline(const Array& array, VISITOR* visitor) {
+template <typename VISITOR, typename... ARGS>
+inline Status VisitArrayInline(const Array& array, VISITOR* visitor, ARGS&&... args) {
switch (array.type_id()) {
ARROW_GENERATE_FOR_ALL_TYPES(ARRAY_VISIT_INLINE);
default:
diff --git a/cpp/src/arrow/visit_scalar_inline.h b/cpp/src/arrow/visit_scalar_inline.h
index f3e8108e9c..85357f288c 100644
--- a/cpp/src/arrow/visit_scalar_inline.h
+++ b/cpp/src/arrow/visit_scalar_inline.h
@@ -28,13 +28,16 @@
namespace arrow {
-#define SCALAR_VISIT_INLINE(TYPE_CLASS) \
- case TYPE_CLASS##Type::type_id: \
- return visitor->Visit(internal::checked_cast<const TYPE_CLASS##Scalar&>(scalar));
+#define SCALAR_VISIT_INLINE(TYPE_CLASS) \
+ case TYPE_CLASS##Type::type_id: \
+ return visitor->Visit(internal::checked_cast<const TYPE_CLASS##Scalar&>(scalar), \
+ std::forward<ARGS>(args)...);
/// \brief Apply the visitors Visit() method specialized to the scalar type
///
/// \tparam VISITOR Visitor type that implements Visit() for all scalar types.
+/// \tparam ARGS Additional arguments, if any, will be passed to the Visit function after
+/// the `scalar` argument
/// \return Status
///
/// A visitor is a type that implements specialized logic for each Arrow type.
@@ -42,15 +45,15 @@ namespace arrow {
///
/// ```
/// class ExampleVisitor {
-/// arrow::Status Visit(arrow::Int32Scalar arr) { ... }
-/// arrow::Status Visit(arrow::Int64Scalar arr) { ... }
+/// arrow::Status Visit(arrow::Int32Scalar scalar) { ... }
+/// arrow::Status Visit(arrow::Int64Scalar scalar) { ... }
/// ...
/// }
/// ExampleVisitor visitor;
/// VisitScalarInline(some_scalar, &visitor);
/// ```
-template <typename VISITOR>
-inline Status VisitScalarInline(const Scalar& scalar, VISITOR* visitor) {
+template <typename VISITOR, typename... ARGS>
+inline Status VisitScalarInline(const Scalar& scalar, VISITOR* visitor, ARGS&&... args) {
switch (scalar.type->id()) {
ARROW_GENERATE_FOR_ALL_TYPES(SCALAR_VISIT_INLINE);
default:
diff --git a/cpp/src/arrow/visit_type_inline.h b/cpp/src/arrow/visit_type_inline.h
index 333ceaea1b..73da58dfcc 100644
--- a/cpp/src/arrow/visit_type_inline.h
+++ b/cpp/src/arrow/visit_type_inline.h
@@ -24,13 +24,16 @@
namespace arrow {
-#define TYPE_VISIT_INLINE(TYPE_CLASS) \
- case TYPE_CLASS##Type::type_id: \
- return visitor->Visit(internal::checked_cast<const TYPE_CLASS##Type&>(type));
+#define TYPE_VISIT_INLINE(TYPE_CLASS) \
+ case TYPE_CLASS##Type::type_id: \
+ return visitor->Visit(internal::checked_cast<const TYPE_CLASS##Type&>(type), \
+ std::forward<ARGS>(args)...);
/// \brief Calls `visitor` with the corresponding concrete type class
///
/// \tparam VISITOR Visitor type that implements Visit() for all Arrow types.
+/// \tparam ARGS Additional arguments, if any, will be passed to the Visit function after
+/// the `type` argument
/// \return Status
///
/// A visitor is a type that implements specialized logic for each Arrow type.
@@ -45,8 +48,8 @@ namespace arrow {
/// ExampleVisitor visitor;
/// VisitTypeInline(some_type, &visitor);
/// ```
-template <typename VISITOR>
-inline Status VisitTypeInline(const DataType& type, VISITOR* visitor) {
+template <typename VISITOR, typename... ARGS>
+inline Status VisitTypeInline(const DataType& type, VISITOR* visitor, ARGS&&... args) {
switch (type.id()) {
ARROW_GENERATE_FOR_ALL_TYPES(TYPE_VISIT_INLINE);
default:
@@ -57,12 +60,15 @@ inline Status VisitTypeInline(const DataType& type, VISITOR* visitor) {
#undef TYPE_VISIT_INLINE
-#define TYPE_VISIT_INLINE(TYPE_CLASS) \
- case TYPE_CLASS##Type::type_id: \
- return std::forward<VISITOR>(visitor)( \
- internal::checked_cast<const TYPE_CLASS##Type&>(type));
+#define TYPE_VISIT_INLINE(TYPE_CLASS) \
+ case TYPE_CLASS##Type::type_id: \
+ return std::forward<VISITOR>(visitor)( \
+ internal::checked_cast<const TYPE_CLASS##Type&>(type), \
+ std::forward<ARGS>(args)...);
/// \brief Call `visitor` with the corresponding concrete type class
+/// \tparam ARGS Additional arguments, if any, will be passed to the Visit function after
+/// the `type` argument
///
/// Unlike VisitTypeInline which calls `visitor.Visit`, here `visitor`
/// itself is called.
@@ -71,31 +77,33 @@ inline Status VisitTypeInline(const DataType& type, VISITOR* visitor) {
///
/// The intent is for this to be called on a generic lambda
/// that may internally use `if constexpr` or similar constructs.
-template <typename VISITOR>
-inline auto VisitType(const DataType& type, VISITOR&& visitor)
- -> decltype(std::forward<VISITOR>(visitor)(type)) {
+template <typename VISITOR, typename... ARGS>
+inline auto VisitType(const DataType& type, VISITOR&& visitor, ARGS&&... args)
+ -> decltype(std::forward<VISITOR>(visitor)(type, args...)) {
switch (type.id()) {
ARROW_GENERATE_FOR_ALL_TYPES(TYPE_VISIT_INLINE);
default:
break;
}
- return std::forward<VISITOR>(visitor)(type);
+ return std::forward<VISITOR>(visitor)(type, std::forward<ARGS>(args)...);
}
#undef TYPE_VISIT_INLINE
-#define TYPE_ID_VISIT_INLINE(TYPE_CLASS) \
- case TYPE_CLASS##Type::type_id: { \
- const TYPE_CLASS##Type* concrete_ptr = NULLPTR; \
- return visitor->Visit(concrete_ptr); \
+#define TYPE_ID_VISIT_INLINE(TYPE_CLASS) \
+ case TYPE_CLASS##Type::type_id: { \
+ const TYPE_CLASS##Type* concrete_ptr = NULLPTR; \
+ return visitor->Visit(concrete_ptr, std::forward<ARGS>(args)...); \
}
/// \brief Calls `visitor` with a nullptr of the corresponding concrete type class
///
/// \tparam VISITOR Visitor type that implements Visit() for all Arrow types.
+/// \tparam ARGS Additional arguments, if any, will be passed to the Visit function after
+/// the `type` argument
/// \return Status
-template <typename VISITOR>
-inline Status VisitTypeIdInline(Type::type id, VISITOR* visitor) {
+template <typename VISITOR, typename... ARGS>
+inline Status VisitTypeIdInline(Type::type id, VISITOR* visitor, ARGS&&... args) {
switch (id) {
ARROW_GENERATE_FOR_ALL_TYPES(TYPE_ID_VISIT_INLINE);
default:
diff --git a/cpp/src/parquet/column_writer.cc b/cpp/src/parquet/column_writer.cc
index 1fc13aa3c1..27ec640cec 100644
--- a/cpp/src/parquet/column_writer.cc
+++ b/cpp/src/parquet/column_writer.cc
@@ -74,9 +74,10 @@ namespace {
// Visitor that exracts the value buffer from a FlatArray at a given offset.
struct ValueBufferSlicer {
template <typename T>
- ::arrow::enable_if_base_binary<typename T::TypeClass, Status> Visit(const T& array) {
+ ::arrow::enable_if_base_binary<typename T::TypeClass, Status> Visit(
+ const T& array, std::shared_ptr<Buffer>* buffer) {
auto data = array.data();
- buffer_ =
+ *buffer =
SliceBuffer(data->buffers[1], data->offset * sizeof(typename T::offset_type),
data->length * sizeof(typename T::offset_type));
return Status::OK();
@@ -84,9 +85,9 @@ struct ValueBufferSlicer {
template <typename T>
::arrow::enable_if_fixed_size_binary<typename T::TypeClass, Status> Visit(
- const T& array) {
+ const T& array, std::shared_ptr<Buffer>* buffer) {
auto data = array.data();
- buffer_ = SliceBuffer(data->buffers[1], data->offset * array.byte_width(),
+ *buffer = SliceBuffer(data->buffers[1], data->offset * array.byte_width(),
data->length * array.byte_width());
return Status::OK();
}
@@ -95,29 +96,30 @@ struct ValueBufferSlicer {
::arrow::enable_if_t<::arrow::has_c_type<typename T::TypeClass>::value &&
!std::is_same<BooleanType, typename T::TypeClass>::value,
Status>
- Visit(const T& array) {
+ Visit(const T& array, std::shared_ptr<Buffer>* buffer) {
auto data = array.data();
- buffer_ = SliceBuffer(
+ *buffer = SliceBuffer(
data->buffers[1],
::arrow::TypeTraits<typename T::TypeClass>::bytes_required(data->offset),
::arrow::TypeTraits<typename T::TypeClass>::bytes_required(data->length));
return Status::OK();
}
- Status Visit(const ::arrow::BooleanArray& array) {
+ Status Visit(const ::arrow::BooleanArray& array, std::shared_ptr<Buffer>* buffer) {
auto data = array.data();
if (bit_util::IsMultipleOf8(data->offset)) {
- buffer_ = SliceBuffer(data->buffers[1], bit_util::BytesForBits(data->offset),
+ *buffer = SliceBuffer(data->buffers[1], bit_util::BytesForBits(data->offset),
bit_util::BytesForBits(data->length));
return Status::OK();
}
- PARQUET_ASSIGN_OR_THROW(buffer_,
+ PARQUET_ASSIGN_OR_THROW(*buffer,
::arrow::internal::CopyBitmap(pool_, data->buffers[1]->data(),
data->offset, data->length));
return Status::OK();
}
#define NOT_IMPLEMENTED_VISIT(ArrowTypePrefix) \
- Status Visit(const ::arrow::ArrowTypePrefix##Array& array) { \
+ Status Visit(const ::arrow::ArrowTypePrefix##Array& array, \
+ std::shared_ptr<Buffer>* buffer) { \
return Status::NotImplemented("Slicing not implemented for " #ArrowTypePrefix); \
}
@@ -133,7 +135,6 @@ struct ValueBufferSlicer {
#undef NOT_IMPLEMENTED_VISIT
MemoryPool* pool_;
- std::shared_ptr<Buffer> buffer_;
};
internal::LevelInfo ComputeLevelInfo(const ColumnDescriptor* descr) {
@@ -1316,10 +1317,9 @@ class TypedColumnWriterImpl : public ColumnWriterImpl, public TypedColumnWriter<
buffers[0] = bits_buffer_;
// Should be a leaf array.
DCHECK_GT(buffers.size(), 1);
- ValueBufferSlicer slicer{memory_pool, /*buffer=*/nullptr};
+ ValueBufferSlicer slicer{memory_pool};
if (array->data()->offset > 0) {
- RETURN_NOT_OK(::arrow::VisitArrayInline(*array, &slicer));
- buffers[1] = slicer.buffer_;
+ RETURN_NOT_OK(::arrow::VisitArrayInline(*array, &slicer, &buffers[1]));
}
return ::arrow::MakeArray(std::make_shared<ArrayData>(
array->type(), array->length(), std::move(buffers), new_null_count));