You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by bk...@apache.org on 2021/02/26 17:39:14 UTC
[arrow] branch master updated: ARROW-11662: [C++] Support sorting
decimal and fixed size binary data
This is an automated email from the ASF dual-hosted git repository.
bkietz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new dfd2323 ARROW-11662: [C++] Support sorting decimal and fixed size binary data
dfd2323 is described below
commit dfd232313e1538b81a38db1e59cf4a109b61a467
Author: Antoine Pitrou <an...@python.org>
AuthorDate: Fri Feb 26 12:37:52 2021 -0500
ARROW-11662: [C++] Support sorting decimal and fixed size binary data
Also enable nth_to_indices on decimal and fixed size binary data.
Closes #9577 from pitrou/ARROW-11662-sort-decimal
Authored-by: Antoine Pitrou <an...@python.org>
Signed-off-by: Benjamin Kietzman <be...@gmail.com>
---
c_glib/test/test-decimal128-data-type.rb | 4 +-
cpp/src/arrow/compute/kernels/codegen_internal.cc | 9 +
cpp/src/arrow/compute/kernels/codegen_internal.h | 32 ++
cpp/src/arrow/compute/kernels/vector_sort.cc | 79 +++--
cpp/src/arrow/compute/kernels/vector_sort_test.cc | 385 +++++++++++++++-------
cpp/src/arrow/testing/gtest_util.h | 2 +
cpp/src/arrow/testing/random.cc | 42 ++-
cpp/src/arrow/testing/random.h | 11 +
cpp/src/arrow/type.cc | 2 +-
cpp/src/arrow/type.h | 4 +-
cpp/src/arrow/type_test.cc | 6 +-
docs/source/cpp/compute.rst | 3 +-
ruby/red-arrow/test/test-decimal128-data-type.rb | 4 +-
13 files changed, 432 insertions(+), 151 deletions(-)
diff --git a/c_glib/test/test-decimal128-data-type.rb b/c_glib/test/test-decimal128-data-type.rb
index a02e3ba..b27e1ca 100644
--- a/c_glib/test/test-decimal128-data-type.rb
+++ b/c_glib/test/test-decimal128-data-type.rb
@@ -23,12 +23,12 @@ class TestDecimal128DataType < Test::Unit::TestCase
def test_name
data_type = Arrow::Decimal128DataType.new(2, 0)
- assert_equal("decimal", data_type.name)
+ assert_equal("decimal128", data_type.name)
end
def test_to_s
data_type = Arrow::Decimal128DataType.new(2, 0)
- assert_equal("decimal(2, 0)", data_type.to_s)
+ assert_equal("decimal128(2, 0)", data_type.to_s)
end
def test_precision
diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.cc b/cpp/src/arrow/compute/kernels/codegen_internal.cc
index b321ff3..ad43b7a 100644
--- a/cpp/src/arrow/compute/kernels/codegen_internal.cc
+++ b/cpp/src/arrow/compute/kernels/codegen_internal.cc
@@ -48,6 +48,7 @@ std::vector<std::shared_ptr<DataType>> g_numeric_types;
std::vector<std::shared_ptr<DataType>> g_base_binary_types;
std::vector<std::shared_ptr<DataType>> g_temporal_types;
std::vector<std::shared_ptr<DataType>> g_primitive_types;
+std::vector<Type::type> g_decimal_type_ids;
static std::once_flag codegen_static_initialized;
template <typename T>
@@ -71,6 +72,9 @@ static void InitStaticData() {
// Floating point types
g_floating_types = {float32(), float64()};
+ // Decimal types
+ g_decimal_type_ids = {Type::DECIMAL128, Type::DECIMAL256};
+
// Numeric types
Extend(g_int_types, &g_numeric_types);
Extend(g_floating_types, &g_numeric_types);
@@ -132,6 +136,11 @@ const std::vector<std::shared_ptr<DataType>>& FloatingPointTypes() {
return g_floating_types;
}
+const std::vector<Type::type>& DecimalTypeIds() {
+ std::call_once(codegen_static_initialized, InitStaticData);
+ return g_decimal_type_ids;
+}
+
const std::vector<TimeUnit::type>& AllTimeUnits() {
static std::vector<TimeUnit::type> units = {TimeUnit::SECOND, TimeUnit::MILLI,
TimeUnit::MICRO, TimeUnit::NANO};
diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h
index 8c49e79..9e2ed82 100644
--- a/cpp/src/arrow/compute/kernels/codegen_internal.h
+++ b/cpp/src/arrow/compute/kernels/codegen_internal.h
@@ -188,6 +188,16 @@ struct GetViewType<Decimal128Type> {
}
};
+template <>
+struct GetViewType<Decimal256Type> {
+ using T = Decimal256;
+ using PhysicalType = util::string_view;
+
+ static T LogicalValue(PhysicalType value) {
+ return Decimal256(reinterpret_cast<const uint8_t*>(value.data()));
+ }
+};
+
template <typename Type, typename Enable = void>
struct GetOutputType;
@@ -206,6 +216,11 @@ struct GetOutputType<Decimal128Type> {
using T = Decimal128;
};
+template <>
+struct GetOutputType<Decimal256Type> {
+ using T = Decimal256;
+};
+
// ----------------------------------------------------------------------
// Iteration / value access utilities
@@ -396,6 +411,7 @@ const std::vector<std::shared_ptr<DataType>>& SignedIntTypes();
const std::vector<std::shared_ptr<DataType>>& UnsignedIntTypes();
const std::vector<std::shared_ptr<DataType>>& IntTypes();
const std::vector<std::shared_ptr<DataType>>& FloatingPointTypes();
+const std::vector<Type::type>& DecimalTypeIds();
ARROW_EXPORT
const std::vector<TimeUnit::type>& AllTimeUnits();
@@ -1185,6 +1201,22 @@ ArrayKernelExec GenerateTemporal(detail::GetTypeId get_id) {
}
}
+// Generate a kernel given a templated functor for decimal types
+//
+// See "Numeric" above for description of the generator functor
+template <template <typename...> class Generator, typename Type0, typename... Args>
+ArrayKernelExec GenerateDecimal(detail::GetTypeId get_id) {
+ switch (get_id.id) {
+ case Type::DECIMAL128:
+ return Generator<Type0, Decimal128Type, Args...>::Exec;
+ case Type::DECIMAL256:
+ return Generator<Type0, Decimal256Type, Args...>::Exec;
+ default:
+ DCHECK(false);
+ return ExecFail;
+ }
+}
+
// END of kernel generator-dispatchers
// ----------------------------------------------------------------------
diff --git a/cpp/src/arrow/compute/kernels/vector_sort.cc b/cpp/src/arrow/compute/kernels/vector_sort.cc
index 5170662..a29c931 100644
--- a/cpp/src/arrow/compute/kernels/vector_sort.cc
+++ b/cpp/src/arrow/compute/kernels/vector_sort.cc
@@ -52,14 +52,18 @@ namespace internal {
VISIT(FloatType) \
VISIT(DoubleType) \
VISIT(BinaryType) \
- VISIT(LargeBinaryType)
+ VISIT(LargeBinaryType) \
+ VISIT(FixedSizeBinaryType) \
+ VISIT(Decimal128Type) \
+ VISIT(Decimal256Type)
namespace {
// The target chunk in a chunked array.
template <typename ArrayType>
struct ResolvedChunk {
- using ViewType = decltype(std::declval<ArrayType>().GetView(0));
+ using V = GetViewType<typename ArrayType::TypeClass>;
+ using LogicalValueType = typename V::T;
// The target array in chunked array.
const ArrayType* array;
@@ -70,7 +74,7 @@ struct ResolvedChunk {
bool IsNull() const { return array->IsNull(index); }
- ViewType GetView() const { return array->GetView(index); }
+ LogicalValueType Value() const { return V::LogicalValue(array->GetView(index)); }
};
// ResolvedChunk specialization for untyped arrays when all is needed is null lookup
@@ -279,7 +283,7 @@ PartitionNullLikes(uint64_t* indices_begin, uint64_t* indices_end,
ChunkedArrayResolver resolver(arrays);
return partitioner(indices_begin, indices_end, [&](uint64_t ind) {
const auto chunk = resolver.Resolve<ArrayType>(ind);
- return !std::isnan(chunk.GetView());
+ return !std::isnan(chunk.Value());
});
}
@@ -318,6 +322,8 @@ struct PartitionNthToIndices {
using ArrayType = typename TypeTraits<InType>::ArrayType;
static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
+ using GetView = GetViewType<InType>;
+
if (ctx->state() == nullptr) {
ctx->SetStatus(Status::Invalid("NthToIndices requires PartitionNthOptions"));
return;
@@ -343,7 +349,9 @@ struct PartitionNthToIndices {
if (nth_begin < nulls_begin) {
std::nth_element(out_begin, nth_begin, nulls_begin,
[&arr](uint64_t left, uint64_t right) {
- return arr.GetView(left) < arr.GetView(right);
+ const auto lval = GetView::LogicalValue(arr.GetView(left));
+ const auto rval = GetView::LogicalValue(arr.GetView(right));
+ return lval < rval;
});
}
}
@@ -365,6 +373,7 @@ inline void VisitRawValuesInline(const ArrayType& values,
template <typename ArrowType>
class ArrayCompareSorter {
using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
+ using GetView = GetViewType<ArrowType>;
public:
// Returns where null starts.
@@ -377,14 +386,18 @@ class ArrayCompareSorter {
if (options.order == SortOrder::Ascending) {
std::stable_sort(
indices_begin, nulls_begin, [&values, &offset](uint64_t left, uint64_t right) {
- return values.GetView(left - offset) < values.GetView(right - offset);
+ const auto lhs = GetView::LogicalValue(values.GetView(left - offset));
+ const auto rhs = GetView::LogicalValue(values.GetView(right - offset));
+ return lhs < rhs;
});
} else {
std::stable_sort(
indices_begin, nulls_begin, [&values, &offset](uint64_t left, uint64_t right) {
+ const auto lhs = GetView::LogicalValue(values.GetView(left - offset));
+ const auto rhs = GetView::LogicalValue(values.GetView(right - offset));
// We don't use 'left > right' here to reduce required operator.
// If we use 'right < left' here, '<' is only required.
- return values.GetView(right - offset) < values.GetView(left - offset);
+ return rhs < lhs;
});
}
return nulls_begin;
@@ -542,8 +555,9 @@ struct ArraySorter<Type, enable_if_t<(is_integer_type<Type>::value &&
};
template <typename Type>
-struct ArraySorter<Type, enable_if_t<is_floating_type<Type>::value ||
- is_base_binary_type<Type>::value>> {
+struct ArraySorter<
+ Type, enable_if_t<is_floating_type<Type>::value || is_base_binary_type<Type>::value ||
+ is_fixed_size_binary_type<Type>::value>> {
ArrayCompareSorter<Type> impl;
};
@@ -585,12 +599,21 @@ void AddSortingKernels(VectorKernel base, VectorFunction* func) {
base.exec = GenerateNumeric<ExecTemplate, UInt64Type>(*physical_type);
DCHECK_OK(func->AddKernel(base));
}
+ for (const auto id : DecimalTypeIds()) {
+ base.signature = KernelSignature::Make({InputType::Array(id)}, uint64());
+ base.exec = GenerateDecimal<ExecTemplate, UInt64Type>(id);
+ DCHECK_OK(func->AddKernel(base));
+ }
for (const auto& ty : BaseBinaryTypes()) {
auto physical_type = GetPhysicalType(ty);
base.signature = KernelSignature::Make({InputType::Array(ty)}, uint64());
base.exec = GenerateVarBinaryBase<ExecTemplate, UInt64Type>(*physical_type);
DCHECK_OK(func->AddKernel(base));
}
+ base.signature =
+ KernelSignature::Make({InputType::Array(Type::FIXED_SIZE_BINARY)}, uint64());
+ base.exec = ExecTemplate<UInt64Type, FixedSizeBinaryType>::Exec;
+ DCHECK_OK(func->AddKernel(base));
}
// ----------------------------------------------------------------------
@@ -617,7 +640,7 @@ class ChunkedArrayCompareSorter {
std::stable_sort(indices_begin, nulls_begin, [&](uint64_t left, uint64_t right) {
const auto chunk_left = resolver.Resolve<ArrayType>(left);
const auto chunk_right = resolver.Resolve<ArrayType>(right);
- return chunk_left.GetView() < chunk_right.GetView();
+ return chunk_left.Value() < chunk_right.Value();
});
} else {
std::stable_sort(indices_begin, nulls_begin, [&](uint64_t left, uint64_t right) {
@@ -625,7 +648,7 @@ class ChunkedArrayCompareSorter {
const auto chunk_right = resolver.Resolve<ArrayType>(right);
// We don't use 'left > right' here to reduce required operator.
// If we use 'right < left' here, '<' is only required.
- return chunk_right.GetView() < chunk_left.GetView();
+ return chunk_right.Value() < chunk_left.Value();
});
}
return nulls_begin;
@@ -786,7 +809,7 @@ class ChunkedArraySorter : public TypeVisitor {
[&](uint64_t left, uint64_t right) {
const auto chunk_left = left_resolver.Resolve<ArrayType>(left);
const auto chunk_right = right_resolver.Resolve<ArrayType>(right);
- return chunk_left.GetView() < chunk_right.GetView();
+ return chunk_left.Value() < chunk_right.Value();
});
} else {
std::merge(indices_begin, indices_middle, indices_middle, indices_end, temp_indices,
@@ -796,7 +819,7 @@ class ChunkedArraySorter : public TypeVisitor {
// We don't use 'left > right' here to reduce required
// operator. If we use 'right < left' here, '<' is only
// required.
- return chunk_right.GetView() < chunk_left.GetView();
+ return chunk_right.Value() < chunk_left.Value();
});
}
// Copy back temp area into main buffer
@@ -822,14 +845,16 @@ class ChunkedArraySorter : public TypeVisitor {
template <typename ArrayType, typename Visitor>
void VisitConstantRanges(const ArrayType& array, uint64_t* indices_begin,
uint64_t* indices_end, Visitor&& visit) {
+ using GetView = GetViewType<typename ArrayType::TypeClass>;
+
if (indices_begin == indices_end) {
return;
}
auto range_start = indices_begin;
auto range_cur = range_start;
- auto last_value = array.GetView(*range_cur);
+ auto last_value = GetView::LogicalValue(array.GetView(*range_cur));
while (++range_cur != indices_end) {
- auto v = array.GetView(*range_cur);
+ auto v = GetView::LogicalValue(array.GetView(*range_cur));
if (v != last_value) {
visit(range_start, range_cur);
range_start = range_cur;
@@ -869,6 +894,8 @@ class ConcreteRecordBatchColumnSorter : public RecordBatchColumnSorter {
null_count_(array_.null_count()) {}
void SortRange(uint64_t* indices_begin, uint64_t* indices_end) {
+ using GetView = GetViewType<Type>;
+
constexpr int64_t offset = 0;
uint64_t* nulls_begin;
if (null_count_ == 0) {
@@ -889,14 +916,18 @@ class ConcreteRecordBatchColumnSorter : public RecordBatchColumnSorter {
if (order_ == SortOrder::Ascending) {
std::stable_sort(
indices_begin, null_likes_begin, [&](uint64_t left, uint64_t right) {
- return array_.GetView(left - offset) < array_.GetView(right - offset);
+ const auto lhs = GetView::LogicalValue(array_.GetView(left - offset));
+ const auto rhs = GetView::LogicalValue(array_.GetView(right - offset));
+ return lhs < rhs;
});
} else {
std::stable_sort(
indices_begin, null_likes_begin, [&](uint64_t left, uint64_t right) {
// We don't use 'left > right' here to reduce required operator.
// If we use 'right < left' here, '<' is only required.
- return array_.GetView(right - offset) < array_.GetView(left - offset);
+ const auto lhs = GetView::LogicalValue(array_.GetView(left - offset));
+ const auto rhs = GetView::LogicalValue(array_.GetView(right - offset));
+ return lhs > rhs;
});
}
@@ -1100,8 +1131,8 @@ class MultipleKeyComparator {
const ResolvedChunk<typename TypeTraits<Type>::ArrayType>& chunk_left,
const ResolvedChunk<typename TypeTraits<Type>::ArrayType>& chunk_right,
const SortOrder order) {
- const auto left = chunk_left.GetView();
- const auto right = chunk_right.GetView();
+ const auto left = chunk_left.Value();
+ const auto right = chunk_right.Value();
int32_t compared;
if (left == right) {
compared = 0;
@@ -1122,8 +1153,8 @@ class MultipleKeyComparator {
const ResolvedChunk<typename TypeTraits<Type>::ArrayType>& chunk_left,
const ResolvedChunk<typename TypeTraits<Type>::ArrayType>& chunk_right,
const SortOrder order) {
- const auto left = chunk_left.GetView();
- const auto right = chunk_right.GetView();
+ const auto left = chunk_left.Value();
+ const auto right = chunk_right.Value();
auto is_nan_left = std::isnan(left);
auto is_nan_right = std::isnan(right);
if (is_nan_left && is_nan_right) {
@@ -1439,8 +1470,8 @@ class MultipleKeyTableSorter : public TypeVisitor {
// Both values are never null nor NaN.
auto chunk_left = first_sort_key.GetChunk<ArrayType>(left);
auto chunk_right = first_sort_key.GetChunk<ArrayType>(right);
- auto value_left = chunk_left.GetView();
- auto value_right = chunk_right.GetView();
+ auto value_left = chunk_left.Value();
+ auto value_right = chunk_right.Value();
if (value_left == value_right) {
// If the left value equals to the right value,
// we need to compare the second and following
@@ -1502,7 +1533,7 @@ class MultipleKeyTableSorter : public TypeVisitor {
DCHECK_EQ(indices_end_ - nulls_begin, first_sort_key.null_count);
uint64_t* nans_begin = partitioner(indices_begin_, nulls_begin, [&](uint64_t index) {
const auto chunk = first_sort_key.GetChunk<ArrayType>(index);
- return !std::isnan(chunk.GetView());
+ return !std::isnan(chunk.Value());
});
auto& comparator = comparator_;
// Sort all NaNs by the second and following sort keys.
diff --git a/cpp/src/arrow/compute/kernels/vector_sort_test.cc b/cpp/src/arrow/compute/kernels/vector_sort_test.cc
index cbeaacf..a54890e 100644
--- a/cpp/src/arrow/compute/kernels/vector_sort_test.cc
+++ b/cpp/src/arrow/compute/kernels/vector_sort_test.cc
@@ -15,11 +15,13 @@
// specific language governing permissions and limitations
// under the License.
+#include <functional>
#include <limits>
#include <memory>
#include <string>
#include <vector>
+#include "arrow/array/array_decimal.h"
#include "arrow/array/concatenate.h"
#include "arrow/compute/api_vector.h"
#include "arrow/table.h"
@@ -67,16 +69,33 @@ TypeToDataType() {
// Tests for NthToIndices
template <typename ArrayType>
+auto GetLogicalValue(const ArrayType& array, uint64_t index)
+ -> decltype(array.GetView(index)) {
+ return array.GetView(index);
+}
+
+Decimal128 GetLogicalValue(const Decimal128Array& array, uint64_t index) {
+ return Decimal128(array.Value(index));
+}
+
+Decimal256 GetLogicalValue(const Decimal256Array& array, uint64_t index) {
+ return Decimal256(array.Value(index));
+}
+
+template <typename ArrayType>
class NthComparator {
public:
bool operator()(const ArrayType& array, uint64_t lhs, uint64_t rhs) {
if (array.IsNull(rhs)) return true;
if (array.IsNull(lhs)) return false;
+ const auto lval = GetLogicalValue(array, lhs);
+ const auto rval = GetLogicalValue(array, rhs);
if (is_floating_type<typename ArrayType::TypeClass>::value) {
- if (array.GetView(rhs) != array.GetView(rhs)) return true;
- if (array.GetView(lhs) != array.GetView(lhs)) return false;
+ // NaNs ordered after non-NaNs
+ if (rval != rval) return true;
+ if (lval != lval) return false;
}
- return array.GetView(lhs) <= array.GetView(rhs);
+ return lval <= rval;
}
};
@@ -87,28 +106,29 @@ class SortComparator {
if (array.IsNull(rhs) && array.IsNull(lhs)) return lhs < rhs;
if (array.IsNull(rhs)) return true;
if (array.IsNull(lhs)) return false;
+ const auto lval = GetLogicalValue(array, lhs);
+ const auto rval = GetLogicalValue(array, rhs);
if (is_floating_type<typename ArrayType::TypeClass>::value) {
- const bool lhs_isnan = array.GetView(lhs) != array.GetView(lhs);
- const bool rhs_isnan = array.GetView(rhs) != array.GetView(rhs);
+ const bool lhs_isnan = lval != lval;
+ const bool rhs_isnan = rval != rval;
if (lhs_isnan && rhs_isnan) return lhs < rhs;
if (rhs_isnan) return true;
if (lhs_isnan) return false;
}
- if (array.GetView(lhs) == array.GetView(rhs)) return lhs < rhs;
+ if (lval == rval) return lhs < rhs;
if (order == SortOrder::Ascending) {
- return array.GetView(lhs) < array.GetView(rhs);
+ return lval < rval;
} else {
- return array.GetView(lhs) > array.GetView(rhs);
+ return lval > rval;
}
}
};
template <typename ArrowType>
-class TestNthToIndices : public TestBase {
+class TestNthToIndicesBase : public TestBase {
using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
- private:
- template <typename ArrayType>
+ protected:
void Validate(const ArrayType& array, int n, UInt64Array& offsets) {
if (n >= array.length()) {
for (int i = 0; i < array.length(); ++i) {
@@ -129,21 +149,26 @@ class TestNthToIndices : public TestBase {
}
}
- protected:
void AssertNthToIndicesArray(const std::shared_ptr<Array> values, int n) {
ASSERT_OK_AND_ASSIGN(std::shared_ptr<Array> offsets, NthToIndices(*values, n));
// null_count field should have been initialized to 0, for convenience
ASSERT_EQ(offsets->data()->null_count, 0);
ASSERT_OK(offsets->ValidateFull());
- Validate<ArrayType>(*checked_pointer_cast<ArrayType>(values), n,
- *checked_pointer_cast<UInt64Array>(offsets));
+ Validate(*checked_pointer_cast<ArrayType>(values), n,
+ *checked_pointer_cast<UInt64Array>(offsets));
}
void AssertNthToIndicesJson(const std::string& values, int n) {
AssertNthToIndicesArray(ArrayFromJSON(GetType(), values), n);
}
- std::shared_ptr<DataType> GetType() { return TypeToDataType<ArrowType>(); }
+ virtual std::shared_ptr<DataType> GetType() = 0;
+};
+
+template <typename ArrowType>
+class TestNthToIndices : public TestNthToIndicesBase<ArrowType> {
+ protected:
+ std::shared_ptr<DataType> GetType() override { return TypeToDataType<ArrowType>(); }
};
template <typename ArrowType>
@@ -159,6 +184,14 @@ class TestNthToIndicesForTemporal : public TestNthToIndices<ArrowType> {};
TYPED_TEST_SUITE(TestNthToIndicesForTemporal, TemporalArrowTypes);
template <typename ArrowType>
+class TestNthToIndicesForDecimal : public TestNthToIndicesBase<ArrowType> {
+ std::shared_ptr<DataType> GetType() override {
+ return std::make_shared<ArrowType>(5, 2);
+ }
+};
+TYPED_TEST_SUITE(TestNthToIndicesForDecimal, DecimalArrowTypes);
+
+template <typename ArrowType>
class TestNthToIndicesForStrings : public TestNthToIndices<ArrowType> {};
TYPED_TEST_SUITE(TestNthToIndicesForStrings, testing::Types<StringType>);
@@ -196,6 +229,14 @@ TYPED_TEST(TestNthToIndicesForTemporal, Temporal) {
this->AssertNthToIndicesJson("[null, 1, 3, null, 2, 5]", 6);
}
+TYPED_TEST(TestNthToIndicesForDecimal, Decimal) {
+ const std::string values = R"(["123.45", null, "-123.45", "456.78", "-456.78"])";
+ this->AssertNthToIndicesJson(values, 0);
+ this->AssertNthToIndicesJson(values, 2);
+ this->AssertNthToIndicesJson(values, 4);
+ this->AssertNthToIndicesJson(values, 5);
+}
+
TYPED_TEST(TestNthToIndicesForStrings, Strings) {
this->AssertNthToIndicesJson(R"(["testing", null, "nth", "for", null, "strings"])", 0);
this->AssertNthToIndicesJson(R"(["testing", null, "nth", "for", null, "strings"])", 2);
@@ -204,31 +245,38 @@ TYPED_TEST(TestNthToIndicesForStrings, Strings) {
}
template <typename ArrowType>
-class TestNthToIndicesRandom : public TestNthToIndices<ArrowType> {};
+class TestNthToIndicesRandom : public TestNthToIndicesBase<ArrowType> {
+ public:
+ std::shared_ptr<DataType> GetType() override {
+ EXPECT_TRUE(0) << "shouldn't be used";
+ return nullptr;
+ }
+};
using NthToIndicesableTypes =
::testing::Types<UInt8Type, UInt16Type, UInt32Type, UInt64Type, Int8Type, Int16Type,
- Int32Type, Int64Type, FloatType, DoubleType, StringType>;
+ Int32Type, Int64Type, FloatType, DoubleType, Decimal128Type,
+ StringType>;
class RandomImpl {
protected:
- random::RandomArrayGenerator generator;
+ random::RandomArrayGenerator generator_;
+ std::shared_ptr<DataType> type_;
+
+ explicit RandomImpl(random::SeedType seed, std::shared_ptr<DataType> type)
+ : generator_(seed), type_(std::move(type)) {}
public:
- explicit RandomImpl(random::SeedType seed) : generator(seed) {}
+ std::shared_ptr<Array> Generate(uint64_t count, double null_prob) {
+ return generator_.ArrayOf(type_, count, null_prob);
+ }
};
template <typename ArrowType>
class Random : public RandomImpl {
- using CType = typename TypeTraits<ArrowType>::CType;
-
public:
- explicit Random(random::SeedType seed) : RandomImpl(seed) {}
-
- std::shared_ptr<Array> Generate(uint64_t count, double null_prob) {
- return generator.Numeric<ArrowType>(count, std::numeric_limits<CType>::min(),
- std::numeric_limits<CType>::max(), null_prob);
- }
+ explicit Random(random::SeedType seed)
+ : RandomImpl(seed, TypeTraits<ArrowType>::type_singleton()) {}
};
template <>
@@ -236,11 +284,11 @@ class Random<FloatType> : public RandomImpl {
using CType = float;
public:
- explicit Random(random::SeedType seed) : RandomImpl(seed) {}
+ explicit Random(random::SeedType seed) : RandomImpl(seed, float32()) {}
std::shared_ptr<Array> Generate(uint64_t count, double null_prob, double nan_prob = 0) {
- return generator.Float32(count, std::numeric_limits<CType>::min(),
- std::numeric_limits<CType>::max(), null_prob, nan_prob);
+ return generator_.Float32(count, std::numeric_limits<CType>::min(),
+ std::numeric_limits<CType>::max(), null_prob, nan_prob);
}
};
@@ -249,22 +297,20 @@ class Random<DoubleType> : public RandomImpl {
using CType = double;
public:
- explicit Random(random::SeedType seed) : RandomImpl(seed) {}
+ explicit Random(random::SeedType seed) : RandomImpl(seed, float64()) {}
std::shared_ptr<Array> Generate(uint64_t count, double null_prob, double nan_prob = 0) {
- return generator.Float64(count, std::numeric_limits<CType>::min(),
- std::numeric_limits<CType>::max(), null_prob, nan_prob);
+ return generator_.Float64(count, std::numeric_limits<CType>::min(),
+ std::numeric_limits<CType>::max(), null_prob, nan_prob);
}
};
template <>
-class Random<StringType> : public RandomImpl {
+class Random<Decimal128Type> : public RandomImpl {
public:
- explicit Random(random::SeedType seed) : RandomImpl(seed) {}
-
- std::shared_ptr<Array> Generate(uint64_t count, double null_prob) {
- return generator.String(count, 1, 100, null_prob);
- }
+ explicit Random(random::SeedType seed,
+ std::shared_ptr<DataType> type = decimal128(18, 5))
+ : RandomImpl(seed, std::move(type)) {}
};
template <typename ArrowType>
@@ -272,7 +318,8 @@ class RandomRange : public RandomImpl {
using CType = typename TypeTraits<ArrowType>::CType;
public:
- explicit RandomRange(random::SeedType seed) : RandomImpl(seed) {}
+ explicit RandomRange(random::SeedType seed)
+ : RandomImpl(seed, TypeTraits<ArrowType>::type_singleton()) {}
std::shared_ptr<Array> Generate(uint64_t count, int range, double null_prob) {
CType min = std::numeric_limits<CType>::min();
@@ -280,7 +327,7 @@ class RandomRange : public RandomImpl {
if (sizeof(CType) < 4 && (range + min) > std::numeric_limits<CType>::max()) {
max = std::numeric_limits<CType>::max();
}
- return generator.Numeric<ArrowType>(count, min, max, null_prob);
+ return generator_.Numeric<ArrowType>(count, min, max, null_prob);
}
};
@@ -325,12 +372,13 @@ void AssertSortIndices(const std::shared_ptr<T>& input, Options&& options,
ArrayFromJSON(uint64(), expected));
}
-template <typename ArrowType>
-class TestArraySortIndicesKernel : public TestBase {
+class TestArraySortIndicesBase : public TestBase {
public:
+ virtual std::shared_ptr<DataType> type() = 0;
+
virtual void AssertSortIndices(const std::string& values, SortOrder order,
const std::string& expected) {
- auto type = TypeToDataType<ArrowType>();
+ auto type = this->type();
arrow::compute::AssertSortIndices(ArrayFromJSON(type, values), order,
ArrayFromJSON(uint64(), expected));
}
@@ -341,25 +389,38 @@ class TestArraySortIndicesKernel : public TestBase {
};
template <typename ArrowType>
-class TestArraySortIndicesKernelForReal : public TestArraySortIndicesKernel<ArrowType> {};
-TYPED_TEST_SUITE(TestArraySortIndicesKernelForReal, RealArrowTypes);
+class TestArraySortIndices : public TestArraySortIndicesBase {
+ public:
+ std::shared_ptr<DataType> type() override {
+ // Will choose default parameters for temporal types
+ return std::make_shared<ArrowType>();
+ }
+};
template <typename ArrowType>
-class TestArraySortIndicesKernelForIntegral
- : public TestArraySortIndicesKernel<ArrowType> {};
-TYPED_TEST_SUITE(TestArraySortIndicesKernelForIntegral, IntegralArrowTypes);
+class TestArraySortIndicesForReal : public TestArraySortIndices<ArrowType> {};
+TYPED_TEST_SUITE(TestArraySortIndicesForReal, RealArrowTypes);
template <typename ArrowType>
-class TestArraySortIndicesKernelForTemporal
- : public TestArraySortIndicesKernel<ArrowType> {};
-TYPED_TEST_SUITE(TestArraySortIndicesKernelForTemporal, TemporalArrowTypes);
+class TestArraySortIndicesForIntegral : public TestArraySortIndices<ArrowType> {};
+TYPED_TEST_SUITE(TestArraySortIndicesForIntegral, IntegralArrowTypes);
template <typename ArrowType>
-class TestArraySortIndicesKernelForStrings
- : public TestArraySortIndicesKernel<ArrowType> {};
-TYPED_TEST_SUITE(TestArraySortIndicesKernelForStrings, testing::Types<StringType>);
+class TestArraySortIndicesForTemporal : public TestArraySortIndices<ArrowType> {};
+TYPED_TEST_SUITE(TestArraySortIndicesForTemporal, TemporalArrowTypes);
+
+using StringSortTestTypes = testing::Types<StringType, LargeStringType>;
+
+template <typename ArrowType>
+class TestArraySortIndicesForStrings : public TestArraySortIndices<ArrowType> {};
+TYPED_TEST_SUITE(TestArraySortIndicesForStrings, StringSortTestTypes);
+
+class TestArraySortIndicesForFixedSizeBinary : public TestArraySortIndicesBase {
+ public:
+ std::shared_ptr<DataType> type() override { return fixed_size_binary(3); }
+};
-TYPED_TEST(TestArraySortIndicesKernelForReal, SortReal) {
+TYPED_TEST(TestArraySortIndicesForReal, SortReal) {
this->AssertSortIndices("[]", "[]");
this->AssertSortIndices("[3.4, 2.6, 6.3]", "[1, 0, 2]");
@@ -384,7 +445,7 @@ TYPED_TEST(TestArraySortIndicesKernelForReal, SortReal) {
"[1, 2, 0, 3]");
}
-TYPED_TEST(TestArraySortIndicesKernelForIntegral, SortIntegral) {
+TYPED_TEST(TestArraySortIndicesForIntegral, SortIntegral) {
this->AssertSortIndices("[]", "[]");
this->AssertSortIndices("[3, 2, 6]", "[1, 0, 2]");
@@ -402,7 +463,7 @@ TYPED_TEST(TestArraySortIndicesKernelForIntegral, SortIntegral) {
"[5, 2, 4, 1, 0, 3]");
}
-TYPED_TEST(TestArraySortIndicesKernelForTemporal, SortTemporal) {
+TYPED_TEST(TestArraySortIndicesForTemporal, SortTemporal) {
this->AssertSortIndices("[]", "[]");
this->AssertSortIndices("[3, 2, 6]", "[1, 0, 2]");
@@ -420,7 +481,7 @@ TYPED_TEST(TestArraySortIndicesKernelForTemporal, SortTemporal) {
"[5, 2, 4, 1, 0, 3]");
}
-TYPED_TEST(TestArraySortIndicesKernelForStrings, SortStrings) {
+TYPED_TEST(TestArraySortIndicesForStrings, SortStrings) {
this->AssertSortIndices("[]", "[]");
this->AssertSortIndices(R"(["a", "b", "c"])", "[0, 1, 2]");
@@ -433,37 +494,58 @@ TYPED_TEST(TestArraySortIndicesKernelForStrings, SortStrings) {
"[0, 1, 3, 2]");
}
+TEST_F(TestArraySortIndicesForFixedSizeBinary, SortFixedSizeBinary) {
+ this->AssertSortIndices("[]", "[]");
+
+ this->AssertSortIndices(R"(["def", "abc", "ghi"])", "[1, 0, 2]");
+ this->AssertSortIndices(R"(["def", "abc", "ghi"])", SortOrder::Descending, "[2, 0, 1]");
+}
+
template <typename ArrowType>
-class TestArraySortIndicesKernelForUInt8 : public TestArraySortIndicesKernel<ArrowType> {
-};
-TYPED_TEST_SUITE(TestArraySortIndicesKernelForUInt8, UInt8Type);
+class TestArraySortIndicesForUInt8 : public TestArraySortIndices<ArrowType> {};
+TYPED_TEST_SUITE(TestArraySortIndicesForUInt8, UInt8Type);
template <typename ArrowType>
-class TestArraySortIndicesKernelForInt8 : public TestArraySortIndicesKernel<ArrowType> {};
-TYPED_TEST_SUITE(TestArraySortIndicesKernelForInt8, Int8Type);
+class TestArraySortIndicesForInt8 : public TestArraySortIndices<ArrowType> {};
+TYPED_TEST_SUITE(TestArraySortIndicesForInt8, Int8Type);
-TYPED_TEST(TestArraySortIndicesKernelForUInt8, SortUInt8) {
+TYPED_TEST(TestArraySortIndicesForUInt8, SortUInt8) {
this->AssertSortIndices("[255, null, 0, 255, 10, null, 128, 0]",
"[2, 7, 4, 6, 0, 3, 1, 5]");
}
-TYPED_TEST(TestArraySortIndicesKernelForInt8, SortInt8) {
+TYPED_TEST(TestArraySortIndicesForInt8, SortInt8) {
this->AssertSortIndices("[null, 10, 127, 0, -128, -128, null]",
"[4, 5, 3, 1, 2, 0, 6]");
}
template <typename ArrowType>
-class TestArraySortIndicesKernelRandom : public TestBase {};
+class TestArraySortIndicesForDecimal : public TestArraySortIndicesBase {
+ public:
+ std::shared_ptr<DataType> type() override { return std::make_shared<ArrowType>(5, 2); }
+};
+TYPED_TEST_SUITE(TestArraySortIndicesForDecimal, DecimalArrowTypes);
+
+TYPED_TEST(TestArraySortIndicesForDecimal, DecimalSortTestTypes) {
+ this->AssertSortIndices(R"(["123.45", null, "-123.45", "456.78", "-456.78"])",
+ "[4, 2, 0, 3, 1]");
+ this->AssertSortIndices(R"(["123.45", null, "-123.45", "456.78", "-456.78"])",
+ SortOrder::Descending, "[3, 0, 2, 4, 1]");
+}
+
+template <typename ArrowType>
+class TestArraySortIndicesRandom : public TestBase {};
template <typename ArrowType>
-class TestArraySortIndicesKernelRandomCount : public TestBase {};
+class TestArraySortIndicesRandomCount : public TestBase {};
template <typename ArrowType>
-class TestArraySortIndicesKernelRandomCompare : public TestBase {};
+class TestArraySortIndicesRandomCompare : public TestBase {};
using SortIndicesableTypes =
::testing::Types<UInt8Type, UInt16Type, UInt32Type, UInt64Type, Int8Type, Int16Type,
- Int32Type, Int64Type, FloatType, DoubleType, StringType>;
+ Int32Type, Int64Type, FloatType, DoubleType, StringType,
+ Decimal128Type>;
template <typename ArrayType>
void ValidateSorted(const ArrayType& array, UInt64Array& offsets, SortOrder order) {
@@ -476,9 +558,9 @@ void ValidateSorted(const ArrayType& array, UInt64Array& offsets, SortOrder orde
}
}
-TYPED_TEST_SUITE(TestArraySortIndicesKernelRandom, SortIndicesableTypes);
+TYPED_TEST_SUITE(TestArraySortIndicesRandom, SortIndicesableTypes);
-TYPED_TEST(TestArraySortIndicesKernelRandom, SortRandomValues) {
+TYPED_TEST(TestArraySortIndicesRandom, SortRandomValues) {
using ArrayType = typename TypeTraits<TypeParam>::ArrayType;
Random<TypeParam> rand(0x5487655);
@@ -499,9 +581,9 @@ TYPED_TEST(TestArraySortIndicesKernelRandom, SortRandomValues) {
// Long array with small value range: counting sort
// - length >= 1024(CountCompareSorter::countsort_min_len_)
// - range <= 4096(CountCompareSorter::countsort_max_range_)
-TYPED_TEST_SUITE(TestArraySortIndicesKernelRandomCount, IntegralArrowTypes);
+TYPED_TEST_SUITE(TestArraySortIndicesRandomCount, IntegralArrowTypes);
-TYPED_TEST(TestArraySortIndicesKernelRandomCount, SortRandomValuesCount) {
+TYPED_TEST(TestArraySortIndicesRandomCount, SortRandomValuesCount) {
using ArrayType = typename TypeTraits<TypeParam>::ArrayType;
RandomRange<TypeParam> rand(0x5487656);
@@ -521,9 +603,9 @@ TYPED_TEST(TestArraySortIndicesKernelRandomCount, SortRandomValuesCount) {
}
// Long array with big value range: std::stable_sort
-TYPED_TEST_SUITE(TestArraySortIndicesKernelRandomCompare, IntegralArrowTypes);
+TYPED_TEST_SUITE(TestArraySortIndicesRandomCompare, IntegralArrowTypes);
-TYPED_TEST(TestArraySortIndicesKernelRandomCompare, SortRandomValuesCompare) {
+TYPED_TEST(TestArraySortIndicesRandomCompare, SortRandomValuesCompare) {
using ArrayType = typename TypeTraits<TypeParam>::ArrayType;
Random<TypeParam> rand(0x5487657);
@@ -583,6 +665,22 @@ TYPED_TEST(TestChunkedArraySortIndicesForTemporal, NoNull) {
AssertSortIndices(chunked_array, SortOrder::Descending, "[5, 2, 3, 1, 4, 0, 6]");
}
+// Tests for decimal types
+template <typename ArrowType>
+class TestChunkedArraySortIndicesForDecimal : public TestChunkedArraySortIndices {
+ protected:
+ std::shared_ptr<DataType> GetType() { return std::make_shared<ArrowType>(5, 2); }
+};
+TYPED_TEST_SUITE(TestChunkedArraySortIndicesForDecimal, DecimalArrowTypes);
+
+TYPED_TEST(TestChunkedArraySortIndicesForDecimal, Basics) {
+ auto type = this->GetType();
+ auto chunked_array = ChunkedArrayFromJSON(
+ type, {R"(["123.45", "-123.45"])", R"([null, "456.78"])", R"(["-456.78", null])"});
+ AssertSortIndices(chunked_array, SortOrder::Ascending, "[4, 1, 0, 3, 2, 5]");
+ AssertSortIndices(chunked_array, SortOrder::Descending, "[3, 0, 1, 4, 2, 5]");
+}
+
// Base class for testing against random chunked array.
template <typename Type>
class TestChunkedArrayRandomBase : public TestBase {
@@ -631,6 +729,7 @@ class TestChunkedArrayRandom : public TestChunkedArrayRandomBase<Type> {
Random<Type>* rand_;
};
TYPED_TEST_SUITE(TestChunkedArrayRandom, SortIndicesableTypes);
+
TYPED_TEST(TestChunkedArrayRandom, SortIndices) { this->TestSortIndices(1000); }
// Long array with small value range: counting sort
@@ -746,19 +845,39 @@ TEST_F(TestRecordBatchSortIndices, MoreTypes) {
auto schema = ::arrow::schema({
{field("a", timestamp(TimeUnit::MICRO))},
{field("b", large_utf8())},
+ {field("c", fixed_size_binary(3))},
+ });
+ SortOptions options({SortKey("a", SortOrder::Ascending),
+ SortKey("b", SortOrder::Descending),
+ SortKey("c", SortOrder::Ascending)});
+
+ auto batch = RecordBatchFromJSON(schema,
+ R"([{"a": 3, "b": "05", "c": "aaa"},
+ {"a": 1, "b": "031", "c": "bbb"},
+ {"a": 3, "b": "05", "c": "bbb"},
+ {"a": 0, "b": "0666", "c": "aaa"},
+ {"a": 2, "b": "05", "c": "aaa"},
+ {"a": 1, "b": "05", "c": "bbb"}
+ ])");
+ AssertSortIndices(batch, options, "[3, 5, 1, 4, 0, 2]");
+}
+
+TEST_F(TestRecordBatchSortIndices, Decimal) {
+ auto schema = ::arrow::schema({
+ {field("a", decimal128(3, 1))},
+ {field("b", decimal256(4, 2))},
});
SortOptions options(
{SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Descending)});
auto batch = RecordBatchFromJSON(schema,
- R"([{"a": 3, "b": "05"},
- {"a": 1, "b": "03"},
- {"a": 3, "b": "04"},
- {"a": 0, "b": "06"},
- {"a": 2, "b": "05"},
- {"a": 1, "b": "05"}
+ R"([{"a": "12.3", "b": "12.34"},
+ {"a": "45.6", "b": "12.34"},
+ {"a": "12.3", "b": "-12.34"},
+ {"a": "-12.3", "b": null},
+ {"a": "-12.3", "b": "-45.67"}
])");
- AssertSortIndices(batch, options, "[3, 5, 1, 4, 0, 2]");
+ AssertSortIndices(batch, options, "[4, 3, 0, 2, 1]");
}
// Test basic cases for table.
@@ -860,6 +979,44 @@ TEST_F(TestTableSortIndices, NaNAndNull) {
AssertSortIndices(table, options, "[7, 1, 2, 6, 5, 4, 0, 3]");
}
+TEST_F(TestTableSortIndices, BinaryLike) {
+ auto schema = ::arrow::schema({
+ {field("a", large_utf8())},
+ {field("b", fixed_size_binary(3))},
+ });
+ SortOptions options(
+ {SortKey("a", SortOrder::Descending), SortKey("b", SortOrder::Ascending)});
+ auto table = TableFromJSON(schema, {R"([{"a": "one", "b": null},
+ {"a": "two", "b": "aaa"},
+ {"a": "three", "b": "bbb"},
+ {"a": "four", "b": "ccc"}
+ ])",
+ R"([{"a": "one", "b": "ddd"},
+ {"a": "two", "b": "ccc"},
+ {"a": "three", "b": "bbb"},
+ {"a": "four", "b": "aaa"}
+ ])"});
+ AssertSortIndices(table, options, "[1, 5, 2, 6, 4, 0, 7, 3]");
+}
+
+TEST_F(TestTableSortIndices, Decimal) {
+ auto schema = ::arrow::schema({
+ {field("a", decimal128(3, 1))},
+ {field("b", decimal256(4, 2))},
+ });
+ SortOptions options(
+ {SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Descending)});
+
+ auto table = TableFromJSON(schema, {R"([{"a": "12.3", "b": "12.34"},
+ {"a": "45.6", "b": "12.34"},
+ {"a": "12.3", "b": "-12.34"}
+ ])",
+ R"([{"a": "-12.3", "b": null},
+ {"a": "-12.3", "b": "-45.67"}
+ ])"});
+ AssertSortIndices(table, options, "[4, 3, 0, 2, 1]");
+}
+
// Tests for temporal types
template <typename ArrowType>
class TestTableSortIndicesForTemporal : public TestTableSortIndices {
@@ -947,6 +1104,7 @@ class TestTableSortIndicesRandom : public testing::TestWithParam<RandomParam> {
VISIT(Float)
VISIT(Double)
VISIT(String)
+ VISIT(Decimal128)
#undef VISIT
@@ -974,8 +1132,10 @@ class TestTableSortIndicesRandom : public testing::TestWithParam<RandomParam> {
template <typename Type>
int CompareType() {
using ArrayType = typename TypeTraits<Type>::ArrayType;
- auto lhs_value = checked_cast<const ArrayType*>(lhs_array_)->GetView(lhs_index_);
- auto rhs_value = checked_cast<const ArrayType*>(rhs_array_)->GetView(rhs_index_);
+ auto lhs_value =
+ GetLogicalValue(checked_cast<const ArrayType&>(*lhs_array_), lhs_index_);
+ auto rhs_value =
+ GetLogicalValue(checked_cast<const ArrayType&>(*rhs_array_), rhs_index_);
if (is_floating_type<Type>::value) {
lhs_isnan_ = lhs_value != lhs_value;
rhs_isnan_ = rhs_value != rhs_value;
@@ -1026,20 +1186,17 @@ TEST_P(TestTableSortIndicesRandom, Sort) {
const auto first_sort_key_name = std::get<0>(GetParam());
const auto null_probability = std::get<1>(GetParam());
const auto seed = 0x61549225;
- std::vector<std::string> column_names = {
- "uint8", "uint16", "uint32", "uint64", "int8", "int16",
- "int32", "int64", "float", "double", "string",
- };
- std::vector<std::shared_ptr<Field>> fields = {
- {field(column_names[0], uint8())}, {field(column_names[1], uint16())},
- {field(column_names[2], uint32())}, {field(column_names[3], uint64())},
- {field(column_names[4], int8())}, {field(column_names[5], int16())},
- {field(column_names[6], int32())}, {field(column_names[7], int64())},
- {field(column_names[8], float32())}, {field(column_names[9], float64())},
- {field(column_names[10], utf8())},
+
+ const FieldVector fields = {
+ {field("uint8", uint8())}, {field("uint16", uint16())},
+ {field("uint32", uint32())}, {field("uint64", uint64())},
+ {field("int8", int8())}, {field("int16", int16())},
+ {field("int32", int32())}, {field("int64", int64())},
+ {field("float", float32())}, {field("double", float64())},
+ {field("string", utf8())}, {field("decimal128", decimal128(18, 3))},
};
const auto length = 200;
- std::vector<std::shared_ptr<Array>> columns = {
+ ArrayVector columns = {
Random<UInt8Type>(seed).Generate(length, null_probability),
Random<UInt16Type>(seed).Generate(length, 0.0),
Random<UInt32Type>(seed).Generate(length, null_probability),
@@ -1051,22 +1208,27 @@ TEST_P(TestTableSortIndicesRandom, Sort) {
Random<FloatType>(seed).Generate(length, null_probability, 1 - null_probability),
Random<DoubleType>(seed).Generate(length, 0.0, null_probability),
Random<StringType>(seed).Generate(length, null_probability),
+ Random<Decimal128Type>(seed, fields[11]->type()).Generate(length, null_probability),
};
const auto table = Table::Make(schema(fields), columns, length);
+
+ // Generate random sort keys
std::default_random_engine engine(seed);
std::uniform_int_distribution<> distribution(0);
- const auto n_sort_keys = 5;
+ const auto n_sort_keys = 7;
std::vector<SortKey> sort_keys;
const auto first_sort_key_order =
(distribution(engine) % 2) == 0 ? SortOrder::Ascending : SortOrder::Descending;
sort_keys.emplace_back(first_sort_key_name, first_sort_key_order);
for (int i = 1; i < n_sort_keys; ++i) {
- const auto& column_name = column_names[distribution(engine) % column_names.size()];
+ const auto& field = *fields[distribution(engine) % fields.size()];
const auto order =
(distribution(engine) % 2) == 0 ? SortOrder::Ascending : SortOrder::Descending;
- sort_keys.emplace_back(column_name, order);
+ sort_keys.emplace_back(field.name(), order);
}
SortOptions options(sort_keys);
+
+ // Test with different table chunkings
for (const int64_t num_chunks : {1, 2, 20}) {
TableBatchReader reader(*table);
reader.set_chunksize((length + num_chunks - 1) / num_chunks);
@@ -1074,6 +1236,7 @@ TEST_P(TestTableSortIndicesRandom, Sort) {
ASSERT_OK_AND_ASSIGN(auto offsets, SortIndices(Datum(*chunked_table), options));
Validate(*table, options, *checked_pointer_cast<UInt64Array>(offsets));
}
+
// Also validate RecordBatch sorting
TableBatchReader reader(*table);
RecordBatchVector batches;
@@ -1083,26 +1246,18 @@ TEST_P(TestTableSortIndicesRandom, Sort) {
Validate(*table, options, *checked_pointer_cast<UInt64Array>(offsets));
}
+static const auto first_sort_keys =
+ testing::Values("uint8", "uint16", "uint32", "uint64", "int8", "int16", "int32",
+ "int64", "float", "double", "string", "decimal128");
+
INSTANTIATE_TEST_SUITE_P(NoNull, TestTableSortIndicesRandom,
- testing::Combine(testing::Values("uint8", "uint16", "uint32",
- "uint64", "int8", "int16",
- "int32", "int64", "float",
- "double", "string"),
- testing::Values(0.0)));
+ testing::Combine(first_sort_keys, testing::Values(0.0)));
INSTANTIATE_TEST_SUITE_P(MayNull, TestTableSortIndicesRandom,
- testing::Combine(testing::Values("uint8", "uint16", "uint32",
- "uint64", "int8", "int16",
- "int32", "int64", "float",
- "double", "string"),
- testing::Values(0.1, 0.5)));
+ testing::Combine(first_sort_keys, testing::Values(0.1, 0.5)));
INSTANTIATE_TEST_SUITE_P(AllNull, TestTableSortIndicesRandom,
- testing::Combine(testing::Values("uint8", "uint16", "uint32",
- "uint64", "int8", "int16",
- "int32", "int64", "float",
- "double", "string"),
- testing::Values(1.0)));
+ testing::Combine(first_sort_keys, testing::Values(1.0)));
} // namespace compute
} // namespace arrow
diff --git a/cpp/src/arrow/testing/gtest_util.h b/cpp/src/arrow/testing/gtest_util.h
index 718d2a3..ff3b751 100644
--- a/cpp/src/arrow/testing/gtest_util.h
+++ b/cpp/src/arrow/testing/gtest_util.h
@@ -193,6 +193,8 @@ using IntegralArrowTypes = ::testing::Types<UInt8Type, UInt16Type, UInt32Type, U
using TemporalArrowTypes =
::testing::Types<Date32Type, Date64Type, TimestampType, Time32Type, Time64Type>;
+using DecimalArrowTypes = ::testing::Types<Decimal128Type, Decimal256Type>;
+
class Array;
class ChunkedArray;
class RecordBatch;
diff --git a/cpp/src/arrow/testing/random.cc b/cpp/src/arrow/testing/random.cc
index 92ac5f8..7bf5dd2 100644
--- a/cpp/src/arrow/testing/random.cc
+++ b/cpp/src/arrow/testing/random.cc
@@ -18,6 +18,7 @@
#include "arrow/testing/random.h"
#include <algorithm>
+#include <cmath>
#include <limits>
#include <memory>
#include <random>
@@ -27,6 +28,7 @@
#include <gtest/gtest.h>
#include "arrow/array.h"
+#include "arrow/array/builder_decimal.h"
#include "arrow/array/builder_primitive.h"
#include "arrow/buffer.h"
#include "arrow/testing/gtest_util.h"
@@ -36,11 +38,13 @@
#include "arrow/util/bit_util.h"
#include "arrow/util/bitmap_reader.h"
#include "arrow/util/checked_cast.h"
+#include "arrow/util/decimal.h"
#include "arrow/util/logging.h"
namespace arrow {
using internal::checked_cast;
+using internal::checked_pointer_cast;
namespace random {
@@ -220,6 +224,36 @@ std::shared_ptr<Array> RandomArrayGenerator::Float64(int64_t size, double min, d
#undef PRIMITIVE_RAND_INTEGER_IMPL
#undef PRIMITIVE_RAND_IMPL
+std::shared_ptr<Array> RandomArrayGenerator::Decimal128(std::shared_ptr<DataType> type,
+ int64_t size,
+ double null_probability) {
+ const auto& decimal_type = checked_cast<const Decimal128Type&>(*type);
+ const auto digits = decimal_type.precision();
+ if (digits > 18) {
+ // More than 18 digits + sign don't fit in a int64_t
+ ABORT_NOT_OK(
+ Status::NotImplemented("random decimal128 generation with precision > 18"));
+ }
+
+ // Generate logical values as integers, then convert them
+ const auto max = static_cast<int64_t>(std::llround(std::pow(10.0, digits)) - 1);
+ const auto int_array =
+ checked_pointer_cast<Int64Array>(Int64(size, -max, max, null_probability));
+
+ Decimal128Builder builder(type);
+ ABORT_NOT_OK(builder.Reserve(size));
+ for (int64_t i = 0; i < size; ++i) {
+ if (int_array->IsValid(i)) {
+ builder.UnsafeAppend(::arrow::Decimal128(int_array->Value(i)));
+ } else {
+ builder.UnsafeAppendNull();
+ }
+ }
+ std::shared_ptr<Array> array;
+ ABORT_NOT_OK(builder.Finish(&array));
+ return array;
+}
+
template <typename TypeClass>
static std::shared_ptr<Array> GenerateBinaryArray(RandomArrayGenerator* gen, int64_t size,
int32_t min_length, int32_t max_length,
@@ -480,7 +514,7 @@ struct RandomArrayGeneratorOfImpl {
}
template <typename T>
- enable_if_fixed_size_binary<T, Status> Visit(const T& t) {
+ enable_if_t<std::is_same<T, FixedSizeBinaryType>::value, Status> Visit(const T& t) {
const int32_t value_size = t.byte_width();
int64_t data_nbytes = size_ * value_size;
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Buffer> data, AllocateBuffer(data_nbytes));
@@ -494,12 +528,18 @@ struct RandomArrayGeneratorOfImpl {
return Status::OK();
}
+ Status Visit(const Decimal128Type&) {
+ out_ = rag_->Decimal128(type_, size_, null_probability_);
+ return Status::OK();
+ }
+
Status Visit(const DataType& t) {
return Status::NotImplemented("generation of random arrays of type ", t);
}
std::shared_ptr<Array> Finish() && {
DCHECK_OK(VisitTypeInline(*type_, this));
+ DCHECK(type_->Equals(out_->type()));
return std::move(out_);
}
diff --git a/cpp/src/arrow/testing/random.h b/cpp/src/arrow/testing/random.h
index f874fae..2358ab0 100644
--- a/cpp/src/arrow/testing/random.h
+++ b/cpp/src/arrow/testing/random.h
@@ -225,6 +225,17 @@ class ARROW_TESTING_EXPORT RandomArrayGenerator {
}
}
+ /// \brief Generate a random Decimal128Array
+ ///
+ /// \param[in] type the type of the array to generate
+ /// (must be an instance of Decimal128Type)
+ /// \param[in] size the size of the array to generate
+ /// \param[in] null_probability the probability of a row being null
+ ///
+ /// \return a generated Array
+ std::shared_ptr<Array> Decimal128(std::shared_ptr<DataType> type, int64_t size,
+ double null_probability = 0);
+
/// \brief Generate an array of offsets (for use in e.g. ListArray::FromArrays)
///
/// \param[in] size the size of the array to generate
diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc
index 9192c32..0a9f505 100644
--- a/cpp/src/arrow/type.cc
+++ b/cpp/src/arrow/type.cc
@@ -2248,7 +2248,7 @@ std::shared_ptr<DataType> decimal256(int32_t precision, int32_t scale) {
std::string Decimal128Type::ToString() const {
std::stringstream s;
- s << "decimal(" << precision_ << ", " << scale_ << ")";
+ s << "decimal128(" << precision_ << ", " << scale_ << ")";
return s.str();
}
diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h
index fafe333..60311f1 100644
--- a/cpp/src/arrow/type.h
+++ b/cpp/src/arrow/type.h
@@ -913,7 +913,7 @@ class ARROW_EXPORT Decimal128Type : public DecimalType {
public:
static constexpr Type::type type_id = Type::DECIMAL128;
- static constexpr const char* type_name() { return "decimal"; }
+ static constexpr const char* type_name() { return "decimal128"; }
/// Decimal128Type constructor that aborts on invalid input.
explicit Decimal128Type(int32_t precision, int32_t scale);
@@ -922,7 +922,7 @@ class ARROW_EXPORT Decimal128Type : public DecimalType {
static Result<std::shared_ptr<DataType>> Make(int32_t precision, int32_t scale);
std::string ToString() const override;
- std::string name() const override { return "decimal"; }
+ std::string name() const override { return "decimal128"; }
static constexpr int32_t kMinPrecision = 1;
static constexpr int32_t kMaxPrecision = 38;
diff --git a/cpp/src/arrow/type_test.cc b/cpp/src/arrow/type_test.cc
index da93e32..fd7fd01 100644
--- a/cpp/src/arrow/type_test.cc
+++ b/cpp/src/arrow/type_test.cc
@@ -1669,7 +1669,7 @@ TEST(TypesTest, TestDecimal128Small) {
EXPECT_EQ(t1.precision(), 8);
EXPECT_EQ(t1.scale(), 4);
- EXPECT_EQ(t1.ToString(), std::string("decimal(8, 4)"));
+ EXPECT_EQ(t1.ToString(), std::string("decimal128(8, 4)"));
// Test properties
EXPECT_EQ(t1.byte_width(), 16);
@@ -1683,7 +1683,7 @@ TEST(TypesTest, TestDecimal128Medium) {
EXPECT_EQ(t1.precision(), 12);
EXPECT_EQ(t1.scale(), 5);
- EXPECT_EQ(t1.ToString(), std::string("decimal(12, 5)"));
+ EXPECT_EQ(t1.ToString(), std::string("decimal128(12, 5)"));
// Test properties
EXPECT_EQ(t1.byte_width(), 16);
@@ -1697,7 +1697,7 @@ TEST(TypesTest, TestDecimal128Large) {
EXPECT_EQ(t1.precision(), 27);
EXPECT_EQ(t1.scale(), 7);
- EXPECT_EQ(t1.ToString(), std::string("decimal(27, 7)"));
+ EXPECT_EQ(t1.ToString(), std::string("decimal128(27, 7)"));
// Test properties
EXPECT_EQ(t1.byte_width(), 16);
diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst
index 7c2eae1..0001633 100644
--- a/docs/source/cpp/compute.rst
+++ b/docs/source/cpp/compute.rst
@@ -168,7 +168,8 @@ To avoid exhaustively listing supported types, the tables below use a number
of general type categories:
* "Numeric": Integer types (Int8, etc.) and Floating-point types (Float32,
- Float64, sometimes Float16). Some functions also accept Decimal128 input.
+ Float64, sometimes Float16). Some functions also accept Decimal128 and
+ Decimal256 input.
* "Temporal": Date types (Date32, Date64), Time types (Time32, Time64),
Timestamp, Duration, Interval.
diff --git a/ruby/red-arrow/test/test-decimal128-data-type.rb b/ruby/red-arrow/test/test-decimal128-data-type.rb
index 6cdd22f..5390a7a 100644
--- a/ruby/red-arrow/test/test-decimal128-data-type.rb
+++ b/ruby/red-arrow/test/test-decimal128-data-type.rb
@@ -18,12 +18,12 @@
class Decimal128DataTypeTest < Test::Unit::TestCase
sub_test_case(".new") do
test("ordered arguments") do
- assert_equal("decimal(8, 2)",
+ assert_equal("decimal128(8, 2)",
Arrow::Decimal128DataType.new(8, 2).to_s)
end
test("description") do
- assert_equal("decimal(8, 2)",
+ assert_equal("decimal128(8, 2)",
Arrow::Decimal128DataType.new(precision: 8,
scale: 2).to_s)
end