You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by GitBox <gi...@apache.org> on 2020/12/11 01:04:08 UTC

[GitHub] [arrow] kou commented on a change in pull request #8890: ARROW-10796: [C++] Implement optimized RecordBatch sorting

kou commented on a change in pull request #8890:
URL: https://github.com/apache/arrow/pull/8890#discussion_r540613034



##########
File path: cpp/src/arrow/compute/kernels/vector_sort.cc
##########
@@ -834,6 +1355,10 @@ class TableRadixSorter {
 // Sort a table using a single sort and multiple-key comparisons.
 class MultipleKeyTableSorter : public TypeVisitor {
  private:
+  // TODO instead of resolving chunks for each column independently, we could
+  // split the table into RecordBatches and pay the cost of chunked indexing
+  // at the first column only.

Review comment:
       Can we always do it?
   My understanding that each chunked array in a table can have different number of chunks. For example, the table is valid:
   
   ```text
   a: [[0, 1], [2, 3, 4]]
   b: [[10], [11, 12], [13], [14]]
   ```
   
   I'm not sure we can split the table into record batches efficiently.

##########
File path: cpp/src/arrow/compute/kernels/vector_sort.cc
##########
@@ -804,6 +814,517 @@ class ChunkedArraySorter : public TypeVisitor {
   ExecContext* ctx_;
 };
 
+// ----------------------------------------------------------------------
+// Record batch sorting implementation(s)
+
+// Visit contiguous ranges of equal values.  All entries are assumed
+// to be non-null.
+template <typename ArrayType, typename Visitor>
+void VisitConstantRanges(const ArrayType& array, uint64_t* indices_begin,
+                         uint64_t* indices_end, Visitor&& visit) {
+  if (indices_begin == indices_end) {
+    return;
+  }
+  auto range_start = indices_begin;
+  auto range_cur = range_start;
+  auto last_value = array.GetView(*range_cur);
+  while (++range_cur != indices_end) {
+    auto v = array.GetView(*range_cur);
+    if (v != last_value) {
+      visit(range_start, range_cur);
+      range_start = range_cur;
+      last_value = v;
+    }
+  }
+  if (range_start != range_cur) {
+    visit(range_start, range_cur);
+  }
+}
+
+// A sorter for a single column of a RecordBatch, deferring to the next column
+// for ranges of equal values.
+class RecordBatchColumnSorter {
+ public:
+  explicit RecordBatchColumnSorter(RecordBatchColumnSorter* next_column = nullptr)
+      : next_column_(next_column) {}
+  virtual ~RecordBatchColumnSorter() {}
+
+  virtual void SortRange(uint64_t* indices_begin, uint64_t* indices_end) = 0;
+
+ protected:
+  RecordBatchColumnSorter* next_column_;
+};
+
+template <typename Type>
+class ConcreteRecordBatchColumnSorter : public RecordBatchColumnSorter {
+ public:
+  using ArrayType = typename TypeTraits<Type>::ArrayType;
+
+  ConcreteRecordBatchColumnSorter(std::shared_ptr<Array> array, SortOrder order,
+                                  RecordBatchColumnSorter* next_column = nullptr)
+      : RecordBatchColumnSorter(next_column),
+        owned_array_(std::move(array)),
+        array_(checked_cast<const ArrayType&>(*owned_array_)),
+        order_(order),
+        null_count_(array_.null_count()) {}
+
+  void SortRange(uint64_t* indices_begin, uint64_t* indices_end) {
+    constexpr int64_t offset = 0;
+    uint64_t* nulls_begin;
+    if (null_count_ == 0) {
+      nulls_begin = indices_end;
+    } else {
+      // NOTE that null_count_ is merely an upper bound on the number of nulls
+      // in this particular range.
+      nulls_begin = PartitionNullsOnly<StablePartitioner>(indices_begin, indices_end,
+                                                          array_, offset);
+      DCHECK_LE(indices_end - nulls_begin, null_count_);
+    }
+    uint64_t* null_likes_begin = PartitionNullLikes<ArrayType, StablePartitioner>(
+        indices_begin, nulls_begin, array_, offset);
+
+    // TODO This is roughly the same as ArrayCompareSorter.
+    // Also, we would like to use a counting sort if possible.  This requires
+    // a counting sort compatible with indirect indexing.
+    if (order_ == SortOrder::Ascending) {
+      std::stable_sort(
+          indices_begin, null_likes_begin, [&](uint64_t left, uint64_t right) {
+            return array_.GetView(left - offset) < array_.GetView(right - offset);
+          });
+    } else {
+      std::stable_sort(
+          indices_begin, null_likes_begin, [&](uint64_t left, uint64_t right) {
+            // We don't use 'left > right' here to reduce required operator.
+            // If we use 'right < left' here, '<' is only required.
+            return array_.GetView(right - offset) < array_.GetView(left - offset);
+          });
+    }
+
+    if (next_column_ != nullptr) {
+      // Visit all ranges of equal values in this column and sort them on
+      // the next column.
+      SortNextColumn(null_likes_begin, nulls_begin);
+      SortNextColumn(nulls_begin, indices_end);
+      VisitConstantRanges(array_, indices_begin, null_likes_begin,
+                          [&](uint64_t* range_start, uint64_t* range_end) {
+                            SortNextColumn(range_start, range_end);
+                          });
+    }
+  }
+
+  void SortNextColumn(uint64_t* indices_begin, uint64_t* indices_end) {
+    // Avoid the cost of a virtual method call in trivial cases
+    if (indices_end - indices_begin > 1) {
+      next_column_->SortRange(indices_begin, indices_end);
+    }
+  }
+
+ protected:
+  const std::shared_ptr<Array> owned_array_;
+  const ArrayType& array_;
+  const SortOrder order_;
+  const int64_t null_count_;
+};
+
+// Sort a batch using a single-pass left-to-right radix sort.
+class RadixRecordBatchSorter {
+ public:
+  RadixRecordBatchSorter(uint64_t* indices_begin, uint64_t* indices_end,
+                         const RecordBatch& batch, const SortOptions& options)
+      : batch_(batch),
+        options_(options),
+        indices_begin_(indices_begin),
+        indices_end_(indices_end) {}
+
+  Status Sort() {
+    ARROW_ASSIGN_OR_RAISE(const auto sort_keys,
+                          ResolveSortKeys(batch_, options_.sort_keys));
+
+    // Create column sorters from right to left
+    std::vector<std::unique_ptr<RecordBatchColumnSorter>> column_sorts(sort_keys.size());
+    RecordBatchColumnSorter* next_column = nullptr;
+    for (int64_t i = static_cast<int64_t>(sort_keys.size() - 1); i >= 0; --i) {
+      ColumnSortFactory factory(sort_keys[i], next_column);
+      ARROW_ASSIGN_OR_RAISE(column_sorts[i], factory.MakeColumnSort());
+      next_column = column_sorts[i].get();
+    }
+
+    // Sort from left to right
+    column_sorts.front()->SortRange(indices_begin_, indices_end_);
+    return Status::OK();
+  }
+
+ protected:
+  struct ResolvedSortKey {
+    std::shared_ptr<Array> array;
+    SortOrder order;
+  };
+
+  struct ColumnSortFactory {
+    ColumnSortFactory(const ResolvedSortKey& sort_key,
+                      RecordBatchColumnSorter* next_column)
+        : physical_type(GetPhysicalType(sort_key.array->type())),
+          array(GetPhysicalArray(*sort_key.array, physical_type)),
+          order(sort_key.order),
+          next_column(next_column) {}
+
+    Result<std::unique_ptr<RecordBatchColumnSorter>> MakeColumnSort() {
+      RETURN_NOT_OK(VisitTypeInline(*physical_type, this));
+      DCHECK_NE(result, nullptr);
+      return std::move(result);
+    }
+
+#define VISIT(TYPE) \
+  Status Visit(const TYPE& type) { return VisitGeneric(type); }
+
+    VISIT_PHYSICAL_TYPES(VISIT)
+
+#undef VISIT
+
+    Status Visit(const DataType& type) {
+      return Status::TypeError("Unsupported type for RecordBatch sorting: ",
+                               type.ToString());
+    }
+
+    template <typename Type>
+    Status VisitGeneric(const Type&) {
+      result.reset(new ConcreteRecordBatchColumnSorter<Type>(array, order, next_column));
+      return Status::OK();
+    }
+
+    std::shared_ptr<DataType> physical_type;
+    std::shared_ptr<Array> array;
+    SortOrder order;
+    RecordBatchColumnSorter* next_column;
+    std::unique_ptr<RecordBatchColumnSorter> result;
+  };
+
+  static Result<std::vector<ResolvedSortKey>> ResolveSortKeys(
+      const RecordBatch& batch, const std::vector<SortKey>& sort_keys) {
+    std::vector<ResolvedSortKey> resolved;
+    resolved.reserve(sort_keys.size());
+    for (const auto& sort_key : sort_keys) {
+      auto array = batch.GetColumnByName(sort_key.name);
+      if (!array) {
+        return Status::Invalid("Nonexistent sort key column: ", sort_key.name);
+      }
+      resolved.push_back({std::move(array), sort_key.order});
+    }
+    return resolved;
+  }
+
+  const RecordBatch& batch_;
+  const SortOptions& options_;
+  uint64_t* indices_begin_;
+  uint64_t* indices_end_;
+};
+
+// Compare two records in the same RecordBatch or Table
+// (indexing is handled through ResolvedSortKey)
+template <typename ResolvedSortKey>
+class MultipleKeyComparator {
+ public:
+  explicit MultipleKeyComparator(const std::vector<ResolvedSortKey>& sort_keys)
+      : sort_keys_(sort_keys) {}
+
+  Status status() const { return status_; }
+
+  // Returns true if the left-th value should be ordered before the
+  // right-th value, false otherwise. The start_sort_key_index-th
+  // sort key and subsequent sort keys are used for comparison.
+  bool Compare(uint64_t left, uint64_t right, size_t start_sort_key_index) {
+    current_left_ = left;
+    current_right_ = right;
+    current_compared_ = 0;
+    auto num_sort_keys = sort_keys_.size();
+    for (size_t i = start_sort_key_index; i < num_sort_keys; ++i) {
+      current_sort_key_index_ = i;
+      status_ = VisitTypeInline(*sort_keys_[i].type, this);
+      // If the left value equals to the right value, we need to
+      // continue to sort.
+      if (current_compared_ != 0) {
+        break;
+      }
+    }
+    return current_compared_ < 0;
+  }
+
+#define VISIT(TYPE)                          \
+  Status Visit(const TYPE& type) {           \
+    current_compared_ = CompareType<TYPE>(); \
+    return Status::OK();                     \
+  }
+
+  VISIT_PHYSICAL_TYPES(VISIT)
+
+#undef VISIT
+
+  Status Visit(const DataType& type) {
+    return Status::TypeError("Unsupported type for RecordBatch sorting: ",
+                             type.ToString());
+  }
+
+ private:
+  // Compares two records in the same table and returns -1, 0 or 1.
+  //
+  // -1: The left is less than the right.
+  // 0: The left equals to the right.
+  // 1: The left is greater than the right.
+  //
+  // This supports null and NaN. Null is processed in this and NaN
+  // is processed in CompareTypeValue().
+  template <typename Type>
+  int32_t CompareType() {
+    using ArrayType = typename TypeTraits<Type>::ArrayType;
+    const auto& sort_key = sort_keys_[current_sort_key_index_];
+    auto order = sort_key.order;
+    const auto chunk_left = sort_key.template GetChunk<ArrayType>(current_left_);
+    const auto chunk_right = sort_key.template GetChunk<ArrayType>(current_right_);
+    if (sort_key.null_count > 0) {
+      auto is_null_left = chunk_left.IsNull();
+      auto is_null_right = chunk_right.IsNull();
+      if (is_null_left && is_null_right) {
+        return 0;
+      } else if (is_null_left) {
+        return 1;
+      } else if (is_null_right) {
+        return -1;
+      }
+    }
+    return CompareTypeValue<Type>(chunk_left, chunk_right, order);
+  }
+
+  // For non-float types. Value is never NaN.
+  template <typename Type>
+  enable_if_t<!is_floating_type<Type>::value, int32_t> CompareTypeValue(
+      const ResolvedChunk<typename TypeTraits<Type>::ArrayType>& chunk_left,
+      const ResolvedChunk<typename TypeTraits<Type>::ArrayType>& chunk_right,
+      const SortOrder order) {
+    const auto left = chunk_left.GetView();
+    const auto right = chunk_right.GetView();
+    int32_t compared;
+    if (left == right) {
+      compared = 0;
+    } else if (left > right) {
+      compared = 1;
+    } else {
+      compared = -1;
+    }
+    if (order == SortOrder::Descending) {
+      compared = -compared;
+    }
+    return compared;
+  }
+
+  // For float types. Value may be NaN.
+  template <typename Type>
+  enable_if_t<is_floating_type<Type>::value, int32_t> CompareTypeValue(
+      const ResolvedChunk<typename TypeTraits<Type>::ArrayType>& chunk_left,
+      const ResolvedChunk<typename TypeTraits<Type>::ArrayType>& chunk_right,
+      const SortOrder order) {
+    const auto left = chunk_left.GetView();
+    const auto right = chunk_right.GetView();
+    auto is_nan_left = std::isnan(left);
+    auto is_nan_right = std::isnan(right);
+    if (is_nan_left && is_nan_right) {
+      return 0;
+    } else if (is_nan_left) {
+      return 1;
+    } else if (is_nan_right) {
+      return -1;
+    }
+    int32_t compared;
+    if (left == right) {
+      compared = 0;
+    } else if (left > right) {
+      compared = 1;
+    } else {
+      compared = -1;
+    }
+    if (order == SortOrder::Descending) {
+      compared = -compared;
+    }
+    return compared;
+  }
+
+  const std::vector<ResolvedSortKey>& sort_keys_;
+  Status status_;
+  int64_t current_left_;
+  int64_t current_right_;
+  size_t current_sort_key_index_;
+  int32_t current_compared_;
+};
+
+// Sort a batch using a single sort and multiple-key comparisons.
+class MultipleKeyRecordBatchSorter : public TypeVisitor {
+ private:
+  // Preprocessed sort key.
+  struct ResolvedSortKey {
+    ResolvedSortKey(const std::shared_ptr<Array>& array, const SortOrder order)
+        : type(GetPhysicalType(array->type())),
+          owned_array(GetPhysicalArray(*array, type)),
+          array(*owned_array),
+          order(order),
+          null_count(array->null_count()) {}
+
+    template <typename ArrayType>
+    ResolvedChunk<ArrayType> GetChunk(int64_t index) const {
+      return {&checked_cast<const ArrayType&>(array), index};
+    }
+
+    const std::shared_ptr<DataType> type;
+    std::shared_ptr<Array> owned_array;
+    const Array& array;
+    SortOrder order;
+    int64_t null_count;
+  };
+
+  using Comparator = MultipleKeyComparator<ResolvedSortKey>;
+
+ public:
+  MultipleKeyRecordBatchSorter(uint64_t* indices_begin, uint64_t* indices_end,
+                               const RecordBatch& batch, const SortOptions& options)
+      : indices_begin_(indices_begin),
+        indices_end_(indices_end),
+        sort_keys_(ResolveSortKeys(batch, options.sort_keys, &status_)),
+        comparator_(sort_keys_) {}
+
+  // This is optimized for the first sort key. The first sort key sort
+  // is processed in this class. The second and following sort keys
+  // are processed in Comparator.
+  Status Sort() {
+    RETURN_NOT_OK(status_);
+    return sort_keys_[0].type->Accept(this);
+  }
+
+#define VISIT(TYPE) \
+  Status Visit(const TYPE& type) override { return SortInternal<TYPE>(); }
+
+  VISIT_PHYSICAL_TYPES(VISIT)
+
+#undef VISIT
+
+ private:
+  static std::vector<ResolvedSortKey> ResolveSortKeys(
+      const RecordBatch& batch, const std::vector<SortKey>& sort_keys, Status* status) {
+    std::vector<ResolvedSortKey> resolved;
+    for (const auto& sort_key : sort_keys) {
+      auto array = batch.GetColumnByName(sort_key.name);
+      if (!array) {
+        *status = Status::Invalid("Nonexistent sort key column: ", sort_key.name);
+        break;
+      }
+      resolved.emplace_back(array, sort_key.order);
+    }
+    return resolved;
+  }
+
+  template <typename Type>
+  Status SortInternal() {
+    using ArrayType = typename TypeTraits<Type>::ArrayType;
+
+    auto& comparator = comparator_;
+    const auto& first_sort_key = sort_keys_[0];
+    const ArrayType& array = checked_cast<const ArrayType&>(first_sort_key.array);
+    auto nulls_begin = indices_end_;
+    nulls_begin = PartitionNullsInternal<Type>(first_sort_key);
+    // Sort first-key non-nulls
+    std::stable_sort(indices_begin_, nulls_begin, [&](uint64_t left, uint64_t right) {
+      // Both values are never null nor NaN
+      // (otherwise they've been partitioned away above).
+      const auto value_left = array.GetView(left);
+      const auto value_right = array.GetView(right);
+      if (value_left != value_right) {
+        bool compared = value_left < value_right;
+        if (first_sort_key.order == SortOrder::Ascending) {
+          return compared;
+        } else {
+          return !compared;
+        }
+      }
+      // If the left value equals to the right value,
+      // we need to compare the second and following
+      // sort keys.
+      return comparator.Compare(left, right, 1);
+    });
+    return comparator_.status();
+  }
+
+  // Behaves like PatitionNulls() but this supports multiple sort keys.
+  //
+  // For non-float types.
+  template <typename Type>
+  enable_if_t<!is_floating_type<Type>::value, uint64_t*> PartitionNullsInternal(
+      const ResolvedSortKey& first_sort_key) {
+    using ArrayType = typename TypeTraits<Type>::ArrayType;
+    if (first_sort_key.null_count == 0) {
+      return indices_end_;
+    }
+    const ArrayType& array = checked_cast<const ArrayType&>(first_sort_key.array);
+    StablePartitioner partitioner;
+    auto nulls_begin = partitioner(indices_begin_, indices_end_,
+                                   [&](uint64_t index) { return !array.IsNull(index); });
+    // Sort all nulls by second and following sort keys
+    // TODO: could we instead run an independent sort from the second key on
+    // this slice?

Review comment:
       Like `ConcreteRecordBatchColumnSorter`'s `next_column_`?
   It would work.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org