You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ap...@apache.org on 2022/04/18 10:46:29 UTC
[arrow] branch master updated: ARROW-11259: [Python] Allow to create field reference to nested field
This is an automated email from the ASF dual-hosted git repository.
apitrou pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new e0668e2024 ARROW-11259: [Python] Allow to create field reference to nested field
e0668e2024 is described below
commit e0668e2024a7929abf2b0d8b53d2886a2e3a3379
Author: Alenka Frim <fr...@gmail.com>
AuthorDate: Mon Apr 18 12:46:22 2022 +0200
ARROW-11259: [Python] Allow to create field reference to nested field
This PR tries to redo the work from https://github.com/apache/arrow/pull/9799.
It will unblock:
- https://issues.apache.org/jira/browse/ARROW-13798
- https://issues.apache.org/jira/browse/ARROW-14596
cc @jorisvandenbossche @pitrou
Closes #12863 from AlenkaF/ARROW-11259
Lead-authored-by: Alenka Frim <fr...@gmail.com>
Co-authored-by: Antoine Pitrou <an...@python.org>
Signed-off-by: Antoine Pitrou <an...@python.org>
---
cpp/src/arrow/compute/exec/expression.cc | 48 ++++++++++++++++++----
cpp/src/arrow/compute/exec/expression_test.cc | 6 +++
cpp/src/arrow/type.h | 5 +++
python/pyarrow/_compute.pyx | 22 ++++++++--
python/pyarrow/compute.py | 38 +++++++++++++++--
python/pyarrow/includes/libarrow.pxd | 3 +-
python/pyarrow/tests/test_compute.py | 8 +++-
python/pyarrow/tests/test_dataset.py | 59 +++++++++++++++++++++------
8 files changed, 161 insertions(+), 28 deletions(-)
diff --git a/cpp/src/arrow/compute/exec/expression.cc b/cpp/src/arrow/compute/exec/expression.cc
index 4249179e1b..1ef5c6e7b9 100644
--- a/cpp/src/arrow/compute/exec/expression.cc
+++ b/cpp/src/arrow/compute/exec/expression.cc
@@ -993,6 +993,21 @@ Result<std::shared_ptr<Buffer>> Serialize(const Expression& expr) {
return std::to_string(ret);
}
+ Status VisitFieldRef(const FieldRef& ref) {
+ if (ref.nested_refs()) {
+ metadata_->Append("nested_field_ref", std::to_string(ref.nested_refs()->size()));
+ for (const auto& child : *ref.nested_refs()) {
+ RETURN_NOT_OK(VisitFieldRef(child));
+ }
+ return Status::OK();
+ }
+ if (!ref.name()) {
+ return Status::NotImplemented("Serialization of non-name field_refs");
+ }
+ metadata_->Append("field_ref", *ref.name());
+ return Status::OK();
+ }
+
Status Visit(const Expression& expr) {
if (auto lit = expr.literal()) {
if (!lit->is_scalar()) {
@@ -1004,11 +1019,7 @@ Result<std::shared_ptr<Buffer>> Serialize(const Expression& expr) {
}
if (auto ref = expr.field_ref()) {
- if (!ref->name()) {
- return Status::NotImplemented("Serialization of non-name field_refs");
- }
- metadata_->Append("field_ref", *ref->name());
- return Status::OK();
+ return VisitFieldRef(*ref);
}
auto call = CallNotNull(expr);
@@ -1067,10 +1078,13 @@ Result<Expression> Deserialize(std::shared_ptr<Buffer> buffer) {
const KeyValueMetadata& metadata() { return *batch_.schema()->metadata(); }
+ bool ParseInteger(const std::string& s, int32_t* value) {
+ return ::arrow::internal::ParseValue<Int32Type>(s.data(), s.length(), value);
+ }
+
Result<std::shared_ptr<Scalar>> GetScalar(const std::string& i) {
int32_t column_index;
- if (!::arrow::internal::ParseValue<Int32Type>(i.data(), i.length(),
- &column_index)) {
+ if (!ParseInteger(i, &column_index)) {
return Status::Invalid("Couldn't parse column_index");
}
if (column_index >= batch_.num_columns()) {
@@ -1093,6 +1107,26 @@ Result<Expression> Deserialize(std::shared_ptr<Buffer> buffer) {
return literal(std::move(scalar));
}
+ if (key == "nested_field_ref") {
+ int32_t size;
+ if (!ParseInteger(value, &size)) {
+ return Status::Invalid("Couldn't parse nested field ref length");
+ }
+ if (size <= 0) {
+ return Status::Invalid("nested field ref length must be > 0");
+ }
+ std::vector<FieldRef> nested;
+ nested.reserve(size);
+ while (size-- > 0) {
+ ARROW_ASSIGN_OR_RAISE(auto ref, GetOne());
+ if (!ref.field_ref()) {
+ return Status::Invalid("invalid nested field ref");
+ }
+ nested.push_back(*ref.field_ref());
+ }
+ return field_ref(FieldRef(std::move(nested)));
+ }
+
if (key == "field_ref") {
return field_ref(value);
}
diff --git a/cpp/src/arrow/compute/exec/expression_test.cc b/cpp/src/arrow/compute/exec/expression_test.cc
index 30ddef6901..f916bc2a1c 100644
--- a/cpp/src/arrow/compute/exec/expression_test.cc
+++ b/cpp/src/arrow/compute/exec/expression_test.cc
@@ -1376,12 +1376,18 @@ TEST(Expression, SerializationRoundTrips) {
ExpectRoundTrips(field_ref("field"));
+ ExpectRoundTrips(field_ref(FieldRef("foo", "bar", "baz")));
+
ExpectRoundTrips(greater(field_ref("a"), literal(0.25)));
ExpectRoundTrips(
or_({equal(field_ref("a"), literal(1)), not_equal(field_ref("b"), literal("hello")),
equal(field_ref("b"), literal("foo bar"))}));
+ ExpectRoundTrips(or_({equal(field_ref(FieldRef("a", "b")), literal(1)),
+ not_equal(field_ref("b"), literal("hello")),
+ equal(field_ref(FieldRef("c", "d")), literal("foo bar"))}));
+
ExpectRoundTrips(not_(field_ref("alpha")));
ExpectRoundTrips(call("is_in", {literal(1)},
diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h
index 83dc6fa569..440b95ce59 100644
--- a/cpp/src/arrow/type.h
+++ b/cpp/src/arrow/type.h
@@ -1615,6 +1615,11 @@ class ARROW_EXPORT FieldRef {
/// Equivalent to a single index string of indices.
FieldRef(int index) : impl_(FieldPath({index})) {} // NOLINT runtime/explicit
+ /// Construct a nested FieldRef.
+ FieldRef(std::vector<FieldRef> refs) { // NOLINT runtime/explicit
+ Flatten(std::move(refs));
+ }
+
/// Convenience constructor for nested FieldRefs: each argument will be used to
/// construct a FieldRef
template <typename A0, typename A1, typename... A>
diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx
index 2f18ab9986..73c188edbb 100644
--- a/python/pyarrow/_compute.pyx
+++ b/python/pyarrow/_compute.pyx
@@ -2219,10 +2219,26 @@ cdef class Expression(_Weakrefable):
@staticmethod
def _field(name_or_idx not None):
- if isinstance(name_or_idx, str):
- return Expression.wrap(CMakeFieldExpression(tobytes(name_or_idx)))
- else:
+ cdef:
+ CFieldRef c_field
+
+ if isinstance(name_or_idx, int):
return Expression.wrap(CMakeFieldExpressionByIndex(name_or_idx))
+ else:
+ c_field = CFieldRef(<c_string> tobytes(name_or_idx))
+ return Expression.wrap(CMakeFieldExpression(c_field))
+
+ @staticmethod
+ def _nested_field(tuple names not None):
+ cdef:
+ vector[CFieldRef] nested
+
+ if len(names) == 0:
+ raise ValueError("nested field reference should be non-empty")
+ nested.reserve(len(names))
+ for name in names:
+ nested.push_back(CFieldRef(<c_string> tobytes(name)))
+ return Expression.wrap(CMakeFieldExpression(CFieldRef(move(nested))))
@staticmethod
def _scalar(value):
diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py
index 6cd65123e8..40751eab26 100644
--- a/python/pyarrow/compute.py
+++ b/python/pyarrow/compute.py
@@ -591,22 +591,52 @@ def bottom_k_unstable(values, k, sort_keys=None, *, memory_pool=None):
return call_function("select_k_unstable", [values], options, memory_pool)
-def field(name_or_index):
+def field(*name_or_index):
"""Reference a column of the dataset.
Stores only the field's name. Type and other information is known only when
the expression is bound to a dataset having an explicit scheme.
+ Nested references are allowed by passing multiple names or a tuple of
+ names. For example ``('foo', 'bar')`` references the field named "bar"
+ inside the field named "foo".
+
Parameters
----------
- name_or_index : string or int
- The name or index of the field the expression references to.
+ *name_or_index : string, multiple strings, tuple or int
+ The name or index of the (possibly nested) field the expression
+ references to.
Returns
-------
field_expr : Expression
+
+ Examples
+ --------
+ >>> import pyarrow.compute as pc
+ >>> pc.field("a")
+ <pyarrow.compute.Expression a>
+ >>> pc.field(1)
+ <pyarrow.compute.Expression FieldPath(1)>
+ >>> pc.field(("a", "b"))
+ <pyarrow.compute.Expression FieldRef.Nested(FieldRef.Name(a) ...
+ >>> pc.field("a", "b")
+ <pyarrow.compute.Expression FieldRef.Nested(FieldRef.Name(a) ...
"""
- return Expression._field(name_or_index)
+ n = len(name_or_index)
+ if n == 1:
+ if isinstance(name_or_index[0], (str, int)):
+ return Expression._field(name_or_index[0])
+ elif isinstance(name_or_index[0], tuple):
+ return Expression._nested_field(name_or_index[0])
+ else:
+ raise TypeError(
+ "field reference should be str, multiple str, tuple or "
+ f"integer, got {type(name_or_index[0])}"
+ )
+ # In case of multiple strings not supplied in a tuple
+ else:
+ return Expression._nested_field(name_or_index)
def scalar(value):
diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd
index 34e467fb8e..85b114bc48 100644
--- a/python/pyarrow/includes/libarrow.pxd
+++ b/python/pyarrow/includes/libarrow.pxd
@@ -427,6 +427,7 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil:
CFieldRef()
CFieldRef(c_string name)
CFieldRef(int index)
+ CFieldRef(vector[CFieldRef])
const c_string* name() const
cdef cppclass CFieldRefHash" arrow::FieldRef::Hash":
@@ -2402,7 +2403,7 @@ cdef extern from "arrow/compute/exec/expression.h" \
"arrow::compute::literal"(shared_ptr[CScalar] value)
cdef CExpression CMakeFieldExpression \
- "arrow::compute::field_ref"(c_string name)
+ "arrow::compute::field_ref"(CFieldRef)
cdef CExpression CMakeFieldExpressionByIndex \
"arrow::compute::field_ref"(int idx)
diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py
index 46d302c214..064572001f 100644
--- a/python/pyarrow/tests/test_compute.py
+++ b/python/pyarrow/tests/test_compute.py
@@ -2652,7 +2652,9 @@ def test_expression_serialization():
d.is_valid(), a.cast(pa.int32(), safe=False),
a.cast(pa.int32(), safe=False), a.isin([1, 2, 3]),
pc.field('i64') > 5, pc.field('i64') == 5,
- pc.field('i64') == 7, pc.field('i64').is_null()]
+ pc.field('i64') == 7, pc.field('i64').is_null(),
+ pc.field(('foo', 'bar')) == 'value',
+ pc.field('foo', 'bar') == 'value']
for expr in all_exprs:
assert isinstance(expr, pc.Expression)
restored = pickle.loads(pickle.dumps(expr))
@@ -2666,6 +2668,8 @@ def test_expression_construction():
false = pc.scalar(False)
string = pc.scalar("string")
field = pc.field("field")
+ nested_field = pc.field(("nested", "field"))
+ nested_field2 = pc.field("nested", "field")
zero | one == string
~true == false
@@ -2673,6 +2677,8 @@ def test_expression_construction():
field.cast(typ) == true
field.isin([1, 2])
+ nested_field.isin(["foo", "bar"])
+ nested_field2.isin(["foo", "bar"])
with pytest.raises(TypeError):
field.isin(1)
diff --git a/python/pyarrow/tests/test_dataset.py b/python/pyarrow/tests/test_dataset.py
index 511f8206cc..9a7f5ea213 100644
--- a/python/pyarrow/tests/test_dataset.py
+++ b/python/pyarrow/tests/test_dataset.py
@@ -102,13 +102,15 @@ def mockfs():
list(range(5)),
list(map(float, range(5))),
list(map(str, range(5))),
- [i] * 5
+ [i] * 5,
+ [{'a': j % 3, 'b': str(j % 3)} for j in range(5)],
]
schema = pa.schema([
- pa.field('i64', pa.int64()),
- pa.field('f64', pa.float64()),
- pa.field('str', pa.string()),
- pa.field('const', pa.int64()),
+ ('i64', pa.int64()),
+ ('f64', pa.float64()),
+ ('str', pa.string()),
+ ('const', pa.int64()),
+ ('struct', pa.struct({'a': pa.int64(), 'b': pa.string()})),
])
batch = pa.record_batch(data, schema=schema)
table = pa.Table.from_batches([batch])
@@ -383,14 +385,41 @@ def test_dataset(dataset, dataset_reader):
assert len(table) == 10
condition = ds.field('i64') == 1
- result = dataset.to_table(use_threads=True, filter=condition).to_pydict()
+ result = dataset.to_table(use_threads=True, filter=condition)
+ # Don't rely on the scanning order
+ result = result.sort_by('group').to_pydict()
- # don't rely on the scanning order
assert result['i64'] == [1, 1]
assert result['f64'] == [1., 1.]
assert sorted(result['group']) == [1, 2]
assert sorted(result['key']) == ['xxx', 'yyy']
+ # Filtering on a nested field ref
+ condition = ds.field(('struct', 'b')) == '1'
+ result = dataset.to_table(use_threads=True, filter=condition)
+ result = result.sort_by('group').to_pydict()
+
+ assert result['i64'] == [1, 4, 1, 4]
+ assert result['f64'] == [1.0, 4.0, 1.0, 4.0]
+ assert result['group'] == [1, 1, 2, 2]
+ assert result['key'] == ['xxx', 'xxx', 'yyy', 'yyy']
+
+ # Projecting on a nested field ref expression
+ projection = {
+ 'i64': ds.field('i64'),
+ 'f64': ds.field('f64'),
+ 'new': ds.field(('struct', 'b')) == '1',
+ }
+ result = dataset.to_table(use_threads=True, columns=projection)
+ result = result.sort_by('i64').to_pydict()
+
+ assert list(result) == ['i64', 'f64', 'new']
+ assert result['i64'] == [0, 0, 1, 1, 2, 2, 3, 3, 4, 4]
+ assert result['f64'] == [0.0, 0.0, 1.0, 1.0,
+ 2.0, 2.0, 3.0, 3.0, 4.0, 4.0]
+ assert result['new'] == [False, False, True, True, False, False,
+ False, False, True, True]
+
@pytest.mark.parquet
def test_scanner(dataset, dataset_reader):
@@ -808,6 +837,8 @@ def test_filesystem_factory(mockfs, paths_or_selector, pre_buffer):
pa.field('f64', pa.float64()),
pa.field('str', pa.dictionary(pa.int32(), pa.string())),
pa.field('const', pa.int64()),
+ pa.field('struct', pa.struct({'a': pa.int64(),
+ 'b': pa.string()})),
pa.field('group', pa.int32()),
pa.field('key', pa.string()),
]), check_metadata=False)
@@ -827,6 +858,8 @@ def test_filesystem_factory(mockfs, paths_or_selector, pre_buffer):
pa.array([0, 1, 2, 3, 4], type=pa.int32()),
pa.array("0 1 2 3 4".split(), type=pa.string())
)
+ expected_struct = pa.array([{'a': i % 3, 'b': str(i % 3)}
+ for i in range(5)])
iterator = scanner.scan_batches()
for (batch, fragment), group, key in zip(iterator, [1, 2], ['xxx', 'yyy']):
expected_group = pa.array([group] * 5, type=pa.int32())
@@ -834,18 +867,19 @@ def test_filesystem_factory(mockfs, paths_or_selector, pre_buffer):
expected_const = pa.array([group - 1] * 5, type=pa.int64())
# Can't compare or really introspect expressions from Python
assert fragment.partition_expression is not None
- assert batch.num_columns == 6
+ assert batch.num_columns == 7
assert batch[0].equals(expected_i64)
assert batch[1].equals(expected_f64)
assert batch[2].equals(expected_str)
assert batch[3].equals(expected_const)
- assert batch[4].equals(expected_group)
- assert batch[5].equals(expected_key)
+ assert batch[4].equals(expected_struct)
+ assert batch[5].equals(expected_group)
+ assert batch[6].equals(expected_key)
table = dataset.to_table()
assert isinstance(table, pa.Table)
assert len(table) == 10
- assert table.num_columns == 6
+ assert table.num_columns == 7
@pytest.mark.parquet
@@ -1480,6 +1514,7 @@ def test_partitioning_factory(mockfs):
("f64", pa.float64()),
("str", pa.string()),
("const", pa.int64()),
+ ("struct", pa.struct({'a': pa.int64(), 'b': pa.string()})),
("group", pa.int32()),
("key", pa.string()),
])
@@ -2047,7 +2082,7 @@ def test_construct_from_mixed_child_datasets(mockfs):
table = dataset.to_table()
assert len(table) == 20
- assert table.num_columns == 4
+ assert table.num_columns == 5
assert len(dataset.children) == 2
for child in dataset.children: