You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by we...@apache.org on 2017/09/07 17:00:42 UTC
arrow git commit: ARROW-229: [C++] Implement cast functions for
numeric types, booleans
Repository: arrow
Updated Branches:
refs/heads/master a9a80fef7 -> b0b125fd7
ARROW-229: [C++] Implement cast functions for numeric types, booleans
Implements safe and unsafe casts amongst booleans and signed/unsigned integers, and single and double precision floating point numbers.
Currently there is only the option to check for integer overflows when casting from a larger integer to a smaller integer. This API should be regarded as experimental in 0.7.0. There are a number of follow up patches we'll want to do quickly after this (exposing this in Python, incorporating into Array.from_pandas)
Author: Wes McKinney <we...@twosigma.com>
Closes #1027 from wesm/ARROW-229 and squashes the following commits:
82fea97 [Wes McKinney] Fix MSVC warning
ead4a95 [Wes McKinney] Fix overflow check where overflow occurs in a null slot
dc7f8d9 [Wes McKinney] Some basic smoke tests to validate implemented casts
879653d [Wes McKinney] Start test suite for Cast
22308ba [Wes McKinney] Implement cast kernels for numbers. Add helper type traits
ca1c813 [Wes McKinney] Work on context
d05c274 [Wes McKinney] Start some prototyping of a cast implementation
Project: http://git-wip-us.apache.org/repos/asf/arrow/repo
Commit: http://git-wip-us.apache.org/repos/asf/arrow/commit/b0b125fd
Tree: http://git-wip-us.apache.org/repos/asf/arrow/tree/b0b125fd
Diff: http://git-wip-us.apache.org/repos/asf/arrow/diff/b0b125fd
Branch: refs/heads/master
Commit: b0b125fd74b2bb334e90d9775a670bf18ffd8a22
Parents: a9a80fe
Author: Wes McKinney <we...@twosigma.com>
Authored: Thu Sep 7 13:00:37 2017 -0400
Committer: Wes McKinney <we...@twosigma.com>
Committed: Thu Sep 7 13:00:37 2017 -0400
----------------------------------------------------------------------
cpp/CMakeLists.txt | 37 +++-
cpp/src/arrow/array.h | 2 +-
cpp/src/arrow/compute/CMakeLists.txt | 28 +++
cpp/src/arrow/compute/cast.cc | 329 +++++++++++++++++++++++++++++
cpp/src/arrow/compute/cast.h | 55 +++++
cpp/src/arrow/compute/compute-test.cc | 315 +++++++++++++++++++++++++++
cpp/src/arrow/compute/context.cc | 46 ++++
cpp/src/arrow/compute/context.h | 68 ++++++
cpp/src/arrow/memory_pool.cc | 4 +-
cpp/src/arrow/test-util.h | 10 +
cpp/src/arrow/type.h | 82 +------
cpp/src/arrow/type_traits.h | 68 ++++++
12 files changed, 956 insertions(+), 88 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/arrow/blob/b0b125fd/cpp/CMakeLists.txt
----------------------------------------------------------------------
diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt
index 9f9d71b..24735ac 100644
--- a/cpp/CMakeLists.txt
+++ b/cpp/CMakeLists.txt
@@ -94,6 +94,10 @@ if("${CMAKE_SOURCE_DIR}" STREQUAL "${CMAKE_CURRENT_SOURCE_DIR}")
"Exclude deprecated APIs from build"
OFF)
+ option(ARROW_COMPUTE
+ "Build the Arrow Compute Modules"
+ ON)
+
option(ARROW_EXTRA_ERROR_CONTEXT
"Compile with extra error context (line numbers, code)"
OFF)
@@ -727,17 +731,6 @@ endif()
add_subdirectory(src/arrow)
add_subdirectory(src/arrow/io)
-if (ARROW_GPU)
- # IPC extensions required to build the GPU library
- set(ARROW_IPC ON)
- add_subdirectory(src/arrow/gpu)
-endif()
-
-if (ARROW_IPC)
- add_subdirectory(src/arrow/ipc)
- add_dependencies(arrow_dependencies metadata_fbs)
-endif()
-
set(ARROW_SRCS
src/arrow/array.cc
src/arrow/buffer.cc
@@ -751,6 +744,9 @@ set(ARROW_SRCS
src/arrow/type.cc
src/arrow/visitor.cc
+ src/arrow/compute/cast.cc
+ src/arrow/compute/context.cc
+
src/arrow/io/file.cc
src/arrow/io/interfaces.cc
src/arrow/io/memory.cc
@@ -763,6 +759,25 @@ set(ARROW_SRCS
src/arrow/util/key_value_metadata.cc
)
+if (ARROW_COMPUTE)
+ add_subdirectory(src/arrow/compute)
+ set(ARROW_SRCS ${ARROW_SRCS}
+ src/arrow/compute/cast.cc
+ src/arrow/compute/context.cc
+ )
+endif()
+
+if (ARROW_GPU)
+ # IPC extensions required to build the GPU library
+ set(ARROW_IPC ON)
+ add_subdirectory(src/arrow/gpu)
+endif()
+
+if (ARROW_IPC)
+ add_subdirectory(src/arrow/ipc)
+ add_dependencies(arrow_dependencies metadata_fbs)
+endif()
+
if (ARROW_WITH_BROTLI)
add_definitions(-DARROW_WITH_BROTLI)
SET(ARROW_SRCS src/arrow/util/compression_brotli.cc ${ARROW_SRCS})
http://git-wip-us.apache.org/repos/asf/arrow/blob/b0b125fd/cpp/src/arrow/array.h
----------------------------------------------------------------------
diff --git a/cpp/src/arrow/array.h b/cpp/src/arrow/array.h
index 8e965e8..61ab2ef 100644
--- a/cpp/src/arrow/array.h
+++ b/cpp/src/arrow/array.h
@@ -678,4 +678,4 @@ MakePrimitiveArray(const std::shared_ptr<DataType>& type,
} // namespace arrow
-#endif
+#endif // ARROW_ARRAY_H
http://git-wip-us.apache.org/repos/asf/arrow/blob/b0b125fd/cpp/src/arrow/compute/CMakeLists.txt
----------------------------------------------------------------------
diff --git a/cpp/src/arrow/compute/CMakeLists.txt b/cpp/src/arrow/compute/CMakeLists.txt
new file mode 100644
index 0000000..a154c47
--- /dev/null
+++ b/cpp/src/arrow/compute/CMakeLists.txt
@@ -0,0 +1,28 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# Headers: top level
+install(FILES
+ cast.h
+ context.h
+ DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/arrow/compute")
+
+#######################################
+# Unit tests
+#######################################
+
+ADD_ARROW_TEST(compute-test)
http://git-wip-us.apache.org/repos/asf/arrow/blob/b0b125fd/cpp/src/arrow/compute/cast.cc
----------------------------------------------------------------------
diff --git a/cpp/src/arrow/compute/cast.cc b/cpp/src/arrow/compute/cast.cc
new file mode 100644
index 0000000..f610f6b
--- /dev/null
+++ b/cpp/src/arrow/compute/cast.cc
@@ -0,0 +1,329 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/compute/cast.h"
+
+#include <cstdint>
+#include <functional>
+#include <limits>
+#include <memory>
+#include <sstream>
+#include <type_traits>
+
+#include "arrow/type_traits.h"
+#include "arrow/util/logging.h"
+
+#include "arrow/compute/context.h"
+
+namespace arrow {
+namespace compute {
+
+struct CastContext {
+ FunctionContext* func_ctx;
+ CastOptions options;
+};
+
+typedef std::function<void(CastContext*, const ArrayData&, ArrayData*)> CastFunction;
+
+template <typename OutType, typename InType, typename Enable = void>
+struct CastFunctor {};
+
+// Type is the same, no computation required
+template <typename O, typename I>
+struct CastFunctor<O, I, typename std::enable_if<std::is_same<I, O>::value>::type> {
+ void operator()(CastContext* ctx, const ArrayData& input, ArrayData* output) {
+ output->type = input.type;
+ output->buffers = input.buffers;
+ output->length = input.length;
+ output->offset = input.offset;
+ output->null_count = input.null_count;
+ output->child_data = input.child_data;
+ }
+};
+
+// ----------------------------------------------------------------------
+// Null to other things
+
+template <typename T>
+struct CastFunctor<T, NullType,
+ typename std::enable_if<!std::is_same<T, NullType>::value>::type> {
+ void operator()(CastContext* ctx, const ArrayData& input, ArrayData* output) {
+ ctx->func_ctx->SetStatus(Status::NotImplemented("NullType"));
+ }
+};
+
+// ----------------------------------------------------------------------
+// Boolean to other things
+
+// Cast from Boolean to other numbers
+template <typename T>
+struct CastFunctor<T, BooleanType,
+ typename std::enable_if<std::is_base_of<Number, T>::value>::type> {
+ void operator()(CastContext* ctx, const ArrayData& input, ArrayData* output) {
+ using c_type = typename T::c_type;
+ const uint8_t* data = input.buffers[1]->data();
+ auto out = reinterpret_cast<c_type*>(output->buffers[1]->mutable_data());
+ constexpr auto kOne = static_cast<c_type>(1);
+ constexpr auto kZero = static_cast<c_type>(0);
+ for (int64_t i = 0; i < input.length; ++i) {
+ *out++ = BitUtil::GetBit(data, i) ? kOne : kZero;
+ }
+ }
+};
+
+// ----------------------------------------------------------------------
+// Integers and Floating Point
+
+template <typename O, typename I>
+struct is_numeric_cast {
+ static constexpr bool value =
+ (std::is_base_of<Number, O>::value && std::is_base_of<Number, I>::value) &&
+ (!std::is_same<O, I>::value);
+};
+
+template <typename O, typename I, typename Enable = void>
+struct is_integer_downcast {
+ static constexpr bool value = false;
+};
+
+template <typename O, typename I>
+struct is_integer_downcast<
+ O, I, typename std::enable_if<std::is_base_of<Integer, O>::value &&
+ std::is_base_of<Integer, I>::value>::type> {
+ using O_T = typename O::c_type;
+ using I_T = typename I::c_type;
+
+ static constexpr bool value =
+ ((!std::is_same<O, I>::value) &&
+
+ // same size, but unsigned to signed
+ ((sizeof(O_T) == sizeof(I_T) && std::is_signed<O_T>::value &&
+ std::is_unsigned<I_T>::value) ||
+
+ // Smaller output size
+ (sizeof(O_T) < sizeof(I_T))));
+};
+
+template <typename O, typename I>
+struct CastFunctor<O, I, typename std::enable_if<std::is_same<BooleanType, O>::value &&
+ std::is_base_of<Number, I>::value &&
+ !std::is_same<O, I>::value>::type> {
+ void operator()(CastContext* ctx, const ArrayData& input, ArrayData* output) {
+ using in_type = typename I::c_type;
+ auto in_data = reinterpret_cast<const in_type*>(input.buffers[1]->data());
+ uint8_t* out_data = reinterpret_cast<uint8_t*>(output->buffers[1]->mutable_data());
+ for (int64_t i = 0; i < input.length; ++i) {
+ BitUtil::SetBitTo(out_data, i, (*in_data++) != 0);
+ }
+ }
+};
+
+template <typename O, typename I>
+struct CastFunctor<O, I,
+ typename std::enable_if<is_integer_downcast<O, I>::value>::type> {
+ void operator()(CastContext* ctx, const ArrayData& input, ArrayData* output) {
+ using in_type = typename I::c_type;
+ using out_type = typename O::c_type;
+
+ auto in_offset = input.offset;
+
+ auto in_data = reinterpret_cast<const in_type*>(input.buffers[1]->data()) + in_offset;
+ auto out_data = reinterpret_cast<out_type*>(output->buffers[1]->mutable_data());
+
+ if (!ctx->options.allow_int_overflow) {
+ constexpr in_type kMax = static_cast<in_type>(std::numeric_limits<out_type>::max());
+ constexpr in_type kMin = static_cast<in_type>(std::numeric_limits<out_type>::min());
+
+ if (input.null_count > 0) {
+ const uint8_t* is_valid = input.buffers[0]->data();
+ int64_t is_valid_offset = in_offset;
+ for (int64_t i = 0; i < input.length; ++i) {
+ if (ARROW_PREDICT_FALSE(BitUtil::GetBit(is_valid, is_valid_offset++) &&
+ (*in_data > kMax || *in_data < kMin))) {
+ ctx->func_ctx->SetStatus(Status::Invalid("Integer value out of bounds"));
+ }
+ *out_data++ = static_cast<out_type>(*in_data++);
+ }
+ } else {
+ for (int64_t i = 0; i < input.length; ++i) {
+ if (ARROW_PREDICT_FALSE(*in_data > kMax || *in_data < kMin)) {
+ ctx->func_ctx->SetStatus(Status::Invalid("Integer value out of bounds"));
+ }
+ *out_data++ = static_cast<out_type>(*in_data++);
+ }
+ }
+ } else {
+ for (int64_t i = 0; i < input.length; ++i) {
+ *out_data++ = static_cast<out_type>(*in_data++);
+ }
+ }
+ }
+};
+
+template <typename O, typename I>
+struct CastFunctor<O, I,
+ typename std::enable_if<is_numeric_cast<O, I>::value &&
+ !is_integer_downcast<O, I>::value>::type> {
+ void operator()(CastContext* ctx, const ArrayData& input, ArrayData* output) {
+ using in_type = typename I::c_type;
+ using out_type = typename O::c_type;
+
+ auto in_data = reinterpret_cast<const in_type*>(input.buffers[1]->data());
+ auto out_data = reinterpret_cast<out_type*>(output->buffers[1]->mutable_data());
+ for (int64_t i = 0; i < input.length; ++i) {
+ *out_data++ = static_cast<out_type>(*in_data++);
+ }
+ }
+};
+
+// ----------------------------------------------------------------------
+
+#define CAST_CASE(InType, OutType) \
+ case OutType::type_id: \
+ return [type](CastContext* ctx, const ArrayData& input, ArrayData* out) { \
+ CastFunctor<OutType, InType> func; \
+ func(ctx, input, out); \
+ }
+
+#define NUMERIC_CASES(FN, IN_TYPE) \
+ FN(IN_TYPE, BooleanType); \
+ FN(IN_TYPE, UInt8Type); \
+ FN(IN_TYPE, Int8Type); \
+ FN(IN_TYPE, UInt16Type); \
+ FN(IN_TYPE, Int16Type); \
+ FN(IN_TYPE, UInt32Type); \
+ FN(IN_TYPE, Int32Type); \
+ FN(IN_TYPE, UInt64Type); \
+ FN(IN_TYPE, Int64Type); \
+ FN(IN_TYPE, FloatType); \
+ FN(IN_TYPE, DoubleType);
+
+#define GET_CAST_FUNCTION(CapType) \
+ static CastFunction Get##CapType##CastFunc(const std::shared_ptr<DataType>& type) { \
+ switch (type->id()) { \
+ NUMERIC_CASES(CAST_CASE, CapType); \
+ default: \
+ break; \
+ } \
+ return nullptr; \
+ }
+
+#define CAST_FUNCTION_CASE(CapType) \
+ case CapType::type_id: \
+ *out = Get##CapType##CastFunc(out_type); \
+ break
+
+GET_CAST_FUNCTION(BooleanType);
+GET_CAST_FUNCTION(UInt8Type);
+GET_CAST_FUNCTION(Int8Type);
+GET_CAST_FUNCTION(UInt16Type);
+GET_CAST_FUNCTION(Int16Type);
+GET_CAST_FUNCTION(UInt32Type);
+GET_CAST_FUNCTION(Int32Type);
+GET_CAST_FUNCTION(UInt64Type);
+GET_CAST_FUNCTION(Int64Type);
+GET_CAST_FUNCTION(FloatType);
+GET_CAST_FUNCTION(DoubleType);
+
+static Status GetCastFunction(const DataType& in_type,
+ const std::shared_ptr<DataType>& out_type,
+ CastFunction* out) {
+ switch (in_type.id()) {
+ CAST_FUNCTION_CASE(BooleanType);
+ CAST_FUNCTION_CASE(UInt8Type);
+ CAST_FUNCTION_CASE(Int8Type);
+ CAST_FUNCTION_CASE(UInt16Type);
+ CAST_FUNCTION_CASE(Int16Type);
+ CAST_FUNCTION_CASE(UInt32Type);
+ CAST_FUNCTION_CASE(Int32Type);
+ CAST_FUNCTION_CASE(UInt64Type);
+ CAST_FUNCTION_CASE(Int64Type);
+ CAST_FUNCTION_CASE(FloatType);
+ CAST_FUNCTION_CASE(DoubleType);
+ default:
+ break;
+ }
+ if (*out == nullptr) {
+ std::stringstream ss;
+ ss << "No cast implemented from " << in_type.ToString() << " to "
+ << out_type->ToString();
+ return Status::NotImplemented(ss.str());
+ }
+ return Status::OK();
+}
+
+static Status AllocateLike(FunctionContext* ctx, const Array& array,
+ const std::shared_ptr<DataType>& out_type,
+ std::shared_ptr<ArrayData>* out) {
+ if (!is_primitive(out_type->id())) {
+ return Status::NotImplemented(out_type->ToString());
+ }
+
+ const auto& fw_type = static_cast<const FixedWidthType&>(*out_type);
+
+ auto result = std::make_shared<ArrayData>();
+ result->type = out_type;
+ result->length = array.length();
+ result->offset = 0;
+ result->null_count = array.null_count();
+
+ // Propagate null bitmap
+ // TODO(wesm): handling null bitmap when input type is NullType
+ result->buffers.push_back(array.data()->buffers[0]);
+
+ std::shared_ptr<Buffer> out_data;
+
+ int bit_width = fw_type.bit_width();
+
+ if (bit_width == 1) {
+ RETURN_NOT_OK(ctx->Allocate(BitUtil::BytesForBits(array.length()), &out_data));
+ } else if (bit_width % 8 == 0) {
+ RETURN_NOT_OK(ctx->Allocate(array.length() * fw_type.bit_width() / 8, &out_data));
+ } else {
+ DCHECK(false);
+ }
+ result->buffers.push_back(out_data);
+
+ *out = result;
+ return Status::OK();
+}
+
+static Status Cast(CastContext* cast_ctx, const Array& array,
+ const std::shared_ptr<DataType>& out_type,
+ std::shared_ptr<Array>* out) {
+ // Dynamic dispatch to obtain right cast function
+ CastFunction func;
+ RETURN_NOT_OK(GetCastFunction(*array.type(), out_type, &func));
+
+ // Allocate memory for output
+ std::shared_ptr<ArrayData> out_data;
+ RETURN_NOT_OK(AllocateLike(cast_ctx->func_ctx, array, out_type, &out_data));
+
+ func(cast_ctx, *array.data(), out_data.get());
+ RETURN_IF_ERROR(cast_ctx->func_ctx);
+ return internal::MakeArray(out_data, out);
+}
+
+Status Cast(FunctionContext* ctx, const Array& array,
+ const std::shared_ptr<DataType>& out_type, const CastOptions& options,
+ std::shared_ptr<Array>* out) {
+ CastContext cast_ctx{ctx, options};
+ return Cast(&cast_ctx, array, out_type, out);
+}
+
+} // namespace compute
+} // namespace arrow
http://git-wip-us.apache.org/repos/asf/arrow/blob/b0b125fd/cpp/src/arrow/compute/cast.h
----------------------------------------------------------------------
diff --git a/cpp/src/arrow/compute/cast.h b/cpp/src/arrow/compute/cast.h
new file mode 100644
index 0000000..9ca70aa
--- /dev/null
+++ b/cpp/src/arrow/compute/cast.h
@@ -0,0 +1,55 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#ifndef ARROW_COMPUTE_CAST_H
+#define ARROW_COMPUTE_CAST_H
+
+#include <memory>
+
+#include "arrow/array.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+
+using internal::ArrayData;
+
+namespace compute {
+
+class FunctionContext;
+
+struct CastOptions {
+ bool allow_int_overflow;
+};
+
+/// \brief Cast from one array type to another
+/// \param[in] context
+/// \param[in] array
+/// \param[in] to_type
+/// \param[in] options
+/// \param[out] out
+///
+/// \since 0.7.0
+/// \note API not yet finalized
+ARROW_EXPORT
+Status Cast(FunctionContext* context, const Array& array,
+ const std::shared_ptr<DataType>& to_type, const CastOptions& options,
+ std::shared_ptr<Array>* out);
+
+} // namespace compute
+} // namespace arrow
+
+#endif // ARROW_COMPUTE_CAST_H
http://git-wip-us.apache.org/repos/asf/arrow/blob/b0b125fd/cpp/src/arrow/compute/compute-test.cc
----------------------------------------------------------------------
diff --git a/cpp/src/arrow/compute/compute-test.cc b/cpp/src/arrow/compute/compute-test.cc
new file mode 100644
index 0000000..cda5755
--- /dev/null
+++ b/cpp/src/arrow/compute/compute-test.cc
@@ -0,0 +1,315 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <cstdint>
+#include <cstdlib>
+#include <memory>
+#include <numeric>
+#include <sstream>
+#include <vector>
+
+#include "gtest/gtest.h"
+
+#include "arrow/array.h"
+#include "arrow/buffer.h"
+#include "arrow/builder.h"
+#include "arrow/compare.h"
+#include "arrow/ipc/test-common.h"
+#include "arrow/memory_pool.h"
+#include "arrow/pretty_print.h"
+#include "arrow/status.h"
+#include "arrow/test-common.h"
+#include "arrow/test-util.h"
+#include "arrow/type.h"
+#include "arrow/type_traits.h"
+
+#include "arrow/compute/cast.h"
+#include "arrow/compute/context.h"
+
+using std::vector;
+
+namespace arrow {
+namespace compute {
+
+void AssertArraysEqual(const Array& left, const Array& right) {
+ bool are_equal = false;
+ ASSERT_OK(ArrayEquals(left, right, &are_equal));
+
+ if (!are_equal) {
+ std::stringstream ss;
+
+ ss << "Left: ";
+ EXPECT_OK(PrettyPrint(left, 0, &ss));
+ ss << "\nRight: ";
+ EXPECT_OK(PrettyPrint(right, 0, &ss));
+ FAIL() << ss.str();
+ }
+}
+
+class ComputeFixture {
+ public:
+ ComputeFixture() : pool_(default_memory_pool()), ctx_(pool_) {}
+
+ protected:
+ MemoryPool* pool_;
+ FunctionContext ctx_;
+};
+
+// ----------------------------------------------------------------------
+// Cast
+
+class TestCast : public ComputeFixture, public ::testing::Test {
+ public:
+ void CheckPass(const Array& input, const Array& expected,
+ const std::shared_ptr<DataType>& out_type, const CastOptions& options) {
+ std::shared_ptr<Array> result;
+ ASSERT_OK(Cast(&ctx_, input, out_type, options, &result));
+ AssertArraysEqual(expected, *result);
+ }
+
+ template <typename InType, typename I_TYPE>
+ void CheckFails(const std::shared_ptr<DataType>& in_type,
+ const std::vector<I_TYPE>& in_values, const std::vector<bool>& is_valid,
+ const std::shared_ptr<DataType>& out_type, const CastOptions& options) {
+ std::shared_ptr<Array> input, result;
+ if (is_valid.size() > 0) {
+ ArrayFromVector<InType, I_TYPE>(in_type, is_valid, in_values, &input);
+ } else {
+ ArrayFromVector<InType, I_TYPE>(in_type, in_values, &input);
+ }
+ ASSERT_RAISES(Invalid, Cast(&ctx_, *input, out_type, options, &result));
+ }
+
+ template <typename InType, typename I_TYPE, typename OutType, typename O_TYPE>
+ void CheckCase(const std::shared_ptr<DataType>& in_type,
+ const std::vector<I_TYPE>& in_values, const std::vector<bool>& is_valid,
+ const std::shared_ptr<DataType>& out_type,
+ const std::vector<O_TYPE>& out_values, const CastOptions& options) {
+ std::shared_ptr<Array> input, expected;
+ if (is_valid.size() > 0) {
+ ArrayFromVector<InType, I_TYPE>(in_type, is_valid, in_values, &input);
+ ArrayFromVector<OutType, O_TYPE>(out_type, is_valid, out_values, &expected);
+ } else {
+ ArrayFromVector<InType, I_TYPE>(in_type, in_values, &input);
+ ArrayFromVector<OutType, O_TYPE>(out_type, out_values, &expected);
+ }
+ CheckPass(*input, *expected, out_type, options);
+ }
+};
+
+TEST_F(TestCast, SameTypeZeroCopy) {
+ vector<bool> is_valid = {true, false, true, true, true};
+ vector<int32_t> v1 = {0, 1, 2, 3, 4};
+
+ std::shared_ptr<Array> arr;
+ ArrayFromVector<Int32Type, int32_t>(int32(), is_valid, v1, &arr);
+
+ std::shared_ptr<Array> result;
+ ASSERT_OK(Cast(&this->ctx_, *arr, int32(), {}, &result));
+
+ const auto& lbuffers = arr->data()->buffers;
+ const auto& rbuffers = result->data()->buffers;
+
+ // Buffers are the same
+ ASSERT_EQ(lbuffers[0].get(), rbuffers[0].get());
+ ASSERT_EQ(lbuffers[1].get(), rbuffers[1].get());
+}
+
+TEST_F(TestCast, ToBoolean) {
+ CastOptions options;
+
+ vector<bool> is_valid = {true, false, true, true, true};
+
+ // int8, should suffice for other integers
+ vector<int8_t> v1 = {0, 1, 127, -1, 0};
+ vector<bool> e1 = {false, true, true, true, false};
+ CheckCase<Int8Type, int8_t, BooleanType, bool>(int8(), v1, is_valid, boolean(), e1,
+ options);
+
+ // floating point
+ vector<double> v2 = {1.0, 0, 0, -1.0, 5.0};
+ vector<bool> e2 = {true, false, false, true, true};
+ CheckCase<DoubleType, double, BooleanType, bool>(float64(), v2, is_valid, boolean(), e2,
+ options);
+}
+
+TEST_F(TestCast, ToIntUpcast) {
+ CastOptions options;
+ options.allow_int_overflow = false;
+
+ vector<bool> is_valid = {true, false, true, true, true};
+
+ // int8 to int32
+ vector<int8_t> v1 = {0, 1, 127, -1, 0};
+ vector<int32_t> e1 = {0, 1, 127, -1, 0};
+ CheckCase<Int8Type, int8_t, Int32Type, int32_t>(int8(), v1, is_valid, int32(), e1,
+ options);
+
+ // bool to int8
+ vector<bool> v2 = {false, true, false, true, true};
+ vector<int8_t> e2 = {0, 1, 0, 1, 1};
+ CheckCase<BooleanType, bool, Int8Type, int8_t>(boolean(), v2, is_valid, int8(), e2,
+ options);
+
+ // uint8 to int16, no overflow/underrun
+ vector<uint8_t> v3 = {0, 100, 200, 255, 0};
+ vector<int16_t> e3 = {0, 100, 200, 255, 0};
+ CheckCase<UInt8Type, uint8_t, Int16Type, int16_t>(uint8(), v3, is_valid, int16(), e3,
+ options);
+
+ // floating point to integer
+ vector<double> v4 = {1.5, 0, 0.5, -1.5, 5.5};
+ vector<int32_t> e4 = {1, 0, 0, -1, 5};
+ CheckCase<DoubleType, double, Int32Type, int32_t>(float64(), v4, is_valid, int32(), e4,
+ options);
+}
+
+TEST_F(TestCast, OverflowInNullSlot) {
+ CastOptions options;
+ options.allow_int_overflow = false;
+
+ vector<bool> is_valid = {true, false, true, true, true};
+
+ vector<int32_t> v11 = {0, 70000, 2000, 1000, 0};
+ vector<int16_t> e11 = {0, 0, 2000, 1000, 0};
+
+ std::shared_ptr<Array> expected;
+ ArrayFromVector<Int16Type, int16_t>(int16(), is_valid, e11, &expected);
+
+ auto buf = std::make_shared<Buffer>(reinterpret_cast<const uint8_t*>(v11.data()),
+ static_cast<int64_t>(v11.size()));
+ Int32Array tmp11(5, buf, expected->null_bitmap(), -1);
+
+ CheckPass(tmp11, *expected, int16(), options);
+}
+
+TEST_F(TestCast, ToIntDowncastSafe) {
+ CastOptions options;
+ options.allow_int_overflow = false;
+
+ vector<bool> is_valid = {true, false, true, true, true};
+
+ // int16 to uint8, no overflow/underrun
+ vector<int16_t> v5 = {0, 100, 200, 1, 2};
+ vector<uint8_t> e5 = {0, 100, 200, 1, 2};
+ CheckCase<Int16Type, int16_t, UInt8Type, uint8_t>(int16(), v5, is_valid, uint8(), e5,
+ options);
+
+ // int16 to uint8, with overflow
+ vector<int16_t> v6 = {0, 100, 256, 0, 0};
+ CheckFails<Int16Type>(int16(), v6, is_valid, uint8(), options);
+
+ // underflow
+ vector<int16_t> v7 = {0, 100, -1, 0, 0};
+ CheckFails<Int16Type>(int16(), v7, is_valid, uint8(), options);
+
+ // int32 to int16, no overflow
+ vector<int32_t> v8 = {0, 1000, 2000, 1, 2};
+ vector<int16_t> e8 = {0, 1000, 2000, 1, 2};
+ CheckCase<Int32Type, int32_t, Int16Type, int16_t>(int32(), v8, is_valid, int16(), e8,
+ options);
+
+ // int32 to int16, overflow
+ vector<int32_t> v9 = {0, 1000, 2000, 70000, 0};
+ CheckFails<Int32Type>(int32(), v9, is_valid, int16(), options);
+
+ // underflow
+ vector<int32_t> v10 = {0, 1000, 2000, -70000, 0};
+ CheckFails<Int32Type>(int32(), v9, is_valid, int16(), options);
+}
+
+TEST_F(TestCast, ToIntDowncastUnsafe) {
+ CastOptions options;
+ options.allow_int_overflow = true;
+
+ vector<bool> is_valid = {true, false, true, true, true};
+
+ // int16 to uint8, no overflow/underrun
+ vector<int16_t> v5 = {0, 100, 200, 1, 2};
+ vector<uint8_t> e5 = {0, 100, 200, 1, 2};
+ CheckCase<Int16Type, int16_t, UInt8Type, uint8_t>(int16(), v5, is_valid, uint8(), e5,
+ options);
+
+ // int16 to uint8, with overflow
+ vector<int16_t> v6 = {0, 100, 256, 0, 0};
+ vector<uint8_t> e6 = {0, 100, 0, 0, 0};
+ CheckCase<Int16Type, int16_t, UInt8Type, uint8_t>(int16(), v6, is_valid, uint8(), e6,
+ options);
+
+ // underflow
+ vector<int16_t> v7 = {0, 100, -1, 0, 0};
+ vector<uint8_t> e7 = {0, 100, 255, 0, 0};
+ CheckCase<Int16Type, int16_t, UInt8Type, uint8_t>(int16(), v7, is_valid, uint8(), e7,
+ options);
+
+ // int32 to int16, no overflow
+ vector<int32_t> v8 = {0, 1000, 2000, 1, 2};
+ vector<int16_t> e8 = {0, 1000, 2000, 1, 2};
+ CheckCase<Int32Type, int32_t, Int16Type, int16_t>(int32(), v8, is_valid, int16(), e8,
+ options);
+
+ // int32 to int16, overflow
+ // TODO(wesm): do we want to allow this? we could set to null
+ vector<int32_t> v9 = {0, 1000, 2000, 70000, 0};
+ vector<int16_t> e9 = {0, 1000, 2000, 4464, 0};
+ CheckCase<Int32Type, int32_t, Int16Type, int16_t>(int32(), v9, is_valid, int16(), e9,
+ options);
+
+ // underflow
+ // TODO(wesm): do we want to allow this? we could set overflow to null
+ vector<int32_t> v10 = {0, 1000, 2000, -70000, 0};
+ vector<int16_t> e10 = {0, 1000, 2000, -4464, 0};
+ CheckCase<Int32Type, int32_t, Int16Type, int16_t>(int32(), v10, is_valid, int16(), e10,
+ options);
+}
+
+TEST_F(TestCast, ToDouble) {
+ CastOptions options;
+ vector<bool> is_valid = {true, false, true, true, true};
+
+ // int16 to double
+ vector<int16_t> v1 = {0, 100, 200, 1, 2};
+ vector<double> e1 = {0, 100, 200, 1, 2};
+ CheckCase<Int16Type, int16_t, DoubleType, double>(int16(), v1, is_valid, float64(), e1,
+ options);
+
+ // float to double
+ vector<float> v2 = {0, 100, 200, 1, 2};
+ vector<double> e2 = {0, 100, 200, 1, 2};
+ CheckCase<FloatType, float, DoubleType, double>(float32(), v2, is_valid, float64(), e2,
+ options);
+
+ // bool to double
+ vector<bool> v3 = {true, true, false, false, true};
+ vector<double> e3 = {1, 1, 0, 0, 1};
+ CheckCase<BooleanType, bool, DoubleType, double>(boolean(), v3, is_valid, float64(), e3,
+ options);
+}
+
+TEST_F(TestCast, UnsupportedTarget) {
+ vector<bool> is_valid = {true, false, true, true, true};
+ vector<int32_t> v1 = {0, 1, 2, 3, 4};
+
+ std::shared_ptr<Array> arr;
+ ArrayFromVector<Int32Type, int32_t>(int32(), is_valid, v1, &arr);
+
+ std::shared_ptr<Array> result;
+ ASSERT_RAISES(NotImplemented, Cast(&this->ctx_, *arr, utf8(), {}, &result));
+}
+
+} // namespace compute
+} // namespace arrow
http://git-wip-us.apache.org/repos/asf/arrow/blob/b0b125fd/cpp/src/arrow/compute/context.cc
----------------------------------------------------------------------
diff --git a/cpp/src/arrow/compute/context.cc b/cpp/src/arrow/compute/context.cc
new file mode 100644
index 0000000..792dc4f
--- /dev/null
+++ b/cpp/src/arrow/compute/context.cc
@@ -0,0 +1,46 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/compute/context.h"
+
+#include <memory>
+
+#include "arrow/buffer.h"
+
+namespace arrow {
+namespace compute {
+
+FunctionContext::FunctionContext(MemoryPool* pool) : pool_(pool) {}
+
+MemoryPool* FunctionContext::memory_pool() const { return pool_; }
+
+Status FunctionContext::Allocate(const int64_t nbytes, std::shared_ptr<Buffer>* out) {
+ return AllocateBuffer(pool_, nbytes, out);
+}
+
+void FunctionContext::SetStatus(const Status& status) {
+ if (ARROW_PREDICT_FALSE(!status_.ok())) {
+ return;
+ }
+ status_ = status;
+}
+
+/// \brief Clear any error status
+void FunctionContext::ResetStatus() { status_ = Status::OK(); }
+
+} // namespace compute
+} // namespace arrow
http://git-wip-us.apache.org/repos/asf/arrow/blob/b0b125fd/cpp/src/arrow/compute/context.h
----------------------------------------------------------------------
diff --git a/cpp/src/arrow/compute/context.h b/cpp/src/arrow/compute/context.h
new file mode 100644
index 0000000..caff2e2
--- /dev/null
+++ b/cpp/src/arrow/compute/context.h
@@ -0,0 +1,68 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#ifndef ARROW_COMPUTE_CONTEXT_H
+#define ARROW_COMPUTE_CONTEXT_H
+
+#include "arrow/status.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/visibility.h"
+
+namespace arrow {
+namespace compute {
+
+#define RETURN_IF_ERROR(ctx) \
+ if (ARROW_PREDICT_FALSE(ctx->HasError())) { \
+ Status s = ctx->status(); \
+ ctx->ResetStatus(); \
+ return s; \
+ }
+
+/// \brief Container for variables and options used by function evaluation
+class ARROW_EXPORT FunctionContext {
+ public:
+ explicit FunctionContext(MemoryPool* pool);
+ MemoryPool* memory_pool() const;
+
+ /// \brief Allocate buffer from the context's memory pool
+ Status Allocate(const int64_t nbytes, std::shared_ptr<Buffer>* out);
+
+ /// \brief Indicate that an error has occurred, to be checked by a parent caller
+ /// \param[in] status a Status instance
+ ///
+ /// \note Will not overwrite a prior set Status, so we will have the first
+ /// error that occurred until FunctionContext::ResetStatus is called
+ void SetStatus(const Status& status);
+
+ /// \brief Clear any error status
+ void ResetStatus();
+
+ /// \brief Return true if an error has occurred
+ bool HasError() const { return !status_.ok(); }
+
+ /// \brief Return the current status of the context
+ const Status& status() const { return status_; }
+
+ private:
+ Status status_;
+ MemoryPool* pool_;
+};
+
+} // namespace compute
+} // namespace arrow
+
+#endif // ARROW_COMPUTE_CONTEXT_H
http://git-wip-us.apache.org/repos/asf/arrow/blob/b0b125fd/cpp/src/arrow/memory_pool.cc
----------------------------------------------------------------------
diff --git a/cpp/src/arrow/memory_pool.cc b/cpp/src/arrow/memory_pool.cc
index 7fd999e..d86fb08 100644
--- a/cpp/src/arrow/memory_pool.cc
+++ b/cpp/src/arrow/memory_pool.cc
@@ -168,8 +168,8 @@ Status LoggingMemoryPool::Allocate(int64_t size, uint8_t** out) {
Status LoggingMemoryPool::Reallocate(int64_t old_size, int64_t new_size, uint8_t** ptr) {
Status s = pool_->Reallocate(old_size, new_size, ptr);
- std::cout << "Reallocate: old_size = " << old_size
- << " - new_size = " << new_size << std::endl;
+ std::cout << "Reallocate: old_size = " << old_size << " - new_size = " << new_size
+ << std::endl;
return s;
}
http://git-wip-us.apache.org/repos/asf/arrow/blob/b0b125fd/cpp/src/arrow/test-util.h
----------------------------------------------------------------------
diff --git a/cpp/src/arrow/test-util.h b/cpp/src/arrow/test-util.h
index 91f2bc9..22a933d 100644
--- a/cpp/src/arrow/test-util.h
+++ b/cpp/src/arrow/test-util.h
@@ -233,6 +233,16 @@ void ArrayFromVector(const std::shared_ptr<DataType>& type,
}
template <typename TYPE, typename C_TYPE>
+void ArrayFromVector(const std::shared_ptr<DataType>& type,
+ const std::vector<C_TYPE>& values, std::shared_ptr<Array>* out) {
+ typename TypeTraits<TYPE>::BuilderType builder(type, default_memory_pool());
+ for (size_t i = 0; i < values.size(); ++i) {
+ ASSERT_OK(builder.Append(values[i]));
+ }
+ ASSERT_OK(builder.Finish(out));
+}
+
+template <typename TYPE, typename C_TYPE>
void ArrayFromVector(const std::vector<bool>& is_valid, const std::vector<C_TYPE>& values,
std::shared_ptr<Array>* out) {
typename TypeTraits<TYPE>::BuilderType builder;
http://git-wip-us.apache.org/repos/asf/arrow/blob/b0b125fd/cpp/src/arrow/type.h
----------------------------------------------------------------------
diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h
index b532cd2..283e27e 100644
--- a/cpp/src/arrow/type.h
+++ b/cpp/src/arrow/type.h
@@ -186,15 +186,20 @@ class ARROW_EXPORT PrimitiveCType : public FixedWidthType {
using FixedWidthType::FixedWidthType;
};
-class ARROW_EXPORT Integer : public PrimitiveCType {
+class ARROW_EXPORT Number : public PrimitiveCType {
public:
using PrimitiveCType::PrimitiveCType;
+};
+
+class ARROW_EXPORT Integer : public Number {
+ public:
+ using Number::Number;
virtual bool is_signed() const = 0;
};
-class ARROW_EXPORT FloatingPoint : public PrimitiveCType {
+class ARROW_EXPORT FloatingPoint : public Number {
public:
- using PrimitiveCType::PrimitiveCType;
+ using Number::Number;
enum Precision { HALF, SINGLE, DOUBLE };
virtual Precision precision() const = 0;
};
@@ -842,77 +847,6 @@ std::shared_ptr<Schema> schema(
std::vector<std::shared_ptr<Field>>&& fields,
const std::shared_ptr<const KeyValueMetadata>& metadata = nullptr);
-// ----------------------------------------------------------------------
-//
-
-static inline bool is_integer(Type::type type_id) {
- switch (type_id) {
- case Type::UINT8:
- case Type::INT8:
- case Type::UINT16:
- case Type::INT16:
- case Type::UINT32:
- case Type::INT32:
- case Type::UINT64:
- case Type::INT64:
- return true;
- default:
- break;
- }
- return false;
-}
-
-static inline bool is_floating(Type::type type_id) {
- switch (type_id) {
- case Type::HALF_FLOAT:
- case Type::FLOAT:
- case Type::DOUBLE:
- return true;
- default:
- break;
- }
- return false;
-}
-
-static inline bool is_primitive(Type::type type_id) {
- switch (type_id) {
- case Type::NA:
- case Type::BOOL:
- case Type::UINT8:
- case Type::INT8:
- case Type::UINT16:
- case Type::INT16:
- case Type::UINT32:
- case Type::INT32:
- case Type::UINT64:
- case Type::INT64:
- case Type::HALF_FLOAT:
- case Type::FLOAT:
- case Type::DOUBLE:
- case Type::DATE32:
- case Type::DATE64:
- case Type::TIME32:
- case Type::TIME64:
- case Type::TIMESTAMP:
- case Type::INTERVAL:
- return true;
- default:
- break;
- }
- return false;
-}
-
-static inline bool is_binary_like(Type::type type_id) {
- switch (type_id) {
- case Type::BINARY:
- case Type::STRING:
- return true;
- default:
- break;
- }
- return false;
-}
-
} // namespace arrow
#endif // ARROW_TYPE_H
http://git-wip-us.apache.org/repos/asf/arrow/blob/b0b125fd/cpp/src/arrow/type_traits.h
----------------------------------------------------------------------
diff --git a/cpp/src/arrow/type_traits.h b/cpp/src/arrow/type_traits.h
index d424cc8..fbd7839 100644
--- a/cpp/src/arrow/type_traits.h
+++ b/cpp/src/arrow/type_traits.h
@@ -362,6 +362,74 @@ struct IsNumeric {
static constexpr bool value = std::is_arithmetic<c_type>::value;
};
+static inline bool is_integer(Type::type type_id) {
+ switch (type_id) {
+ case Type::UINT8:
+ case Type::INT8:
+ case Type::UINT16:
+ case Type::INT16:
+ case Type::UINT32:
+ case Type::INT32:
+ case Type::UINT64:
+ case Type::INT64:
+ return true;
+ default:
+ break;
+ }
+ return false;
+}
+
+static inline bool is_floating(Type::type type_id) {
+ switch (type_id) {
+ case Type::HALF_FLOAT:
+ case Type::FLOAT:
+ case Type::DOUBLE:
+ return true;
+ default:
+ break;
+ }
+ return false;
+}
+
+static inline bool is_primitive(Type::type type_id) {
+ switch (type_id) {
+ case Type::NA:
+ case Type::BOOL:
+ case Type::UINT8:
+ case Type::INT8:
+ case Type::UINT16:
+ case Type::INT16:
+ case Type::UINT32:
+ case Type::INT32:
+ case Type::UINT64:
+ case Type::INT64:
+ case Type::HALF_FLOAT:
+ case Type::FLOAT:
+ case Type::DOUBLE:
+ case Type::DATE32:
+ case Type::DATE64:
+ case Type::TIME32:
+ case Type::TIME64:
+ case Type::TIMESTAMP:
+ case Type::INTERVAL:
+ return true;
+ default:
+ break;
+ }
+ return false;
+}
+
+static inline bool is_binary_like(Type::type type_id) {
+ switch (type_id) {
+ case Type::BINARY:
+ case Type::STRING:
+ return true;
+ default:
+ break;
+ }
+ return false;
+}
+
} // namespace arrow
#endif // ARROW_TYPE_TRAITS_H