You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ap...@apache.org on 2020/12/14 08:25:20 UTC

[arrow] branch master updated: ARROW-10796: [C++] Implement optimized RecordBatch sorting

This is an automated email from the ASF dual-hosted git repository.

apitrou pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 8ae596a  ARROW-10796: [C++] Implement optimized RecordBatch sorting
8ae596a is described below

commit 8ae596a9b6292f0b0ee3ae770ceeeb5bf4006ce8
Author: Antoine Pitrou <an...@python.org>
AuthorDate: Mon Dec 14 09:24:06 2020 +0100

    ARROW-10796: [C++] Implement optimized RecordBatch sorting
    
    Add two RecordBatch sorting implementations:
    * A single-pass left-to-right radix sort that's fast up to ~8 sort keys
    * A single-pass multiple-key-comparing sort that gives decent performance for large numbers of sort keys
    
    Both implementations benefit from direct indexed access into the contiguous RecordBatch columns (as opposed to table sorting, which must index into the chunks).
    
    Add some RecordBatch-sorting benchmarks.
    
    Also, add and improve tests; and fix a bug related to sorting of NaNs and nulls.
    
    Benchmarks (changes less than 10% in absolute value not shown):
    ```
                                               benchmark            baseline           contender  change %                                                                                                                                                            counters
    10   RecordBatchSortIndicesInt64Narrow/1048576/100/8    1.482m items/sec    5.410m items/sec   265.083    {'run_name': 'RecordBatchSortIndicesInt64Narrow/1048576/100/8', 'run_type': 'iteration', 'repetitions': 0, 'repetition_index': 0, 'threads': 1, 'iterations': 1}
    20     RecordBatchSortIndicesInt64Narrow/1048576/0/8    1.524m items/sec    5.478m items/sec   259.478      {'run_name': 'RecordBatchSortIndicesInt64Narrow/1048576/0/8', 'run_type': 'iteration', 'repetitions': 0, 'repetition_index': 0, 'threads': 1, 'iterations': 1}
    60   RecordBatchSortIndicesInt64Narrow/1048576/100/2    2.276m items/sec    7.803m items/sec   242.839    {'run_name': 'RecordBatchSortIndicesInt64Narrow/1048576/100/2', 'run_type': 'iteration', 'repetitions': 0, 'repetition_index': 0, 'threads': 1, 'iterations': 2}
    21     RecordBatchSortIndicesInt64Narrow/1048576/0/2    2.340m items/sec    7.802m items/sec   233.369      {'run_name': 'RecordBatchSortIndicesInt64Narrow/1048576/0/2', 'run_type': 'iteration', 'repetitions': 0, 'repetition_index': 0, 'threads': 1, 'iterations': 2}
    23     RecordBatchSortIndicesInt64Wide/1048576/100/2    4.673m items/sec    9.867m items/sec   111.164      {'run_name': 'RecordBatchSortIndicesInt64Wide/1048576/100/2', 'run_type': 'iteration', 'repetitions': 0, 'repetition_index': 0, 'threads': 1, 'iterations': 3}
    61       RecordBatchSortIndicesInt64Wide/1048576/0/2    4.677m items/sec    9.820m items/sec   109.971        {'run_name': 'RecordBatchSortIndicesInt64Wide/1048576/0/2', 'run_type': 'iteration', 'repetitions': 0, 'repetition_index': 0, 'threads': 1, 'iterations': 3}
    35     RecordBatchSortIndicesInt64Wide/1048576/100/8    4.680m items/sec    9.822m items/sec   109.850      {'run_name': 'RecordBatchSortIndicesInt64Wide/1048576/100/8', 'run_type': 'iteration', 'repetitions': 0, 'repetition_index': 0, 'threads': 1, 'iterations': 3}
    55       RecordBatchSortIndicesInt64Wide/1048576/0/8    4.755m items/sec    9.933m items/sec   108.895        {'run_name': 'RecordBatchSortIndicesInt64Wide/1048576/0/8', 'run_type': 'iteration', 'repetitions': 0, 'repetition_index': 0, 'threads': 1, 'iterations': 3}
    59      RecordBatchSortIndicesInt64Wide/1048576/0/16    4.794m items/sec    8.408m items/sec    75.389       {'run_name': 'RecordBatchSortIndicesInt64Wide/1048576/0/16', 'run_type': 'iteration', 'repetitions': 0, 'repetition_index': 0, 'threads': 1, 'iterations': 3}
    16    RecordBatchSortIndicesInt64Wide/1048576/100/16    4.733m items/sec    8.177m items/sec    72.780     {'run_name': 'RecordBatchSortIndicesInt64Wide/1048576/100/16', 'run_type': 'iteration', 'repetitions': 0, 'repetition_index': 0, 'threads': 1, 'iterations': 3}
    29    RecordBatchSortIndicesInt64Narrow/1048576/0/16    1.640m items/sec    2.627m items/sec    60.146     {'run_name': 'RecordBatchSortIndicesInt64Narrow/1048576/0/16', 'run_type': 'iteration', 'repetitions': 0, 'repetition_index': 0, 'threads': 1, 'iterations': 1}
    9   RecordBatchSortIndicesInt64Narrow/1048576/100/16    1.559m items/sec    2.342m items/sec    50.201   {'run_name': 'RecordBatchSortIndicesInt64Narrow/1048576/100/16', 'run_type': 'iteration', 'repetitions': 0, 'repetition_index': 0, 'threads': 1, 'iterations': 1}
    4          TableSortIndicesInt64Narrow/1048576/0/2/1    2.415m items/sec    2.699m items/sec    11.723          {'run_name': 'TableSortIndicesInt64Narrow/1048576/0/2/1', 'run_type': 'iteration', 'repetitions': 0, 'repetition_index': 0, 'threads': 1, 'iterations': 2}
    51        TableSortIndicesInt64Narrow/1048576/0/2/32    1.814m items/sec    2.023m items/sec    11.513         {'run_name': 'TableSortIndicesInt64Narrow/1048576/0/2/32', 'run_type': 'iteration', 'repetitions': 0, 'repetition_index': 0, 'threads': 1, 'iterations': 1}
    49        TableSortIndicesInt64Narrow/1048576/0/16/4    1.542m items/sec    1.717m items/sec    11.361         {'run_name': 'TableSortIndicesInt64Narrow/1048576/0/16/4', 'run_type': 'iteration', 'repetitions': 0, 'repetition_index': 0, 'threads': 1, 'iterations': 1}
    30         TableSortIndicesInt64Narrow/1048576/0/2/4    2.272m items/sec    2.516m items/sec    10.733          {'run_name': 'TableSortIndicesInt64Narrow/1048576/0/2/4', 'run_type': 'iteration', 'repetitions': 0, 'repetition_index': 0, 'threads': 1, 'iterations': 2}
    25         TableSortIndicesInt64Narrow/1048576/0/8/4    1.542m items/sec    1.706m items/sec    10.628          {'run_name': 'TableSortIndicesInt64Narrow/1048576/0/8/4', 'run_type': 'iteration', 'repetitions': 0, 'repetition_index': 0, 'threads': 1, 'iterations': 1}
    11        TableSortIndicesInt64Narrow/1048576/0/16/1    1.691m items/sec    1.866m items/sec    10.316         {'run_name': 'TableSortIndicesInt64Narrow/1048576/0/16/1', 'run_type': 'iteration', 'repetitions': 0, 'repetition_index': 0, 'threads': 1, 'iterations': 1}
    12         TableSortIndicesInt64Narrow/1048576/0/8/1    1.683m items/sec    1.856m items/sec    10.280          {'run_name': 'TableSortIndicesInt64Narrow/1048576/0/8/1', 'run_type': 'iteration', 'repetitions': 0, 'repetition_index': 0, 'threads': 1, 'iterations': 1}
    [...]
    6      RecordBatchSortIndicesInt64Narrow/1048576/0/1  185.050m items/sec  164.579m items/sec   -11.062    {'run_name': 'RecordBatchSortIndicesInt64Narrow/1048576/0/1', 'run_type': 'iteration', 'repetitions': 0, 'repetition_index': 0, 'threads': 1, 'iterations': 122}
    ```
    
    Closes #8890 from pitrou/ARROW-10796-batch-sort
    
    Authored-by: Antoine Pitrou <an...@python.org>
    Signed-off-by: Antoine Pitrou <an...@python.org>
---
 cpp/src/arrow/compute/kernels/vector_sort.cc       | 914 +++++++++++++++------
 .../arrow/compute/kernels/vector_sort_benchmark.cc | 125 ++-
 cpp/src/arrow/compute/kernels/vector_sort_test.cc  | 367 +++++++--
 cpp/src/arrow/testing/random.cc                    |  60 +-
 cpp/src/arrow/testing/random.h                     |   6 +-
 5 files changed, 1085 insertions(+), 387 deletions(-)

diff --git a/cpp/src/arrow/compute/kernels/vector_sort.cc b/cpp/src/arrow/compute/kernels/vector_sort.cc
index 85d2557..ef0c80d 100644
--- a/cpp/src/arrow/compute/kernels/vector_sort.cc
+++ b/cpp/src/arrow/compute/kernels/vector_sort.cc
@@ -30,6 +30,7 @@
 #include "arrow/util/bit_block_counter.h"
 #include "arrow/util/checked_cast.h"
 #include "arrow/util/optional.h"
+#include "arrow/visitor_inline.h"
 
 namespace arrow {
 
@@ -38,6 +39,21 @@ using internal::checked_cast;
 namespace compute {
 namespace internal {
 
+// Visit all physical types for which sorting is implemented.
+#define VISIT_PHYSICAL_TYPES(VISIT) \
+  VISIT(Int8Type)                   \
+  VISIT(Int16Type)                  \
+  VISIT(Int32Type)                  \
+  VISIT(Int64Type)                  \
+  VISIT(UInt8Type)                  \
+  VISIT(UInt16Type)                 \
+  VISIT(UInt32Type)                 \
+  VISIT(UInt64Type)                 \
+  VISIT(FloatType)                  \
+  VISIT(DoubleType)                 \
+  VISIT(BinaryType)                 \
+  VISIT(LargeBinaryType)
+
 namespace {
 
 // The target chunk in a chunked array.
@@ -142,15 +158,20 @@ struct ChunkedArrayResolver {
 // (such as cached raw values pointer) in a separate hierarchy of
 // physical accessors, but doing so ends up too cumbersome.
 // Instead, we simply create the desired concrete Array objects.
+std::shared_ptr<Array> GetPhysicalArray(const Array& array,
+                                        const std::shared_ptr<DataType>& physical_type) {
+  auto new_data = array.data()->Copy();
+  new_data->type = physical_type;
+  return MakeArray(std::move(new_data));
+}
+
 ArrayVector GetPhysicalChunks(const ChunkedArray& chunked_array,
                               const std::shared_ptr<DataType>& physical_type) {
   const auto& chunks = chunked_array.chunks();
   ArrayVector physical(chunks.size());
   std::transform(chunks.begin(), chunks.end(), physical.begin(),
                  [&](const std::shared_ptr<Array>& array) {
-                   auto new_data = array->data()->Copy();
-                   new_data->type = physical_type;
-                   return MakeArray(std::move(new_data));
+                   return GetPhysicalArray(*array, physical_type);
                  });
   return physical;
 }
@@ -634,20 +655,9 @@ class ChunkedArraySorter : public TypeVisitor {
   Status Sort() { return physical_type_->Accept(this); }
 
 #define VISIT(TYPE) \
-  Status Visit(const TYPE##Type& type) override { return SortInternal<TYPE##Type>(); }
-
-  VISIT(Int8)
-  VISIT(Int16)
-  VISIT(Int32)
-  VISIT(Int64)
-  VISIT(UInt8)
-  VISIT(UInt16)
-  VISIT(UInt32)
-  VISIT(UInt64)
-  VISIT(Float)
-  VISIT(Double)
-  VISIT(Binary)
-  VISIT(LargeBinary)
+  Status Visit(const TYPE& type) override { return SortInternal<TYPE>(); }
+
+  VISIT_PHYSICAL_TYPES(VISIT)
 
 #undef VISIT
 
@@ -805,6 +815,517 @@ class ChunkedArraySorter : public TypeVisitor {
 };
 
 // ----------------------------------------------------------------------
+// Record batch sorting implementation(s)
+
+// Visit contiguous ranges of equal values.  All entries are assumed
+// to be non-null.
+template <typename ArrayType, typename Visitor>
+void VisitConstantRanges(const ArrayType& array, uint64_t* indices_begin,
+                         uint64_t* indices_end, Visitor&& visit) {
+  if (indices_begin == indices_end) {
+    return;
+  }
+  auto range_start = indices_begin;
+  auto range_cur = range_start;
+  auto last_value = array.GetView(*range_cur);
+  while (++range_cur != indices_end) {
+    auto v = array.GetView(*range_cur);
+    if (v != last_value) {
+      visit(range_start, range_cur);
+      range_start = range_cur;
+      last_value = v;
+    }
+  }
+  if (range_start != range_cur) {
+    visit(range_start, range_cur);
+  }
+}
+
+// A sorter for a single column of a RecordBatch, deferring to the next column
+// for ranges of equal values.
+class RecordBatchColumnSorter {
+ public:
+  explicit RecordBatchColumnSorter(RecordBatchColumnSorter* next_column = nullptr)
+      : next_column_(next_column) {}
+  virtual ~RecordBatchColumnSorter() {}
+
+  virtual void SortRange(uint64_t* indices_begin, uint64_t* indices_end) = 0;
+
+ protected:
+  RecordBatchColumnSorter* next_column_;
+};
+
+template <typename Type>
+class ConcreteRecordBatchColumnSorter : public RecordBatchColumnSorter {
+ public:
+  using ArrayType = typename TypeTraits<Type>::ArrayType;
+
+  ConcreteRecordBatchColumnSorter(std::shared_ptr<Array> array, SortOrder order,
+                                  RecordBatchColumnSorter* next_column = nullptr)
+      : RecordBatchColumnSorter(next_column),
+        owned_array_(std::move(array)),
+        array_(checked_cast<const ArrayType&>(*owned_array_)),
+        order_(order),
+        null_count_(array_.null_count()) {}
+
+  void SortRange(uint64_t* indices_begin, uint64_t* indices_end) {
+    constexpr int64_t offset = 0;
+    uint64_t* nulls_begin;
+    if (null_count_ == 0) {
+      nulls_begin = indices_end;
+    } else {
+      // NOTE that null_count_ is merely an upper bound on the number of nulls
+      // in this particular range.
+      nulls_begin = PartitionNullsOnly<StablePartitioner>(indices_begin, indices_end,
+                                                          array_, offset);
+      DCHECK_LE(indices_end - nulls_begin, null_count_);
+    }
+    uint64_t* null_likes_begin = PartitionNullLikes<ArrayType, StablePartitioner>(
+        indices_begin, nulls_begin, array_, offset);
+
+    // TODO This is roughly the same as ArrayCompareSorter.
+    // Also, we would like to use a counting sort if possible.  This requires
+    // a counting sort compatible with indirect indexing.
+    if (order_ == SortOrder::Ascending) {
+      std::stable_sort(
+          indices_begin, null_likes_begin, [&](uint64_t left, uint64_t right) {
+            return array_.GetView(left - offset) < array_.GetView(right - offset);
+          });
+    } else {
+      std::stable_sort(
+          indices_begin, null_likes_begin, [&](uint64_t left, uint64_t right) {
+            // We don't use 'left > right' here to reduce required operator.
+            // If we use 'right < left' here, '<' is only required.
+            return array_.GetView(right - offset) < array_.GetView(left - offset);
+          });
+    }
+
+    if (next_column_ != nullptr) {
+      // Visit all ranges of equal values in this column and sort them on
+      // the next column.
+      SortNextColumn(null_likes_begin, nulls_begin);
+      SortNextColumn(nulls_begin, indices_end);
+      VisitConstantRanges(array_, indices_begin, null_likes_begin,
+                          [&](uint64_t* range_start, uint64_t* range_end) {
+                            SortNextColumn(range_start, range_end);
+                          });
+    }
+  }
+
+  void SortNextColumn(uint64_t* indices_begin, uint64_t* indices_end) {
+    // Avoid the cost of a virtual method call in trivial cases
+    if (indices_end - indices_begin > 1) {
+      next_column_->SortRange(indices_begin, indices_end);
+    }
+  }
+
+ protected:
+  const std::shared_ptr<Array> owned_array_;
+  const ArrayType& array_;
+  const SortOrder order_;
+  const int64_t null_count_;
+};
+
+// Sort a batch using a single-pass left-to-right radix sort.
+class RadixRecordBatchSorter {
+ public:
+  RadixRecordBatchSorter(uint64_t* indices_begin, uint64_t* indices_end,
+                         const RecordBatch& batch, const SortOptions& options)
+      : batch_(batch),
+        options_(options),
+        indices_begin_(indices_begin),
+        indices_end_(indices_end) {}
+
+  Status Sort() {
+    ARROW_ASSIGN_OR_RAISE(const auto sort_keys,
+                          ResolveSortKeys(batch_, options_.sort_keys));
+
+    // Create column sorters from right to left
+    std::vector<std::unique_ptr<RecordBatchColumnSorter>> column_sorts(sort_keys.size());
+    RecordBatchColumnSorter* next_column = nullptr;
+    for (int64_t i = static_cast<int64_t>(sort_keys.size() - 1); i >= 0; --i) {
+      ColumnSortFactory factory(sort_keys[i], next_column);
+      ARROW_ASSIGN_OR_RAISE(column_sorts[i], factory.MakeColumnSort());
+      next_column = column_sorts[i].get();
+    }
+
+    // Sort from left to right
+    column_sorts.front()->SortRange(indices_begin_, indices_end_);
+    return Status::OK();
+  }
+
+ protected:
+  struct ResolvedSortKey {
+    std::shared_ptr<Array> array;
+    SortOrder order;
+  };
+
+  struct ColumnSortFactory {
+    ColumnSortFactory(const ResolvedSortKey& sort_key,
+                      RecordBatchColumnSorter* next_column)
+        : physical_type(GetPhysicalType(sort_key.array->type())),
+          array(GetPhysicalArray(*sort_key.array, physical_type)),
+          order(sort_key.order),
+          next_column(next_column) {}
+
+    Result<std::unique_ptr<RecordBatchColumnSorter>> MakeColumnSort() {
+      RETURN_NOT_OK(VisitTypeInline(*physical_type, this));
+      DCHECK_NE(result, nullptr);
+      return std::move(result);
+    }
+
+#define VISIT(TYPE) \
+  Status Visit(const TYPE& type) { return VisitGeneric(type); }
+
+    VISIT_PHYSICAL_TYPES(VISIT)
+
+#undef VISIT
+
+    Status Visit(const DataType& type) {
+      return Status::TypeError("Unsupported type for RecordBatch sorting: ",
+                               type.ToString());
+    }
+
+    template <typename Type>
+    Status VisitGeneric(const Type&) {
+      result.reset(new ConcreteRecordBatchColumnSorter<Type>(array, order, next_column));
+      return Status::OK();
+    }
+
+    std::shared_ptr<DataType> physical_type;
+    std::shared_ptr<Array> array;
+    SortOrder order;
+    RecordBatchColumnSorter* next_column;
+    std::unique_ptr<RecordBatchColumnSorter> result;
+  };
+
+  static Result<std::vector<ResolvedSortKey>> ResolveSortKeys(
+      const RecordBatch& batch, const std::vector<SortKey>& sort_keys) {
+    std::vector<ResolvedSortKey> resolved;
+    resolved.reserve(sort_keys.size());
+    for (const auto& sort_key : sort_keys) {
+      auto array = batch.GetColumnByName(sort_key.name);
+      if (!array) {
+        return Status::Invalid("Nonexistent sort key column: ", sort_key.name);
+      }
+      resolved.push_back({std::move(array), sort_key.order});
+    }
+    return resolved;
+  }
+
+  const RecordBatch& batch_;
+  const SortOptions& options_;
+  uint64_t* indices_begin_;
+  uint64_t* indices_end_;
+};
+
+// Compare two records in the same RecordBatch or Table
+// (indexing is handled through ResolvedSortKey)
+template <typename ResolvedSortKey>
+class MultipleKeyComparator {
+ public:
+  explicit MultipleKeyComparator(const std::vector<ResolvedSortKey>& sort_keys)
+      : sort_keys_(sort_keys) {}
+
+  Status status() const { return status_; }
+
+  // Returns true if the left-th value should be ordered before the
+  // right-th value, false otherwise. The start_sort_key_index-th
+  // sort key and subsequent sort keys are used for comparison.
+  bool Compare(uint64_t left, uint64_t right, size_t start_sort_key_index) {
+    current_left_ = left;
+    current_right_ = right;
+    current_compared_ = 0;
+    auto num_sort_keys = sort_keys_.size();
+    for (size_t i = start_sort_key_index; i < num_sort_keys; ++i) {
+      current_sort_key_index_ = i;
+      status_ = VisitTypeInline(*sort_keys_[i].type, this);
+      // If the left value equals to the right value, we need to
+      // continue to sort.
+      if (current_compared_ != 0) {
+        break;
+      }
+    }
+    return current_compared_ < 0;
+  }
+
+#define VISIT(TYPE)                          \
+  Status Visit(const TYPE& type) {           \
+    current_compared_ = CompareType<TYPE>(); \
+    return Status::OK();                     \
+  }
+
+  VISIT_PHYSICAL_TYPES(VISIT)
+
+#undef VISIT
+
+  Status Visit(const DataType& type) {
+    return Status::TypeError("Unsupported type for RecordBatch sorting: ",
+                             type.ToString());
+  }
+
+ private:
+  // Compares two records in the same table and returns -1, 0 or 1.
+  //
+  // -1: The left is less than the right.
+  // 0: The left equals to the right.
+  // 1: The left is greater than the right.
+  //
+  // This supports null and NaN. Null is processed in this and NaN
+  // is processed in CompareTypeValue().
+  template <typename Type>
+  int32_t CompareType() {
+    using ArrayType = typename TypeTraits<Type>::ArrayType;
+    const auto& sort_key = sort_keys_[current_sort_key_index_];
+    auto order = sort_key.order;
+    const auto chunk_left = sort_key.template GetChunk<ArrayType>(current_left_);
+    const auto chunk_right = sort_key.template GetChunk<ArrayType>(current_right_);
+    if (sort_key.null_count > 0) {
+      auto is_null_left = chunk_left.IsNull();
+      auto is_null_right = chunk_right.IsNull();
+      if (is_null_left && is_null_right) {
+        return 0;
+      } else if (is_null_left) {
+        return 1;
+      } else if (is_null_right) {
+        return -1;
+      }
+    }
+    return CompareTypeValue<Type>(chunk_left, chunk_right, order);
+  }
+
+  // For non-float types. Value is never NaN.
+  template <typename Type>
+  enable_if_t<!is_floating_type<Type>::value, int32_t> CompareTypeValue(
+      const ResolvedChunk<typename TypeTraits<Type>::ArrayType>& chunk_left,
+      const ResolvedChunk<typename TypeTraits<Type>::ArrayType>& chunk_right,
+      const SortOrder order) {
+    const auto left = chunk_left.GetView();
+    const auto right = chunk_right.GetView();
+    int32_t compared;
+    if (left == right) {
+      compared = 0;
+    } else if (left > right) {
+      compared = 1;
+    } else {
+      compared = -1;
+    }
+    if (order == SortOrder::Descending) {
+      compared = -compared;
+    }
+    return compared;
+  }
+
+  // For float types. Value may be NaN.
+  template <typename Type>
+  enable_if_t<is_floating_type<Type>::value, int32_t> CompareTypeValue(
+      const ResolvedChunk<typename TypeTraits<Type>::ArrayType>& chunk_left,
+      const ResolvedChunk<typename TypeTraits<Type>::ArrayType>& chunk_right,
+      const SortOrder order) {
+    const auto left = chunk_left.GetView();
+    const auto right = chunk_right.GetView();
+    auto is_nan_left = std::isnan(left);
+    auto is_nan_right = std::isnan(right);
+    if (is_nan_left && is_nan_right) {
+      return 0;
+    } else if (is_nan_left) {
+      return 1;
+    } else if (is_nan_right) {
+      return -1;
+    }
+    int32_t compared;
+    if (left == right) {
+      compared = 0;
+    } else if (left > right) {
+      compared = 1;
+    } else {
+      compared = -1;
+    }
+    if (order == SortOrder::Descending) {
+      compared = -compared;
+    }
+    return compared;
+  }
+
+  const std::vector<ResolvedSortKey>& sort_keys_;
+  Status status_;
+  int64_t current_left_;
+  int64_t current_right_;
+  size_t current_sort_key_index_;
+  int32_t current_compared_;
+};
+
+// Sort a batch using a single sort and multiple-key comparisons.
+class MultipleKeyRecordBatchSorter : public TypeVisitor {
+ private:
+  // Preprocessed sort key.
+  struct ResolvedSortKey {
+    ResolvedSortKey(const std::shared_ptr<Array>& array, const SortOrder order)
+        : type(GetPhysicalType(array->type())),
+          owned_array(GetPhysicalArray(*array, type)),
+          array(*owned_array),
+          order(order),
+          null_count(array->null_count()) {}
+
+    template <typename ArrayType>
+    ResolvedChunk<ArrayType> GetChunk(int64_t index) const {
+      return {&checked_cast<const ArrayType&>(array), index};
+    }
+
+    const std::shared_ptr<DataType> type;
+    std::shared_ptr<Array> owned_array;
+    const Array& array;
+    SortOrder order;
+    int64_t null_count;
+  };
+
+  using Comparator = MultipleKeyComparator<ResolvedSortKey>;
+
+ public:
+  MultipleKeyRecordBatchSorter(uint64_t* indices_begin, uint64_t* indices_end,
+                               const RecordBatch& batch, const SortOptions& options)
+      : indices_begin_(indices_begin),
+        indices_end_(indices_end),
+        sort_keys_(ResolveSortKeys(batch, options.sort_keys, &status_)),
+        comparator_(sort_keys_) {}
+
+  // This is optimized for the first sort key. The first sort key sort
+  // is processed in this class. The second and following sort keys
+  // are processed in Comparator.
+  Status Sort() {
+    RETURN_NOT_OK(status_);
+    return sort_keys_[0].type->Accept(this);
+  }
+
+#define VISIT(TYPE) \
+  Status Visit(const TYPE& type) override { return SortInternal<TYPE>(); }
+
+  VISIT_PHYSICAL_TYPES(VISIT)
+
+#undef VISIT
+
+ private:
+  static std::vector<ResolvedSortKey> ResolveSortKeys(
+      const RecordBatch& batch, const std::vector<SortKey>& sort_keys, Status* status) {
+    std::vector<ResolvedSortKey> resolved;
+    for (const auto& sort_key : sort_keys) {
+      auto array = batch.GetColumnByName(sort_key.name);
+      if (!array) {
+        *status = Status::Invalid("Nonexistent sort key column: ", sort_key.name);
+        break;
+      }
+      resolved.emplace_back(array, sort_key.order);
+    }
+    return resolved;
+  }
+
+  template <typename Type>
+  Status SortInternal() {
+    using ArrayType = typename TypeTraits<Type>::ArrayType;
+
+    auto& comparator = comparator_;
+    const auto& first_sort_key = sort_keys_[0];
+    const ArrayType& array = checked_cast<const ArrayType&>(first_sort_key.array);
+    auto nulls_begin = indices_end_;
+    nulls_begin = PartitionNullsInternal<Type>(first_sort_key);
+    // Sort first-key non-nulls
+    std::stable_sort(indices_begin_, nulls_begin, [&](uint64_t left, uint64_t right) {
+      // Both values are never null nor NaN
+      // (otherwise they've been partitioned away above).
+      const auto value_left = array.GetView(left);
+      const auto value_right = array.GetView(right);
+      if (value_left != value_right) {
+        bool compared = value_left < value_right;
+        if (first_sort_key.order == SortOrder::Ascending) {
+          return compared;
+        } else {
+          return !compared;
+        }
+      }
+      // If the left value equals to the right value,
+      // we need to compare the second and following
+      // sort keys.
+      return comparator.Compare(left, right, 1);
+    });
+    return comparator_.status();
+  }
+
+  // Behaves like PatitionNulls() but this supports multiple sort keys.
+  //
+  // For non-float types.
+  template <typename Type>
+  enable_if_t<!is_floating_type<Type>::value, uint64_t*> PartitionNullsInternal(
+      const ResolvedSortKey& first_sort_key) {
+    using ArrayType = typename TypeTraits<Type>::ArrayType;
+    if (first_sort_key.null_count == 0) {
+      return indices_end_;
+    }
+    const ArrayType& array = checked_cast<const ArrayType&>(first_sort_key.array);
+    StablePartitioner partitioner;
+    auto nulls_begin = partitioner(indices_begin_, indices_end_,
+                                   [&](uint64_t index) { return !array.IsNull(index); });
+    // Sort all nulls by second and following sort keys
+    // TODO: could we instead run an independent sort from the second key on
+    // this slice?
+    if (nulls_begin != indices_end_) {
+      auto& comparator = comparator_;
+      std::stable_sort(nulls_begin, indices_end_,
+                       [&comparator](uint64_t left, uint64_t right) {
+                         return comparator.Compare(left, right, 1);
+                       });
+    }
+    return nulls_begin;
+  }
+
+  // Behaves like PatitionNulls() but this supports multiple sort keys.
+  //
+  // For float types.
+  template <typename Type>
+  enable_if_t<is_floating_type<Type>::value, uint64_t*> PartitionNullsInternal(
+      const ResolvedSortKey& first_sort_key) {
+    using ArrayType = typename TypeTraits<Type>::ArrayType;
+    const ArrayType& array = checked_cast<const ArrayType&>(first_sort_key.array);
+    StablePartitioner partitioner;
+    uint64_t* nulls_begin;
+    if (first_sort_key.null_count == 0) {
+      nulls_begin = indices_end_;
+    } else {
+      nulls_begin = partitioner(indices_begin_, indices_end_,
+                                [&](uint64_t index) { return !array.IsNull(index); });
+    }
+    uint64_t* nans_and_nulls_begin =
+        partitioner(indices_begin_, nulls_begin,
+                    [&](uint64_t index) { return !std::isnan(array.GetView(index)); });
+    auto& comparator = comparator_;
+    if (nans_and_nulls_begin != nulls_begin) {
+      // Sort all NaNs by the second and following sort keys.
+      // TODO: could we instead run an independent sort from the second key on
+      // this slice?
+      std::stable_sort(nans_and_nulls_begin, nulls_begin,
+                       [&comparator](uint64_t left, uint64_t right) {
+                         return comparator.Compare(left, right, 1);
+                       });
+    }
+    if (nulls_begin != indices_end_) {
+      // Sort all nulls by the second and following sort keys.
+      // TODO: could we instead run an independent sort from the second key on
+      // this slice?
+      std::stable_sort(nulls_begin, indices_end_,
+                       [&comparator](uint64_t left, uint64_t right) {
+                         return comparator.Compare(left, right, 1);
+                       });
+    }
+    return nans_and_nulls_begin;
+  }
+
+  uint64_t* indices_begin_;
+  uint64_t* indices_end_;
+  Status status_;
+  std::vector<ResolvedSortKey> sort_keys_;
+  Comparator comparator_;
+};
+
+// ----------------------------------------------------------------------
 // Table sorting implementations
 
 // Sort a table using a radix sort-like algorithm.
@@ -834,6 +1355,10 @@ class TableRadixSorter {
 // Sort a table using a single sort and multiple-key comparisons.
 class MultipleKeyTableSorter : public TypeVisitor {
  private:
+  // TODO instead of resolving chunks for each column independently, we could
+  // split the table into RecordBatches and pay the cost of chunked indexing
+  // at the first column only.
+
   // Preprocessed sort key.
   struct ResolvedSortKey {
     ResolvedSortKey(const ChunkedArray& chunked_array, const SortOrder order)
@@ -861,229 +1386,76 @@ class MultipleKeyTableSorter : public TypeVisitor {
     const ChunkedArrayResolver resolver;
   };
 
-  // Compare two records in the same table.
-  class Comparer : public TypeVisitor {
-   public:
-    Comparer(const Table& table, const std::vector<SortKey>& sort_keys)
-        : TypeVisitor(),
-          status_(Status::OK()),
-          sort_keys_(ResolveSortKeys(table, sort_keys, &status_)) {}
-
-    Status status() { return status_; }
-
-    const std::vector<ResolvedSortKey>& sort_keys() { return sort_keys_; }
-
-    // Returns true if the left-th value should be ordered before the
-    // right-th value, false otherwise. The start_sort_key_index-th
-    // sort key and subsequent sort keys are used for comparison.
-    bool Compare(uint64_t left, uint64_t right, size_t start_sort_key_index) {
-      current_left_ = left;
-      current_right_ = right;
-      current_compared_ = 0;
-      auto num_sort_keys = sort_keys_.size();
-      for (size_t i = start_sort_key_index; i < num_sort_keys; ++i) {
-        current_sort_key_index_ = i;
-        status_ = sort_keys_[i].type->Accept(this);
-        // If the left value equals to the right value, we need to
-        // continue to sort.
-        if (current_compared_ != 0) {
-          break;
-        }
-      }
-      return current_compared_ < 0;
-    }
-
-#define VISIT(TYPE)                                \
-  Status Visit(const TYPE##Type& type) override {  \
-    current_compared_ = CompareType<TYPE##Type>(); \
-    return Status::OK();                           \
-  }
-
-    VISIT(Int8)
-    VISIT(Int16)
-    VISIT(Int32)
-    VISIT(Int64)
-    VISIT(UInt8)
-    VISIT(UInt16)
-    VISIT(UInt32)
-    VISIT(UInt64)
-    VISIT(Float)
-    VISIT(Double)
-    VISIT(Binary)
-    VISIT(LargeBinary)
-
-#undef VISIT
-
-   private:
-    // Compares two records in the same table and returns -1, 0 or 1.
-    //
-    // -1: The left is less than the right.
-    // 0: The left equals to the right.
-    // 1: The left is greater than the right.
-    //
-    // This supports null and NaN. Null is processed in this and NaN
-    // is processed in CompareTypeValue().
-    template <typename Type>
-    int32_t CompareType() {
-      using ArrayType = typename TypeTraits<Type>::ArrayType;
-      const auto& sort_key = sort_keys_[current_sort_key_index_];
-      auto order = sort_key.order;
-      const auto chunk_left = sort_key.GetChunk<ArrayType>(current_left_);
-      const auto chunk_right = sort_key.GetChunk<ArrayType>(current_right_);
-      if (sort_key.null_count > 0) {
-        auto is_null_left = chunk_left.IsNull();
-        auto is_null_right = chunk_right.IsNull();
-        if (is_null_left && is_null_right) {
-          return 0;
-        } else if (is_null_left) {
-          return 1;
-        } else if (is_null_right) {
-          return -1;
-        }
-      }
-      return CompareTypeValue<Type>(chunk_left, chunk_right, order);
-    }
-
-    // For non-float types. Value is never NaN.
-    template <typename Type>
-    enable_if_t<!is_floating_type<Type>::value, int32_t> CompareTypeValue(
-        const ResolvedChunk<typename TypeTraits<Type>::ArrayType>& chunk_left,
-        const ResolvedChunk<typename TypeTraits<Type>::ArrayType>& chunk_right,
-        const SortOrder order) {
-      const auto left = chunk_left.GetView();
-      const auto right = chunk_right.GetView();
-      int32_t compared;
-      if (left == right) {
-        compared = 0;
-      } else if (left > right) {
-        compared = 1;
-      } else {
-        compared = -1;
-      }
-      if (order == SortOrder::Descending) {
-        compared = -compared;
-      }
-      return compared;
-    }
-
-    // For float types. Value may be NaN.
-    template <typename Type>
-    enable_if_t<is_floating_type<Type>::value, int32_t> CompareTypeValue(
-        const ResolvedChunk<typename TypeTraits<Type>::ArrayType>& chunk_left,
-        const ResolvedChunk<typename TypeTraits<Type>::ArrayType>& chunk_right,
-        const SortOrder order) {
-      const auto left = chunk_left.GetView();
-      const auto right = chunk_right.GetView();
-      auto is_nan_left = std::isnan(left);
-      auto is_nan_right = std::isnan(right);
-      if (is_nan_left && is_nan_right) {
-        return 0;
-      } else if (is_nan_left) {
-        return 1;
-      } else if (is_nan_right) {
-        return -1;
-      }
-      int32_t compared;
-      if (left == right) {
-        compared = 0;
-      } else if (left > right) {
-        compared = 1;
-      } else {
-        compared = -1;
-      }
-      if (order == SortOrder::Descending) {
-        compared = -compared;
-      }
-      return compared;
-    }
-
-    static std::vector<ResolvedSortKey> ResolveSortKeys(
-        const Table& table, const std::vector<SortKey>& sort_keys, Status* status) {
-      std::vector<ResolvedSortKey> resolved;
-      resolved.reserve(sort_keys.size());
-      for (const auto& sort_key : sort_keys) {
-        const auto& chunked_array = table.GetColumnByName(sort_key.name);
-        if (!chunked_array) {
-          *status = Status::Invalid("Nonexistent sort key column: ", sort_key.name);
-          break;
-        }
-        resolved.emplace_back(*chunked_array, sort_key.order);
-      }
-      return resolved;
-    }
-
-    Status status_;
-    const std::vector<ResolvedSortKey> sort_keys_;
-    int64_t current_left_;
-    int64_t current_right_;
-    size_t current_sort_key_index_;
-    int32_t current_compared_;
-  };
+  using Comparator = MultipleKeyComparator<ResolvedSortKey>;
 
  public:
   MultipleKeyTableSorter(uint64_t* indices_begin, uint64_t* indices_end,
                          const Table& table, const SortOptions& options)
       : indices_begin_(indices_begin),
         indices_end_(indices_end),
-        comparer_(table, options.sort_keys) {}
+        sort_keys_(ResolveSortKeys(table, options.sort_keys, &status_)),
+        comparator_(sort_keys_) {}
 
   // This is optimized for the first sort key. The first sort key sort
   // is processed in this class. The second and following sort keys
-  // are processed in Comparer.
+  // are processed in Comparator.
   Status Sort() {
-    ARROW_RETURN_NOT_OK(comparer_.status());
-    return comparer_.sort_keys()[0].type->Accept(this);
+    ARROW_RETURN_NOT_OK(status_);
+    return sort_keys_[0].type->Accept(this);
   }
 
 #define VISIT(TYPE) \
-  Status Visit(const TYPE##Type& type) override { return SortInternal<TYPE##Type>(); }
-
-  VISIT(Int8)
-  VISIT(Int16)
-  VISIT(Int32)
-  VISIT(Int64)
-  VISIT(UInt8)
-  VISIT(UInt16)
-  VISIT(UInt32)
-  VISIT(UInt64)
-  VISIT(Float)
-  VISIT(Double)
-  VISIT(Binary)
-  VISIT(LargeBinary)
+  Status Visit(const TYPE& type) override { return SortInternal<TYPE>(); }
+
+  VISIT_PHYSICAL_TYPES(VISIT)
 
 #undef VISIT
 
  private:
+  static std::vector<ResolvedSortKey> ResolveSortKeys(
+      const Table& table, const std::vector<SortKey>& sort_keys, Status* status) {
+    std::vector<ResolvedSortKey> resolved;
+    resolved.reserve(sort_keys.size());
+    for (const auto& sort_key : sort_keys) {
+      const auto& chunked_array = table.GetColumnByName(sort_key.name);
+      if (!chunked_array) {
+        *status = Status::Invalid("Nonexistent sort key column: ", sort_key.name);
+        break;
+      }
+      resolved.emplace_back(*chunked_array, sort_key.order);
+    }
+    return resolved;
+  }
+
   template <typename Type>
   Status SortInternal() {
     using ArrayType = typename TypeTraits<Type>::ArrayType;
 
-    auto& comparer = comparer_;
-    const auto& first_sort_key = comparer.sort_keys()[0];
+    auto& comparator = comparator_;
+    const auto& first_sort_key = sort_keys_[0];
     auto nulls_begin = indices_end_;
     nulls_begin = PartitionNullsInternal<Type>(first_sort_key);
-    std::stable_sort(indices_begin_, nulls_begin,
-                     [&first_sort_key, &comparer](uint64_t left, uint64_t right) {
-                       // Both values are never null nor NaN.
-                       auto chunk_left = first_sort_key.GetChunk<ArrayType>(left);
-                       auto chunk_right = first_sort_key.GetChunk<ArrayType>(right);
-                       auto value_left = chunk_left.GetView();
-                       auto value_right = chunk_right.GetView();
-                       if (value_left == value_right) {
-                         // If the left value equals to the right value,
-                         // we need to compare the second and following
-                         // sort keys.
-                         return comparer.Compare(left, right, 1);
-                       } else {
-                         auto compared = value_left < value_right;
-                         if (first_sort_key.order == SortOrder::Ascending) {
-                           return compared;
-                         } else {
-                           return !compared;
-                         }
-                       }
-                     });
-    return Status::OK();
+    std::stable_sort(indices_begin_, nulls_begin, [&](uint64_t left, uint64_t right) {
+      // Both values are never null nor NaN.
+      auto chunk_left = first_sort_key.GetChunk<ArrayType>(left);
+      auto chunk_right = first_sort_key.GetChunk<ArrayType>(right);
+      auto value_left = chunk_left.GetView();
+      auto value_right = chunk_right.GetView();
+      if (value_left == value_right) {
+        // If the left value equals to the right value,
+        // we need to compare the second and following
+        // sort keys.
+        return comparator.Compare(left, right, 1);
+      } else {
+        auto compared = value_left < value_right;
+        if (first_sort_key.order == SortOrder::Ascending) {
+          return compared;
+        } else {
+          return !compared;
+        }
+      }
+    });
+    return comparator_.status();
   }
 
   // Behaves like PatitionNulls() but this supports multiple sort keys.
@@ -1102,11 +1474,11 @@ class MultipleKeyTableSorter : public TypeVisitor {
           const auto chunk = first_sort_key.GetChunk<ArrayType>(index);
           return !chunk.IsNull();
         });
-    auto& comparer = comparer_;
-    std::stable_sort(nulls_begin, indices_end_,
-                     [&comparer](uint64_t left, uint64_t right) {
-                       return comparer.Compare(left, right, 1);
-                     });
+    DCHECK_EQ(indices_end_ - nulls_begin, first_sort_key.null_count);
+    auto& comparator = comparator_;
+    std::stable_sort(nulls_begin, indices_end_, [&](uint64_t left, uint64_t right) {
+      return comparator.Compare(left, right, 1);
+    });
     return nulls_begin;
   }
 
@@ -1118,45 +1490,37 @@ class MultipleKeyTableSorter : public TypeVisitor {
       const ResolvedSortKey& first_sort_key) {
     using ArrayType = typename TypeTraits<Type>::ArrayType;
     StablePartitioner partitioner;
+    uint64_t* nulls_begin;
     if (first_sort_key.null_count == 0) {
-      return partitioner(indices_begin_, indices_end_, [&first_sort_key](uint64_t index) {
+      nulls_begin = indices_end_;
+    } else {
+      nulls_begin = partitioner(indices_begin_, indices_end_, [&](uint64_t index) {
         const auto chunk = first_sort_key.GetChunk<ArrayType>(index);
-        return !std::isnan(chunk.GetView());
+        return !chunk.IsNull();
       });
     }
-    auto nans_and_nulls_begin =
-        partitioner(indices_begin_, indices_end_, [&first_sort_key](uint64_t index) {
-          const auto chunk = first_sort_key.GetChunk<ArrayType>(index);
-          return !chunk.IsNull() && !std::isnan(chunk.GetView());
-        });
-    auto nulls_begin = nans_and_nulls_begin;
-    if (first_sort_key.null_count < static_cast<int64_t>(indices_end_ - nulls_begin)) {
-      // move nulls after NaN
-      nulls_begin = partitioner(
-          nans_and_nulls_begin, indices_end_, [&first_sort_key](uint64_t index) {
-            const auto chunk = first_sort_key.GetChunk<ArrayType>(index);
-            return !chunk.IsNull();
-          });
-    }
-    auto& comparer = comparer_;
-    if (nans_and_nulls_begin != nulls_begin) {
-      // Sort all NaNs by the second and following sort keys.
-      std::stable_sort(nans_and_nulls_begin, nulls_begin,
-                       [&comparer](uint64_t left, uint64_t right) {
-                         return comparer.Compare(left, right, 1);
-                       });
-    }
+    DCHECK_EQ(indices_end_ - nulls_begin, first_sort_key.null_count);
+    uint64_t* nans_begin = partitioner(indices_begin_, nulls_begin, [&](uint64_t index) {
+      const auto chunk = first_sort_key.GetChunk<ArrayType>(index);
+      return !std::isnan(chunk.GetView());
+    });
+    auto& comparator = comparator_;
+    // Sort all NaNs by the second and following sort keys.
+    std::stable_sort(nans_begin, nulls_begin, [&](uint64_t left, uint64_t right) {
+      return comparator.Compare(left, right, 1);
+    });
     // Sort all nulls by the second and following sort keys.
-    std::stable_sort(nulls_begin, indices_end_,
-                     [&comparer](uint64_t left, uint64_t right) {
-                       return comparer.Compare(left, right, 1);
-                     });
-    return nans_and_nulls_begin;
+    std::stable_sort(nulls_begin, indices_end_, [&](uint64_t left, uint64_t right) {
+      return comparator.Compare(left, right, 1);
+    });
+    return nans_begin;
   }
 
   uint64_t* indices_begin_;
   uint64_t* indices_end_;
-  Comparer comparer_;
+  Status status_;
+  std::vector<ResolvedSortKey> sort_keys_;
+  Comparator comparator_;
 };
 
 // ----------------------------------------------------------------------
@@ -1188,9 +1552,7 @@ class SortIndicesMetaFunction : public MetaFunction {
         return SortIndices(*args[0].chunked_array(), sort_options, ctx);
         break;
       case Datum::RECORD_BATCH: {
-        ARROW_ASSIGN_OR_RAISE(auto table,
-                              Table::FromRecordBatches({args[0].record_batch()}));
-        return SortIndices(*table, sort_options, ctx);
+        return SortIndices(*args[0].record_batch(), sort_options, ctx);
       } break;
       case Datum::TABLE:
         return SortIndices(*args[0].table(), sort_options, ctx);
@@ -1239,6 +1601,46 @@ class SortIndicesMetaFunction : public MetaFunction {
     return Datum(out);
   }
 
+  Result<Datum> SortIndices(const RecordBatch& batch, const SortOptions& options,
+                            ExecContext* ctx) const {
+    auto n_sort_keys = options.sort_keys.size();
+    if (n_sort_keys == 0) {
+      return Status::Invalid("Must specify one or more sort keys");
+    }
+    if (n_sort_keys == 1) {
+      auto array = batch.GetColumnByName(options.sort_keys[0].name);
+      if (!array) {
+        return Status::Invalid("Nonexistent sort key column: ",
+                               options.sort_keys[0].name);
+      }
+      return SortIndices(*array, options, ctx);
+    }
+
+    auto out_type = uint64();
+    auto length = batch.num_rows();
+    auto buffer_size = BitUtil::BytesForBits(
+        length * std::static_pointer_cast<UInt64Type>(out_type)->bit_width());
+    BufferVector buffers(2);
+    ARROW_ASSIGN_OR_RAISE(buffers[1],
+                          AllocateResizableBuffer(buffer_size, ctx->memory_pool()));
+    auto out = std::make_shared<ArrayData>(out_type, length, buffers, 0);
+    auto out_begin = out->GetMutableValues<uint64_t>(1);
+    auto out_end = out_begin + length;
+    std::iota(out_begin, out_end, 0);
+
+    // Radix sorting is consistently faster except when there is a large number
+    // of sort keys, in which case it can end up degrading catastrophically.
+    // Cut off above 8 sort keys.
+    if (n_sort_keys <= 8) {
+      RadixRecordBatchSorter sorter(out_begin, out_end, batch, options);
+      ARROW_RETURN_NOT_OK(sorter.Sort());
+    } else {
+      MultipleKeyRecordBatchSorter sorter(out_begin, out_end, batch, options);
+      ARROW_RETURN_NOT_OK(sorter.Sort());
+    }
+    return Datum(out);
+  }
+
   Result<Datum> SortIndices(const Table& table, const SortOptions& options,
                             ExecContext* ctx) const {
     auto n_sort_keys = options.sort_keys.size();
@@ -1330,6 +1732,8 @@ void RegisterVectorSort(FunctionRegistry* registry) {
   DCHECK_OK(registry->AddFunction(std::move(part_indices)));
 }
 
+#undef VISIT_PHYSICAL_TYPES
+
 }  // namespace internal
 }  // namespace compute
 }  // namespace arrow
diff --git a/cpp/src/arrow/compute/kernels/vector_sort_benchmark.cc b/cpp/src/arrow/compute/kernels/vector_sort_benchmark.cc
index f48d69e..820c51b 100644
--- a/cpp/src/arrow/compute/kernels/vector_sort_benchmark.cc
+++ b/cpp/src/arrow/compute/kernels/vector_sort_benchmark.cc
@@ -23,6 +23,7 @@
 #include "arrow/testing/gtest_util.h"
 #include "arrow/testing/random.h"
 #include "arrow/util/benchmark_util.h"
+#include "arrow/util/logging.h"
 
 namespace arrow {
 namespace compute {
@@ -90,16 +91,15 @@ static void ChunkedArraySortIndicesInt64Wide(benchmark::State& state) {
   ChunkedArraySortIndicesInt64Benchmark(state, min, max);
 }
 
-static void TableSortIndicesBenchmark(benchmark::State& state,
-                                      const std::shared_ptr<Table>& table,
+static void DatumSortIndicesBenchmark(benchmark::State& state, const Datum& datum,
                                       const SortOptions& options) {
   for (auto _ : state) {
-    ABORT_NOT_OK(SortIndices(Datum(*table), options).status());
+    ABORT_NOT_OK(SortIndices(datum, options).status());
   }
 }
 
 // Extract benchmark args from benchmark::State
-struct TableSortIndicesArgs {
+struct RecordBatchSortIndicesArgs {
   // the number of records
   const int64_t num_records;
 
@@ -109,20 +109,18 @@ struct TableSortIndicesArgs {
   // the number of columns
   const int64_t num_columns;
 
-  // the number of chunks in each generated column
-  const int64_t num_chunks;
-
   // Extract args
-  explicit TableSortIndicesArgs(benchmark::State& state)
+  explicit RecordBatchSortIndicesArgs(benchmark::State& state)
       : num_records(state.range(0)),
         null_proportion(ComputeNullProportion(state.range(1))),
         num_columns(state.range(2)),
-        num_chunks(state.range(3)),
         state_(state) {}
 
-  ~TableSortIndicesArgs() { state_.SetItemsProcessed(state_.iterations() * num_records); }
+  ~RecordBatchSortIndicesArgs() {
+    state_.SetItemsProcessed(state_.iterations() * num_records);
+  }
 
- private:
+ protected:
   double ComputeNullProportion(int64_t inverse_null_proportion) {
     if (inverse_null_proportion == 0) {
       return 0.0;
@@ -134,37 +132,86 @@ struct TableSortIndicesArgs {
   benchmark::State& state_;
 };
 
-static void TableSortIndicesInt64(benchmark::State& state, int64_t min, int64_t max) {
-  TableSortIndicesArgs args(state);
+struct TableSortIndicesArgs : public RecordBatchSortIndicesArgs {
+  // the number of chunks in each generated column
+  const int64_t num_chunks;
 
-  auto rand = random::RandomArrayGenerator(kSeed);
-  std::vector<std::shared_ptr<Field>> fields;
+  // Extract args
+  explicit TableSortIndicesArgs(benchmark::State& state)
+      : RecordBatchSortIndicesArgs(state), num_chunks(state.range(3)) {}
+};
+
+struct BatchOrTableBenchmarkData {
+  std::shared_ptr<Schema> schema;
   std::vector<SortKey> sort_keys;
-  std::vector<std::shared_ptr<ChunkedArray>> columns;
+  ChunkedArrayVector columns;
+};
+
+BatchOrTableBenchmarkData MakeBatchOrTableBenchmarkDataInt64(
+    const RecordBatchSortIndicesArgs& args, int64_t num_chunks, int64_t min_value,
+    int64_t max_value) {
+  auto rand = random::RandomArrayGenerator(kSeed);
+  FieldVector fields;
+  BatchOrTableBenchmarkData data;
+
   for (int64_t i = 0; i < args.num_columns; ++i) {
     auto name = std::to_string(i);
     fields.push_back(field(name, int64()));
     auto order = (i % 2) == 0 ? SortOrder::Ascending : SortOrder::Descending;
-    sort_keys.emplace_back(name, order);
-    std::vector<std::shared_ptr<Array>> arrays;
-    if ((args.num_records % args.num_chunks) != 0) {
-      Status::Invalid("The number of chunks (", args.num_chunks,
+    data.sort_keys.emplace_back(name, order);
+    ArrayVector chunks;
+    if ((args.num_records % num_chunks) != 0) {
+      Status::Invalid("The number of chunks (", num_chunks,
                       ") must be "
                       "a multiple of the number of records (",
                       args.num_records, ")")
           .Abort();
     }
-    auto num_records_in_array = args.num_records / args.num_chunks;
-    for (int64_t j = 0; j < args.num_chunks; ++j) {
-      arrays.push_back(rand.Int64(num_records_in_array, min, max, args.null_proportion));
+    auto num_records_in_array = args.num_records / num_chunks;
+    for (int64_t j = 0; j < num_chunks; ++j) {
+      chunks.push_back(
+          rand.Int64(num_records_in_array, min_value, max_value, args.null_proportion));
     }
-    ASSIGN_OR_ABORT(auto chunked_array, ChunkedArray::Make(arrays, int64()));
-    columns.push_back(chunked_array);
+    ASSIGN_OR_ABORT(auto chunked_array, ChunkedArray::Make(chunks, int64()));
+    data.columns.push_back(chunked_array);
+  }
+
+  data.schema = schema(fields);
+  return data;
+}
+
+static void RecordBatchSortIndicesInt64(benchmark::State& state, int64_t min,
+                                        int64_t max) {
+  RecordBatchSortIndicesArgs args(state);
+
+  auto data = MakeBatchOrTableBenchmarkDataInt64(args, /*num_chunks=*/1, min, max);
+  ArrayVector columns;
+  for (const auto& chunked : data.columns) {
+    ARROW_CHECK_EQ(chunked->num_chunks(), 1);
+    columns.push_back(chunked->chunk(0));
   }
 
-  auto table = Table::Make(schema(fields), columns, args.num_records);
-  SortOptions options(sort_keys);
-  TableSortIndicesBenchmark(state, table, options);
+  auto batch = RecordBatch::Make(data.schema, args.num_records, columns);
+  SortOptions options(data.sort_keys);
+  DatumSortIndicesBenchmark(state, Datum(*batch), options);
+}
+
+static void TableSortIndicesInt64(benchmark::State& state, int64_t min, int64_t max) {
+  TableSortIndicesArgs args(state);
+
+  auto data = MakeBatchOrTableBenchmarkDataInt64(args, args.num_chunks, min, max);
+  auto table = Table::Make(data.schema, data.columns, args.num_records);
+  SortOptions options(data.sort_keys);
+  DatumSortIndicesBenchmark(state, Datum(*table), options);
+}
+
+static void RecordBatchSortIndicesInt64Narrow(benchmark::State& state) {
+  RecordBatchSortIndicesInt64(state, -100, 100);
+}
+
+static void RecordBatchSortIndicesInt64Wide(benchmark::State& state) {
+  RecordBatchSortIndicesInt64(state, std::numeric_limits<int64_t>::min(),
+                              std::numeric_limits<int64_t>::max());
 }
 
 static void TableSortIndicesInt64Narrow(benchmark::State& state) {
@@ -180,28 +227,40 @@ BENCHMARK(ArraySortIndicesInt64Narrow)
     ->Apply(RegressionSetArgs)
     ->Args({1 << 20, 100})
     ->Args({1 << 23, 100})
-    ->MinTime(1.0)
     ->Unit(benchmark::TimeUnit::kNanosecond);
 
 BENCHMARK(ArraySortIndicesInt64Wide)
     ->Apply(RegressionSetArgs)
     ->Args({1 << 20, 100})
     ->Args({1 << 23, 100})
-    ->MinTime(1.0)
     ->Unit(benchmark::TimeUnit::kNanosecond);
 
 BENCHMARK(ChunkedArraySortIndicesInt64Narrow)
     ->Apply(RegressionSetArgs)
     ->Args({1 << 20, 100})
     ->Args({1 << 23, 100})
-    ->MinTime(1.0)
     ->Unit(benchmark::TimeUnit::kNanosecond);
 
 BENCHMARK(ChunkedArraySortIndicesInt64Wide)
     ->Apply(RegressionSetArgs)
     ->Args({1 << 20, 100})
     ->Args({1 << 23, 100})
-    ->MinTime(1.0)
+    ->Unit(benchmark::TimeUnit::kNanosecond);
+
+BENCHMARK(RecordBatchSortIndicesInt64Narrow)
+    ->ArgsProduct({
+        {1 << 20},      // the number of records
+        {100, 0},       // inverse null proportion
+        {16, 8, 2, 1},  // the number of columns
+    })
+    ->Unit(benchmark::TimeUnit::kNanosecond);
+
+BENCHMARK(RecordBatchSortIndicesInt64Wide)
+    ->ArgsProduct({
+        {1 << 20},      // the number of records
+        {100, 0},       // inverse null proportion
+        {16, 8, 2, 1},  // the number of columns
+    })
     ->Unit(benchmark::TimeUnit::kNanosecond);
 
 BENCHMARK(TableSortIndicesInt64Narrow)
@@ -211,7 +270,6 @@ BENCHMARK(TableSortIndicesInt64Narrow)
         {16, 8, 2, 1},  // the number of columns
         {32, 4, 1},     // the number of chunks
     })
-    ->MinTime(1.0)
     ->Unit(benchmark::TimeUnit::kNanosecond);
 
 BENCHMARK(TableSortIndicesInt64Wide)
@@ -221,7 +279,6 @@ BENCHMARK(TableSortIndicesInt64Wide)
         {16, 8, 2, 1},  // the number of columns
         {32, 4, 1},     // the number of chunks
     })
-    ->MinTime(1.0)
     ->Unit(benchmark::TimeUnit::kNanosecond);
 
 }  // namespace compute
diff --git a/cpp/src/arrow/compute/kernels/vector_sort_test.cc b/cpp/src/arrow/compute/kernels/vector_sort_test.cc
index 4c42cff..0c9cad5 100644
--- a/cpp/src/arrow/compute/kernels/vector_sort_test.cc
+++ b/cpp/src/arrow/compute/kernels/vector_sort_test.cc
@@ -63,6 +63,9 @@ TypeToDataType() {
   return time64(TimeUnit::NANO);
 }
 
+// ----------------------------------------------------------------------
+// Tests for NthToIndices
+
 template <typename ArrayType>
 class NthComparator {
  public:
@@ -227,6 +230,32 @@ class Random : public RandomImpl {
 };
 
 template <>
+class Random<FloatType> : public RandomImpl {
+  using CType = float;
+
+ public:
+  explicit Random(random::SeedType seed) : RandomImpl(seed) {}
+
+  std::shared_ptr<Array> Generate(uint64_t count, double null_prob, double nan_prob = 0) {
+    return generator.Float32(count, std::numeric_limits<CType>::min(),
+                             std::numeric_limits<CType>::max(), null_prob, nan_prob);
+  }
+};
+
+template <>
+class Random<DoubleType> : public RandomImpl {
+  using CType = double;
+
+ public:
+  explicit Random(random::SeedType seed) : RandomImpl(seed) {}
+
+  std::shared_ptr<Array> Generate(uint64_t count, double null_prob, double nan_prob = 0) {
+    return generator.Float64(count, std::numeric_limits<CType>::min(),
+                             std::numeric_limits<CType>::max(), null_prob, nan_prob);
+  }
+};
+
+template <>
 class Random<StringType> : public RandomImpl {
  public:
   explicit Random(random::SeedType seed) : RandomImpl(seed) {}
@@ -267,24 +296,41 @@ TYPED_TEST(TestNthToIndicesRandom, RandomValues) {
   }
 }
 
-using arrow::internal::checked_pointer_cast;
+// ----------------------------------------------------------------------
+// Tests for SortToIndices
+
+template <typename T>
+void AssertSortIndices(const std::shared_ptr<T>& input, SortOrder order,
+                       const std::shared_ptr<Array>& expected) {
+  ASSERT_OK_AND_ASSIGN(auto actual, SortIndices(*input, order));
+  ASSERT_OK(actual->ValidateFull());
+  AssertArraysEqual(*expected, *actual, /*verbose=*/true);
+}
+
+template <typename T>
+void AssertSortIndices(const std::shared_ptr<T>& input, const SortOptions& options,
+                       const std::shared_ptr<Array>& expected) {
+  ASSERT_OK_AND_ASSIGN(auto actual, SortIndices(Datum(*input), options));
+  ASSERT_OK(actual->ValidateFull());
+  AssertArraysEqual(*expected, *actual, /*verbose=*/true);
+}
+
+// `Options` may be both SortOptions or SortOrder
+template <typename T, typename Options>
+void AssertSortIndices(const std::shared_ptr<T>& input, Options&& options,
+                       const std::string& expected) {
+  AssertSortIndices(input, std::forward<Options>(options),
+                    ArrayFromJSON(uint64(), expected));
+}
 
 template <typename ArrowType>
 class TestArraySortIndicesKernel : public TestBase {
- private:
-  void AssertArraysSortIndices(const std::shared_ptr<Array> values, SortOrder order,
-                               const std::shared_ptr<Array> expected) {
-    ASSERT_OK_AND_ASSIGN(std::shared_ptr<Array> actual, SortIndices(*values, order));
-    ASSERT_OK(actual->ValidateFull());
-    AssertArraysEqual(*expected, *actual);
-  }
-
- protected:
+ public:
   virtual void AssertSortIndices(const std::string& values, SortOrder order,
                                  const std::string& expected) {
     auto type = TypeToDataType<ArrowType>();
-    AssertArraysSortIndices(ArrayFromJSON(type, values), order,
-                            ArrayFromJSON(uint64(), expected));
+    arrow::compute::AssertSortIndices(ArrayFromJSON(type, values), order,
+                                      ArrayFromJSON(uint64(), expected));
   }
 
   virtual void AssertSortIndices(const std::string& values, const std::string& expected) {
@@ -494,19 +540,7 @@ TYPED_TEST(TestArraySortIndicesKernelRandomCompare, SortRandomValuesCompare) {
 }
 
 // Test basic cases for chunked array.
-class TestChunkedArraySortIndices : public ::testing::Test {
- protected:
-  void AssertSortIndices(const std::shared_ptr<ChunkedArray> chunked_array,
-                         SortOrder order, const std::shared_ptr<Array> expected) {
-    ASSERT_OK_AND_ASSIGN(auto actual, SortIndices(*chunked_array, order));
-    AssertArraysEqual(*expected, *actual, /*verbose=*/true);
-  }
-
-  void AssertSortIndices(const std::shared_ptr<ChunkedArray> chunked_array,
-                         SortOrder order, const std::string expected) {
-    AssertSortIndices(chunked_array, order, ArrayFromJSON(uint64(), expected));
-  }
-};
+class TestChunkedArraySortIndices : public ::testing::Test {};
 
 TEST_F(TestChunkedArraySortIndices, Null) {
   auto chunked_array = ChunkedArrayFromJSON(uint8(), {
@@ -514,8 +548,8 @@ TEST_F(TestChunkedArraySortIndices, Null) {
                                                          "[3, null, 2]",
                                                          "[1]",
                                                      });
-  this->AssertSortIndices(chunked_array, SortOrder::Ascending, "[1, 5, 4, 2, 0, 3]");
-  this->AssertSortIndices(chunked_array, SortOrder::Descending, "[2, 4, 1, 5, 0, 3]");
+  AssertSortIndices(chunked_array, SortOrder::Ascending, "[1, 5, 4, 2, 0, 3]");
+  AssertSortIndices(chunked_array, SortOrder::Descending, "[2, 4, 1, 5, 0, 3]");
 }
 
 TEST_F(TestChunkedArraySortIndices, NaN) {
@@ -524,8 +558,8 @@ TEST_F(TestChunkedArraySortIndices, NaN) {
                                                            "[3, null, NaN]",
                                                            "[NaN, 1]",
                                                        });
-  this->AssertSortIndices(chunked_array, SortOrder::Ascending, "[1, 6, 2, 4, 5, 0, 3]");
-  this->AssertSortIndices(chunked_array, SortOrder::Descending, "[2, 1, 6, 4, 5, 0, 3]");
+  AssertSortIndices(chunked_array, SortOrder::Ascending, "[1, 6, 2, 4, 5, 0, 3]");
+  AssertSortIndices(chunked_array, SortOrder::Descending, "[2, 1, 6, 4, 5, 0, 3]");
 }
 
 // Tests for temporal types
@@ -543,8 +577,8 @@ TYPED_TEST(TestChunkedArraySortIndicesForTemporal, NoNull) {
                                                       "[3, 2, 1]",
                                                       "[5, 0]",
                                                   });
-  this->AssertSortIndices(chunked_array, SortOrder::Ascending, "[0, 6, 1, 4, 3, 2, 5]");
-  this->AssertSortIndices(chunked_array, SortOrder::Descending, "[5, 2, 3, 1, 4, 0, 6]");
+  AssertSortIndices(chunked_array, SortOrder::Ascending, "[0, 6, 1, 4, 3, 2, 5]");
+  AssertSortIndices(chunked_array, SortOrder::Descending, "[5, 2, 3, 1, 4, 0, 6]");
 }
 
 // Base class for testing against random chunked array.
@@ -622,57 +656,206 @@ class TestChunkedArrayRandomNarrow : public TestChunkedArrayRandomBase<Type> {
 TYPED_TEST_SUITE(TestChunkedArrayRandomNarrow, IntegralArrowTypes);
 TYPED_TEST(TestChunkedArrayRandomNarrow, SortIndices) { this->TestSortIndices(1000); }
 
-// Test basic cases for table.
-class TestTableSortIndices : public ::testing::Test {
- protected:
-  void AssertSortIndices(const std::shared_ptr<Table> table, const SortOptions& options,
-                         const std::shared_ptr<Array> expected) {
-    ASSERT_OK_AND_ASSIGN(auto actual, SortIndices(Datum(*table), options));
-    AssertArraysEqual(*expected, *actual);
-  }
+// Test basic cases for record batch.
+class TestRecordBatchSortIndices : public ::testing::Test {};
 
-  void AssertSortIndices(const std::shared_ptr<Table> table, const SortOptions& options,
-                         const std::string expected) {
-    AssertSortIndices(table, options, ArrayFromJSON(uint64(), expected));
-  }
-};
+TEST_F(TestRecordBatchSortIndices, NoNull) {
+  auto schema = ::arrow::schema({
+      {field("a", uint8())},
+      {field("b", uint32())},
+  });
+  SortOptions options(
+      {SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Descending)});
+
+  auto batch = RecordBatchFromJSON(schema,
+                                   R"([{"a": 3,    "b": 5},
+                                       {"a": 1,    "b": 3},
+                                       {"a": 3,    "b": 4},
+                                       {"a": 0,    "b": 6},
+                                       {"a": 2,    "b": 5},
+                                       {"a": 1,    "b": 5},
+                                       {"a": 1,    "b": 3}
+                                       ])");
+  AssertSortIndices(batch, options, "[3, 5, 1, 6, 4, 0, 2]");
+}
+
+TEST_F(TestRecordBatchSortIndices, Null) {
+  auto schema = ::arrow::schema({
+      {field("a", uint8())},
+      {field("b", uint32())},
+  });
+  SortOptions options(
+      {SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Descending)});
+
+  auto batch = RecordBatchFromJSON(schema,
+                                   R"([{"a": null, "b": 5},
+                                       {"a": 1,    "b": 3},
+                                       {"a": 3,    "b": null},
+                                       {"a": null, "b": null},
+                                       {"a": 2,    "b": 5},
+                                       {"a": 1,    "b": 5}
+                                       ])");
+  AssertSortIndices(batch, options, "[5, 1, 4, 2, 0, 3]");
+}
+
+TEST_F(TestRecordBatchSortIndices, NaN) {
+  auto schema = ::arrow::schema({
+      {field("a", float32())},
+      {field("b", float64())},
+  });
+  SortOptions options(
+      {SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Descending)});
+
+  auto batch = RecordBatchFromJSON(schema,
+                                   R"([{"a": 3,    "b": 5},
+                                       {"a": 1,    "b": NaN},
+                                       {"a": 3,    "b": 4},
+                                       {"a": 0,    "b": 6},
+                                       {"a": NaN,  "b": 5},
+                                       {"a": NaN,  "b": NaN},
+                                       {"a": NaN,  "b": 5},
+                                       {"a": 1,    "b": 5}
+                                      ])");
+  AssertSortIndices(batch, options, "[3, 7, 1, 0, 2, 4, 6, 5]");
+}
+
+TEST_F(TestRecordBatchSortIndices, NaNAndNull) {
+  auto schema = ::arrow::schema({
+      {field("a", float32())},
+      {field("b", float64())},
+  });
+  SortOptions options(
+      {SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Descending)});
+
+  auto batch = RecordBatchFromJSON(schema,
+                                   R"([{"a": null, "b": 5},
+                                       {"a": 1,    "b": 3},
+                                       {"a": 3,    "b": null},
+                                       {"a": null, "b": null},
+                                       {"a": NaN,  "b": null},
+                                       {"a": NaN,  "b": NaN},
+                                       {"a": NaN,  "b": 5},
+                                       {"a": 1,    "b": 5}
+                                      ])");
+  AssertSortIndices(batch, options, "[7, 1, 2, 6, 5, 4, 0, 3]");
+}
+
+TEST_F(TestRecordBatchSortIndices, MoreTypes) {
+  auto schema = ::arrow::schema({
+      {field("a", timestamp(TimeUnit::MICRO))},
+      {field("b", large_utf8())},
+  });
+  SortOptions options(
+      {SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Descending)});
+
+  auto batch = RecordBatchFromJSON(schema,
+                                   R"([{"a": 3,    "b": "05"},
+                                       {"a": 1,    "b": "03"},
+                                       {"a": 3,    "b": "04"},
+                                       {"a": 0,    "b": "06"},
+                                       {"a": 2,    "b": "05"},
+                                       {"a": 1,    "b": "05"}
+                                       ])");
+  AssertSortIndices(batch, options, "[3, 5, 1, 4, 0, 2]");
+}
+
+// Test basic cases for table.
+class TestTableSortIndices : public ::testing::Test {};
 
 TEST_F(TestTableSortIndices, Null) {
-  auto table = TableFromJSON(schema({
-                                 {field("a", uint8())},
-                                 {field("b", uint8())},
-                             }),
-                             {"["
-                              "{\"a\": null, \"b\": 5},"
-                              "{\"a\": 1,    \"b\": 3},"
-                              "{\"a\": 3,    \"b\": null},"
-                              "{\"a\": null, \"b\": null},"
-                              "{\"a\": 2,    \"b\": 5},"
-                              "{\"a\": 1,    \"b\": 5}"
-                              "]"});
+  auto schema = ::arrow::schema({
+      {field("a", uint8())},
+      {field("b", uint32())},
+  });
   SortOptions options(
       {SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Descending)});
-  this->AssertSortIndices(table, options, "[5, 1, 4, 2, 0, 3]");
+  std::shared_ptr<Table> table;
+
+  table = TableFromJSON(schema, {R"([{"a": null, "b": 5},
+                                     {"a": 1,    "b": 3},
+                                     {"a": 3,    "b": null},
+                                     {"a": null, "b": null},
+                                     {"a": 2,    "b": 5},
+                                     {"a": 1,    "b": 5}
+                                    ])"});
+  AssertSortIndices(table, options, "[5, 1, 4, 2, 0, 3]");
+
+  // Same data, several chunks
+  table = TableFromJSON(schema, {R"([{"a": null, "b": 5},
+                                     {"a": 1,    "b": 3},
+                                     {"a": 3,    "b": null}
+                                    ])",
+                                 R"([{"a": null, "b": null},
+                                     {"a": 2,    "b": 5},
+                                     {"a": 1,    "b": 5}
+                                    ])"});
+  AssertSortIndices(table, options, "[5, 1, 4, 2, 0, 3]");
 }
 
 TEST_F(TestTableSortIndices, NaN) {
-  auto table = TableFromJSON(schema({
-                                 {field("a", float32())},
-                                 {field("b", float32())},
-                             }),
-                             {"["
-                              "{\"a\": null, \"b\": 5},"
-                              "{\"a\": 1,    \"b\": 3},"
-                              "{\"a\": 3,    \"b\": null},"
-                              "{\"a\": null, \"b\": null},"
-                              "{\"a\": NaN,  \"b\": null},"
-                              "{\"a\": NaN,  \"b\": NaN},"
-                              "{\"a\": NaN,  \"b\": 5},"
-                              "{\"a\": 1,    \"b\": 5}"
-                              "]"});
+  auto schema = ::arrow::schema({
+      {field("a", float32())},
+      {field("b", float64())},
+  });
   SortOptions options(
       {SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Descending)});
-  this->AssertSortIndices(table, options, "[7, 1, 2, 6, 5, 4, 0, 3]");
+  std::shared_ptr<Table> table;
+  table = TableFromJSON(schema, {R"([{"a": 3,    "b": 5},
+                                     {"a": 1,    "b": NaN},
+                                     {"a": 3,    "b": 4},
+                                     {"a": 0,    "b": 6},
+                                     {"a": NaN,  "b": 5},
+                                     {"a": NaN,  "b": NaN},
+                                     {"a": NaN,  "b": 5},
+                                     {"a": 1,    "b": 5}
+                                    ])"});
+  AssertSortIndices(table, options, "[3, 7, 1, 0, 2, 4, 6, 5]");
+
+  // Same data, several chunks
+  table = TableFromJSON(schema, {R"([{"a": 3,    "b": 5},
+                                     {"a": 1,    "b": NaN},
+                                     {"a": 3,    "b": 4},
+                                     {"a": 0,    "b": 6}
+                                    ])",
+                                 R"([{"a": NaN,  "b": 5},
+                                     {"a": NaN,  "b": NaN},
+                                     {"a": NaN,  "b": 5},
+                                     {"a": 1,    "b": 5}
+                                    ])"});
+  AssertSortIndices(table, options, "[3, 7, 1, 0, 2, 4, 6, 5]");
+}
+
+TEST_F(TestTableSortIndices, NaNAndNull) {
+  auto schema = ::arrow::schema({
+      {field("a", float32())},
+      {field("b", float64())},
+  });
+  SortOptions options(
+      {SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Descending)});
+  std::shared_ptr<Table> table;
+  table = TableFromJSON(schema, {R"([{"a": null, "b": 5},
+                                     {"a": 1,    "b": 3},
+                                     {"a": 3,    "b": null},
+                                     {"a": null, "b": null},
+                                     {"a": NaN,  "b": null},
+                                     {"a": NaN,  "b": NaN},
+                                     {"a": NaN,  "b": 5},
+                                     {"a": 1,    "b": 5}
+                                    ])"});
+  AssertSortIndices(table, options, "[7, 1, 2, 6, 5, 4, 0, 3]");
+
+  // Same data, several chunks
+  table = TableFromJSON(schema, {R"([{"a": null, "b": 5},
+                                     {"a": 1,    "b": 3},
+                                     {"a": 3,    "b": null},
+                                     {"a": null, "b": null}
+                                    ])",
+                                 R"([{"a": NaN,  "b": null},
+                                     {"a": NaN,  "b": NaN},
+                                     {"a": NaN,  "b": 5},
+                                     {"a": 1,    "b": 5}
+                                    ])"});
+  AssertSortIndices(table, options, "[7, 1, 2, 6, 5, 4, 0, 3]");
 }
 
 // Tests for temporal types
