You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ap...@apache.org on 2022/04/18 10:46:29 UTC

[arrow] branch master updated: ARROW-11259: [Python] Allow to create field reference to nested field

This is an automated email from the ASF dual-hosted git repository.

apitrou pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new e0668e2024 ARROW-11259: [Python] Allow to create field reference to nested field
e0668e2024 is described below

commit e0668e2024a7929abf2b0d8b53d2886a2e3a3379
Author: Alenka Frim <fr...@gmail.com>
AuthorDate: Mon Apr 18 12:46:22 2022 +0200

    ARROW-11259: [Python] Allow to create field reference to nested field
    
    This PR tries to redo the work from https://github.com/apache/arrow/pull/9799.
    
    It will unblock:
    - https://issues.apache.org/jira/browse/ARROW-13798
    - https://issues.apache.org/jira/browse/ARROW-14596
    
    cc @jorisvandenbossche @pitrou
    
    Closes #12863 from AlenkaF/ARROW-11259
    
    Lead-authored-by: Alenka Frim <fr...@gmail.com>
    Co-authored-by: Antoine Pitrou <an...@python.org>
    Signed-off-by: Antoine Pitrou <an...@python.org>
---
 cpp/src/arrow/compute/exec/expression.cc      | 48 ++++++++++++++++++----
 cpp/src/arrow/compute/exec/expression_test.cc |  6 +++
 cpp/src/arrow/type.h                          |  5 +++
 python/pyarrow/_compute.pyx                   | 22 ++++++++--
 python/pyarrow/compute.py                     | 38 +++++++++++++++--
 python/pyarrow/includes/libarrow.pxd          |  3 +-
 python/pyarrow/tests/test_compute.py          |  8 +++-
 python/pyarrow/tests/test_dataset.py          | 59 +++++++++++++++++++++------
 8 files changed, 161 insertions(+), 28 deletions(-)

diff --git a/cpp/src/arrow/compute/exec/expression.cc b/cpp/src/arrow/compute/exec/expression.cc
index 4249179e1b..1ef5c6e7b9 100644
--- a/cpp/src/arrow/compute/exec/expression.cc
+++ b/cpp/src/arrow/compute/exec/expression.cc
@@ -993,6 +993,21 @@ Result<std::shared_ptr<Buffer>> Serialize(const Expression& expr) {
       return std::to_string(ret);
     }
 
