You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by we...@apache.org on 2019/02/15 04:56:23 UTC
[arrow] branch master updated: ARROW-1896: [C++] Do not allocate
memory inside CastKernel. Clean up template instantiation to not generate
dead identity cast code
This is an automated email from the ASF dual-hosted git repository.
wesm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 47ebb1a ARROW-1896: [C++] Do not allocate memory inside CastKernel. Clean up template instantiation to not generate dead identity cast code
47ebb1a is described below
commit 47ebb1af1f6e1bcac95cf99f8258257f471f043b
Author: Wes McKinney <we...@apache.org>
AuthorDate: Thu Feb 14 22:56:13 2019 -0600
ARROW-1896: [C++] Do not allocate memory inside CastKernel. Clean up template instantiation to not generate dead identity cast code
Also resolves ARROW-4110, which has been on my list for some time.
This ended up being a huge pain.
* `detail::PrimitiveAllocatingUnaryKernel` can now allocate memory for any kind of fixed width type.
* I factored out simple bitmap propagation into `detail::PropagateNulls`
* I moved the null count resolution code one level down into `ArrayData`, since there are cases where it may be set to `kUnknownNullCount` (e.g. after a slice) and you need to know what it is. This isn't tested but I suggest addressing this in a follow up patch
I also moved hand-maintained macro spaghetti for instantiating CastFunctors into a Python code-generation script. This might be the most controversial change in this patch, but the problem here is that we needed to exclude 1 macro case for each numeric type -- currently they were relying on `NUMERIC_CASES`. This means the list of generated types is slightly different for each type, lending to poor code reuse. Rather than maintaining this code by hand, it is _so much simpler_ to genera [...]
Speaking of code generation, I think we should continue to invest in code generation scripts to make generating mundane C++ code for pre-compiled kernels simpler. I checked the file in but I'm not opposed to auto-generating the files as part of the CMake build -- we could do that in a follow up PR.
Author: Wes McKinney <we...@apache.org>
Closes #3642 from wesm/ARROW-1896 and squashes the following commits:
57d10840c <Wes McKinney> Fix another clang warning
0d3a7b39c <Wes McKinney> Fix clang warning on macOS
8aeaf967c <Wes McKinney> Code review
ab534d174 <Wes McKinney> Fix dictionary->dense conversion for Decimal128
7a178a4be <Wes McKinney> Refactoring around kernel memory allocation, do not allocate memory inside CastKernel. Use code generation to avoid instantiating CastFunctors for identity casts that are never used
---
cpp/build-support/lint_exclusions.txt | 1 +
cpp/src/arrow/array.cc | 23 +-
cpp/src/arrow/array.h | 5 +-
cpp/src/arrow/compute/kernel.h | 4 +
cpp/src/arrow/compute/kernels/aggregate.cc | 4 +
cpp/src/arrow/compute/kernels/aggregate.h | 4 +
cpp/src/arrow/compute/kernels/boolean.cc | 21 +-
cpp/src/arrow/compute/kernels/cast-test.cc | 2 +-
cpp/src/arrow/compute/kernels/cast.cc | 439 ++++++---------------
cpp/src/arrow/compute/kernels/cast.h | 7 +-
.../kernels/generated/cast-codegen-internal.h | 226 +++++++++++
cpp/src/arrow/compute/kernels/generated/codegen.py | 134 +++++++
cpp/src/arrow/compute/kernels/hash.cc | 26 +-
cpp/src/arrow/compute/kernels/sum.cc | 4 +
cpp/src/arrow/compute/kernels/util-internal.cc | 112 ++++--
cpp/src/arrow/compute/kernels/util-internal.h | 33 +-
16 files changed, 654 insertions(+), 391 deletions(-)
diff --git a/cpp/build-support/lint_exclusions.txt b/cpp/build-support/lint_exclusions.txt
index 2964898..1187beb 100644
--- a/cpp/build-support/lint_exclusions.txt
+++ b/cpp/build-support/lint_exclusions.txt
@@ -1,3 +1,4 @@
+*codegen*
*_generated*
*windows_compatibility.h
*pyarrow_api.h
diff --git a/cpp/src/arrow/array.cc b/cpp/src/arrow/array.cc
index 1569889..8ceb6ea 100644
--- a/cpp/src/arrow/array.cc
+++ b/cpp/src/arrow/array.cc
@@ -76,22 +76,23 @@ std::shared_ptr<ArrayData> ArrayData::Make(const std::shared_ptr<DataType>& type
return std::make_shared<ArrayData>(type, length, null_count, offset);
}
-// ----------------------------------------------------------------------
-// Base array class
-
-int64_t Array::null_count() const {
- if (ARROW_PREDICT_FALSE(data_->null_count < 0)) {
- if (data_->buffers[0]) {
- data_->null_count =
- data_->length - CountSetBits(null_bitmap_data_, data_->offset, data_->length);
-
+int64_t ArrayData::GetNullCount() const {
+ if (ARROW_PREDICT_FALSE(this->null_count == kUnknownNullCount)) {
+ if (this->buffers[0]) {
+ this->null_count = this->length - CountSetBits(this->buffers[0]->data(),
+ this->offset, this->length);
} else {
- data_->null_count = 0;
+ this->null_count = 0;
}
}
- return data_->null_count;
+ return this->null_count;
}
+// ----------------------------------------------------------------------
+// Base array class
+
+int64_t Array::null_count() const { return data_->GetNullCount(); }
+
bool Array::Equals(const Array& arr) const { return ArrayEquals(*this, arr); }
bool Array::Equals(const std::shared_ptr<Array>& arr) const {
diff --git a/cpp/src/arrow/array.h b/cpp/src/arrow/array.h
index 5b4daa8..674bf7b 100644
--- a/cpp/src/arrow/array.h
+++ b/cpp/src/arrow/array.h
@@ -197,9 +197,12 @@ struct ARROW_EXPORT ArrayData {
return GetMutableValues<T>(i, offset);
}
+ /// \brief Return null count, or compute and set it if it's not known
+ int64_t GetNullCount() const;
+
std::shared_ptr<DataType> type;
int64_t length;
- int64_t null_count;
+ mutable int64_t null_count;
// The logical start point into the physical buffers (in values, not bytes).
// Note that, for child data, this must be *added* to the child data's own offset.
int64_t offset;
diff --git a/cpp/src/arrow/compute/kernel.h b/cpp/src/arrow/compute/kernel.h
index 38b78ca..b8033eb 100644
--- a/cpp/src/arrow/compute/kernel.h
+++ b/cpp/src/arrow/compute/kernel.h
@@ -188,6 +188,10 @@ class ARROW_EXPORT UnaryKernel : public OpKernel {
/// there will be a more generic mechansim for understanding the necessary
/// contracts.
virtual Status Call(FunctionContext* ctx, const Datum& input, Datum* out) = 0;
+
+ /// \brief EXPERIMENTAL The output data type of the kernel
+ /// \return the output type
+ virtual std::shared_ptr<DataType> out_type() const = 0;
};
/// \class BinaryKernel
diff --git a/cpp/src/arrow/compute/kernels/aggregate.cc b/cpp/src/arrow/compute/kernels/aggregate.cc
index e1e2dd9..3825e32 100644
--- a/cpp/src/arrow/compute/kernels/aggregate.cc
+++ b/cpp/src/arrow/compute/kernels/aggregate.cc
@@ -62,5 +62,9 @@ Status AggregateUnaryKernel::Call(FunctionContext* ctx, const Datum& input, Datu
return Status::OK();
}
+std::shared_ptr<DataType> AggregateUnaryKernel::out_type() const {
+ return aggregate_function_->out_type();
+}
+
} // namespace compute
} // namespace arrow
diff --git a/cpp/src/arrow/compute/kernels/aggregate.h b/cpp/src/arrow/compute/kernels/aggregate.h
index 4bc869a..2fe8263 100644
--- a/cpp/src/arrow/compute/kernels/aggregate.h
+++ b/cpp/src/arrow/compute/kernels/aggregate.h
@@ -57,6 +57,8 @@ class AggregateFunction {
virtual ~AggregateFunction() {}
+ virtual std::shared_ptr<DataType> out_type() const = 0;
+
/// State management methods.
virtual int64_t Size() const = 0;
virtual void New(void* ptr) const = 0;
@@ -103,6 +105,8 @@ class ARROW_EXPORT AggregateUnaryKernel : public UnaryKernel {
Status Call(FunctionContext* ctx, const Datum& input, Datum* out) override;
+ std::shared_ptr<DataType> out_type() const override;
+
private:
std::shared_ptr<AggregateFunction> aggregate_function_;
};
diff --git a/cpp/src/arrow/compute/kernels/boolean.cc b/cpp/src/arrow/compute/kernels/boolean.cc
index 78ae7d4..7d8b15a 100644
--- a/cpp/src/arrow/compute/kernels/boolean.cc
+++ b/cpp/src/arrow/compute/kernels/boolean.cc
@@ -34,13 +34,17 @@ namespace arrow {
using internal::BitmapAnd;
using internal::BitmapOr;
using internal::BitmapXor;
-using internal::CopyBitmap;
using internal::CountSetBits;
using internal::InvertBitmap;
namespace compute {
-class InvertKernel : public UnaryKernel {
+class BooleanUnaryKernel : public UnaryKernel {
+ public:
+ std::shared_ptr<DataType> out_type() const override { return boolean(); }
+};
+
+class InvertKernel : public BooleanUnaryKernel {
Status Call(FunctionContext* ctx, const Datum& input, Datum* out) override {
DCHECK_EQ(Datum::ARRAY, input.kind());
constexpr int64_t kZeroDestOffset = 0;
@@ -49,17 +53,6 @@ class InvertKernel : public UnaryKernel {
std::shared_ptr<ArrayData> result = out->array();
result->type = boolean();
- // Handle validity bitmap
- result->null_count = in_data.null_count;
- const std::shared_ptr<Buffer>& validity_bitmap = in_data.buffers[0];
- if (in_data.offset != 0 && in_data.null_count > 0) {
- DCHECK_LE(BitUtil::BytesForBits(in_data.length), validity_bitmap->size());
- CopyBitmap(validity_bitmap->data(), in_data.offset, in_data.length,
- result->buffers[0]->mutable_data(), kZeroDestOffset);
- } else {
- result->buffers[0] = validity_bitmap;
- }
-
// Handle output data buffer
if (in_data.length > 0) {
const Buffer& data_buffer = *in_data.buffers[1];
@@ -73,7 +66,7 @@ class InvertKernel : public UnaryKernel {
Status Invert(FunctionContext* ctx, const Datum& value, Datum* out) {
detail::PrimitiveAllocatingUnaryKernel kernel(
- std::unique_ptr<UnaryKernel>(new InvertKernel()));
+ std::unique_ptr<UnaryKernel>(new InvertKernel()), boolean());
std::vector<Datum> result;
RETURN_NOT_OK(detail::InvokeUnaryArrayKernel(ctx, &kernel, value, &result));
diff --git a/cpp/src/arrow/compute/kernels/cast-test.cc b/cpp/src/arrow/compute/kernels/cast-test.cc
index 961b359..e7f5a4a 100644
--- a/cpp/src/arrow/compute/kernels/cast-test.cc
+++ b/cpp/src/arrow/compute/kernels/cast-test.cc
@@ -844,7 +844,7 @@ TEST_F(TestCast, PreallocatedMemory) {
shared_ptr<Buffer> out_values;
ASSERT_OK(this->ctx_.Allocate(length * sizeof(int64_t), &out_values));
- out_data->buffers.push_back(nullptr);
+ out_data->buffers.push_back(arr->data()->buffers[0]);
out_data->buffers.push_back(out_values);
Datum out(out_data);
diff --git a/cpp/src/arrow/compute/kernels/cast.cc b/cpp/src/arrow/compute/kernels/cast.cc
index 74ee7d6..5ddc59f 100644
--- a/cpp/src/arrow/compute/kernels/cast.cc
+++ b/cpp/src/arrow/compute/kernels/cast.cc
@@ -78,83 +78,9 @@ namespace compute {
constexpr int64_t kMillisecondsInDay = 86400000;
-template <typename O, typename I, typename Enable = void>
-struct is_binary_to_string {
- static constexpr bool value = false;
-};
-
-template <typename O, typename I>
-struct is_binary_to_string<
- O, I,
- typename std::enable_if<std::is_same<BinaryType, I>::value &&
- std::is_base_of<StringType, O>::value>::type> {
- static constexpr bool value = true;
-};
-
-// ----------------------------------------------------------------------
-// Zero copy casts
-
-template <typename O, typename I, typename Enable = void>
-struct is_zero_copy_cast {
- static constexpr bool value = false;
-};
-
-// TODO(wesm): ARROW-4110; this is no longer needed, but may be useful if we
-// ever _do_ want to generate identity cast kernels at compile time
-template <typename O, typename I>
-struct is_zero_copy_cast<
- O, I,
- typename std::enable_if<std::is_same<I, O>::value &&
- // Parametric types contains runtime data which
- // differentiate them, it cannot be checked statically.
- !std::is_base_of<ParametricType, O>::value>::type> {
- static constexpr bool value = true;
-};
-
-// From integers to date/time types with zero copy
-template <typename O, typename I>
-struct is_zero_copy_cast<
- O, I,
- typename std::enable_if<
- (std::is_base_of<Integer, I>::value &&
- (std::is_base_of<TimeType, O>::value || std::is_base_of<DateType, O>::value ||
- std::is_base_of<TimestampType, O>::value)) ||
- (std::is_base_of<Integer, O>::value &&
- (std::is_base_of<TimeType, I>::value || std::is_base_of<DateType, I>::value ||
- std::is_base_of<TimestampType, I>::value))>::type> {
- using O_T = typename O::c_type;
- using I_T = typename I::c_type;
-
- static constexpr bool value = sizeof(O_T) == sizeof(I_T);
-};
-
-// Binary to String doesn't require copying, the payload only needs to be
-// validated.
-template <typename O, typename I>
-struct is_zero_copy_cast<
- O, I,
- typename std::enable_if<!std::is_same<I, O>::value &&
- is_binary_to_string<O, I>::value>::type> {
- static constexpr bool value = true;
-};
-
template <typename OutType, typename InType, typename Enable = void>
struct CastFunctor {};
-// Indicated no computation required
-//
-// The case BinaryType -> StringType is special cased due to validation
-// requirements.
-template <typename O, typename I>
-struct CastFunctor<O, I,
- typename std::enable_if<is_zero_copy_cast<O, I>::value &&
- !is_binary_to_string<O, I>::value>::type> {
- void operator()(FunctionContext* ctx, const CastOptions& options,
- const ArrayData& input, ArrayData* output) {
- ZeroCopyData(input, output);
- }
-};
-
// ----------------------------------------------------------------------
// Null to other things
@@ -690,11 +616,41 @@ struct CastFunctor<Date32Type, Date64Type> {
// ----------------------------------------------------------------------
// List to List
-class ListCastKernel : public UnaryKernel {
+class CastKernelBase : public UnaryKernel {
+ public:
+ explicit CastKernelBase(std::shared_ptr<DataType> out_type)
+ : out_type_(std::move(out_type)) {}
+
+ std::shared_ptr<DataType> out_type() const override { return out_type_; }
+
+ protected:
+ std::shared_ptr<DataType> out_type_;
+};
+
+bool NeedToPreallocate(const DataType& type) {
+ return dynamic_cast<const FixedWidthType*>(&type) != nullptr;
+}
+
+Status InvokeWithAllocation(FunctionContext* ctx, UnaryKernel* func, const Datum& input,
+ Datum* out) {
+ std::vector<Datum> result;
+ if (NeedToPreallocate(*func->out_type())) {
+ // Create wrapper that allocates output memory for primitive types
+ detail::PrimitiveAllocatingUnaryKernel wrapper(func, func->out_type());
+ RETURN_NOT_OK(detail::InvokeUnaryArrayKernel(ctx, &wrapper, input, &result));
+ } else {
+ RETURN_NOT_OK(detail::InvokeUnaryArrayKernel(ctx, func, input, &result));
+ }
+ RETURN_IF_ERROR(ctx);
+ *out = detail::WrapDatumsLike(input, result);
+ return Status::OK();
+}
+
+class ListCastKernel : public CastKernelBase {
public:
ListCastKernel(std::unique_ptr<UnaryKernel> child_caster,
- const std::shared_ptr<DataType>& out_type)
- : child_caster_(std::move(child_caster)), out_type_(out_type) {}
+ std::shared_ptr<DataType> out_type)
+ : CastKernelBase(std::move(out_type)), child_caster_(std::move(child_caster)) {}
Status Call(FunctionContext* ctx, const Datum& input, Datum* out) override {
DCHECK_EQ(Datum::ARRAY, input.kind());
@@ -718,16 +674,15 @@ class ListCastKernel : public UnaryKernel {
result->buffers = in_data.buffers;
Datum casted_child;
- RETURN_NOT_OK(child_caster_->Call(ctx, Datum(in_data.child_data[0]), &casted_child));
+ RETURN_NOT_OK(InvokeWithAllocation(ctx, child_caster_.get(), in_data.child_data[0],
+ &casted_child));
+ DCHECK_EQ(Datum::ARRAY, casted_child.kind());
result->child_data.push_back(casted_child.array());
-
- RETURN_IF_ERROR(ctx);
return Status::OK();
}
private:
std::unique_ptr<UnaryKernel> child_caster_;
- std::shared_ptr<DataType> out_type_;
};
// ----------------------------------------------------------------------
@@ -1038,9 +993,8 @@ struct CastFunctor<TimestampType, StringType> {
//
template <typename I>
-struct CastFunctor<
- StringType, I,
- typename std::enable_if<is_binary_to_string<StringType, I>::value>::type> {
+struct CastFunctor<StringType, I,
+ typename std::enable_if<std::is_same<BinaryType, I>::value>::type> {
void operator()(FunctionContext* ctx, const CastOptions& options,
const ArrayData& input, ArrayData* output) {
BinaryArray binary(input.Copy());
@@ -1085,111 +1039,46 @@ typedef std::function<void(FunctionContext*, const CastOptions& options, const A
ArrayData*)>
CastFunction;
-static Status AllocateIfNotPreallocated(FunctionContext* ctx, const ArrayData& input,
- bool can_pre_allocate_values, ArrayData* out) {
- const int64_t length = input.length;
- out->null_count = input.null_count;
-
- // Propagate bitmap unless we are null type
- std::shared_ptr<Buffer> validity_bitmap = input.buffers[0];
- if (input.type->id() == Type::NA) {
- int64_t bitmap_size = BitUtil::BytesForBits(length);
- RETURN_NOT_OK(ctx->Allocate(bitmap_size, &validity_bitmap));
- memset(validity_bitmap->mutable_data(), 0, bitmap_size);
- } else if (input.offset != 0) {
- RETURN_NOT_OK(CopyBitmap(ctx->memory_pool(), validity_bitmap->data(), input.offset,
- length, &validity_bitmap));
- }
+class IdentityCast : public CastKernelBase {
+ public:
+ using CastKernelBase::CastKernelBase;
- if (out->buffers.size() == 2) {
- // Assuming preallocated, propagage bitmap and move on
- out->buffers[0] = validity_bitmap;
+ Status Call(FunctionContext* ctx, const Datum& input, Datum* out) override {
+ DCHECK_EQ(input.kind(), Datum::ARRAY);
+ out->value = input.array()->Copy();
return Status::OK();
- } else {
- DCHECK_EQ(0, out->buffers.size());
}
+};
- out->buffers.push_back(validity_bitmap);
-
- if (can_pre_allocate_values) {
- std::shared_ptr<Buffer> out_data;
-
- const Type::type type_id = out->type->id();
-
- if (!(is_primitive(type_id) || type_id == Type::FIXED_SIZE_BINARY ||
- type_id == Type::DECIMAL)) {
- return Status::NotImplemented("Cannot pre-allocate memory for type: ",
- out->type->ToString());
- }
-
- if (type_id != Type::NA) {
- const auto& fw_type = checked_cast<const FixedWidthType&>(*out->type);
-
- int bit_width = fw_type.bit_width();
- int64_t buffer_size = 0;
-
- if (bit_width == 1) {
- buffer_size = BitUtil::BytesForBits(length);
- } else if (bit_width % 8 == 0) {
- buffer_size = length * fw_type.bit_width() / 8;
- } else {
- DCHECK(false);
- }
-
- RETURN_NOT_OK(ctx->Allocate(buffer_size, &out_data));
- memset(out_data->mutable_data(), 0, buffer_size);
-
- out->buffers.push_back(out_data);
- }
- }
-
- return Status::OK();
-}
-
-class IdentityCast : public UnaryKernel {
+class ZeroCopyCast : public CastKernelBase {
public:
- IdentityCast() {}
+ using CastKernelBase::CastKernelBase;
Status Call(FunctionContext* ctx, const Datum& input, Datum* out) override {
DCHECK_EQ(input.kind(), Datum::ARRAY);
- out->value = input.array()->Copy();
+ auto result = input.array()->Copy();
+ result->type = out_type_;
+ out->value = result;
return Status::OK();
}
};
-class CastKernel : public UnaryKernel {
+class CastKernel : public CastKernelBase {
public:
- CastKernel(const CastOptions& options, const CastFunction& func, bool is_zero_copy,
- bool can_pre_allocate_values, const std::shared_ptr<DataType>& out_type)
- : options_(options),
- func_(func),
- is_zero_copy_(is_zero_copy),
- can_pre_allocate_values_(can_pre_allocate_values),
- out_type_(out_type) {}
+ CastKernel(const CastOptions& options, const CastFunction& func,
+ std::shared_ptr<DataType> out_type)
+ : CastKernelBase(std::move(out_type)), options_(options), func_(func) {}
Status Call(FunctionContext* ctx, const Datum& input, Datum* out) override {
- if (input.kind() != Datum::ARRAY)
- return Status::NotImplemented("CastKernel only supports Datum::ARRAY input");
+ DCHECK_EQ(input.kind(), Datum::ARRAY);
+ DCHECK_EQ(out->kind(), Datum::ARRAY);
const ArrayData& in_data = *input.array();
+ ArrayData* result = out->array().get();
- switch (out->kind()) {
- case Datum::NONE:
- out->value = ArrayData::Make(out_type_, in_data.length);
- break;
- case Datum::ARRAY:
- break;
- default:
- return Status::NotImplemented("CastKernel only supports Datum::ARRAY output");
- }
+ RETURN_NOT_OK(detail::PropagateNulls(ctx, in_data, result));
- ArrayData* result = out->array().get();
- if (!is_zero_copy_) {
- RETURN_NOT_OK(
- AllocateIfNotPreallocated(ctx, in_data, can_pre_allocate_values_, result));
- }
func_(ctx, options_, in_data, result);
-
RETURN_IF_ERROR(ctx);
return Status::OK();
}
@@ -1197,18 +1086,10 @@ class CastKernel : public UnaryKernel {
private:
CastOptions options_;
CastFunction func_;
- bool is_zero_copy_;
- bool can_pre_allocate_values_;
- std::shared_ptr<DataType> out_type_;
};
-// TODO(wesm): ARROW-4110 Do not generate cases that could return IdentityCast
-
#define CAST_CASE(InType, OutType) \
case OutType::type_id: \
- is_zero_copy = is_zero_copy_cast<OutType, InType>::value; \
- can_pre_allocate_values = \
- !(!is_binary_like(InType::type_id) && is_binary_like(OutType::type_id)); \
func = [](FunctionContext* ctx, const CastOptions& options, const ArrayData& input, \
ArrayData* out) { \
CastFunctor<OutType, InType> func; \
@@ -1216,130 +1097,36 @@ class CastKernel : public UnaryKernel {
}; \
break;
-#define NUMERIC_CASES(FN, IN_TYPE) \
- FN(IN_TYPE, BooleanType); \
- FN(IN_TYPE, UInt8Type); \
- FN(IN_TYPE, Int8Type); \
- FN(IN_TYPE, UInt16Type); \
- FN(IN_TYPE, Int16Type); \
- FN(IN_TYPE, UInt32Type); \
- FN(IN_TYPE, Int32Type); \
- FN(IN_TYPE, UInt64Type); \
- FN(IN_TYPE, Int64Type); \
- FN(IN_TYPE, FloatType); \
- FN(IN_TYPE, DoubleType);
-
-#define NULL_CASES(FN, IN_TYPE) \
- NUMERIC_CASES(FN, IN_TYPE) \
- FN(NullType, Time32Type); \
- FN(NullType, Date32Type); \
- FN(NullType, TimestampType); \
- FN(NullType, Time64Type); \
- FN(NullType, Date64Type);
-
-#define INT32_CASES(FN, IN_TYPE) \
- NUMERIC_CASES(FN, IN_TYPE) \
- FN(Int32Type, Time32Type); \
- FN(Int32Type, Date32Type);
-
-#define INT64_CASES(FN, IN_TYPE) \
- NUMERIC_CASES(FN, IN_TYPE) \
- FN(Int64Type, TimestampType); \
- FN(Int64Type, Time64Type); \
- FN(Int64Type, Date64Type);
-
-#define DATE32_CASES(FN, IN_TYPE) \
- FN(Date32Type, Date64Type); \
- FN(Date32Type, Int32Type);
-
-#define DATE64_CASES(FN, IN_TYPE) \
- FN(Date64Type, Date32Type); \
- FN(Date64Type, Int64Type);
-
-#define TIME32_CASES(FN, IN_TYPE) \
- FN(Time32Type, Time32Type); \
- FN(Time32Type, Time64Type); \
- FN(Time32Type, Int32Type);
-
-#define TIME64_CASES(FN, IN_TYPE) \
- FN(Time64Type, Time32Type); \
- FN(Time64Type, Time64Type); \
- FN(Time64Type, Int64Type);
-
-#define TIMESTAMP_CASES(FN, IN_TYPE) \
- FN(TimestampType, TimestampType); \
- FN(TimestampType, Date32Type); \
- FN(TimestampType, Date64Type); \
- FN(TimestampType, Int64Type);
-
-#define BINARY_CASES(FN, IN_TYPE) FN(BinaryType, StringType);
-
-#define STRING_CASES(FN, IN_TYPE) \
- FN(StringType, BooleanType); \
- FN(StringType, UInt8Type); \
- FN(StringType, Int8Type); \
- FN(StringType, UInt16Type); \
- FN(StringType, Int16Type); \
- FN(StringType, UInt32Type); \
- FN(StringType, Int32Type); \
- FN(StringType, UInt64Type); \
- FN(StringType, Int64Type); \
- FN(StringType, FloatType); \
- FN(StringType, DoubleType); \
- FN(StringType, TimestampType);
-
-#define DICTIONARY_CASES(FN, IN_TYPE) \
- FN(IN_TYPE, NullType); \
- FN(IN_TYPE, Time32Type); \
- FN(IN_TYPE, Date32Type); \
- FN(IN_TYPE, TimestampType); \
- FN(IN_TYPE, Time64Type); \
- FN(IN_TYPE, Date64Type); \
- FN(IN_TYPE, UInt8Type); \
- FN(IN_TYPE, Int8Type); \
- FN(IN_TYPE, UInt16Type); \
- FN(IN_TYPE, Int16Type); \
- FN(IN_TYPE, UInt32Type); \
- FN(IN_TYPE, Int32Type); \
- FN(IN_TYPE, UInt64Type); \
- FN(IN_TYPE, Int64Type); \
- FN(IN_TYPE, FloatType); \
- FN(IN_TYPE, DoubleType); \
- FN(IN_TYPE, FixedSizeBinaryType); \
- FN(IN_TYPE, Decimal128Type); \
- FN(IN_TYPE, BinaryType); \
- FN(IN_TYPE, StringType);
-
-#define GET_CAST_FUNCTION(CASE_GENERATOR, InType) \
- static std::unique_ptr<UnaryKernel> Get##InType##CastFunc( \
- const std::shared_ptr<DataType>& out_type, const CastOptions& options) { \
- CastFunction func; \
- bool is_zero_copy = false; \
- bool can_pre_allocate_values = true; \
- switch (out_type->id()) { \
- CASE_GENERATOR(CAST_CASE, InType); \
- default: \
- break; \
- } \
- if (func != nullptr) { \
- return std::unique_ptr<UnaryKernel>(new CastKernel( \
- options, func, is_zero_copy, can_pre_allocate_values, out_type)); \
- } \
- return nullptr; \
+#define GET_CAST_FUNCTION(CASE_GENERATOR, InType) \
+ static std::unique_ptr<UnaryKernel> Get##InType##CastFunc( \
+ std::shared_ptr<DataType> out_type, const CastOptions& options) { \
+ CastFunction func; \
+ switch (out_type->id()) { \
+ CASE_GENERATOR(CAST_CASE); \
+ default: \
+ break; \
+ } \
+ if (func != nullptr) { \
+ return std::unique_ptr<UnaryKernel>( \
+ new CastKernel(options, func, std::move(out_type))); \
+ } \
+ return nullptr; \
}
+#include "generated/cast-codegen-internal.h" // NOLINT
+
GET_CAST_FUNCTION(NULL_CASES, NullType)
-GET_CAST_FUNCTION(NUMERIC_CASES, BooleanType)
-GET_CAST_FUNCTION(NUMERIC_CASES, UInt8Type)
-GET_CAST_FUNCTION(NUMERIC_CASES, Int8Type)
-GET_CAST_FUNCTION(NUMERIC_CASES, UInt16Type)
-GET_CAST_FUNCTION(NUMERIC_CASES, Int16Type)
-GET_CAST_FUNCTION(NUMERIC_CASES, UInt32Type)
+GET_CAST_FUNCTION(BOOLEAN_CASES, BooleanType)
+GET_CAST_FUNCTION(UINT8_CASES, UInt8Type)
+GET_CAST_FUNCTION(INT8_CASES, Int8Type)
+GET_CAST_FUNCTION(UINT16_CASES, UInt16Type)
+GET_CAST_FUNCTION(INT16_CASES, Int16Type)
+GET_CAST_FUNCTION(UINT32_CASES, UInt32Type)
GET_CAST_FUNCTION(INT32_CASES, Int32Type)
-GET_CAST_FUNCTION(NUMERIC_CASES, UInt64Type)
+GET_CAST_FUNCTION(UINT64_CASES, UInt64Type)
GET_CAST_FUNCTION(INT64_CASES, Int64Type)
-GET_CAST_FUNCTION(NUMERIC_CASES, FloatType)
-GET_CAST_FUNCTION(NUMERIC_CASES, DoubleType)
+GET_CAST_FUNCTION(FLOAT_CASES, FloatType)
+GET_CAST_FUNCTION(DOUBLE_CASES, DoubleType)
GET_CAST_FUNCTION(DATE32_CASES, Date32Type)
GET_CAST_FUNCTION(DATE64_CASES, Date64Type)
GET_CAST_FUNCTION(TIME32_CASES, Time32Type)
@@ -1356,7 +1143,7 @@ GET_CAST_FUNCTION(DICTIONARY_CASES, DictionaryType)
namespace {
-Status GetListCastFunc(const DataType& in_type, const std::shared_ptr<DataType>& out_type,
+Status GetListCastFunc(const DataType& in_type, std::shared_ptr<DataType> out_type,
const CastOptions& options, std::unique_ptr<UnaryKernel>* kernel) {
if (out_type->id() != Type::LIST) {
// Kernel will be null
@@ -1367,17 +1154,42 @@ Status GetListCastFunc(const DataType& in_type, const std::shared_ptr<DataType>&
checked_cast<const ListType&>(*out_type).value_type();
std::unique_ptr<UnaryKernel> child_caster;
RETURN_NOT_OK(GetCastFunction(in_value_type, out_value_type, options, &child_caster));
- *kernel =
- std::unique_ptr<UnaryKernel>(new ListCastKernel(std::move(child_caster), out_type));
+ *kernel = std::unique_ptr<UnaryKernel>(
+ new ListCastKernel(std::move(child_caster), std::move(out_type)));
return Status::OK();
}
} // namespace
-Status GetCastFunction(const DataType& in_type, const std::shared_ptr<DataType>& out_type,
+inline bool IsZeroCopyCast(Type::type in_type, Type::type out_type) {
+ switch (in_type) {
+ case Type::INT32:
+ return (out_type == Type::DATE32) || (out_type == Type::TIME32);
+ case Type::INT64:
+ return ((out_type == Type::DATE64) || (out_type == Type::TIME64) ||
+ (out_type == Type::TIMESTAMP));
+ case Type::DATE32:
+ case Type::TIME32:
+ return out_type == Type::INT32;
+ case Type::DATE64:
+ case Type::TIME64:
+ case Type::TIMESTAMP:
+ return out_type == Type::INT64;
+ default:
+ break;
+ }
+ return false;
+}
+
+Status GetCastFunction(const DataType& in_type, std::shared_ptr<DataType> out_type,
const CastOptions& options, std::unique_ptr<UnaryKernel>* kernel) {
if (in_type.Equals(out_type)) {
- *kernel = std::unique_ptr<UnaryKernel>(new IdentityCast);
+ *kernel = std::unique_ptr<UnaryKernel>(new IdentityCast(std::move(out_type)));
+ return Status::OK();
+ }
+
+ if (IsZeroCopyCast(in_type.id(), out_type->id())) {
+ *kernel = std::unique_ptr<UnaryKernel>(new ZeroCopyCast(std::move(out_type)));
return Status::OK();
}
@@ -1403,7 +1215,7 @@ Status GetCastFunction(const DataType& in_type, const std::shared_ptr<DataType>&
CAST_FUNCTION_CASE(StringType);
CAST_FUNCTION_CASE(DictionaryType);
case Type::LIST:
- RETURN_NOT_OK(GetListCastFunc(in_type, out_type, options, kernel));
+ RETURN_NOT_OK(GetListCastFunc(in_type, std::move(out_type), options, kernel));
break;
default:
break;
@@ -1415,25 +1227,20 @@ Status GetCastFunction(const DataType& in_type, const std::shared_ptr<DataType>&
return Status::OK();
}
-Status Cast(FunctionContext* ctx, const Datum& value,
- const std::shared_ptr<DataType>& out_type, const CastOptions& options,
- Datum* out) {
+Status Cast(FunctionContext* ctx, const Datum& value, std::shared_ptr<DataType> out_type,
+ const CastOptions& options, Datum* out) {
+ const DataType& in_type = *value.type();
+
// Dynamic dispatch to obtain right cast function
std::unique_ptr<UnaryKernel> func;
- RETURN_NOT_OK(GetCastFunction(*value.type(), out_type, options, &func));
-
- std::vector<Datum> result;
- RETURN_NOT_OK(detail::InvokeUnaryArrayKernel(ctx, func.get(), value, &result));
-
- *out = detail::WrapDatumsLike(value, result);
- return Status::OK();
+ RETURN_NOT_OK(GetCastFunction(in_type, std::move(out_type), options, &func));
+ return InvokeWithAllocation(ctx, func.get(), value, out);
}
-Status Cast(FunctionContext* ctx, const Array& array,
- const std::shared_ptr<DataType>& out_type, const CastOptions& options,
- std::shared_ptr<Array>* out) {
+Status Cast(FunctionContext* ctx, const Array& array, std::shared_ptr<DataType> out_type,
+ const CastOptions& options, std::shared_ptr<Array>* out) {
Datum datum_out;
- RETURN_NOT_OK(Cast(ctx, Datum(array.data()), out_type, options, &datum_out));
+ RETURN_NOT_OK(Cast(ctx, Datum(array.data()), std::move(out_type), options, &datum_out));
DCHECK_EQ(Datum::ARRAY, datum_out.kind());
*out = MakeArray(datum_out.array());
return Status::OK();
diff --git a/cpp/src/arrow/compute/kernels/cast.h b/cpp/src/arrow/compute/kernels/cast.h
index 8c42f07..5a7c5be 100644
--- a/cpp/src/arrow/compute/kernels/cast.h
+++ b/cpp/src/arrow/compute/kernels/cast.h
@@ -62,7 +62,7 @@ struct ARROW_EXPORT CastOptions {
/// \since 0.7.0
/// \note API not yet finalized
ARROW_EXPORT
-Status GetCastFunction(const DataType& in_type, const std::shared_ptr<DataType>& to_type,
+Status GetCastFunction(const DataType& in_type, std::shared_ptr<DataType> to_type,
const CastOptions& options, std::unique_ptr<UnaryKernel>* kernel);
/// \brief Cast from one array type to another
@@ -76,7 +76,7 @@ Status GetCastFunction(const DataType& in_type, const std::shared_ptr<DataType>&
/// \note API not yet finalized
ARROW_EXPORT
Status Cast(FunctionContext* context, const Array& value,
- const std::shared_ptr<DataType>& to_type, const CastOptions& options,
+ std::shared_ptr<DataType> to_type, const CastOptions& options,
std::shared_ptr<Array>* out);
/// \brief Cast from one value to another
@@ -90,8 +90,7 @@ Status Cast(FunctionContext* context, const Array& value,
/// \note API not yet finalized
ARROW_EXPORT
Status Cast(FunctionContext* context, const Datum& value,
- const std::shared_ptr<DataType>& to_type, const CastOptions& options,
- Datum* out);
+ std::shared_ptr<DataType> to_type, const CastOptions& options, Datum* out);
} // namespace compute
} // namespace arrow
diff --git a/cpp/src/arrow/compute/kernels/generated/cast-codegen-internal.h b/cpp/src/arrow/compute/kernels/generated/cast-codegen-internal.h
new file mode 100644
index 0000000..cf2c036
--- /dev/null
+++ b/cpp/src/arrow/compute/kernels/generated/cast-codegen-internal.h
@@ -0,0 +1,226 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT
+// Generated by codegen.py script
+#define NULL_CASES(TEMPLATE) \
+ TEMPLATE(NullType, BooleanType) \
+ TEMPLATE(NullType, UInt8Type) \
+ TEMPLATE(NullType, Int8Type) \
+ TEMPLATE(NullType, UInt16Type) \
+ TEMPLATE(NullType, Int16Type) \
+ TEMPLATE(NullType, UInt32Type) \
+ TEMPLATE(NullType, Int32Type) \
+ TEMPLATE(NullType, UInt64Type) \
+ TEMPLATE(NullType, Int64Type) \
+ TEMPLATE(NullType, FloatType) \
+ TEMPLATE(NullType, DoubleType) \
+ TEMPLATE(NullType, Date32Type) \
+ TEMPLATE(NullType, Date64Type) \
+ TEMPLATE(NullType, Time32Type) \
+ TEMPLATE(NullType, Time64Type) \
+ TEMPLATE(NullType, TimestampType)
+
+#define BOOLEAN_CASES(TEMPLATE) \
+ TEMPLATE(BooleanType, UInt8Type) \
+ TEMPLATE(BooleanType, Int8Type) \
+ TEMPLATE(BooleanType, UInt16Type) \
+ TEMPLATE(BooleanType, Int16Type) \
+ TEMPLATE(BooleanType, UInt32Type) \
+ TEMPLATE(BooleanType, Int32Type) \
+ TEMPLATE(BooleanType, UInt64Type) \
+ TEMPLATE(BooleanType, Int64Type) \
+ TEMPLATE(BooleanType, FloatType) \
+ TEMPLATE(BooleanType, DoubleType)
+
+#define UINT8_CASES(TEMPLATE) \
+ TEMPLATE(UInt8Type, BooleanType) \
+ TEMPLATE(UInt8Type, Int8Type) \
+ TEMPLATE(UInt8Type, UInt16Type) \
+ TEMPLATE(UInt8Type, Int16Type) \
+ TEMPLATE(UInt8Type, UInt32Type) \
+ TEMPLATE(UInt8Type, Int32Type) \
+ TEMPLATE(UInt8Type, UInt64Type) \
+ TEMPLATE(UInt8Type, Int64Type) \
+ TEMPLATE(UInt8Type, FloatType) \
+ TEMPLATE(UInt8Type, DoubleType)
+
+#define INT8_CASES(TEMPLATE) \
+ TEMPLATE(Int8Type, BooleanType) \
+ TEMPLATE(Int8Type, UInt8Type) \
+ TEMPLATE(Int8Type, UInt16Type) \
+ TEMPLATE(Int8Type, Int16Type) \
+ TEMPLATE(Int8Type, UInt32Type) \
+ TEMPLATE(Int8Type, Int32Type) \
+ TEMPLATE(Int8Type, UInt64Type) \
+ TEMPLATE(Int8Type, Int64Type) \
+ TEMPLATE(Int8Type, FloatType) \
+ TEMPLATE(Int8Type, DoubleType)
+
+#define UINT16_CASES(TEMPLATE) \
+ TEMPLATE(UInt16Type, BooleanType) \
+ TEMPLATE(UInt16Type, UInt8Type) \
+ TEMPLATE(UInt16Type, Int8Type) \
+ TEMPLATE(UInt16Type, Int16Type) \
+ TEMPLATE(UInt16Type, UInt32Type) \
+ TEMPLATE(UInt16Type, Int32Type) \
+ TEMPLATE(UInt16Type, UInt64Type) \
+ TEMPLATE(UInt16Type, Int64Type) \
+ TEMPLATE(UInt16Type, FloatType) \
+ TEMPLATE(UInt16Type, DoubleType)
+
+#define INT16_CASES(TEMPLATE) \
+ TEMPLATE(Int16Type, BooleanType) \
+ TEMPLATE(Int16Type, UInt8Type) \
+ TEMPLATE(Int16Type, Int8Type) \
+ TEMPLATE(Int16Type, UInt16Type) \
+ TEMPLATE(Int16Type, UInt32Type) \
+ TEMPLATE(Int16Type, Int32Type) \
+ TEMPLATE(Int16Type, UInt64Type) \
+ TEMPLATE(Int16Type, Int64Type) \
+ TEMPLATE(Int16Type, FloatType) \
+ TEMPLATE(Int16Type, DoubleType)
+
+#define UINT32_CASES(TEMPLATE) \
+ TEMPLATE(UInt32Type, BooleanType) \
+ TEMPLATE(UInt32Type, UInt8Type) \
+ TEMPLATE(UInt32Type, Int8Type) \
+ TEMPLATE(UInt32Type, UInt16Type) \
+ TEMPLATE(UInt32Type, Int16Type) \
+ TEMPLATE(UInt32Type, Int32Type) \
+ TEMPLATE(UInt32Type, UInt64Type) \
+ TEMPLATE(UInt32Type, Int64Type) \
+ TEMPLATE(UInt32Type, FloatType) \
+ TEMPLATE(UInt32Type, DoubleType)
+
+#define UINT64_CASES(TEMPLATE) \
+ TEMPLATE(UInt64Type, BooleanType) \
+ TEMPLATE(UInt64Type, UInt8Type) \
+ TEMPLATE(UInt64Type, Int8Type) \
+ TEMPLATE(UInt64Type, UInt16Type) \
+ TEMPLATE(UInt64Type, Int16Type) \
+ TEMPLATE(UInt64Type, UInt32Type) \
+ TEMPLATE(UInt64Type, Int32Type) \
+ TEMPLATE(UInt64Type, Int64Type) \
+ TEMPLATE(UInt64Type, FloatType) \
+ TEMPLATE(UInt64Type, DoubleType)
+
+#define INT32_CASES(TEMPLATE) \
+ TEMPLATE(Int32Type, BooleanType) \
+ TEMPLATE(Int32Type, UInt8Type) \
+ TEMPLATE(Int32Type, Int8Type) \
+ TEMPLATE(Int32Type, UInt16Type) \
+ TEMPLATE(Int32Type, Int16Type) \
+ TEMPLATE(Int32Type, UInt32Type) \
+ TEMPLATE(Int32Type, UInt64Type) \
+ TEMPLATE(Int32Type, Int64Type) \
+ TEMPLATE(Int32Type, FloatType) \
+ TEMPLATE(Int32Type, DoubleType)
+
+#define INT64_CASES(TEMPLATE) \
+ TEMPLATE(Int64Type, BooleanType) \
+ TEMPLATE(Int64Type, UInt8Type) \
+ TEMPLATE(Int64Type, Int8Type) \
+ TEMPLATE(Int64Type, UInt16Type) \
+ TEMPLATE(Int64Type, Int16Type) \
+ TEMPLATE(Int64Type, UInt32Type) \
+ TEMPLATE(Int64Type, Int32Type) \
+ TEMPLATE(Int64Type, UInt64Type) \
+ TEMPLATE(Int64Type, FloatType) \
+ TEMPLATE(Int64Type, DoubleType)
+
+#define FLOAT_CASES(TEMPLATE) \
+ TEMPLATE(FloatType, BooleanType) \
+ TEMPLATE(FloatType, UInt8Type) \
+ TEMPLATE(FloatType, Int8Type) \
+ TEMPLATE(FloatType, UInt16Type) \
+ TEMPLATE(FloatType, Int16Type) \
+ TEMPLATE(FloatType, UInt32Type) \
+ TEMPLATE(FloatType, Int32Type) \
+ TEMPLATE(FloatType, UInt64Type) \
+ TEMPLATE(FloatType, Int64Type) \
+ TEMPLATE(FloatType, DoubleType)
+
+#define DOUBLE_CASES(TEMPLATE) \
+ TEMPLATE(DoubleType, BooleanType) \
+ TEMPLATE(DoubleType, UInt8Type) \
+ TEMPLATE(DoubleType, Int8Type) \
+ TEMPLATE(DoubleType, UInt16Type) \
+ TEMPLATE(DoubleType, Int16Type) \
+ TEMPLATE(DoubleType, UInt32Type) \
+ TEMPLATE(DoubleType, Int32Type) \
+ TEMPLATE(DoubleType, UInt64Type) \
+ TEMPLATE(DoubleType, Int64Type) \
+ TEMPLATE(DoubleType, FloatType)
+
+#define DATE32_CASES(TEMPLATE) \
+ TEMPLATE(Date32Type, Date64Type)
+
+#define DATE64_CASES(TEMPLATE) \
+ TEMPLATE(Date64Type, Date32Type)
+
+#define TIME32_CASES(TEMPLATE) \
+ TEMPLATE(Time32Type, Time32Type) \
+ TEMPLATE(Time32Type, Time64Type)
+
+#define TIME64_CASES(TEMPLATE) \
+ TEMPLATE(Time64Type, Time32Type) \
+ TEMPLATE(Time64Type, Time64Type)
+
+#define TIMESTAMP_CASES(TEMPLATE) \
+ TEMPLATE(TimestampType, Date32Type) \
+ TEMPLATE(TimestampType, Date64Type) \
+ TEMPLATE(TimestampType, TimestampType)
+
+#define BINARY_CASES(TEMPLATE) \
+ TEMPLATE(BinaryType, StringType)
+
+#define STRING_CASES(TEMPLATE) \
+ TEMPLATE(StringType, BooleanType) \
+ TEMPLATE(StringType, UInt8Type) \
+ TEMPLATE(StringType, Int8Type) \
+ TEMPLATE(StringType, UInt16Type) \
+ TEMPLATE(StringType, Int16Type) \
+ TEMPLATE(StringType, UInt32Type) \
+ TEMPLATE(StringType, Int32Type) \
+ TEMPLATE(StringType, UInt64Type) \
+ TEMPLATE(StringType, Int64Type) \
+ TEMPLATE(StringType, FloatType) \
+ TEMPLATE(StringType, DoubleType) \
+ TEMPLATE(StringType, TimestampType)
+
+#define DICTIONARY_CASES(TEMPLATE) \
+ TEMPLATE(DictionaryType, UInt8Type) \
+ TEMPLATE(DictionaryType, Int8Type) \
+ TEMPLATE(DictionaryType, UInt16Type) \
+ TEMPLATE(DictionaryType, Int16Type) \
+ TEMPLATE(DictionaryType, UInt32Type) \
+ TEMPLATE(DictionaryType, Int32Type) \
+ TEMPLATE(DictionaryType, UInt64Type) \
+ TEMPLATE(DictionaryType, Int64Type) \
+ TEMPLATE(DictionaryType, FloatType) \
+ TEMPLATE(DictionaryType, DoubleType) \
+ TEMPLATE(DictionaryType, Date32Type) \
+ TEMPLATE(DictionaryType, Date64Type) \
+ TEMPLATE(DictionaryType, Time32Type) \
+ TEMPLATE(DictionaryType, Time64Type) \
+ TEMPLATE(DictionaryType, TimestampType) \
+ TEMPLATE(DictionaryType, NullType) \
+ TEMPLATE(DictionaryType, BinaryType) \
+ TEMPLATE(DictionaryType, FixedSizeBinaryType) \
+ TEMPLATE(DictionaryType, StringType) \
+ TEMPLATE(DictionaryType, Decimal128Type)
diff --git a/cpp/src/arrow/compute/kernels/generated/codegen.py b/cpp/src/arrow/compute/kernels/generated/codegen.py
new file mode 100644
index 0000000..397ba66
--- /dev/null
+++ b/cpp/src/arrow/compute/kernels/generated/codegen.py
@@ -0,0 +1,134 @@
+#!/usr/bin/env python
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# Generate boilerplate code for kernel instantiation and other tedious tasks
+
+
+import io
+
+
+INTEGER_TYPES = ['UInt8', 'Int8', 'UInt16', 'Int16',
+ 'UInt32', 'Int32', 'UInt64', 'Int64']
+FLOATING_TYPES = ['Float', 'Double']
+NUMERIC_TYPES = ['Boolean'] + INTEGER_TYPES + FLOATING_TYPES
+
+DATE_TIME_TYPES = ['Date32', 'Date64', 'Time32', 'Time64', 'Timestamp']
+
+
+def _format_type(name):
+ return name + "Type"
+
+
+class CastCodeGenerator(object):
+
+ def __init__(self, type_name, out_types, parametric=False,
+ exclusions=None):
+ self.type_name = type_name
+ self.out_types = out_types
+ self.parametric = parametric
+ self.exclusions = exclusions
+
+ def generate(self):
+ buf = io.StringIO()
+ print("#define {0}_CASES(TEMPLATE) \\"
+ .format(self.type_name.upper()), file=buf)
+
+ this_type = _format_type(self.type_name)
+
+ templates = []
+ for out_type in self.out_types:
+ if not self.parametric and out_type == self.type_name:
+ # Parametric types need T -> T cast generated
+ continue
+ templates.append(" TEMPLATE({0}, {1})"
+ .format(this_type, _format_type(out_type)))
+
+ print(" \\\n".join(templates), file=buf)
+ return buf.getvalue()
+
+
+CAST_GENERATORS = [
+ CastCodeGenerator('Null', NUMERIC_TYPES + DATE_TIME_TYPES),
+ CastCodeGenerator('Boolean', NUMERIC_TYPES),
+ CastCodeGenerator('UInt8', NUMERIC_TYPES),
+ CastCodeGenerator('Int8', NUMERIC_TYPES),
+ CastCodeGenerator('UInt16', NUMERIC_TYPES),
+ CastCodeGenerator('Int16', NUMERIC_TYPES),
+ CastCodeGenerator('UInt32', NUMERIC_TYPES),
+ CastCodeGenerator('UInt64', NUMERIC_TYPES),
+ CastCodeGenerator('Int32', NUMERIC_TYPES),
+ CastCodeGenerator('Int64', NUMERIC_TYPES),
+ CastCodeGenerator('Float', NUMERIC_TYPES),
+ CastCodeGenerator('Double', NUMERIC_TYPES),
+ CastCodeGenerator('Date32', ['Date64']),
+ CastCodeGenerator('Date64', ['Date32']),
+ CastCodeGenerator('Time32', ['Time32', 'Time64'],
+ parametric=True),
+ CastCodeGenerator('Time64', ['Time32', 'Time64'],
+ parametric=True),
+ CastCodeGenerator('Timestamp', ['Date32', 'Date64', 'Timestamp'],
+ parametric=True),
+ CastCodeGenerator('Binary', ['String']),
+ CastCodeGenerator('String', NUMERIC_TYPES + ['Timestamp']),
+ CastCodeGenerator('Dictionary',
+ INTEGER_TYPES + FLOATING_TYPES + DATE_TIME_TYPES +
+ ['Null', 'Binary', 'FixedSizeBinary', 'String',
+ 'Decimal128'])
+]
+
+
+def generate_cast_code():
+ blocks = [generator.generate() for generator in CAST_GENERATORS]
+ return '\n'.join(blocks)
+
+
+def write_file_with_preamble(path, code):
+ preamble = """// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT
+// Generated by codegen.py script
+"""
+
+ with open(path, 'wb') as f:
+ f.write(preamble.encode('utf-8'))
+ f.write(code.encode('utf-8'))
+
+
+def write_files():
+ cast_code = generate_cast_code()
+ write_file_with_preamble('cast-codegen-internal.h', cast_code)
+
+
+if __name__ == '__main__':
+ write_files()
diff --git a/cpp/src/arrow/compute/kernels/hash.cc b/cpp/src/arrow/compute/kernels/hash.cc
index 0513fe1..f443282 100644
--- a/cpp/src/arrow/compute/kernels/hash.cc
+++ b/cpp/src/arrow/compute/kernels/hash.cc
@@ -64,9 +64,19 @@ namespace {
// ----------------------------------------------------------------------
// Unique implementation
-class UniqueAction {
+class ActionBase {
public:
- UniqueAction(const std::shared_ptr<DataType>& type, MemoryPool* pool) {}
+ ActionBase(const std::shared_ptr<DataType>& type, MemoryPool* pool)
+ : type_(type), pool_(pool) {}
+
+ protected:
+ std::shared_ptr<DataType> type_;
+ MemoryPool* pool_;
+};
+
+class UniqueAction : public ActionBase {
+ public:
+ using ActionBase::ActionBase;
Status Reset() { return Status::OK(); }
@@ -81,15 +91,17 @@ class UniqueAction {
void ObserveNotFound(Index index) {}
Status Flush(Datum* out) { return Status::OK(); }
+
+ std::shared_ptr<DataType> out_type() const { return type_; }
};
// ----------------------------------------------------------------------
// Dictionary encode implementation
-class DictEncodeAction {
+class DictEncodeAction : public ActionBase {
public:
DictEncodeAction(const std::shared_ptr<DataType>& type, MemoryPool* pool)
- : indices_builder_(pool) {}
+ : ActionBase(type, pool), indices_builder_(pool) {}
Status Reset() {
indices_builder_.Reset();
@@ -117,6 +129,8 @@ class DictEncodeAction {
return Status::OK();
}
+ std::shared_ptr<DataType> out_type() const { return int32(); }
+
private:
Int32Builder indices_builder_;
};
@@ -184,6 +198,8 @@ class RegularHashKernelImpl : public HashKernelImpl {
return Status::OK();
}
+ std::shared_ptr<DataType> out_type() const override { return action_.out_type(); }
+
protected:
using MemoTable = typename HashTraits<Type>::MemoTableType;
@@ -221,6 +237,8 @@ class NullHashKernelImpl : public HashKernelImpl {
return Status::OK();
}
+ std::shared_ptr<DataType> out_type() const override { return null(); }
+
protected:
MemoryPool* pool_;
std::shared_ptr<DataType> type_;
diff --git a/cpp/src/arrow/compute/kernels/sum.cc b/cpp/src/arrow/compute/kernels/sum.cc
index 007412a..a1487c1 100644
--- a/cpp/src/arrow/compute/kernels/sum.cc
+++ b/cpp/src/arrow/compute/kernels/sum.cc
@@ -86,6 +86,10 @@ class SumAggregateFunction final : public AggregateFunctionStaticState<StateType
return Status::OK();
}
+ std::shared_ptr<DataType> out_type() const override {
+ return TypeTraits<typename FindAccumulatorType<ArrowType>::Type>::type_singleton();
+ }
+
private:
StateType ConsumeDense(const ArrayType& array) const {
StateType local;
diff --git a/cpp/src/arrow/compute/kernels/util-internal.cc b/cpp/src/arrow/compute/kernels/util-internal.cc
index 745b30c..60b668d 100644
--- a/cpp/src/arrow/compute/kernels/util-internal.cc
+++ b/cpp/src/arrow/compute/kernels/util-internal.cc
@@ -26,27 +26,33 @@
#include "arrow/array.h"
#include "arrow/status.h"
#include "arrow/table.h"
+#include "arrow/util/checked_cast.h"
#include "arrow/util/logging.h"
#include "arrow/compute/context.h"
#include "arrow/compute/kernel.h"
namespace arrow {
+
+using internal::checked_cast;
+
namespace compute {
namespace detail {
Status InvokeUnaryArrayKernel(FunctionContext* ctx, UnaryKernel* kernel,
const Datum& value, std::vector<Datum>* outputs) {
if (value.kind() == Datum::ARRAY) {
- Datum output;
- RETURN_NOT_OK(kernel->Call(ctx, value, &output));
- outputs->push_back(output);
+ Datum out;
+ out.value = ArrayData::Make(kernel->out_type(), value.array()->length);
+ RETURN_NOT_OK(kernel->Call(ctx, value, &out));
+ outputs->push_back(out);
} else if (value.kind() == Datum::CHUNKED_ARRAY) {
const ChunkedArray& array = *value.chunked_array();
for (int i = 0; i < array.num_chunks(); i++) {
- Datum output;
- RETURN_NOT_OK(kernel->Call(ctx, Datum(array.chunk(i)), &output));
- outputs->push_back(output);
+ Datum out;
+ out.value = ArrayData::Make(kernel->out_type(), array.chunk(i)->length());
+ RETURN_NOT_OK(kernel->Call(ctx, array.chunk(i), &out));
+ outputs->push_back(out);
}
} else {
return Status::Invalid("Input Datum was not array-like");
@@ -165,46 +171,96 @@ Datum WrapDatumsLike(const Datum& value, const std::vector<Datum>& datums) {
}
PrimitiveAllocatingUnaryKernel::PrimitiveAllocatingUnaryKernel(
- std::unique_ptr<UnaryKernel> delegate)
- : delegate_(std::move(delegate)) {}
+ UnaryKernel* delegate, const std::shared_ptr<DataType>& out_type)
+ : delegate_(delegate), out_type_(out_type) {}
+
+PrimitiveAllocatingUnaryKernel::PrimitiveAllocatingUnaryKernel(
+ std::unique_ptr<UnaryKernel> delegate, const std::shared_ptr<DataType>& out_type)
+ : PrimitiveAllocatingUnaryKernel(delegate.get(), out_type) {
+ owned_delegate_ = std::move(delegate);
+}
inline void ZeroLastByte(Buffer* buffer) {
*(buffer->mutable_data() + (buffer->size() - 1)) = 0;
}
-Status PrimitiveAllocatingUnaryKernel::Call(FunctionContext* ctx, const Datum& input,
- Datum* out) {
- std::vector<std::shared_ptr<Buffer>> data_buffers;
- const ArrayData& in_data = *input.array();
- MemoryPool* pool = ctx->memory_pool();
+Status PropagateNulls(FunctionContext* ctx, const ArrayData& input, ArrayData* output) {
+ const int64_t length = input.length;
- // Handle the validity buffer.
- if (in_data.offset == 0 || in_data.null_count <= 0) {
- // Validity bitmap will be zero copied (or allocated when buffer is known).
- data_buffers.emplace_back();
- } else {
+ if (output->buffers.size() == 0) {
+ // Ensure we can assign a buffer
+ output->buffers.resize(1);
+ }
+
+ // Handle validity bitmap
+ output->null_count = input.GetNullCount();
+ if (input.offset != 0 && output->null_count > 0) {
+ DCHECK(input.buffers[0]);
+ const Buffer& validity_bitmap = *input.buffers[0];
std::shared_ptr<Buffer> buffer;
- RETURN_NOT_OK(AllocateBitmap(pool, in_data.length, &buffer));
+ RETURN_NOT_OK(ctx->Allocate(BitUtil::BytesForBits(length), &buffer));
// Per spec all trailing bits should indicate nullness, since
// the last byte might only be partially set, we ensure the
// remaining bit is set.
ZeroLastByte(buffer.get());
buffer->ZeroPadding();
- data_buffers.push_back(buffer);
+ internal::CopyBitmap(validity_bitmap.data(), input.offset, length,
+ buffer->mutable_data(), 0 /* destination offset */);
+ output->buffers[0] = std::move(buffer);
+ } else {
+ output->buffers[0] = input.buffers[0];
}
- // Allocate the boolean value buffer.
+ return Status::OK();
+}
+
+Status PrimitiveAllocatingUnaryKernel::Call(FunctionContext* ctx, const Datum& input,
+ Datum* out) {
+ std::vector<std::shared_ptr<Buffer>> data_buffers;
+ const ArrayData& in_data = *input.array();
+
+ DCHECK_EQ(out->kind(), Datum::ARRAY);
+
+ ArrayData* result = out->array().get();
+
+ result->buffers.resize(2);
+
+ const int64_t length = in_data.length;
+
+ // Allocate the value buffer
std::shared_ptr<Buffer> buffer;
- RETURN_NOT_OK(AllocateBitmap(pool, in_data.length, &buffer));
- // Some utility methods access the last byte before it might be
- // initialized this makes valgrind/asan unhappy, so we proactively
- // zero it.
- ZeroLastByte(buffer.get());
- data_buffers.push_back(buffer);
- out->value = ArrayData::Make(null(), in_data.length, data_buffers);
+ if (out_type_->id() != Type::NA) {
+ const auto& fw_type = checked_cast<const FixedWidthType&>(*out_type_);
+
+ int bit_width = fw_type.bit_width();
+ int64_t buffer_size = 0;
+ if (bit_width == 1) {
+ buffer_size = BitUtil::BytesForBits(length);
+ } else {
+ DCHECK_EQ(bit_width % 8, 0)
+ << "Only bit widths with multiple of 8 are currently supported";
+ buffer_size = length * fw_type.bit_width() / 8;
+ }
+ RETURN_NOT_OK(ctx->Allocate(buffer_size, &buffer));
+ buffer->ZeroPadding();
+
+ if (bit_width == 1 && buffer_size > 0) {
+ // Some utility methods access the last byte before it might be
+ // initialized this makes valgrind/asan unhappy, so we proactively
+ // zero it.
+ ZeroLastByte(buffer.get());
+ }
+
+ memset(buffer->mutable_data(), 0, buffer_size);
+ result->buffers[1] = std::move(buffer);
+ }
return delegate_->Call(ctx, input, out);
}
+std::shared_ptr<DataType> PrimitiveAllocatingUnaryKernel::out_type() const {
+ return delegate_->out_type();
+}
+
} // namespace detail
} // namespace compute
} // namespace arrow
diff --git a/cpp/src/arrow/compute/kernels/util-internal.h b/cpp/src/arrow/compute/kernels/util-internal.h
index 2252023..bd27280 100644
--- a/cpp/src/arrow/compute/kernels/util-internal.h
+++ b/cpp/src/arrow/compute/kernels/util-internal.h
@@ -62,6 +62,16 @@ ARROW_EXPORT
Status InvokeBinaryArrayKernel(FunctionContext* ctx, BinaryKernel* kernel,
const Datum& left, const Datum& right, Datum* output);
+/// \brief Assign validity bitmap to output, copying bitmap if necessary, but
+/// zero-copy otherwise, so that the same value slots are valid/not-null in the
+/// output
+/// (sliced arrays)
+/// \param[in] ctx the kernel FunctionContext
+/// \param[in] input the input array
+/// \param[out] output the output array
+ARROW_EXPORT
+Status PropagateNulls(FunctionContext* ctx, const ArrayData& input, ArrayData* output);
+
ARROW_EXPORT
Datum WrapArraysLike(const Datum& value,
const std::vector<std::shared_ptr<Array>>& arrays);
@@ -72,21 +82,20 @@ Datum WrapDatumsLike(const Datum& value, const std::vector<Datum>& datums);
/// \brief Kernel used to preallocate outputs for primitive types.
class PrimitiveAllocatingUnaryKernel : public UnaryKernel {
public:
- explicit PrimitiveAllocatingUnaryKernel(std::unique_ptr<UnaryKernel> delegate);
- /// \brief Sets out to be of type ArrayData with the necessary
- /// data buffers prepopulated.
- ///
- /// This method does not populate types on arrays and sets type to null.
- ///
- /// The current implementation only supports primitive boolean outputs and
- /// assumes validity bitmaps that are not sliced will be zero copied (i.e.
- /// no allocation happens for them).
- ///
- /// TODO(ARROW-1896): Make this generic enough to support casts.
+ PrimitiveAllocatingUnaryKernel(std::unique_ptr<UnaryKernel> delegate,
+ const std::shared_ptr<DataType>& out_type);
+ PrimitiveAllocatingUnaryKernel(UnaryKernel* delegate,
+ const std::shared_ptr<DataType>& out_type);
+ /// \brief Allocates ArrayData with the necessary data buffers allocated and
+ /// then written into by the delegate kernel
Status Call(FunctionContext* ctx, const Datum& input, Datum* out) override;
+ std::shared_ptr<DataType> out_type() const override;
+
private:
- std::unique_ptr<UnaryKernel> delegate_;
+ UnaryKernel* delegate_;
+ std::shared_ptr<DataType> out_type_;
+ std::unique_ptr<UnaryKernel> owned_delegate_;
};
} // namespace detail