You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ap...@apache.org on 2023/06/20 07:28:01 UTC

[arrow] branch main updated: GH-33206: [C++] Add support for StructArray sorting and nested sort keys (#35727)

This is an automated email from the ASF dual-hosted git repository.

apitrou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new 0f6522018a GH-33206: [C++] Add support for StructArray sorting and nested sort keys (#35727)
0f6522018a is described below

commit 0f6522018a3524fb2775934cd8712e4c94667974
Author: Ben Harkins <60...@users.noreply.github.com>
AuthorDate: Tue Jun 20 03:27:54 2023 -0400

    GH-33206: [C++] Add support for StructArray sorting and nested sort keys (#35727)
    
    
    
    ### Rationale for this change
    
    We don't currently support sorting `StructArray`s despite already having the high-level facilities to do so. For instance, we allow passing multiple sort keys (based on `FieldRef`s) to sort record batches and tables - but the current implementations are fairly limited since nested refs aren't allowed (due to the burden of null flattening). Since https://github.com/apache/arrow/pull/35197, we now have an easier way to do this.
    
    ### What changes are included in this PR?
    
    - Adds support for `StructArray` in `sort_indices`
    - Adds support for nested sort keys in `sort_indices` for `RecordBatch`, `ChunkedArray`, and `Table`
    
    ### Are these changes tested?
    
    Yes (tests are included)
    
    ### Are there any user-facing changes?
    
    Yes
    * Closes: #33206
    
    Authored-by: benibus <bp...@gmx.com>
    Signed-off-by: Antoine Pitrou <an...@python.org>
---
 cpp/src/arrow/compute/kernels/vector_array_sort.cc |  23 +-
 cpp/src/arrow/compute/kernels/vector_sort.cc       | 235 ++++++++++++++++-----
 .../arrow/compute/kernels/vector_sort_internal.h   |  58 +++--
 cpp/src/arrow/compute/kernels/vector_sort_test.cc  | 115 ++++++++--
 4 files changed, 350 insertions(+), 81 deletions(-)

diff --git a/cpp/src/arrow/compute/kernels/vector_array_sort.cc b/cpp/src/arrow/compute/kernels/vector_array_sort.cc
index 1499554a96..b9cfd38a02 100644
--- a/cpp/src/arrow/compute/kernels/vector_array_sort.cc
+++ b/cpp/src/arrow/compute/kernels/vector_array_sort.cc
@@ -262,6 +262,19 @@ class ArrayCompareSorter<DictionaryType> {
   }
 };
 
+template <>
+class ArrayCompareSorter<StructType> {
+ public:
+  Result<NullPartitionResult> operator()(uint64_t* indices_begin, uint64_t* indices_end,
+                                         const Array& array, int64_t offset,
+                                         const ArraySortOptions& options,
+                                         ExecContext* ctx) {
+    const auto& struct_array = checked_cast<const StructArray&>(array);
+    return SortStructArray(ctx, indices_begin, indices_end, struct_array, options.order,
+                           options.null_placement);
+  }
+};
+
 template <typename ArrowType>
 class ArrayCountSorter {
   using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
@@ -497,7 +510,7 @@ template <typename Type>
 struct ArraySorter<
     Type, enable_if_t<is_floating_type<Type>::value || is_base_binary_type<Type>::value ||
                       is_fixed_size_binary_type<Type>::value ||
-                      is_dictionary_type<Type>::value>> {
+                      is_dictionary_type<Type>::value || is_struct_type<Type>::value>> {
   ArrayCompareSorter<Type> impl;
 };
 
@@ -606,6 +619,13 @@ void AddDictArraySortingKernels(VectorKernel base, VectorFunction* func) {
   DCHECK_OK(func->AddKernel(base));
 }
 
+template <template <typename...> class ExecTemplate>
+void AddStructArraySortingKernels(VectorKernel base, VectorFunction* func) {
+  base.signature = KernelSignature::Make({Type::STRUCT}, uint64());
+  base.exec = ExecTemplate<UInt64Type, StructType>::Exec;
+  DCHECK_OK(func->AddKernel(base));
+}
+
 const ArraySortOptions* GetDefaultArraySortOptions() {
   static const auto kDefaultArraySortOptions = ArraySortOptions::Defaults();
   return &kDefaultArraySortOptions;
@@ -661,6 +681,7 @@ void RegisterVectorArraySort(FunctionRegistry* registry) {
   base.exec_chunked = ArraySortIndicesChunked;
   AddArraySortingKernels<ArraySortIndices>(base, array_sort_indices.get());
   AddDictArraySortingKernels<ArraySortIndices>(base, array_sort_indices.get());
+  AddStructArraySortingKernels<ArraySortIndices>(base, array_sort_indices.get());
   DCHECK_OK(registry->AddFunction(std::move(array_sort_indices)));
 
   // partition_nth_indices has a parameter so needs its init function
diff --git a/cpp/src/arrow/compute/kernels/vector_sort.cc b/cpp/src/arrow/compute/kernels/vector_sort.cc
index 1de90cac35..5ee8cbaf6e 100644
--- a/cpp/src/arrow/compute/kernels/vector_sort.cc
+++ b/cpp/src/arrow/compute/kernels/vector_sort.cc
@@ -20,6 +20,9 @@
 #include "arrow/compute/kernels/vector_sort_internal.h"
 #include "arrow/compute/registry.h"
 
+template <>
+struct std::hash<arrow::FieldPath> : public arrow::FieldPath::Hash {};
+
 namespace arrow {
 
 using internal::checked_cast;
@@ -337,26 +340,56 @@ class ConcreteRecordBatchColumnSorter<NullType> : public RecordBatchColumnSorter
   const NullPlacement null_placement_;
 };
 
+Result<std::vector<ResolvedRecordBatchSortKey>> ResolveRecordBatchSortKeys(
+    const RecordBatch& batch, const std::vector<SortKey>& sort_keys) {
+  return ::arrow::compute::internal::ResolveSortKeys<ResolvedRecordBatchSortKey>(
+      batch, sort_keys);
+}
+
+std::vector<ResolvedRecordBatchSortKey> ResolveRecordBatchSortKeys(
+    const RecordBatch& batch, const std::vector<SortKey>& sort_keys, Status* status) {
+  const auto maybe_resolved = ResolveRecordBatchSortKeys(batch, sort_keys);
+  if (!maybe_resolved.ok()) {
+    *status = maybe_resolved.status();
+    return {};
+  }
+  return *std::move(maybe_resolved);
+}
+
+// Radix sorting is consistently faster except when there is a large number of sort keys,
+// in which case it can end up degrading catastrophically. This establishes a cutoff point
+// where we should use a different strategy.
+constexpr int kMaxRadixSortKeys = 8;
+
 // Sort a batch using a single-pass left-to-right radix sort.
 class RadixRecordBatchSorter {
  public:
+  using ResolvedSortKey = ResolvedRecordBatchSortKey;
+
+  RadixRecordBatchSorter(uint64_t* indices_begin, uint64_t* indices_end,
+                         std::vector<ResolvedSortKey> sort_keys,
+                         const SortOptions& options)
+      : sort_keys_(std::move(sort_keys)),
+        options_(options),
+        indices_begin_(indices_begin),
+        indices_end_(indices_end) {}
+
   RadixRecordBatchSorter(uint64_t* indices_begin, uint64_t* indices_end,
                          const RecordBatch& batch, const SortOptions& options)
-      : batch_(batch),
+      : sort_keys_(ResolveRecordBatchSortKeys(batch, options.sort_keys, &status_)),
         options_(options),
         indices_begin_(indices_begin),
         indices_end_(indices_end) {}
 
   // Offset is for table sorting
   Result<NullPartitionResult> Sort(int64_t offset = 0) {
-    ARROW_ASSIGN_OR_RAISE(const auto sort_keys,
-                          ResolveSortKeys(batch_, options_.sort_keys));
+    ARROW_RETURN_NOT_OK(status_);
 
     // Create column sorters from right to left
-    std::vector<std::unique_ptr<RecordBatchColumnSorter>> column_sorts(sort_keys.size());
+    std::vector<std::unique_ptr<RecordBatchColumnSorter>> column_sorts(sort_keys_.size());
     RecordBatchColumnSorter* next_column = nullptr;
-    for (int64_t i = static_cast<int64_t>(sort_keys.size() - 1); i >= 0; --i) {
-      ColumnSortFactory factory(sort_keys[i], options_, next_column);
+    for (int64_t i = static_cast<int64_t>(sort_keys_.size() - 1); i >= 0; --i) {
+      ColumnSortFactory factory(sort_keys_[i], options_, next_column);
       ARROW_ASSIGN_OR_RAISE(column_sorts[i], factory.MakeColumnSort());
       next_column = column_sorts[i].get();
     }
@@ -366,16 +399,11 @@ class RadixRecordBatchSorter {
   }
 
  protected:
-  struct ResolvedSortKey {
-    std::shared_ptr<Array> array;
-    SortOrder order;
-  };
-
   struct ColumnSortFactory {
     ColumnSortFactory(const ResolvedSortKey& sort_key, const SortOptions& options,
                       RecordBatchColumnSorter* next_column)
-        : physical_type(GetPhysicalType(sort_key.array->type())),
-          array(GetPhysicalArray(*sort_key.array, physical_type)),
+        : physical_type(sort_key.type),
+          array(sort_key.owned_array),
           order(sort_key.order),
           null_placement(options.null_placement),
           next_column(next_column) {}
@@ -419,19 +447,27 @@ class RadixRecordBatchSorter {
     return ::arrow::compute::internal::ResolveSortKeys<ResolvedSortKey>(batch, sort_keys);
   }
 
-  const RecordBatch& batch_;
+  const std::vector<ResolvedSortKey> sort_keys_;
   const SortOptions& options_;
   uint64_t* indices_begin_;
   uint64_t* indices_end_;
+  Status status_;
 };
 
 // Sort a batch using a single sort and multiple-key comparisons.
 class MultipleKeyRecordBatchSorter : public TypeVisitor {
- private:
+ public:
   using ResolvedSortKey = ResolvedRecordBatchSortKey;
-  using Comparator = MultipleKeyComparator<ResolvedSortKey>;
 
- public:
+  MultipleKeyRecordBatchSorter(uint64_t* indices_begin, uint64_t* indices_end,
+                               std::vector<ResolvedSortKey> sort_keys,
+                               const SortOptions& options)
+      : indices_begin_(indices_begin),
+        indices_end_(indices_end),
+        sort_keys_(std::move(sort_keys)),
+        null_placement_(options.null_placement),
+        comparator_(sort_keys_, null_placement_) {}
+
   MultipleKeyRecordBatchSorter(uint64_t* indices_begin, uint64_t* indices_end,
                                const RecordBatch& batch, const SortOptions& options)
       : indices_begin_(indices_begin),
@@ -457,6 +493,8 @@ class MultipleKeyRecordBatchSorter : public TypeVisitor {
 #undef VISIT
 
  private:
+  using Comparator = MultipleKeyComparator<ResolvedSortKey>;
+
   static std::vector<ResolvedSortKey> ResolveSortKeys(
       const RecordBatch& batch, const std::vector<SortKey>& sort_keys, Status* status) {
     const auto maybe_resolved =
@@ -850,12 +888,22 @@ class SortIndicesMetaFunction : public MetaFunction {
                             ExecContext* ctx) const override {
     const SortOptions& sort_options = static_cast<const SortOptions&>(*options);
     switch (args[0].kind()) {
-      case Datum::ARRAY:
-        return SortIndices(*args[0].make_array(), sort_options, ctx);
-        break;
-      case Datum::CHUNKED_ARRAY:
-        return SortIndices(*args[0].chunked_array(), sort_options, ctx);
-        break;
+      case Datum::ARRAY: {
+        auto array = args[0].make_array();
+        if (array->type_id() == Type::STRUCT) {
+          ARROW_ASSIGN_OR_RAISE(auto batch, RecordBatch::FromStructArray(array))
+          return SortIndices(*batch, sort_options, ctx);
+        }
+        return SortIndices(*array, sort_options, ctx);
+      } break;
+      case Datum::CHUNKED_ARRAY: {
+        const auto& chunked_array = args[0].chunked_array();
+        if (chunked_array->type()->id() == Type::STRUCT) {
+          ARROW_ASSIGN_OR_RAISE(auto table, ToTable(chunked_array))
+          return SortIndices(*table, sort_options, ctx);
+        }
+        return SortIndices(*chunked_array, sort_options, ctx);
+      } break;
       case Datum::RECORD_BATCH: {
         return SortIndices(*args[0].record_batch(), sort_options, ctx);
       } break;
@@ -872,6 +920,22 @@ class SortIndicesMetaFunction : public MetaFunction {
   }
 
  private:
+  static Result<std::shared_ptr<Table>> ToTable(
+      const std::shared_ptr<ChunkedArray>& chunked_array) {
+    if (chunked_array->null_count() == 0) {
+      return Table::FromChunkedStructArray(chunked_array);
+    }
+    // We avoid using `Table::FromChunkedStructArray` here since it doesn't take top-level
+    // validity into account for the columns.
+    //
+    // TODO: We could instead use the provided sort keys to only flatten the selected
+    // columns (via `GetFlattenedField`). Same for the Array -> RecordBatch conversion,
+    // since `RecordBatch::FromStructArray` flattens all columns as well.
+    ARROW_ASSIGN_OR_RAISE(auto columns, chunked_array->Flatten());
+    return Table::Make(schema(chunked_array->type()->fields()), std::move(columns),
+                       chunked_array->length());
+  }
+
   Result<Datum> SortIndices(const Array& values, const SortOptions& options,
                             ExecContext* ctx) const {
     SortOrder order = SortOrder::Ascending;
@@ -908,14 +972,15 @@ class SortIndicesMetaFunction : public MetaFunction {
 
   Result<Datum> SortIndices(const RecordBatch& batch, const SortOptions& options,
                             ExecContext* ctx) const {
-    auto n_sort_keys = options.sort_keys.size();
+    ARROW_ASSIGN_OR_RAISE(auto sort_keys,
+                          ResolveRecordBatchSortKeys(batch, options.sort_keys));
+
+    auto n_sort_keys = sort_keys.size();
     if (n_sort_keys == 0) {
       return Status::Invalid("Must specify one or more sort keys");
     }
     if (n_sort_keys == 1) {
-      ARROW_ASSIGN_OR_RAISE(auto array, PrependInvalidColumn(GetColumn(
-                                            batch, options.sort_keys[0].target)));
-      return SortIndices(*array, options, ctx);
+      return SortIndices(sort_keys[0].array, options, ctx);
     }
 
     auto out_type = uint64();
@@ -930,14 +995,12 @@ class SortIndicesMetaFunction : public MetaFunction {
     auto out_end = out_begin + length;
     std::iota(out_begin, out_end, 0);
 
-    // Radix sorting is consistently faster except when there is a large number
-    // of sort keys, in which case it can end up degrading catastrophically.
-    // Cut off above 8 sort keys.
-    if (n_sort_keys <= 8) {
-      RadixRecordBatchSorter sorter(out_begin, out_end, batch, options);
+    if (n_sort_keys <= kMaxRadixSortKeys) {
+      RadixRecordBatchSorter sorter(out_begin, out_end, std::move(sort_keys), options);
       ARROW_RETURN_NOT_OK(sorter.Sort());
     } else {
-      MultipleKeyRecordBatchSorter sorter(out_begin, out_end, batch, options);
+      MultipleKeyRecordBatchSorter sorter(out_begin, out_end, std::move(sort_keys),
+                                          options);
       ARROW_RETURN_NOT_OK(sorter.Sort());
     }
     return Datum(out);
@@ -949,11 +1012,6 @@ class SortIndicesMetaFunction : public MetaFunction {
     if (n_sort_keys == 0) {
       return Status::Invalid("Must specify one or more sort keys");
     }
-    if (n_sort_keys == 1) {
-      ARROW_ASSIGN_OR_RAISE(auto chunked_array, PrependInvalidColumn(GetColumn(
-                                                    table, options.sort_keys[0].target)));
-      return SortIndices(*chunked_array, options, ctx);
-    }
 
     auto out_type = uint64();
     auto length = table.num_rows();
@@ -974,25 +1032,67 @@ class SortIndicesMetaFunction : public MetaFunction {
   }
 };
 
+// Helper for processing a vector of `SortKeys` into a vector of `SortFields`
+struct SortFieldPopulator {
+ public:
+  // Process sort keys with the given schema. Note that the output vector may be larger
+  // than the input, as keys referencing a struct are recursively "expanded" into leaf
+  // fields.
+  Result<std::vector<SortField>> FindSortKeys(const Schema& schema,
+                                              const std::vector<SortKey>& sort_keys) {
+    sort_fields_.reserve(sort_keys.size());
+    seen_.reserve(sort_keys.size());
+
+    for (const auto& sort_key : sort_keys) {
+      ARROW_ASSIGN_OR_RAISE(auto match,
+                            PrependInvalidColumn(sort_key.target.FindOne(schema)));
+      if (seen_.insert(match).second) {
+        ARROW_ASSIGN_OR_RAISE(auto schema_field, match.Get(schema));
+        AddField(*schema_field->type(), match, sort_key.order);
+      }
+    }
+
+    return std::move(sort_fields_);
+  }
+
+ protected:
+  void AddLeafFields(const FieldVector& fields, SortOrder order) {
+    if (fields.empty()) {
+      return;
+    }
+
+    tmp_indices_.push_back(0);
+    for (const auto& f : fields) {
+      const auto& type = *f->type();
+      if (type.id() == Type::STRUCT) {
+        AddLeafFields(type.fields(), order);
+      } else {
+        sort_fields_.emplace_back(FieldPath(tmp_indices_), order, &type);
+      }
+      ++tmp_indices_.back();
+    }
+    tmp_indices_.pop_back();
+  }
+
+  void AddField(const DataType& type, const FieldPath& path, SortOrder order) {
+    if (type.id() == Type::STRUCT) {
+      tmp_indices_ = path.indices();
+      AddLeafFields(type.fields(), order);
+    } else {
+      sort_fields_.emplace_back(path, order, &type);
+    }
+  }
+
+  std::vector<SortField> sort_fields_;
+  std::unordered_set<FieldPath> seen_;
+  std::vector<int> tmp_indices_;
+};
+
 }  // namespace
 
 Result<std::vector<SortField>> FindSortKeys(const Schema& schema,
                                             const std::vector<SortKey>& sort_keys) {
-  std::vector<SortField> fields;
-  std::unordered_set<int> seen;
-  fields.reserve(sort_keys.size());
-  seen.reserve(sort_keys.size());
-
-  for (const auto& sort_key : sort_keys) {
-    RETURN_NOT_OK(CheckNonNested(sort_key.target));
-
-    ARROW_ASSIGN_OR_RAISE(auto match,
-                          PrependInvalidColumn(sort_key.target.FindOne(schema)));
-    if (seen.insert(match[0]).second) {
-      fields.push_back({match[0], sort_key.order});
-    }
-  }
-  return fields;
+  return SortFieldPopulator{}.FindSortKeys(schema, sort_keys);
 }
 
 Result<NullPartitionResult> SortChunkedArray(ExecContext* ctx, uint64_t* indices_begin,
@@ -1017,6 +1117,35 @@ Result<NullPartitionResult> SortChunkedArray(
   return output;
 }
 
+Result<NullPartitionResult> SortStructArray(ExecContext* ctx, uint64_t* indices_begin,
+                                            uint64_t* indices_end,
+                                            const StructArray& array,
+                                            SortOrder sort_order,
+                                            NullPlacement null_placement) {
+  ARROW_ASSIGN_OR_RAISE(auto columns, array.Flatten());
+  auto batch = RecordBatch::Make(schema(array.type()->fields()), array.length(),
+                                 std::move(columns));
+
+  auto options = SortOptions::Defaults();
+  options.null_placement = null_placement;
+  options.sort_keys.reserve(array.num_fields());
+  for (int i = 0; i < array.num_fields(); ++i) {
+    options.sort_keys.push_back(SortKey(FieldRef(i), sort_order));
+  }
+
+  ARROW_ASSIGN_OR_RAISE(auto sort_keys,
+                        ResolveRecordBatchSortKeys(*batch, options.sort_keys));
+  if (sort_keys.size() <= kMaxRadixSortKeys) {
+    RadixRecordBatchSorter sorter(indices_begin, indices_end, std::move(sort_keys),
+                                  options);
+    return sorter.Sort();
+  } else {
+    MultipleKeyRecordBatchSorter sorter(indices_begin, indices_end, std::move(sort_keys),
+                                        options);
+    return sorter.Sort();
+  }
+}
+
 void RegisterVectorSort(FunctionRegistry* registry) {
   DCHECK_OK(registry->AddFunction(std::make_shared<SortIndicesMetaFunction>()));
 }
diff --git a/cpp/src/arrow/compute/kernels/vector_sort_internal.h b/cpp/src/arrow/compute/kernels/vector_sort_internal.h
index d78e513061..d7e5575c80 100644
--- a/cpp/src/arrow/compute/kernels/vector_sort_internal.h
+++ b/cpp/src/arrow/compute/kernels/vector_sort_internal.h
@@ -464,12 +464,27 @@ Result<NullPartitionResult> SortChunkedArray(
     const std::shared_ptr<DataType>& physical_type, const ArrayVector& physical_chunks,
     SortOrder sort_order, NullPlacement null_placement);
 
+Result<NullPartitionResult> SortStructArray(ExecContext* ctx, uint64_t* indices_begin,
+                                            uint64_t* indices_end,
+                                            const StructArray& array,
+                                            SortOrder sort_order,
+                                            NullPlacement null_placement);
+
 // ----------------------------------------------------------------------
 // Helpers for Sort/SelectK/Rank implementations
 
 struct SortField {
-  int field_index;
+  SortField() = default;
+  SortField(FieldPath path, SortOrder order, const DataType* type)
+      : path(std::move(path)), order(order), type(type) {}
+  SortField(int index, SortOrder order, const DataType* type)
+      : SortField(FieldPath({index}), order, type) {}
+
+  bool is_nested() const { return path.indices().size() > 1; }
+
+  FieldPath path;
   SortOrder order;
+  const DataType* type;
 };
 
 inline Status CheckNonNested(const FieldRef& ref) {
@@ -496,7 +511,10 @@ Result<std::vector<ResolvedSortKey>> ResolveSortKeys(
   ARROW_ASSIGN_OR_RAISE(const auto fields, FindSortKeys(schema, sort_keys));
   std::vector<ResolvedSortKey> resolved;
   resolved.reserve(fields.size());
-  std::transform(fields.begin(), fields.end(), std::back_inserter(resolved), factory);
+  for (const auto& f : fields) {
+    ARROW_ASSIGN_OR_RAISE(auto resolved_key, factory(f));
+    resolved.push_back(std::move(resolved_key));
+  }
   return resolved;
 }
 
@@ -504,8 +522,17 @@ template <typename ResolvedSortKey, typename TableOrBatch>
 Result<std::vector<ResolvedSortKey>> ResolveSortKeys(
     const TableOrBatch& table_or_batch, const std::vector<SortKey>& sort_keys) {
   return ResolveSortKeys<ResolvedSortKey>(
-      *table_or_batch.schema(), sort_keys, [&](const SortField& f) {
-        return ResolvedSortKey{table_or_batch.column(f.field_index), f.order};
+      *table_or_batch.schema(), sort_keys,
+      [&](const SortField& f) -> Result<ResolvedSortKey> {
+        if (f.is_nested()) {
+          // TODO: Some room for improvement here, as we potentially duplicate some of the
+          // null-flattening work for nested sort keys. For instance, given two keys with
+          // paths [0,0,0,0] and [0,0,0,1], we shouldn't need to flatten the first three
+          // components more than once.
+          ARROW_ASSIGN_OR_RAISE(auto child, f.path.GetFlattened(table_or_batch));
+          return ResolvedSortKey{std::move(child), f.order};
+        }
+        return ResolvedSortKey{table_or_batch.column(f.path[0]), f.order};
       });
 }
 
@@ -737,17 +764,20 @@ struct ResolvedTableSortKey {
   static Result<std::vector<ResolvedTableSortKey>> Make(
       const Table& table, const RecordBatchVector& batches,
       const std::vector<SortKey>& sort_keys) {
-    auto factory = [&](const SortField& f) {
-      const auto& type = table.schema()->field(f.field_index)->type();
+    auto factory = [&](const SortField& f) -> Result<ResolvedTableSortKey> {
       // We must expose a homogenous chunking for all ResolvedSortKey,
-      // so we can't simply pass `table.column(f.field_index)`
-      ArrayVector chunks(batches.size());
-      std::transform(batches.begin(), batches.end(), chunks.begin(),
-                     [&](const std::shared_ptr<RecordBatch>& batch) {
-                       return batch->column(f.field_index);
-                     });
-      return ResolvedTableSortKey(type, std::move(chunks), f.order,
-                                  table.column(f.field_index)->null_count());
+      // so we can't simply access the column from the table directly.
+      ArrayVector chunks;
+      chunks.reserve(batches.size());
+      int64_t null_count = 0;
+      for (const auto& batch : batches) {
+        ARROW_ASSIGN_OR_RAISE(auto child, f.path.GetFlattened(*batch));
+        null_count += child->null_count();
+        chunks.push_back(std::move(child));
+      }
+
+      return ResolvedTableSortKey(f.type->GetSharedPtr(), std::move(chunks), f.order,
+                                  null_count);
     };
 
     return ::arrow::compute::internal::ResolveSortKeys<ResolvedTableSortKey>(
diff --git a/cpp/src/arrow/compute/kernels/vector_sort_test.cc b/cpp/src/arrow/compute/kernels/vector_sort_test.cc
index 3429a5a878..1328dddc04 100644
--- a/cpp/src/arrow/compute/kernels/vector_sort_test.cc
+++ b/cpp/src/arrow/compute/kernels/vector_sort_test.cc
@@ -515,30 +515,28 @@ TEST(ArraySortIndicesFunction, ChunkedArray) {
 // ----------------------------------------------------------------------
 // Tests for SortToIndices
 
-template <typename T>
-void AssertSortIndices(const std::shared_ptr<T>& input, SortOrder order,
-                       NullPlacement null_placement,
+void AssertSortIndices(const Datum& datum, const SortOptions& options,
                        const std::shared_ptr<Array>& expected) {
-  ArraySortOptions options(order, null_placement);
-  ASSERT_OK_AND_ASSIGN(auto actual, SortIndices(*input, options));
+  ASSERT_OK_AND_ASSIGN(auto actual, SortIndices(datum, options));
   ValidateOutput(*actual);
   AssertArraysEqual(*expected, *actual, /*verbose=*/true);
 }
 
+void AssertSortIndices(const Datum& datum, const SortOptions& options,
+                       const std::string& expected) {
+  AssertSortIndices(datum, options, ArrayFromJSON(uint64(), expected));
+}
+
 template <typename T>
-void AssertSortIndices(const std::shared_ptr<T>& input, const SortOptions& options,
+void AssertSortIndices(const std::shared_ptr<T>& input, SortOrder order,
+                       NullPlacement null_placement,
                        const std::shared_ptr<Array>& expected) {
-  ASSERT_OK_AND_ASSIGN(auto actual, SortIndices(Datum(*input), options));
+  ArraySortOptions options(order, null_placement);
+  ASSERT_OK_AND_ASSIGN(auto actual, SortIndices(*input, options));
   ValidateOutput(*actual);
   AssertArraysEqual(*expected, *actual, /*verbose=*/true);
 }
 
-template <typename T>
-void AssertSortIndices(const std::shared_ptr<T>& input, const SortOptions& options,
-                       const std::string& expected) {
-  AssertSortIndices(input, options, ArrayFromJSON(uint64(), expected));
-}
-
 template <typename T>
 void AssertSortIndices(const std::shared_ptr<T>& input, SortOrder order,
                        NullPlacement null_placement, const std::string& expected) {
@@ -2115,6 +2113,97 @@ INSTANTIATE_TEST_SUITE_P(AllNull, TestTableSortIndicesRandom,
                          testing::Combine(first_sort_keys, num_sort_keys,
                                           testing::Values(1.0)));
 
+class TestNestedSortIndices : public ::testing::Test {
+ protected:
+  static std::shared_ptr<Array> GetArray() {
+    auto struct_type = struct_({field(
+        "a",
+        struct_({field("a", uint8()),
+                 field("b", struct_({field("a", int32()), field("b", uint32())}))}))});
+    auto struct_array = checked_pointer_cast<StructArray>(
+        ArrayFromJSON(struct_type,
+                      R"([{"a": {"a": 5,    "b": {"a": null, "b": 8   }}},
+                          {"a": {"a": null, "b": {"a": 8,    "b": null}}},
+                          {"a": {"a": null, "b": {"a": 9,    "b": 0   }}},
+                          {"a": {"a": 2,    "b": {"a": 4,    "b": null}}},
+                          {"a": {"a": 5,    "b": {"a": 1,    "b": 8   }}},
+                          {"a": {"a": 3,    "b": {"a": null, "b": 0   }}},
+                          {"a": {"a": 2,    "b": {"a": 4,    "b": 2   }}},
+                          {"a": {"a": 2,    "b": {"a": 4,    "b": 4   }}},
+                          {"a": {"a": null, "b": {"a": 7,    "b": 7   }}}])"));
+
+    // The top-level validity bitmap is created independently to test null inheritance for
+    // child fields.
+    std::shared_ptr<Buffer> parent_bitmap;
+    ARROW_CHECK_OK(
+        GetBitmapFromVector<bool>({1, 1, 1, 1, 1, 0, 1, 1, 1}, &parent_bitmap));
+
+    auto array =
+        *StructArray::Make(struct_array->fields(), struct_type->fields(), parent_bitmap);
+    ARROW_CHECK_OK(array->ValidateFull());
+    return array;
+  }
+
+  static std::shared_ptr<RecordBatch> GetRecordBatch() {
+    auto batch = *RecordBatch::FromStructArray(GetArray());
+    ARROW_CHECK_OK(batch->ValidateFull());
+    return batch;
+  }
+
+  static std::shared_ptr<ChunkedArray> GetChunkedArray() {
+    auto array = GetArray();
+    ArrayVector chunks(2);
+    chunks[0] = *array->SliceSafe(0, 3);
+    chunks[1] = *array->SliceSafe(3);
+    auto chunked = *ChunkedArray::Make(std::move(chunks));
+    ARROW_CHECK_OK(chunked->ValidateFull());
+    return chunked;
+  }
+
+  static std::shared_ptr<Table> GetTable() {
+    auto chunked = GetChunkedArray();
+    auto columns = *chunked->Flatten();
+    auto table =
+        Table::Make(arrow::schema(chunked->type()->fields()), std::move(columns));
+    ARROW_CHECK_OK(table->ValidateFull());
+    return table;
+  }
+
+  void TestSort(const Datum& datum) const {
+    std::vector<SortKey> sort_keys = {SortKey(FieldRef("a", "a"), SortOrder::Ascending),
+                                      SortKey(FieldRef("a", "b"), SortOrder::Descending)};
+
+    SortOptions options(sort_keys, NullPlacement::AtEnd);
+    AssertSortIndices(datum, options, "[7, 6, 3, 4, 0, 2, 1, 8, 5]");
+    options.null_placement = NullPlacement::AtStart;
+    AssertSortIndices(datum, options, "[5, 2, 1, 8, 3, 7, 6, 0, 4]");
+
+    // Implementations may have an optimized path for cases with one sort key.
+    // Additionally, this key references a struct containing another struct, which should
+    // work recursively
+    options.sort_keys = {SortKey(FieldRef("a"), SortOrder::Ascending)};
+    options.null_placement = NullPlacement::AtEnd;
+    AssertSortIndices(datum, options, "[6, 7, 3, 4, 0, 8, 1, 2, 5]");
+    options.null_placement = NullPlacement::AtStart;
+    AssertSortIndices(datum, options, "[5, 8, 1, 2, 3, 6, 7, 0, 4]");
+  }
+
+  void TestArraySort() const {
+    auto array = GetArray();
+    AssertSortIndices(array, SortOrder::Ascending, NullPlacement::AtEnd,
+                      "[6, 7, 3, 4, 0, 8, 1, 2, 5]");
+    AssertSortIndices(array, SortOrder::Ascending, NullPlacement::AtStart,
+                      "[5, 8, 1, 2, 3, 6, 7, 0, 4]");
+  }
+};
+
+TEST_F(TestNestedSortIndices, ArraySort) { TestArraySort(); }
+
+TEST_F(TestNestedSortIndices, SortStructArray) { TestSort(GetArray()); }
+TEST_F(TestNestedSortIndices, SortChunkedArray) { TestSort(GetChunkedArray()); }
+TEST_F(TestNestedSortIndices, SortRecordBatch) { TestSort(GetRecordBatch()); }
+TEST_F(TestNestedSortIndices, SortTable) { TestSort(GetTable()); }
+
 // ----------------------------------------------------------------------
 // Tests for Rank