You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ap...@apache.org on 2022/06/16 10:27:31 UTC

[arrow] branch master updated: ARROW-16706: [Python] Expose RankOptions (#13327)

This is an automated email from the ASF dual-hosted git repository.

apitrou pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 56cac6da37 ARROW-16706: [Python] Expose RankOptions (#13327)
56cac6da37 is described below

commit 56cac6da37928b7e7fb9a9a039b8ffdeebe933a9
Author: Raúl Cumplido <ra...@gmail.com>
AuthorDate: Thu Jun 16 12:27:24 2022 +0200

    ARROW-16706: [Python] Expose RankOptions (#13327)
    
    Authored-by: Raúl Cumplido <ra...@gmail.com>
    Signed-off-by: Antoine Pitrou <pi...@free.fr>
---
 python/pyarrow/_compute.pyx          | 62 ++++++++++++++++++++++++++++++++++++
 python/pyarrow/compute.py            |  1 +
 python/pyarrow/includes/libarrow.pxd | 15 +++++++++
 python/pyarrow/tests/test_compute.py | 52 ++++++++++++++++++++++++++++++
 4 files changed, 130 insertions(+)

diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx
index f020a69eff..936fe2f0cd 100644
--- a/python/pyarrow/_compute.pyx
+++ b/python/pyarrow/_compute.pyx
@@ -2060,6 +2060,68 @@ class RandomOptions(_RandomOptions):
         self._set_options(initializer)
 
 
+cdef class _RankOptions(FunctionOptions):
+
+    _tiebreaker_map = {
+        "min": CRankOptionsTiebreaker_Min,
+        "max": CRankOptionsTiebreaker_Max,
+        "first": CRankOptionsTiebreaker_First,
+        "dense": CRankOptionsTiebreaker_Dense,
+    }
+
+    def _set_options(self, sort_keys, null_placement, tiebreaker):
+        cdef vector[CSortKey] c_sort_keys
+        if isinstance(sort_keys, str):
+            c_sort_keys.push_back(
+                CSortKey(tobytes(""), unwrap_sort_order(sort_keys))
+            )
+        else:
+            for name, order in sort_keys:
+                c_sort_keys.push_back(
+                    CSortKey(tobytes(name), unwrap_sort_order(order))
+                )
+        try:
+            self.wrapped.reset(
+                new CRankOptions(c_sort_keys,
+                                 unwrap_null_placement(null_placement),
+                                 self._tiebreaker_map[tiebreaker])
+            )
+        except KeyError:
+            _raise_invalid_function_option(tiebreaker, "tiebreaker")
+
+
+class RankOptions(_RankOptions):
+    """
+    Options for the `rank` function.
+
+    Parameters
+    ----------
+    sort_keys : sequence of (name, order) tuples or str, default "ascending"
+        Names of field/column keys to sort the input on,
+        along with the order each field/column is sorted in.
+        Accepted values for `order` are "ascending", "descending".
+        Alternatively, one can simply pass "ascending" or "descending" as a string
+        if the input is array-like.
+    null_placement : str, default "at_end"
+        Where nulls in input should be sorted.
+        Accepted values are "at_start", "at_end".
+    tiebreaker : str, default "first"
+        Configure how ties between equal values are handled.
+        Accepted values are:
+
+        - "min": Ties get the smallest possible rank in sorted order.
+        - "max": Ties get the largest possible rank in sorted order.
+        - "first": Ranks are assigned in order of when ties appear in the
+                   input. This ensures the ranks are a stable permutation
+                   of the input.
+        - "dense": The ranks span a dense [1, M] interval where M is the
+                   number of distinct values in the input.
+    """
+
+    def __init__(self, sort_keys="ascending", *, null_placement="at_end", tiebreaker="first"):
+        self._set_options(sort_keys, null_placement, tiebreaker)
+
+
 def _group_by(args, keys, aggregations):
     cdef:
         vector[CDatum] c_args
diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py
index f591b95c01..526f0e4f7b 100644
--- a/python/pyarrow/compute.py
+++ b/python/pyarrow/compute.py
@@ -50,6 +50,7 @@ from pyarrow._compute import (  # noqa
     PartitionNthOptions,
     QuantileOptions,
     RandomOptions,
+    RankOptions,
     ReplaceSliceOptions,
     ReplaceSubstringOptions,
     RoundOptions,
diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd
index 39fc130c8a..8597874ea1 100644
--- a/python/pyarrow/includes/libarrow.pxd
+++ b/python/pyarrow/includes/libarrow.pxd
@@ -2344,6 +2344,21 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil:
         @staticmethod
         CRandomOptions FromSeed(uint64_t seed)
 
+    cdef enum CRankOptionsTiebreaker \
+            "arrow::compute::RankOptions::Tiebreaker":
+        CRankOptionsTiebreaker_Min "arrow::compute::RankOptions::Min"
+        CRankOptionsTiebreaker_Max "arrow::compute::RankOptions::Max"
+        CRankOptionsTiebreaker_First "arrow::compute::RankOptions::First"
+        CRankOptionsTiebreaker_Dense "arrow::compute::RankOptions::Dense"
+
+    cdef cppclass CRankOptions \
+            "arrow::compute::RankOptions"(CFunctionOptions):
+        CRankOptions(vector[CSortKey] sort_keys, CNullPlacement,
+                     CRankOptionsTiebreaker tiebreaker)
+        vector[CSortKey] sort_keys
+        CNullPlacement null_placement
+        CRankOptionsTiebreaker tiebreaker
+
     cdef enum DatumType" arrow::Datum::type":
         DatumType_NONE" arrow::Datum::NONE"
         DatumType_SCALAR" arrow::Datum::SCALAR"
diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py
index ae1c2f7712..67857ed6ec 100644
--- a/python/pyarrow/tests/test_compute.py
+++ b/python/pyarrow/tests/test_compute.py
@@ -150,6 +150,8 @@ def test_option_class_equality():
         pc.CumulativeSumOptions(start=0, skip_nulls=False),
         pc.QuantileOptions(),
         pc.RandomOptions(),
+        pc.RankOptions(sort_keys="ascending",
+                       null_placement="at_start", tiebreaker="max"),
         pc.ReplaceSliceOptions(0, 1, "a"),
         pc.ReplaceSubstringOptions("a", "b"),
         pc.RoundOptions(2, "towards_infinity"),
@@ -2718,6 +2720,56 @@ def test_random():
         pc.random(100, initializer=[])
 
 
+@pytest.mark.parametrize(
+    "tiebreaker,expected_values",
+    [("min", [3, 1, 4, 6, 4, 6, 1]),
+     ("max", [3, 2, 5, 7, 5, 7, 2]),
+     ("first", [3, 1, 4, 6, 5, 7, 2]),
+     ("dense", [2, 1, 3, 4, 3, 4, 1])]
+)
+def test_rank_options_tiebreaker(tiebreaker, expected_values):
+    arr = pa.array([1.2, 0.0, 5.3, None, 5.3, None, 0.0])
+    rank_options = pc.RankOptions(sort_keys="ascending",
+                                  null_placement="at_end",
+                                  tiebreaker=tiebreaker)
+    result = pc.rank(arr, options=rank_options)
+    expected = pa.array(expected_values, type=pa.uint64())
+    assert result.equals(expected)
+
+
+def test_rank_options():
+    arr = pa.array([1.2, 0.0, 5.3, None, 5.3, None, 0.0])
+    expected = pa.array([3, 1, 4, 6, 5, 7, 2], type=pa.uint64())
+
+    # Ensure rank can be called without specifying options
+    result = pc.rank(arr)
+    assert result.equals(expected)
+
+    # Ensure default RankOptions
+    result = pc.rank(arr, options=pc.RankOptions())
+    assert result.equals(expected)
+
+    # Ensure sort_keys tuple usage
+    result = pc.rank(arr, options=pc.RankOptions(
+        sort_keys=[("b", "ascending")])
+    )
+    assert result.equals(expected)
+
+    result = pc.rank(arr, null_placement="at_start")
+    expected_at_start = pa.array([5, 3, 6, 1, 7, 2, 4], type=pa.uint64())
+    assert result.equals(expected_at_start)
+
+    result = pc.rank(arr, sort_keys="descending")
+    expected_descending = pa.array([3, 4, 1, 6, 2, 7, 5], type=pa.uint64())
+    assert result.equals(expected_descending)
+
+    with pytest.raises(ValueError,
+                       match=r'"NonExisting" is not a valid tiebreaker'):
+        pc.RankOptions(sort_keys="descending",
+                       null_placement="at_end",
+                       tiebreaker="NonExisting")
+
+
 def test_expression_serialization():
     a = pc.scalar(1)
     b = pc.scalar(1.1)