You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ap...@apache.org on 2023/06/20 07:28:01 UTC
[arrow] branch main updated: GH-33206: [C++] Add support for StructArray sorting and nested sort keys (#35727)
This is an automated email from the ASF dual-hosted git repository.
apitrou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new 0f6522018a GH-33206: [C++] Add support for StructArray sorting and nested sort keys (#35727)
0f6522018a is described below
commit 0f6522018a3524fb2775934cd8712e4c94667974
Author: Ben Harkins <60...@users.noreply.github.com>
AuthorDate: Tue Jun 20 03:27:54 2023 -0400
GH-33206: [C++] Add support for StructArray sorting and nested sort keys (#35727)
### Rationale for this change
We don't currently support sorting `StructArray`s despite already having the high-level facilities to do so. For instance, we allow passing multiple sort keys (based on `FieldRef`s) to sort record batches and tables - but the current implementations are fairly limited since nested refs aren't allowed (due to the burden of null flattening). Since https://github.com/apache/arrow/pull/35197, we now have an easier way to do this.
### What changes are included in this PR?
- Adds support for `StructArray` in `sort_indices`
- Adds support for nested sort keys in `sort_indices` for `RecordBatch`, `ChunkedArray`, and `Table`
### Are these changes tested?
Yes (tests are included)
### Are there any user-facing changes?
Yes
* Closes: #33206
Authored-by: benibus <bp...@gmx.com>
Signed-off-by: Antoine Pitrou <an...@python.org>
---
cpp/src/arrow/compute/kernels/vector_array_sort.cc | 23 +-
cpp/src/arrow/compute/kernels/vector_sort.cc | 235 ++++++++++++++++-----
.../arrow/compute/kernels/vector_sort_internal.h | 58 +++--
cpp/src/arrow/compute/kernels/vector_sort_test.cc | 115 ++++++++--
4 files changed, 350 insertions(+), 81 deletions(-)
diff --git a/cpp/src/arrow/compute/kernels/vector_array_sort.cc b/cpp/src/arrow/compute/kernels/vector_array_sort.cc
index 1499554a96..b9cfd38a02 100644
--- a/cpp/src/arrow/compute/kernels/vector_array_sort.cc
+++ b/cpp/src/arrow/compute/kernels/vector_array_sort.cc
@@ -262,6 +262,19 @@ class ArrayCompareSorter<DictionaryType> {
}
};
+template <>
+class ArrayCompareSorter<StructType> {
+ public:
+ Result<NullPartitionResult> operator()(uint64_t* indices_begin, uint64_t* indices_end,
+ const Array& array, int64_t offset,
+ const ArraySortOptions& options,
+ ExecContext* ctx) {
+ const auto& struct_array = checked_cast<const StructArray&>(array);
+ return SortStructArray(ctx, indices_begin, indices_end, struct_array, options.order,
+ options.null_placement);
+ }
+};
+
template <typename ArrowType>
class ArrayCountSorter {
using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
@@ -497,7 +510,7 @@ template <typename Type>
struct ArraySorter<
Type, enable_if_t<is_floating_type<Type>::value || is_base_binary_type<Type>::value ||
is_fixed_size_binary_type<Type>::value ||
- is_dictionary_type<Type>::value>> {
+ is_dictionary_type<Type>::value || is_struct_type<Type>::value>> {
ArrayCompareSorter<Type> impl;
};
@@ -606,6 +619,13 @@ void AddDictArraySortingKernels(VectorKernel base, VectorFunction* func) {
DCHECK_OK(func->AddKernel(base));
}
+template <template <typename...> class ExecTemplate>
+void AddStructArraySortingKernels(VectorKernel base, VectorFunction* func) {
+ base.signature = KernelSignature::Make({Type::STRUCT}, uint64());
+ base.exec = ExecTemplate<UInt64Type, StructType>::Exec;
+ DCHECK_OK(func->AddKernel(base));
+}
+
const ArraySortOptions* GetDefaultArraySortOptions() {
static const auto kDefaultArraySortOptions = ArraySortOptions::Defaults();
return &kDefaultArraySortOptions;
@@ -661,6 +681,7 @@ void RegisterVectorArraySort(FunctionRegistry* registry) {
base.exec_chunked = ArraySortIndicesChunked;
AddArraySortingKernels<ArraySortIndices>(base, array_sort_indices.get());
AddDictArraySortingKernels<ArraySortIndices>(base, array_sort_indices.get());
+ AddStructArraySortingKernels<ArraySortIndices>(base, array_sort_indices.get());
DCHECK_OK(registry->AddFunction(std::move(array_sort_indices)));
// partition_nth_indices has a parameter so needs its init function
diff --git a/cpp/src/arrow/compute/kernels/vector_sort.cc b/cpp/src/arrow/compute/kernels/vector_sort.cc
index 1de90cac35..5ee8cbaf6e 100644
--- a/cpp/src/arrow/compute/kernels/vector_sort.cc
+++ b/cpp/src/arrow/compute/kernels/vector_sort.cc
@@ -20,6 +20,9 @@
#include "arrow/compute/kernels/vector_sort_internal.h"
#include "arrow/compute/registry.h"
+template <>
+struct std::hash<arrow::FieldPath> : public arrow::FieldPath::Hash {};
+
namespace arrow {
using internal::checked_cast;
@@ -337,26 +340,56 @@ class ConcreteRecordBatchColumnSorter<NullType> : public RecordBatchColumnSorter
const NullPlacement null_placement_;
};
+Result<std::vector<ResolvedRecordBatchSortKey>> ResolveRecordBatchSortKeys(
+ const RecordBatch& batch, const std::vector<SortKey>& sort_keys) {
+ return ::arrow::compute::internal::ResolveSortKeys<ResolvedRecordBatchSortKey>(
+ batch, sort_keys);
+}
+
+std::vector<ResolvedRecordBatchSortKey> ResolveRecordBatchSortKeys(
+ const RecordBatch& batch, const std::vector<SortKey>& sort_keys, Status* status) {
+ const auto maybe_resolved = ResolveRecordBatchSortKeys(batch, sort_keys);
+ if (!maybe_resolved.ok()) {
+ *status = maybe_resolved.status();
+ return {};
+ }
+ return *std::move(maybe_resolved);
+}
+
+// Radix sorting is consistently faster except when there is a large number of sort keys,
+// in which case it can end up degrading catastrophically. This establishes a cutoff point
+// where we should use a different strategy.
+constexpr int kMaxRadixSortKeys = 8;
+
// Sort a batch using a single-pass left-to-right radix sort.
class RadixRecordBatchSorter {
public:
+ using ResolvedSortKey = ResolvedRecordBatchSortKey;
+
+ RadixRecordBatchSorter(uint64_t* indices_begin, uint64_t* indices_end,
+ std::vector<ResolvedSortKey> sort_keys,
+ const SortOptions& options)
+ : sort_keys_(std::move(sort_keys)),
+ options_(options),
+ indices_begin_(indices_begin),
+ indices_end_(indices_end) {}
+
RadixRecordBatchSorter(uint64_t* indices_begin, uint64_t* indices_end,
const RecordBatch& batch, const SortOptions& options)
- : batch_(batch),
+ : sort_keys_(ResolveRecordBatchSortKeys(batch, options.sort_keys, &status_)),
options_(options),
indices_begin_(indices_begin),
indices_end_(indices_end) {}
// Offset is for table sorting
Result<NullPartitionResult> Sort(int64_t offset = 0) {
- ARROW_ASSIGN_OR_RAISE(const auto sort_keys,
- ResolveSortKeys(batch_, options_.sort_keys));
+ ARROW_RETURN_NOT_OK(status_);
// Create column sorters from right to left
- std::vector<std::unique_ptr<RecordBatchColumnSorter>> column_sorts(sort_keys.size());
+ std::vector<std::unique_ptr<RecordBatchColumnSorter>> column_sorts(sort_keys_.size());
RecordBatchColumnSorter* next_column = nullptr;
- for (int64_t i = static_cast<int64_t>(sort_keys.size() - 1); i >= 0; --i) {
- ColumnSortFactory factory(sort_keys[i], options_, next_column);
+ for (int64_t i = static_cast<int64_t>(sort_keys_.size() - 1); i >= 0; --i) {
+ ColumnSortFactory factory(sort_keys_[i], options_, next_column);
ARROW_ASSIGN_OR_RAISE(column_sorts[i], factory.MakeColumnSort());
next_column = column_sorts[i].get();
}
@@ -366,16 +399,11 @@ class RadixRecordBatchSorter {
}
protected:
- struct ResolvedSortKey {
- std::shared_ptr<Array> array;
- SortOrder order;
- };
-
struct ColumnSortFactory {
ColumnSortFactory(const ResolvedSortKey& sort_key, const SortOptions& options,
RecordBatchColumnSorter* next_column)
- : physical_type(GetPhysicalType(sort_key.array->type())),
- array(GetPhysicalArray(*sort_key.array, physical_type)),
+ : physical_type(sort_key.type),
+ array(sort_key.owned_array),
order(sort_key.order),
null_placement(options.null_placement),
next_column(next_column) {}
@@ -419,19 +447,27 @@ class RadixRecordBatchSorter {
return ::arrow::compute::internal::ResolveSortKeys<ResolvedSortKey>(batch, sort_keys);
}
- const RecordBatch& batch_;
+ const std::vector<ResolvedSortKey> sort_keys_;
const SortOptions& options_;
uint64_t* indices_begin_;
uint64_t* indices_end_;
+ Status status_;
};
// Sort a batch using a single sort and multiple-key comparisons.
class MultipleKeyRecordBatchSorter : public TypeVisitor {
- private:
+ public:
using ResolvedSortKey = ResolvedRecordBatchSortKey;
- using Comparator = MultipleKeyComparator<ResolvedSortKey>;
- public:
+ MultipleKeyRecordBatchSorter(uint64_t* indices_begin, uint64_t* indices_end,
+ std::vector<ResolvedSortKey> sort_keys,
+ const SortOptions& options)
+ : indices_begin_(indices_begin),
+ indices_end_(indices_end),
+ sort_keys_(std::move(sort_keys)),
+ null_placement_(options.null_placement),
+ comparator_(sort_keys_, null_placement_) {}
+
MultipleKeyRecordBatchSorter(uint64_t* indices_begin, uint64_t* indices_end,
const RecordBatch& batch, const SortOptions& options)
: indices_begin_(indices_begin),
@@ -457,6 +493,8 @@ class MultipleKeyRecordBatchSorter : public TypeVisitor {
#undef VISIT
private:
+ using Comparator = MultipleKeyComparator<ResolvedSortKey>;
+
static std::vector<ResolvedSortKey> ResolveSortKeys(
const RecordBatch& batch, const std::vector<SortKey>& sort_keys, Status* status) {
const auto maybe_resolved =
@@ -850,12 +888,22 @@ class SortIndicesMetaFunction : public MetaFunction {
ExecContext* ctx) const override {
const SortOptions& sort_options = static_cast<const SortOptions&>(*options);
switch (args[0].kind()) {
- case Datum::ARRAY:
- return SortIndices(*args[0].make_array(), sort_options, ctx);
- break;
- case Datum::CHUNKED_ARRAY:
- return SortIndices(*args[0].chunked_array(), sort_options, ctx);
- break;
+ case Datum::ARRAY: {
+ auto array = args[0].make_array();
+ if (array->type_id() == Type::STRUCT) {
+ ARROW_ASSIGN_OR_RAISE(auto batch, RecordBatch::FromStructArray(array))
+ return SortIndices(*batch, sort_options, ctx);
+ }
+ return SortIndices(*array, sort_options, ctx);
+ } break;
+ case Datum::CHUNKED_ARRAY: {
+ const auto& chunked_array = args[0].chunked_array();
+ if (chunked_array->type()->id() == Type::STRUCT) {
+ ARROW_ASSIGN_OR_RAISE(auto table, ToTable(chunked_array))
+ return SortIndices(*table, sort_options, ctx);
+ }
+ return SortIndices(*chunked_array, sort_options, ctx);
+ } break;
case Datum::RECORD_BATCH: {
return SortIndices(*args[0].record_batch(), sort_options, ctx);
} break;
@@ -872,6 +920,22 @@ class SortIndicesMetaFunction : public MetaFunction {
}
private:
+ static Result<std::shared_ptr<Table>> ToTable(
+ const std::shared_ptr<ChunkedArray>& chunked_array) {
+ if (chunked_array->null_count() == 0) {
+ return Table::FromChunkedStructArray(chunked_array);
+ }
+ // We avoid using `Table::FromChunkedStructArray` here since it doesn't take top-level
+ // validity into account for the columns.
+ //
+ // TODO: We could instead use the provided sort keys to only flatten the selected
+ // columns (via `GetFlattenedField`). Same for the Array -> RecordBatch conversion,
+ // since `RecordBatch::FromStructArray` flattens all columns as well.
+ ARROW_ASSIGN_OR_RAISE(auto columns, chunked_array->Flatten());
+ return Table::Make(schema(chunked_array->type()->fields()), std::move(columns),
+ chunked_array->length());
+ }
+
Result<Datum> SortIndices(const Array& values, const SortOptions& options,
ExecContext* ctx) const {
SortOrder order = SortOrder::Ascending;
@@ -908,14 +972,15 @@ class SortIndicesMetaFunction : public MetaFunction {
Result<Datum> SortIndices(const RecordBatch& batch, const SortOptions& options,
ExecContext* ctx) const {
- auto n_sort_keys = options.sort_keys.size();
+ ARROW_ASSIGN_OR_RAISE(auto sort_keys,
+ ResolveRecordBatchSortKeys(batch, options.sort_keys));
+
+ auto n_sort_keys = sort_keys.size();
if (n_sort_keys == 0) {
return Status::Invalid("Must specify one or more sort keys");
}
if (n_sort_keys == 1) {
- ARROW_ASSIGN_OR_RAISE(auto array, PrependInvalidColumn(GetColumn(
- batch, options.sort_keys[0].target)));
- return SortIndices(*array, options, ctx);
+ return SortIndices(sort_keys[0].array, options, ctx);
}
auto out_type = uint64();
@@ -930,14 +995,12 @@ class SortIndicesMetaFunction : public MetaFunction {
auto out_end = out_begin + length;
std::iota(out_begin, out_end, 0);
- // Radix sorting is consistently faster except when there is a large number
- // of sort keys, in which case it can end up degrading catastrophically.
- // Cut off above 8 sort keys.
- if (n_sort_keys <= 8) {
- RadixRecordBatchSorter sorter(out_begin, out_end, batch, options);
+ if (n_sort_keys <= kMaxRadixSortKeys) {
+ RadixRecordBatchSorter sorter(out_begin, out_end, std::move(sort_keys), options);
ARROW_RETURN_NOT_OK(sorter.Sort());
} else {
- MultipleKeyRecordBatchSorter sorter(out_begin, out_end, batch, options);
+ MultipleKeyRecordBatchSorter sorter(out_begin, out_end, std::move(sort_keys),
+ options);
ARROW_RETURN_NOT_OK(sorter.Sort());
}
return Datum(out);
@@ -949,11 +1012,6 @@ class SortIndicesMetaFunction : public MetaFunction {
if (n_sort_keys == 0) {
return Status::Invalid("Must specify one or more sort keys");
}
- if (n_sort_keys == 1) {
- ARROW_ASSIGN_OR_RAISE(auto chunked_array, PrependInvalidColumn(GetColumn(
- table, options.sort_keys[0].target)));
- return SortIndices(*chunked_array, options, ctx);
- }
auto out_type = uint64();
auto length = table.num_rows();
@@ -974,25 +1032,67 @@ class SortIndicesMetaFunction : public MetaFunction {
}
};
+// Helper for processing a vector of `SortKeys` into a vector of `SortFields`
+struct SortFieldPopulator {
+ public:
+ // Process sort keys with the given schema. Note that the output vector may be larger
+ // than the input, as keys referencing a struct are recursively "expanded" into leaf
+ // fields.
+ Result<std::vector<SortField>> FindSortKeys(const Schema& schema,
+ const std::vector<SortKey>& sort_keys) {
+ sort_fields_.reserve(sort_keys.size());
+ seen_.reserve(sort_keys.size());
+
+ for (const auto& sort_key : sort_keys) {
+ ARROW_ASSIGN_OR_RAISE(auto match,
+ PrependInvalidColumn(sort_key.target.FindOne(schema)));
+ if (seen_.insert(match).second) {
+ ARROW_ASSIGN_OR_RAISE(auto schema_field, match.Get(schema));
+ AddField(*schema_field->type(), match, sort_key.order);
+ }
+ }
+
+ return std::move(sort_fields_);
+ }
+
+ protected:
+ void AddLeafFields(const FieldVector& fields, SortOrder order) {
+ if (fields.empty()) {
+ return;
+ }
+
+ tmp_indices_.push_back(0);
+ for (const auto& f : fields) {
+ const auto& type = *f->type();
+ if (type.id() == Type::STRUCT) {
+ AddLeafFields(type.fields(), order);
+ } else {
+ sort_fields_.emplace_back(FieldPath(tmp_indices_), order, &type);
+ }
+ ++tmp_indices_.back();
+ }
+ tmp_indices_.pop_back();
+ }
+
+ void AddField(const DataType& type, const FieldPath& path, SortOrder order) {
+ if (type.id() == Type::STRUCT) {
+ tmp_indices_ = path.indices();
+ AddLeafFields(type.fields(), order);
+ } else {
+ sort_fields_.emplace_back(path, order, &type);
+ }
+ }
+
+ std::vector<SortField> sort_fields_;
+ std::unordered_set<FieldPath> seen_;
+ std::vector<int> tmp_indices_;
+};
+
} // namespace
Result<std::vector<SortField>> FindSortKeys(const Schema& schema,
const std::vector<SortKey>& sort_keys) {
- std::vector<SortField> fields;
- std::unordered_set<int> seen;
- fields.reserve(sort_keys.size());
- seen.reserve(sort_keys.size());
-
- for (const auto& sort_key : sort_keys) {
- RETURN_NOT_OK(CheckNonNested(sort_key.target));
-
- ARROW_ASSIGN_OR_RAISE(auto match,
- PrependInvalidColumn(sort_key.target.FindOne(schema)));
- if (seen.insert(match[0]).second) {
- fields.push_back({match[0], sort_key.order});
- }
- }
- return fields;
+ return SortFieldPopulator{}.FindSortKeys(schema, sort_keys);
}
Result<NullPartitionResult> SortChunkedArray(ExecContext* ctx, uint64_t* indices_begin,
@@ -1017,6 +1117,35 @@ Result<NullPartitionResult> SortChunkedArray(
return output;
}
+Result<NullPartitionResult> SortStructArray(ExecContext* ctx, uint64_t* indices_begin,
+ uint64_t* indices_end,
+ const StructArray& array,
+ SortOrder sort_order,
+ NullPlacement null_placement) {
+ ARROW_ASSIGN_OR_RAISE(auto columns, array.Flatten());
+ auto batch = RecordBatch::Make(schema(array.type()->fields()), array.length(),
+ std::move(columns));
+
+ auto options = SortOptions::Defaults();
+ options.null_placement = null_placement;
+ options.sort_keys.reserve(array.num_fields());
+ for (int i = 0; i < array.num_fields(); ++i) {
+ options.sort_keys.push_back(SortKey(FieldRef(i), sort_order));
+ }
+
+ ARROW_ASSIGN_OR_RAISE(auto sort_keys,
+ ResolveRecordBatchSortKeys(*batch, options.sort_keys));
+ if (sort_keys.size() <= kMaxRadixSortKeys) {
+ RadixRecordBatchSorter sorter(indices_begin, indices_end, std::move(sort_keys),
+ options);
+ return sorter.Sort();
+ } else {
+ MultipleKeyRecordBatchSorter sorter(indices_begin, indices_end, std::move(sort_keys),
+ options);
+ return sorter.Sort();
+ }
+}
+
void RegisterVectorSort(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunction(std::make_shared<SortIndicesMetaFunction>()));
}
diff --git a/cpp/src/arrow/compute/kernels/vector_sort_internal.h b/cpp/src/arrow/compute/kernels/vector_sort_internal.h
index d78e513061..d7e5575c80 100644
--- a/cpp/src/arrow/compute/kernels/vector_sort_internal.h
+++ b/cpp/src/arrow/compute/kernels/vector_sort_internal.h
@@ -464,12 +464,27 @@ Result<NullPartitionResult> SortChunkedArray(
const std::shared_ptr<DataType>& physical_type, const ArrayVector& physical_chunks,
SortOrder sort_order, NullPlacement null_placement);
+Result<NullPartitionResult> SortStructArray(ExecContext* ctx, uint64_t* indices_begin,
+ uint64_t* indices_end,
+ const StructArray& array,
+ SortOrder sort_order,
+ NullPlacement null_placement);
+
// ----------------------------------------------------------------------
// Helpers for Sort/SelectK/Rank implementations
struct SortField {
- int field_index;
+ SortField() = default;
+ SortField(FieldPath path, SortOrder order, const DataType* type)
+ : path(std::move(path)), order(order), type(type) {}
+ SortField(int index, SortOrder order, const DataType* type)
+ : SortField(FieldPath({index}), order, type) {}
+
+ bool is_nested() const { return path.indices().size() > 1; }
+
+ FieldPath path;
SortOrder order;
+ const DataType* type;
};
inline Status CheckNonNested(const FieldRef& ref) {
@@ -496,7 +511,10 @@ Result<std::vector<ResolvedSortKey>> ResolveSortKeys(
ARROW_ASSIGN_OR_RAISE(const auto fields, FindSortKeys(schema, sort_keys));
std::vector<ResolvedSortKey> resolved;
resolved.reserve(fields.size());
- std::transform(fields.begin(), fields.end(), std::back_inserter(resolved), factory);
+ for (const auto& f : fields) {
+ ARROW_ASSIGN_OR_RAISE(auto resolved_key, factory(f));
+ resolved.push_back(std::move(resolved_key));
+ }
return resolved;
}
@@ -504,8 +522,17 @@ template <typename ResolvedSortKey, typename TableOrBatch>
Result<std::vector<ResolvedSortKey>> ResolveSortKeys(
const TableOrBatch& table_or_batch, const std::vector<SortKey>& sort_keys) {
return ResolveSortKeys<ResolvedSortKey>(
- *table_or_batch.schema(), sort_keys, [&](const SortField& f) {
- return ResolvedSortKey{table_or_batch.column(f.field_index), f.order};
+ *table_or_batch.schema(), sort_keys,
+ [&](const SortField& f) -> Result<ResolvedSortKey> {
+ if (f.is_nested()) {
+ // TODO: Some room for improvement here, as we potentially duplicate some of the
+ // null-flattening work for nested sort keys. For instance, given two keys with
+ // paths [0,0,0,0] and [0,0,0,1], we shouldn't need to flatten the first three
+ // components more than once.
+ ARROW_ASSIGN_OR_RAISE(auto child, f.path.GetFlattened(table_or_batch));
+ return ResolvedSortKey{std::move(child), f.order};
+ }
+ return ResolvedSortKey{table_or_batch.column(f.path[0]), f.order};
});
}
@@ -737,17 +764,20 @@ struct ResolvedTableSortKey {
static Result<std::vector<ResolvedTableSortKey>> Make(
const Table& table, const RecordBatchVector& batches,
const std::vector<SortKey>& sort_keys) {
- auto factory = [&](const SortField& f) {
- const auto& type = table.schema()->field(f.field_index)->type();
+ auto factory = [&](const SortField& f) -> Result<ResolvedTableSortKey> {
// We must expose a homogenous chunking for all ResolvedSortKey,
- // so we can't simply pass `table.column(f.field_index)`
- ArrayVector chunks(batches.size());
- std::transform(batches.begin(), batches.end(), chunks.begin(),
- [&](const std::shared_ptr<RecordBatch>& batch) {
- return batch->column(f.field_index);
- });
- return ResolvedTableSortKey(type, std::move(chunks), f.order,
- table.column(f.field_index)->null_count());
+ // so we can't simply access the column from the table directly.
+ ArrayVector chunks;
+ chunks.reserve(batches.size());
+ int64_t null_count = 0;
+ for (const auto& batch : batches) {
+ ARROW_ASSIGN_OR_RAISE(auto child, f.path.GetFlattened(*batch));
+ null_count += child->null_count();
+ chunks.push_back(std::move(child));
+ }
+
+ return ResolvedTableSortKey(f.type->GetSharedPtr(), std::move(chunks), f.order,
+ null_count);
};
return ::arrow::compute::internal::ResolveSortKeys<ResolvedTableSortKey>(
diff --git a/cpp/src/arrow/compute/kernels/vector_sort_test.cc b/cpp/src/arrow/compute/kernels/vector_sort_test.cc
index 3429a5a878..1328dddc04 100644
--- a/cpp/src/arrow/compute/kernels/vector_sort_test.cc
+++ b/cpp/src/arrow/compute/kernels/vector_sort_test.cc
@@ -515,30 +515,28 @@ TEST(ArraySortIndicesFunction, ChunkedArray) {
// ----------------------------------------------------------------------
// Tests for SortToIndices
-template <typename T>
-void AssertSortIndices(const std::shared_ptr<T>& input, SortOrder order,
- NullPlacement null_placement,
+void AssertSortIndices(const Datum& datum, const SortOptions& options,
const std::shared_ptr<Array>& expected) {
- ArraySortOptions options(order, null_placement);
- ASSERT_OK_AND_ASSIGN(auto actual, SortIndices(*input, options));
+ ASSERT_OK_AND_ASSIGN(auto actual, SortIndices(datum, options));
ValidateOutput(*actual);
AssertArraysEqual(*expected, *actual, /*verbose=*/true);
}
+void AssertSortIndices(const Datum& datum, const SortOptions& options,
+ const std::string& expected) {
+ AssertSortIndices(datum, options, ArrayFromJSON(uint64(), expected));
+}
+
template <typename T>
-void AssertSortIndices(const std::shared_ptr<T>& input, const SortOptions& options,
+void AssertSortIndices(const std::shared_ptr<T>& input, SortOrder order,
+ NullPlacement null_placement,
const std::shared_ptr<Array>& expected) {
- ASSERT_OK_AND_ASSIGN(auto actual, SortIndices(Datum(*input), options));
+ ArraySortOptions options(order, null_placement);
+ ASSERT_OK_AND_ASSIGN(auto actual, SortIndices(*input, options));
ValidateOutput(*actual);
AssertArraysEqual(*expected, *actual, /*verbose=*/true);
}
-template <typename T>
-void AssertSortIndices(const std::shared_ptr<T>& input, const SortOptions& options,
- const std::string& expected) {
- AssertSortIndices(input, options, ArrayFromJSON(uint64(), expected));
-}
-
template <typename T>
void AssertSortIndices(const std::shared_ptr<T>& input, SortOrder order,
NullPlacement null_placement, const std::string& expected) {
@@ -2115,6 +2113,97 @@ INSTANTIATE_TEST_SUITE_P(AllNull, TestTableSortIndicesRandom,
testing::Combine(first_sort_keys, num_sort_keys,
testing::Values(1.0)));
+class TestNestedSortIndices : public ::testing::Test {
+ protected:
+ static std::shared_ptr<Array> GetArray() {
+ auto struct_type = struct_({field(
+ "a",
+ struct_({field("a", uint8()),
+ field("b", struct_({field("a", int32()), field("b", uint32())}))}))});
+ auto struct_array = checked_pointer_cast<StructArray>(
+ ArrayFromJSON(struct_type,
+ R"([{"a": {"a": 5, "b": {"a": null, "b": 8 }}},
+ {"a": {"a": null, "b": {"a": 8, "b": null}}},
+ {"a": {"a": null, "b": {"a": 9, "b": 0 }}},
+ {"a": {"a": 2, "b": {"a": 4, "b": null}}},
+ {"a": {"a": 5, "b": {"a": 1, "b": 8 }}},
+ {"a": {"a": 3, "b": {"a": null, "b": 0 }}},
+ {"a": {"a": 2, "b": {"a": 4, "b": 2 }}},
+ {"a": {"a": 2, "b": {"a": 4, "b": 4 }}},
+ {"a": {"a": null, "b": {"a": 7, "b": 7 }}}])"));
+
+ // The top-level validity bitmap is created independently to test null inheritance for
+ // child fields.
+ std::shared_ptr<Buffer> parent_bitmap;
+ ARROW_CHECK_OK(
+ GetBitmapFromVector<bool>({1, 1, 1, 1, 1, 0, 1, 1, 1}, &parent_bitmap));
+
+ auto array =
+ *StructArray::Make(struct_array->fields(), struct_type->fields(), parent_bitmap);
+ ARROW_CHECK_OK(array->ValidateFull());
+ return array;
+ }
+
+ static std::shared_ptr<RecordBatch> GetRecordBatch() {
+ auto batch = *RecordBatch::FromStructArray(GetArray());
+ ARROW_CHECK_OK(batch->ValidateFull());
+ return batch;
+ }
+
+ static std::shared_ptr<ChunkedArray> GetChunkedArray() {
+ auto array = GetArray();
+ ArrayVector chunks(2);
+ chunks[0] = *array->SliceSafe(0, 3);
+ chunks[1] = *array->SliceSafe(3);
+ auto chunked = *ChunkedArray::Make(std::move(chunks));
+ ARROW_CHECK_OK(chunked->ValidateFull());
+ return chunked;
+ }
+
+ static std::shared_ptr<Table> GetTable() {
+ auto chunked = GetChunkedArray();
+ auto columns = *chunked->Flatten();
+ auto table =
+ Table::Make(arrow::schema(chunked->type()->fields()), std::move(columns));
+ ARROW_CHECK_OK(table->ValidateFull());
+ return table;
+ }
+
+ void TestSort(const Datum& datum) const {
+ std::vector<SortKey> sort_keys = {SortKey(FieldRef("a", "a"), SortOrder::Ascending),
+ SortKey(FieldRef("a", "b"), SortOrder::Descending)};
+
+ SortOptions options(sort_keys, NullPlacement::AtEnd);
+ AssertSortIndices(datum, options, "[7, 6, 3, 4, 0, 2, 1, 8, 5]");
+ options.null_placement = NullPlacement::AtStart;
+ AssertSortIndices(datum, options, "[5, 2, 1, 8, 3, 7, 6, 0, 4]");
+
+ // Implementations may have an optimized path for cases with one sort key.
+ // Additionally, this key references a struct containing another struct, which should
+ // work recursively
+ options.sort_keys = {SortKey(FieldRef("a"), SortOrder::Ascending)};
+ options.null_placement = NullPlacement::AtEnd;
+ AssertSortIndices(datum, options, "[6, 7, 3, 4, 0, 8, 1, 2, 5]");
+ options.null_placement = NullPlacement::AtStart;
+ AssertSortIndices(datum, options, "[5, 8, 1, 2, 3, 6, 7, 0, 4]");
+ }
+
+ void TestArraySort() const {
+ auto array = GetArray();
+ AssertSortIndices(array, SortOrder::Ascending, NullPlacement::AtEnd,
+ "[6, 7, 3, 4, 0, 8, 1, 2, 5]");
+ AssertSortIndices(array, SortOrder::Ascending, NullPlacement::AtStart,
+ "[5, 8, 1, 2, 3, 6, 7, 0, 4]");
+ }
+};
+
+TEST_F(TestNestedSortIndices, ArraySort) { TestArraySort(); }
+
+TEST_F(TestNestedSortIndices, SortStructArray) { TestSort(GetArray()); }
+TEST_F(TestNestedSortIndices, SortChunkedArray) { TestSort(GetChunkedArray()); }
+TEST_F(TestNestedSortIndices, SortRecordBatch) { TestSort(GetRecordBatch()); }
+TEST_F(TestNestedSortIndices, SortTable) { TestSort(GetTable()); }
+
// ----------------------------------------------------------------------
// Tests for Rank