@@ -701,7 +884,7 @@ TYPED_TEST(TestTableSortIndicesForTemporal, NoNull) {
                               "]"});
   SortOptions options(
       {SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Descending)});
-  this->AssertSortIndices(table, options, "[0, 6, 1, 4, 7, 3, 2, 5]");
+  AssertSortIndices(table, options, "[0, 6, 1, 4, 7, 3, 2, 5]");
 }
 
 // For random table tests.
@@ -733,7 +916,8 @@ class TestTableSortIndicesRandom : public testing::TestWithParam<RandomParam> {
         if (lhs_array_->IsNull(lhs_index_)) return false;
         status_ = lhs_array_->type()->Accept(this);
         if (compared_ == 0) continue;
-        if (pair.second == SortOrder::Ascending) {
+        // If either value is NaN, it must sort after the other regardless of order
+        if (pair.second == SortOrder::Ascending || lhs_isnan_ || rhs_isnan_) {
           return compared_ < 0;
         } else {
           return compared_ > 0;
@@ -791,11 +975,14 @@ class TestTableSortIndicesRandom : public testing::TestWithParam<RandomParam> {
       auto lhs_value = checked_cast<const ArrayType*>(lhs_array_)->GetView(lhs_index_);
       auto rhs_value = checked_cast<const ArrayType*>(rhs_array_)->GetView(rhs_index_);
       if (is_floating_type<Type>::value) {
-        const bool lhs_isnan = lhs_value != lhs_value;
-        const bool rhs_isnan = rhs_value != rhs_value;
-        if (lhs_isnan && rhs_isnan) return 0;
-        if (rhs_isnan) return 1;
-        if (lhs_isnan) return -1;
+        lhs_isnan_ = lhs_value != lhs_value;
+        rhs_isnan_ = rhs_value != rhs_value;
+        if (lhs_isnan_ && rhs_isnan_) return 0;
+        // NaN is considered greater than non-NaN
+        if (rhs_isnan_) return -1;
+        if (lhs_isnan_) return 1;
+      } else {
+        lhs_isnan_ = rhs_isnan_ = false;
       }
       if (lhs_value == rhs_value) {
         return 0;
@@ -814,6 +1001,7 @@ class TestTableSortIndicesRandom : public testing::TestWithParam<RandomParam> {
     int64_t rhs_;
     const Array* rhs_array_;
     int64_t rhs_index_;
+    bool lhs_isnan_, rhs_isnan_;
     int compared_;
     Status status_;
   };
@@ -826,8 +1014,8 @@ class TestTableSortIndicesRandom : public testing::TestWithParam<RandomParam> {
     for (int i = 1; i < table.num_rows(); i++) {
       uint64_t lhs = offsets.Value(i - 1);
       uint64_t rhs = offsets.Value(i);
-      ASSERT_TRUE(comparator(lhs, rhs));
       ASSERT_OK(comparator.status());
+      ASSERT_TRUE(comparator(lhs, rhs)) << "lhs = " << lhs << ", rhs = " << rhs;
     }
   }
 };
@@ -851,15 +1039,15 @@ TEST_P(TestTableSortIndicesRandom, Sort) {
   const auto length = 200;
   std::vector<std::shared_ptr<Array>> columns = {
       Random<UInt8Type>(seed).Generate(length, null_probability),
-      Random<UInt16Type>(seed).Generate(length, null_probability),
+      Random<UInt16Type>(seed).Generate(length, 0.0),
       Random<UInt32Type>(seed).Generate(length, null_probability),
-      Random<UInt64Type>(seed).Generate(length, null_probability),
-      Random<Int8Type>(seed).Generate(length, null_probability),
+      Random<UInt64Type>(seed).Generate(length, 0.0),
+      Random<Int8Type>(seed).Generate(length, 0.0),
       Random<Int16Type>(seed).Generate(length, null_probability),
-      Random<Int32Type>(seed).Generate(length, null_probability),
+      Random<Int32Type>(seed).Generate(length, 0.0),
       Random<Int64Type>(seed).Generate(length, null_probability),
-      Random<FloatType>(seed).Generate(length, null_probability),
-      Random<DoubleType>(seed).Generate(length, null_probability),
+      Random<FloatType>(seed).Generate(length, null_probability, 1 - null_probability),
+      Random<DoubleType>(seed).Generate(length, 0.0, null_probability),
       Random<StringType>(seed).Generate(length, null_probability),
   };
   const auto table = Table::Make(schema(fields), columns, length);
@@ -884,6 +1072,13 @@ TEST_P(TestTableSortIndicesRandom, Sort) {
     ASSERT_OK_AND_ASSIGN(auto offsets, SortIndices(Datum(*chunked_table), options));
     Validate(*table, options, *checked_pointer_cast<UInt64Array>(offsets));
   }
+  // Also validate RecordBatch sorting
+  TableBatchReader reader(*table);
+  RecordBatchVector batches;
+  ASSERT_OK(reader.ReadAll(&batches));
+  ASSERT_EQ(batches.size(), 1);
+  ASSERT_OK_AND_ASSIGN(auto offsets, SortIndices(Datum(*batches[0]), options));
+  Validate(*table, options, *checked_pointer_cast<UInt64Array>(offsets));
 }
 
 INSTANTIATE_TEST_SUITE_P(NoNull, TestTableSortIndicesRandom,
diff --git a/cpp/src/arrow/testing/random.cc b/cpp/src/arrow/testing/random.cc
index 5d25250..92ac5f8 100644
--- a/cpp/src/arrow/testing/random.cc
+++ b/cpp/src/arrow/testing/random.cc
@@ -18,8 +18,10 @@
 #include "arrow/testing/random.h"
 
 #include <algorithm>
+#include <limits>
 #include <memory>
 #include <random>
+#include <type_traits>
 #include <vector>
 
 #include <gtest/gtest.h>
@@ -46,20 +48,48 @@ namespace {
 
 template <typename ValueType, typename DistributionType>
 struct GenerateOptions {
-  GenerateOptions(SeedType seed, ValueType min, ValueType max, double probability)
-      : min_(min), max_(max), seed_(seed), probability_(probability) {}
+  GenerateOptions(SeedType seed, ValueType min, ValueType max, double probability,
+                  double nan_probability = 0.0)
+      : min_(min),
+        max_(max),
+        seed_(seed),
+        probability_(probability),
+        nan_probability_(nan_probability) {}
 
   void GenerateData(uint8_t* buffer, size_t n) {
     GenerateTypedData(reinterpret_cast<ValueType*>(buffer), n);
   }
 
-  void GenerateTypedData(ValueType* data, size_t n) {
+  template <typename V>
+  typename std::enable_if<!std::is_floating_point<V>::value>::type GenerateTypedData(
+      V* data, size_t n) {
+    GenerateTypedDataNoNan(data, n);
+  }
+
+  template <typename V>
+  typename std::enable_if<std::is_floating_point<V>::value>::type GenerateTypedData(
+      V* data, size_t n) {
+    if (nan_probability_ == 0.0) {
+      GenerateTypedDataNoNan(data, n);
+      return;
+    }
     std::default_random_engine rng(seed_++);
     DistributionType dist(min_, max_);
+    std::bernoulli_distribution nan_dist(nan_probability_);
+    const ValueType nan_value = std::numeric_limits<ValueType>::quiet_NaN();
 
     // A static cast is required due to the int16 -> int8 handling.
-    std::generate(data, data + n,
-                  [&dist, &rng] { return static_cast<ValueType>(dist(rng)); });
+    std::generate(data, data + n, [&] {
+      return nan_dist(rng) ? nan_value : static_cast<ValueType>(dist(rng));
+    });
+  }
+
+  void GenerateTypedDataNoNan(ValueType* data, size_t n) {
+    std::default_random_engine rng(seed_++);
+    DistributionType dist(min_, max_);
+
+    // A static cast is required due to the int16 -> int8 handling.
+    std::generate(data, data + n, [&] { return static_cast<ValueType>(dist(rng)); });
   }
 
   void GenerateBitmap(uint8_t* buffer, size_t n, int64_t* null_count) {
@@ -82,6 +112,7 @@ struct GenerateOptions {
   ValueType max_;
   SeedType seed_;
   double probability_;
+  double nan_probability_;
 };
 
 }  // namespace
@@ -170,14 +201,23 @@ PRIMITIVE_RAND_INTEGER_IMPL(Int64, int64_t, Int64Type)
 // Generate 16bit values for half-float
 PRIMITIVE_RAND_INTEGER_IMPL(Float16, int16_t, HalfFloatType)
 
-#define PRIMITIVE_RAND_FLOAT_IMPL(Name, CType, ArrowType) \
-  PRIMITIVE_RAND_IMPL(Name, CType, ArrowType, std::uniform_real_distribution<CType>)
+std::shared_ptr<Array> RandomArrayGenerator::Float32(int64_t size, float min, float max,
+                                                     double null_probability,
+                                                     double nan_probability) {
+  using OptionType = GenerateOptions<float, std::uniform_real_distribution<float>>;
+  OptionType options(seed(), min, max, null_probability, nan_probability);
+  return GenerateNumericArray<FloatType, OptionType>(size, options);
+}
 
-PRIMITIVE_RAND_FLOAT_IMPL(Float32, float, FloatType)
-PRIMITIVE_RAND_FLOAT_IMPL(Float64, double, DoubleType)
+std::shared_ptr<Array> RandomArrayGenerator::Float64(int64_t size, double min, double max,
+                                                     double null_probability,
+                                                     double nan_probability) {
+  using OptionType = GenerateOptions<double, std::uniform_real_distribution<double>>;
+  OptionType options(seed(), min, max, null_probability, nan_probability);
+  return GenerateNumericArray<DoubleType, OptionType>(size, options);
+}
 
 #undef PRIMITIVE_RAND_INTEGER_IMPL
-#undef PRIMITIVE_RAND_FLOAT_IMPL
 #undef PRIMITIVE_RAND_IMPL
 
 template <typename TypeClass>
diff --git a/cpp/src/arrow/testing/random.h b/cpp/src/arrow/testing/random.h
index a1a16c0..f874fae 100644
--- a/cpp/src/arrow/testing/random.h
+++ b/cpp/src/arrow/testing/random.h
@@ -165,10 +165,11 @@ class ARROW_TESTING_EXPORT RandomArrayGenerator {
   /// \param[in] min the lower bound of the uniform distribution
   /// \param[in] max the upper bound of the uniform distribution
   /// \param[in] null_probability the probability of a row being null
+  /// \param[in] nan_probability the probability of a row being NaN
   ///
   /// \return a generated Array
   std::shared_ptr<Array> Float32(int64_t size, float min, float max,
-                                 double null_probability = 0);
+                                 double null_probability = 0, double nan_probability = 0);
 
   /// \brief Generate a random DoubleArray
   ///
@@ -176,10 +177,11 @@ class ARROW_TESTING_EXPORT RandomArrayGenerator {
   /// \param[in] min the lower bound of the uniform distribution
   /// \param[in] max the upper bound of the uniform distribution
   /// \param[in] null_probability the probability of a row being null
+  /// \param[in] nan_probability the probability of a row being NaN
   ///
   /// \return a generated Array
   std::shared_ptr<Array> Float64(int64_t size, double min, double max,
-                                 double null_probability = 0);
+                                 double null_probability = 0, double nan_probability = 0);
 
   template <typename ArrowType, typename CType = typename ArrowType::c_type>
   std::shared_ptr<Array> Numeric(int64_t size, CType min, CType max,