+    Status VisitFieldRef(const FieldRef& ref) {
+      if (ref.nested_refs()) {
+        metadata_->Append("nested_field_ref", std::to_string(ref.nested_refs()->size()));
+        for (const auto& child : *ref.nested_refs()) {
+          RETURN_NOT_OK(VisitFieldRef(child));
+        }
+        return Status::OK();
+      }
+      if (!ref.name()) {
+        return Status::NotImplemented("Serialization of non-name field_refs");
+      }
+      metadata_->Append("field_ref", *ref.name());
+      return Status::OK();
+    }
+
     Status Visit(const Expression& expr) {
       if (auto lit = expr.literal()) {
         if (!lit->is_scalar()) {
@@ -1004,11 +1019,7 @@ Result<std::shared_ptr<Buffer>> Serialize(const Expression& expr) {
       }
 
       if (auto ref = expr.field_ref()) {
-        if (!ref->name()) {
-          return Status::NotImplemented("Serialization of non-name field_refs");
-        }
-        metadata_->Append("field_ref", *ref->name());
-        return Status::OK();
+        return VisitFieldRef(*ref);
       }
 
       auto call = CallNotNull(expr);
@@ -1067,10 +1078,13 @@ Result<Expression> Deserialize(std::shared_ptr<Buffer> buffer) {
 
     const KeyValueMetadata& metadata() { return *batch_.schema()->metadata(); }
 
+    bool ParseInteger(const std::string& s, int32_t* value) {
+      return ::arrow::internal::ParseValue<Int32Type>(s.data(), s.length(), value);
+    }
+
     Result<std::shared_ptr<Scalar>> GetScalar(const std::string& i) {
       int32_t column_index;
-      if (!::arrow::internal::ParseValue<Int32Type>(i.data(), i.length(),
-                                                    &column_index)) {
+      if (!ParseInteger(i, &column_index)) {
         return Status::Invalid("Couldn't parse column_index");
       }
       if (column_index >= batch_.num_columns()) {
@@ -1093,6 +1107,26 @@ Result<Expression> Deserialize(std::shared_ptr<Buffer> buffer) {
         return literal(std::move(scalar));
       }
 
+      if (key == "nested_field_ref") {
+        int32_t size;
+        if (!ParseInteger(value, &size)) {
+          return Status::Invalid("Couldn't parse nested field ref length");
+        }
+        if (size <= 0) {
+          return Status::Invalid("nested field ref length must be > 0");
+        }
+        std::vector<FieldRef> nested;
+        nested.reserve(size);
+        while (size-- > 0) {
+          ARROW_ASSIGN_OR_RAISE(auto ref, GetOne());
+          if (!ref.field_ref()) {
+            return Status::Invalid("invalid nested field ref");
+          }
+          nested.push_back(*ref.field_ref());
+        }
+        return field_ref(FieldRef(std::move(nested)));
+      }
+
       if (key == "field_ref") {
         return field_ref(value);
       }
diff --git a/cpp/src/arrow/compute/exec/expression_test.cc b/cpp/src/arrow/compute/exec/expression_test.cc
index 30ddef6901..f916bc2a1c 100644
--- a/cpp/src/arrow/compute/exec/expression_test.cc
+++ b/cpp/src/arrow/compute/exec/expression_test.cc
@@ -1376,12 +1376,18 @@ TEST(Expression, SerializationRoundTrips) {
 
   ExpectRoundTrips(field_ref("field"));
 
+  ExpectRoundTrips(field_ref(FieldRef("foo", "bar", "baz")));
+
   ExpectRoundTrips(greater(field_ref("a"), literal(0.25)));
 
   ExpectRoundTrips(
       or_({equal(field_ref("a"), literal(1)), not_equal(field_ref("b"), literal("hello")),
            equal(field_ref("b"), literal("foo bar"))}));
 
+  ExpectRoundTrips(or_({equal(field_ref(FieldRef("a", "b")), literal(1)),
+                        not_equal(field_ref("b"), literal("hello")),
+                        equal(field_ref(FieldRef("c", "d")), literal("foo bar"))}));
+
   ExpectRoundTrips(not_(field_ref("alpha")));
 
   ExpectRoundTrips(call("is_in", {literal(1)},
diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h
index 83dc6fa569..440b95ce59 100644
--- a/cpp/src/arrow/type.h
+++ b/cpp/src/arrow/type.h
@@ -1615,6 +1615,11 @@ class ARROW_EXPORT FieldRef {
   /// Equivalent to a single index string of indices.
   FieldRef(int index) : impl_(FieldPath({index})) {}  // NOLINT runtime/explicit
 
+  /// Construct a nested FieldRef.
+  FieldRef(std::vector<FieldRef> refs) {  // NOLINT runtime/explicit
+    Flatten(std::move(refs));
+  }
+
   /// Convenience constructor for nested FieldRefs: each argument will be used to
   /// construct a FieldRef
   template <typename A0, typename A1, typename... A>
diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx
index 2f18ab9986..73c188edbb 100644
--- a/python/pyarrow/_compute.pyx
+++ b/python/pyarrow/_compute.pyx
@@ -2219,10 +2219,26 @@ cdef class Expression(_Weakrefable):
 
     @staticmethod
     def _field(name_or_idx not None):
-        if isinstance(name_or_idx, str):
-            return Expression.wrap(CMakeFieldExpression(tobytes(name_or_idx)))
-        else:
+        cdef:
+            CFieldRef c_field
+
+        if isinstance(name_or_idx, int):
             return Expression.wrap(CMakeFieldExpressionByIndex(name_or_idx))
+        else:
+            c_field = CFieldRef(<c_string> tobytes(name_or_idx))
+            return Expression.wrap(CMakeFieldExpression(c_field))
+
+    @staticmethod
+    def _nested_field(tuple names not None):
+        cdef:
+            vector[CFieldRef] nested
+
+        if len(names) == 0:
+            raise ValueError("nested field reference should be non-empty")
+        nested.reserve(len(names))
+        for name in names:
+            nested.push_back(CFieldRef(<c_string> tobytes(name)))
+        return Expression.wrap(CMakeFieldExpression(CFieldRef(move(nested))))
 
     @staticmethod
     def _scalar(value):
diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py
index 6cd65123e8..40751eab26 100644
--- a/python/pyarrow/compute.py
+++ b/python/pyarrow/compute.py
@@ -591,22 +591,52 @@ def bottom_k_unstable(values, k, sort_keys=None, *, memory_pool=None):
     return call_function("select_k_unstable", [values], options, memory_pool)
 
 
-def field(name_or_index):
+def field(*name_or_index):
     """Reference a column of the dataset.
 
     Stores only the field's name. Type and other information is known only when
     the expression is bound to a dataset having an explicit scheme.
 
+    Nested references are allowed by passing multiple names or a tuple of
+    names. For example ``('foo', 'bar')`` references the field named "bar"
+    inside the field named "foo".
+
     Parameters
     ----------
-    name_or_index : string or int
-        The name or index of the field the expression references to.
+    *name_or_index : string, multiple strings, tuple or int
+        The name or index of the (possibly nested) field the expression
+        references to.
 
     Returns
     -------
     field_expr : Expression
+
+    Examples
+    --------
+    >>> import pyarrow.compute as pc
+    >>> pc.field("a")
+    <pyarrow.compute.Expression a>
+    >>> pc.field(1)
+    <pyarrow.compute.Expression FieldPath(1)>
+    >>> pc.field(("a", "b"))
+    <pyarrow.compute.Expression FieldRef.Nested(FieldRef.Name(a) ...
+    >>> pc.field("a", "b")
+    <pyarrow.compute.Expression FieldRef.Nested(FieldRef.Name(a) ...
     """
-    return Expression._field(name_or_index)
+    n = len(name_or_index)
+    if n == 1:
+        if isinstance(name_or_index[0], (str, int)):
+            return Expression._field(name_or_index[0])
+        elif isinstance(name_or_index[0], tuple):
+            return Expression._nested_field(name_or_index[0])
+        else:
+            raise TypeError(
+                "field reference should be str, multiple str, tuple or "
+                f"integer, got {type(name_or_index[0])}"
+            )
+    # In case of multiple strings not supplied in a tuple
+    else:
+        return Expression._nested_field(name_or_index)
 
 
 def scalar(value):
diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd
index 34e467fb8e..85b114bc48 100644
--- a/python/pyarrow/includes/libarrow.pxd
+++ b/python/pyarrow/includes/libarrow.pxd
@@ -427,6 +427,7 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil:
         CFieldRef()
         CFieldRef(c_string name)
         CFieldRef(int index)
+        CFieldRef(vector[CFieldRef])
         const c_string* name() const
 
     cdef cppclass CFieldRefHash" arrow::FieldRef::Hash":
@@ -2402,7 +2403,7 @@ cdef extern from "arrow/compute/exec/expression.h" \
         "arrow::compute::literal"(shared_ptr[CScalar] value)
 
     cdef CExpression CMakeFieldExpression \
-        "arrow::compute::field_ref"(c_string name)
+        "arrow::compute::field_ref"(CFieldRef)
 
     cdef CExpression CMakeFieldExpressionByIndex \
         "arrow::compute::field_ref"(int idx)
diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py
index 46d302c214..064572001f 100644
--- a/python/pyarrow/tests/test_compute.py
+++ b/python/pyarrow/tests/test_compute.py
@@ -2652,7 +2652,9 @@ def test_expression_serialization():
                  d.is_valid(), a.cast(pa.int32(), safe=False),
                  a.cast(pa.int32(), safe=False), a.isin([1, 2, 3]),
                  pc.field('i64') > 5, pc.field('i64') == 5,
-                 pc.field('i64') == 7, pc.field('i64').is_null()]
+                 pc.field('i64') == 7, pc.field('i64').is_null(),
+                 pc.field(('foo', 'bar')) == 'value',
+                 pc.field('foo', 'bar') == 'value']
     for expr in all_exprs:
         assert isinstance(expr, pc.Expression)
         restored = pickle.loads(pickle.dumps(expr))
@@ -2666,6 +2668,8 @@ def test_expression_construction():
     false = pc.scalar(False)
     string = pc.scalar("string")
     field = pc.field("field")
+    nested_field = pc.field(("nested", "field"))
+    nested_field2 = pc.field("nested", "field")
 
     zero | one == string
     ~true == false
@@ -2673,6 +2677,8 @@ def test_expression_construction():
         field.cast(typ) == true
 
     field.isin([1, 2])
+    nested_field.isin(["foo", "bar"])
+    nested_field2.isin(["foo", "bar"])
 
     with pytest.raises(TypeError):
         field.isin(1)
diff --git a/python/pyarrow/tests/test_dataset.py b/python/pyarrow/tests/test_dataset.py
index 511f8206cc..9a7f5ea213 100644
--- a/python/pyarrow/tests/test_dataset.py
+++ b/python/pyarrow/tests/test_dataset.py
@@ -102,13 +102,15 @@ def mockfs():
                 list(range(5)),
                 list(map(float, range(5))),
                 list(map(str, range(5))),
-                [i] * 5
+                [i] * 5,
+                [{'a': j % 3, 'b': str(j % 3)} for j in range(5)],
             ]
             schema = pa.schema([
-                pa.field('i64', pa.int64()),
-                pa.field('f64', pa.float64()),
-                pa.field('str', pa.string()),
-                pa.field('const', pa.int64()),
+                ('i64', pa.int64()),
+                ('f64', pa.float64()),
+                ('str', pa.string()),
+                ('const', pa.int64()),
+                ('struct', pa.struct({'a': pa.int64(), 'b': pa.string()})),
             ])
             batch = pa.record_batch(data, schema=schema)
             table = pa.Table.from_batches([batch])
@@ -383,14 +385,41 @@ def test_dataset(dataset, dataset_reader):
     assert len(table) == 10
 
     condition = ds.field('i64') == 1
-    result = dataset.to_table(use_threads=True, filter=condition).to_pydict()
+    result = dataset.to_table(use_threads=True, filter=condition)
+    # Don't rely on the scanning order
+    result = result.sort_by('group').to_pydict()
 
-    # don't rely on the scanning order
     assert result['i64'] == [1, 1]
     assert result['f64'] == [1., 1.]
     assert sorted(result['group']) == [1, 2]
     assert sorted(result['key']) == ['xxx', 'yyy']
 
+    # Filtering on a nested field ref
+    condition = ds.field(('struct', 'b')) == '1'
+    result = dataset.to_table(use_threads=True, filter=condition)
+    result = result.sort_by('group').to_pydict()
+
+    assert result['i64'] == [1, 4, 1, 4]
+    assert result['f64'] == [1.0, 4.0, 1.0, 4.0]
+    assert result['group'] == [1, 1, 2, 2]
+    assert result['key'] == ['xxx', 'xxx', 'yyy', 'yyy']
+
+    # Projecting on a nested field ref expression
+    projection = {
+        'i64': ds.field('i64'),
+        'f64': ds.field('f64'),
+        'new': ds.field(('struct', 'b')) == '1',
+    }
+    result = dataset.to_table(use_threads=True, columns=projection)
+    result = result.sort_by('i64').to_pydict()
+
+    assert list(result) == ['i64', 'f64', 'new']
+    assert result['i64'] == [0, 0, 1, 1, 2, 2, 3, 3, 4, 4]
+    assert result['f64'] == [0.0, 0.0, 1.0, 1.0,
+                             2.0, 2.0, 3.0, 3.0, 4.0, 4.0]
+    assert result['new'] == [False, False, True, True, False, False,
+                             False, False, True, True]
+
 
 @pytest.mark.parquet
 def test_scanner(dataset, dataset_reader):
@@ -808,6 +837,8 @@ def test_filesystem_factory(mockfs, paths_or_selector, pre_buffer):
         pa.field('f64', pa.float64()),
         pa.field('str', pa.dictionary(pa.int32(), pa.string())),
         pa.field('const', pa.int64()),
+        pa.field('struct', pa.struct({'a': pa.int64(),
+                                      'b': pa.string()})),
         pa.field('group', pa.int32()),
         pa.field('key', pa.string()),
     ]), check_metadata=False)
@@ -827,6 +858,8 @@ def test_filesystem_factory(mockfs, paths_or_selector, pre_buffer):
         pa.array([0, 1, 2, 3, 4], type=pa.int32()),
         pa.array("0 1 2 3 4".split(), type=pa.string())
     )
+    expected_struct = pa.array([{'a': i % 3, 'b': str(i % 3)}
+                                for i in range(5)])
     iterator = scanner.scan_batches()
     for (batch, fragment), group, key in zip(iterator, [1, 2], ['xxx', 'yyy']):
         expected_group = pa.array([group] * 5, type=pa.int32())
@@ -834,18 +867,19 @@ def test_filesystem_factory(mockfs, paths_or_selector, pre_buffer):
         expected_const = pa.array([group - 1] * 5, type=pa.int64())
         # Can't compare or really introspect expressions from Python
         assert fragment.partition_expression is not None
-        assert batch.num_columns == 6
+        assert batch.num_columns == 7
         assert batch[0].equals(expected_i64)
         assert batch[1].equals(expected_f64)
         assert batch[2].equals(expected_str)
         assert batch[3].equals(expected_const)
-        assert batch[4].equals(expected_group)
-        assert batch[5].equals(expected_key)
+        assert batch[4].equals(expected_struct)
+        assert batch[5].equals(expected_group)
+        assert batch[6].equals(expected_key)
 
     table = dataset.to_table()
     assert isinstance(table, pa.Table)
     assert len(table) == 10
-    assert table.num_columns == 6
+    assert table.num_columns == 7
 
 
 @pytest.mark.parquet
@@ -1480,6 +1514,7 @@ def test_partitioning_factory(mockfs):
         ("f64", pa.float64()),
         ("str", pa.string()),
         ("const", pa.int64()),
+        ("struct", pa.struct({'a': pa.int64(), 'b': pa.string()})),
         ("group", pa.int32()),
         ("key", pa.string()),
     ])
@@ -2047,7 +2082,7 @@ def test_construct_from_mixed_child_datasets(mockfs):
 
     table = dataset.to_table()
     assert len(table) == 20
-    assert table.num_columns == 4
+    assert table.num_columns == 5
 
     assert len(dataset.children) == 2
     for child in dataset.children: