You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by we...@apache.org on 2020/06/01 15:13:59 UTC

[arrow] branch master updated: ARROW-5854: [Python] Expose compare kernels on Array class

This is an automated email from the ASF dual-hosted git repository.

wesm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new b2287a2  ARROW-5854: [Python] Expose compare kernels on Array class
b2287a2 is described below

commit b2287a20f2304df45153683e5c9b668d315f271a
Author: Joris Van den Bossche <jo...@gmail.com>
AuthorDate: Mon Jun 1 10:13:25 2020 -0500

    ARROW-5854: [Python] Expose compare kernels on Array class
    
    Closes #7273 from jorisvandenbossche/ARROW-5854-array-compare
    
    Authored-by: Joris Van den Bossche <jo...@gmail.com>
    Signed-off-by: Wes McKinney <we...@apache.org>
---
 python/pyarrow/array.pxi             | 29 ++++++++++++--
 python/pyarrow/table.pxi             | 10 ++---
 python/pyarrow/tests/test_array.py   | 10 -----
 python/pyarrow/tests/test_compute.py | 75 ++++++++++++++++++++++++++++++++++++
 python/pyarrow/tests/test_table.py   |  3 --
 5 files changed, 104 insertions(+), 23 deletions(-)

diff --git a/python/pyarrow/array.pxi b/python/pyarrow/array.pxi
index c25bbbe2..c575819 100644
--- a/python/pyarrow/array.pxi
+++ b/python/pyarrow/array.pxi
@@ -17,6 +17,27 @@
 
 import warnings
 
+from cpython.object cimport Py_LT, Py_EQ, Py_GT, Py_LE, Py_NE, Py_GE
+
+
+cdef str _op_to_function_name(int op):
+    cdef str function_name
+
+    if op == Py_EQ:
+        function_name = "equal"
+    elif op == Py_NE:
+        function_name = "not_equal"
+    elif op == Py_GT:
+        function_name = "greater"
+    elif op == Py_GE:
+        function_name = "greater_equal"
+    elif op == Py_LT:
+        function_name = "less"
+    elif op == Py_LE:
+        function_name = "less_equal"
+
+    return function_name
+
 
 cdef _sequence_to_array(object sequence, object mask, object size,
                         DataType type, CMemoryPool* pool, c_bool from_pandas):
@@ -602,14 +623,14 @@ cdef class Array(_PandasConvertible):
         self.ap = sp_array.get()
         self.type = pyarrow_wrap_data_type(self.sp_array.get().type())
 
-    def __eq__(self, other):
-        raise NotImplementedError('Comparisons with pyarrow.Array are not '
-                                  'implemented')
-
     def _debug_print(self):
         with nogil:
             check_status(DebugPrint(deref(self.ap), 0))
 
+    def __richcmp__(self, other, int op):
+        function_name = _op_to_function_name(op)
+        return _pc().call_function(function_name, [self, other])
+
     def diff(self, Array other):
         """
         Compare contents of this array against another one.
diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi
index 043da82..afcece5 100644
--- a/python/pyarrow/table.pxi
+++ b/python/pyarrow/table.pxi
@@ -41,6 +41,10 @@ cdef class ChunkedArray(_PandasConvertible):
     def __reduce__(self):
         return chunked_array, (self.chunks, self.type)
 
+    def __richcmp__(self, other, int op):
+        function_name = _op_to_function_name(op)
+        return _pc().call_function(function_name, [self, other])
+
     @property
     def data(self):
         import warnings
@@ -173,12 +177,6 @@ cdef class ChunkedArray(_PandasConvertible):
             else:
                 index -= self.chunked_array.chunk(j).get().length()
 
-    def __eq__(self, other):
-        try:
-            return self.equals(other)
-        except TypeError:
-            return NotImplemented
-
     def equals(self, ChunkedArray other):
         """
         Return whether the contents of two chunked arrays are equal.
diff --git a/python/pyarrow/tests/test_array.py b/python/pyarrow/tests/test_array.py
index 9f63a35..46f7b2e 100644
--- a/python/pyarrow/tests/test_array.py
+++ b/python/pyarrow/tests/test_array.py
@@ -395,16 +395,6 @@ def test_array_ref_to_ndarray_base():
     assert sys.getrefcount(arr) == (refcount + 1)
 
 
-def test_array_eq_raises():
-    # ARROW-2150: we are raising when comparing arrays until we define the
-    # behavior to either be elementwise comparisons or data equality
-    arr1 = pa.array([1, 2, 3], type=pa.int32())
-    arr2 = pa.array([1, 2, 3], type=pa.int32())
-
-    with pytest.raises(NotImplementedError):
-        arr1 == arr2
-
-
 def test_array_from_buffers():
     values_buf = pa.py_buffer(np.int16([4, 5, 6, 7]))
     nulls_buf = pa.py_buffer(np.uint8([0b00001101]))
diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py
index 09c4d02..fd262e1 100644
--- a/python/pyarrow/tests/test_compute.py
+++ b/python/pyarrow/tests/test_compute.py
@@ -238,3 +238,78 @@ def test_filter_errors():
         with pytest.raises(pa.ArrowInvalid,
                            match="must all be the same length"):
             obj.filter(mask)
+
+
+@pytest.mark.parametrize("typ", ["array", "chunked_array"])
+def test_compare_array(typ):
+    if typ == "array":
+        def con(values): return pa.array(values)
+    else:
+        def con(values): return pa.chunked_array([values])
+
+    arr1 = con([1, 2, 3, 4, None])
+    arr2 = con([1, 1, 4, None, 4])
+
+    result = arr1 == arr2
+    assert result.equals(con([True, False, False, None, None]))
+
+    result = arr1 != arr2
+    assert result.equals(con([False, True, True, None, None]))
+
+    result = arr1 < arr2
+    assert result.equals(con([False, False, True, None, None]))
+
+    result = arr1 <= arr2
+    assert result.equals(con([True, False, True, None, None]))
+
+    result = arr1 > arr2
+    assert result.equals(con([False, True, False, None, None]))
+
+    result = arr1 >= arr2
+    assert result.equals(con([True, True, False, None, None]))
+
+
+@pytest.mark.parametrize("typ", ["array", "chunked_array"])
+def test_compare_scalar(typ):
+    if typ == "array":
+        def con(values): return pa.array(values)
+    else:
+        def con(values): return pa.chunked_array([values])
+
+    arr = con([1, 2, 3, None])
+    # TODO this is a hacky way to construct a scalar ..
+    scalar = pa.array([2]).sum()
+
+    result = arr == scalar
+    assert result.equals(con([False, True, False, None]))
+
+    result = arr != scalar
+    assert result.equals(con([True, False, True, None]))
+
+    result = arr < scalar
+    assert result.equals(con([True, False, False, None]))
+
+    result = arr <= scalar
+    assert result.equals(con([True, True, False, None]))
+
+    result = arr > scalar
+    assert result.equals(con([False, False, True, None]))
+
+    result = arr >= scalar
+    assert result.equals(con([False, True, True, None]))
+
+
+def test_compare_chunked_array_mixed():
+
+    arr = pa.array([1, 2, 3, 4, None])
+    arr_chunked = pa.chunked_array([[1, 2, 3], [4, None]])
+    arr_chunked2 = pa.chunked_array([[1, 2], [3, 4, None]])
+
+    expected = pa.chunked_array([[True, True, True, True, None]])
+
+    for result in [
+        arr == arr_chunked,
+        arr_chunked == arr,
+        arr_chunked == arr_chunked2,
+    ]:
+        assert result.equals(expected)
diff --git a/python/pyarrow/tests/test_table.py b/python/pyarrow/tests/test_table.py
index 491cca0..69ee405 100644
--- a/python/pyarrow/tests/test_table.py
+++ b/python/pyarrow/tests/test_table.py
@@ -128,8 +128,6 @@ def test_chunked_array_equals():
             y = pa.chunked_array(yarrs)
         assert x.equals(y)
         assert y.equals(x)
-        assert x == y
-        assert x != str(y)
 
     def ne(xarrs, yarrs):
         if isinstance(xarrs, pa.ChunkedArray):
@@ -142,7 +140,6 @@ def test_chunked_array_equals():
             y = pa.chunked_array(yarrs)
         assert not x.equals(y)
         assert not y.equals(x)
-        assert x != y
 
     eq(pa.chunked_array([], type=pa.int32()),
        pa.chunked_array([], type=pa.int32()))