You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ks...@apache.org on 2018/12/15 21:20:31 UTC

[arrow] branch master updated: ARROW-3230: [Python] Missing comparisons on ChunkedArray, Table

This is an automated email from the ASF dual-hosted git repository.

kszucs pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 2e8cfca  ARROW-3230: [Python] Missing comparisons on ChunkedArray, Table
2e8cfca is described below

commit 2e8cfcac93596fb630310ca975b72a62208381d7
Author: Tanya Schlusser <ta...@tickel.net>
AuthorDate: Sat Dec 15 22:20:12 2018 +0100

    ARROW-3230: [Python] Missing comparisons on ChunkedArray, Table
    
    Add `__eq__` method to `Table`, `Column`, and `ChunkedArray`, plus relevant tests.
    
    Author: Tanya Schlusser <ta...@tickel.net>
    Author: Krisztián Szűcs <sz...@gmail.com>
    
    Closes #3183 from tanyaschlusser/ARROW-3230 and squashes the following commits:
    
    0ea512e0 <Krisztián Szűcs> minor fixes
    2ea12f3c <Tanya Schlusser> Add '__eq__' method to Table, Column, and ChunkedArray and remove '__richcmp__' from Column
    47d24973 <Tanya Schlusser> Add '==' and '!=' tests for Table, Column, and ChunkedArray
---
 python/pyarrow/table.pxi           | 26 ++++++++++++++++++--------
 python/pyarrow/tests/test_table.py |  9 +++++++++
 2 files changed, 27 insertions(+), 8 deletions(-)

diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi
index cf3411d..4d52f26 100644
--- a/python/pyarrow/table.pxi
+++ b/python/pyarrow/table.pxi
@@ -117,6 +117,12 @@ cdef class ChunkedArray:
             else:
                 index -= self.chunked_array.chunk(j).get().length()
 
+    def __eq__(self, other):
+        try:
+            return self.equals(other)
+        except TypeError:
+            return NotImplemented
+
     def equals(self, ChunkedArray other):
         """
         Return whether the contents of two chunked arrays are equal
@@ -411,14 +417,6 @@ cdef class Column:
 
         return result.getvalue()
 
-    def __richcmp__(Column self, Column other, int op):
-        if op == cp.Py_EQ:
-            return self.equals(other)
-        elif op == cp.Py_NE:
-            return not self.equals(other)
-        else:
-            raise TypeError('Invalid comparison')
-
     def __getitem__(self, key):
         return self.data[key]
 
@@ -540,6 +538,12 @@ cdef class Column:
     def __array__(self, dtype=None):
         return self.data.__array__(dtype=dtype)
 
+    def __eq__(self, other):
+        try:
+            return self.equals(other)
+        except TypeError:
+            return NotImplemented
+
     def equals(self, Column other):
         """
         Check if contents of two columns are equal
@@ -1111,6 +1115,12 @@ cdef class Table:
 
         return pyarrow_wrap_table(flattened)
 
+    def __eq__(self, other):
+        try:
+            return self.equals(other)
+        except TypeError:
+            return NotImplemented
+
     def equals(self, Table other):
         """
         Check if contents of two tables are equal
diff --git a/python/pyarrow/tests/test_table.py b/python/pyarrow/tests/test_table.py
index ecbf93b..847b1a4 100644
--- a/python/pyarrow/tests/test_table.py
+++ b/python/pyarrow/tests/test_table.py
@@ -117,6 +117,8 @@ def test_chunked_array_equals():
             y = pa.chunked_array(yarrs)
         assert x.equals(y)
         assert y.equals(x)
+        assert x == y
+        assert x != str(y)
 
     def ne(xarrs, yarrs):
         if isinstance(xarrs, pa.ChunkedArray):
@@ -129,6 +131,7 @@ def test_chunked_array_equals():
             y = pa.chunked_array(yarrs)
         assert not x.equals(y)
         assert not y.equals(x)
+        assert x != y
 
     eq(pa.chunked_array([], type=pa.int32()),
        pa.chunked_array([], type=pa.int32()))
@@ -224,6 +227,9 @@ def test_column_basics():
     assert len(column) == 5
     assert column.shape == (5,)
     assert column.to_pylist() == [-10, -5, 0, 5, 10]
+    assert column == pa.Column.from_array("a", column.data)
+    assert column != pa.Column.from_array("b", column.data)
+    assert column != column.data
 
 
 def test_column_factory_function():
@@ -577,6 +583,9 @@ def test_table_basics():
             col.data.chunk(col.data.num_chunks)
 
     assert table.columns == columns
+    assert table == pa.Table.from_arrays(columns)
+    assert table != pa.Table.from_arrays(columns[1:])
+    assert table != columns
 
 
 def test_table_from_arrays_preserves_column_metadata():