You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by jo...@apache.org on 2022/11/22 10:57:14 UTC
[arrow] branch master updated: ARROW-17989: [C++][Python] Enable struct_field kernel to accept string field names (#14495)
This is an automated email from the ASF dual-hosted git repository.
jorisvandenbossche pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new b1110ae377 ARROW-17989: [C++][Python] Enable struct_field kernel to accept string field names (#14495)
b1110ae377 is described below
commit b1110ae377c66bc3b666f9c287afdf4907bb1952
Author: Miles Granger <mi...@gmail.com>
AuthorDate: Tue Nov 22 11:57:08 2022 +0100
ARROW-17989: [C++][Python] Enable struct_field kernel to accept string field names (#14495)
Will close [ARROW-17989](https://issues.apache.org/jira/browse/ARROW-17989)
Allows using names in `pc.struct_field`
```python
In [1]: arr = pa.array([{'a': {'b': 1}, 'c': 2}])
In [2]: pc.struct_field(arr, 'c')
Out[2]:
<pyarrow.lib.Int64Array object at 0x7f1442da3d60>
[
2
]
In [3]: pc.struct_field(arr, '.a.b')
Out[3]:
<pyarrow.lib.Int64Array object at 0x7f14436d0f40>
[
1
]
# And indices as before...
In [4]: pc.struct_field(arr, [0, 0])
Out[4]:
<pyarrow.lib.Int64Array object at 0x7f14436d0ee0>
[
1
]
In [5]:
```
Lead-authored-by: Miles Granger <mi...@gmail.com>
Co-authored-by: Antoine Pitrou <an...@python.org>
Co-authored-by: Joris Van den Bossche <jo...@gmail.com>
Signed-off-by: Joris Van den Bossche <jo...@gmail.com>
---
cpp/src/arrow/compute/api_scalar.cc | 11 ++-
cpp/src/arrow/compute/api_scalar.h | 7 +-
cpp/src/arrow/compute/kernels/scalar_nested.cc | 27 ++++++--
.../arrow/compute/kernels/scalar_nested_test.cc | 46 +++++++++++--
.../arrow/engine/substrait/expression_internal.cc | 18 +++--
cpp/src/arrow/type.cc | 78 +++++++++++++++++-----
cpp/src/arrow/type.h | 3 +
cpp/src/arrow/type_test.cc | 52 ++++++++++++++-
python/pyarrow/_compute.pyx | 39 ++++++++++-
python/pyarrow/includes/libarrow.pxd | 6 ++
python/pyarrow/tests/test_compute.py | 32 +++++++--
11 files changed, 267 insertions(+), 52 deletions(-)
diff --git a/cpp/src/arrow/compute/api_scalar.cc b/cpp/src/arrow/compute/api_scalar.cc
index 5de6eade5b..425274043e 100644
--- a/cpp/src/arrow/compute/api_scalar.cc
+++ b/cpp/src/arrow/compute/api_scalar.cc
@@ -365,7 +365,7 @@ static auto kStrptimeOptionsType = GetFunctionOptionsType<StrptimeOptions>(
DataMember("unit", &StrptimeOptions::unit),
DataMember("error_is_null", &StrptimeOptions::error_is_null));
static auto kStructFieldOptionsType = GetFunctionOptionsType<StructFieldOptions>(
- DataMember("indices", &StructFieldOptions::indices));
+ DataMember("field_ref", &StructFieldOptions::field_ref));
static auto kTrimOptionsType = GetFunctionOptionsType<TrimOptions>(
DataMember("characters", &TrimOptions::characters));
static auto kUtf8NormalizeOptionsType = GetFunctionOptionsType<Utf8NormalizeOptions>(
@@ -578,8 +578,13 @@ StrptimeOptions::StrptimeOptions() : StrptimeOptions("", TimeUnit::MICRO, false)
constexpr char StrptimeOptions::kTypeName[];
StructFieldOptions::StructFieldOptions(std::vector<int> indices)
- : FunctionOptions(internal::kStructFieldOptionsType), indices(std::move(indices)) {}
-StructFieldOptions::StructFieldOptions() : StructFieldOptions(std::vector<int>()) {}
+ : FunctionOptions(internal::kStructFieldOptionsType), field_ref(std::move(indices)) {}
+StructFieldOptions::StructFieldOptions(std::initializer_list<int> indices)
+ : FunctionOptions(internal::kStructFieldOptionsType), field_ref(std::move(indices)) {}
+StructFieldOptions::StructFieldOptions(FieldRef ref)
+ : FunctionOptions(internal::kStructFieldOptionsType), field_ref(std::move(ref)) {}
+StructFieldOptions::StructFieldOptions()
+ : FunctionOptions(internal::kStructFieldOptionsType) {}
constexpr char StructFieldOptions::kTypeName[];
TrimOptions::TrimOptions(std::string characters)
diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h
index f15d9c667f..1c27757fcf 100644
--- a/cpp/src/arrow/compute/api_scalar.h
+++ b/cpp/src/arrow/compute/api_scalar.h
@@ -278,12 +278,13 @@ class ARROW_EXPORT SetLookupOptions : public FunctionOptions {
class ARROW_EXPORT StructFieldOptions : public FunctionOptions {
public:
explicit StructFieldOptions(std::vector<int> indices);
+ explicit StructFieldOptions(std::initializer_list<int>);
+ explicit StructFieldOptions(FieldRef field_ref);
StructFieldOptions();
static constexpr char const kTypeName[] = "StructFieldOptions";
- /// The child indices to extract. For instance, to get the 2nd child
- /// of the 1st child of a struct or union, this would be {0, 1}.
- std::vector<int> indices;
+ /// The FieldRef specifying what to extract from struct or union.
+ FieldRef field_ref;
};
class ARROW_EXPORT StrptimeOptions : public FunctionOptions {
diff --git a/cpp/src/arrow/compute/kernels/scalar_nested.cc b/cpp/src/arrow/compute/kernels/scalar_nested.cc
index 5af6b78182..fb1cd9220b 100644
--- a/cpp/src/arrow/compute/kernels/scalar_nested.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_nested.cc
@@ -388,9 +388,17 @@ const FunctionDoc list_element_doc(
struct StructFieldFunctor {
static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
const auto& options = OptionsWrapper<StructFieldOptions>::Get(ctx);
-
std::shared_ptr<Array> current = MakeArray(batch[0].array.ToArrayData());
- for (const auto& index : options.indices) {
+
+ FieldPath field_path;
+ if (options.field_ref.IsNested() || options.field_ref.IsName()) {
+ ARROW_ASSIGN_OR_RAISE(field_path, options.field_ref.FindOne(*current->type()));
+ } else {
+ DCHECK(options.field_ref.IsFieldPath());
+ field_path = *options.field_ref.field_path();
+ }
+
+ for (const auto& index : field_path.indices()) {
RETURN_NOT_OK(CheckIndex(index, *current->type()));
switch (current->type()->id()) {
case Type::STRUCT: {
@@ -421,7 +429,8 @@ struct StructFieldFunctor {
ArrayData(int32(), union_array.length(),
{std::move(take_bitmap), union_array.value_offsets()},
kUnknownNullCount, union_array.offset()));
- // Do not slice the child since the indices are relative to the unsliced array.
+ // Do not slice the child since the indices are relative to the unsliced
+ // array.
ARROW_ASSIGN_OR_RAISE(
Datum result,
CallFunction("take", {union_array.field(index), std::move(take_indices)}));
@@ -463,9 +472,17 @@ struct StructFieldFunctor {
Result<TypeHolder> ResolveStructFieldType(KernelContext* ctx,
const std::vector<TypeHolder>& types) {
- const auto& options = OptionsWrapper<StructFieldOptions>::Get(ctx);
+ const auto& field_ref = OptionsWrapper<StructFieldOptions>::Get(ctx).field_ref;
const DataType* type = types.front().type;
- for (const auto& index : options.indices) {
+
+ FieldPath field_path;
+ if (field_ref.IsNested() || field_ref.IsName()) {
+ ARROW_ASSIGN_OR_RAISE(field_path, field_ref.FindOne(*type));
+ } else {
+ field_path = *field_ref.field_path();
+ }
+
+ for (const auto& index : field_path.indices()) {
RETURN_NOT_OK(StructFieldFunctor::CheckIndex(index, *type));
type = type->field(index)->type().get();
}
diff --git a/cpp/src/arrow/compute/kernels/scalar_nested_test.cc b/cpp/src/arrow/compute/kernels/scalar_nested_test.cc
index ec1e7ceeae..744f188908 100644
--- a/cpp/src/arrow/compute/kernels/scalar_nested_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_nested_test.cc
@@ -261,6 +261,13 @@ TEST(TestScalarNested, StructField) {
StructFieldOptions invalid2({2, 4});
StructFieldOptions invalid3({3});
StructFieldOptions invalid4({0, 1});
+
+ // Test using FieldRefs
+ StructFieldOptions extract0_field_ref_path(FieldRef(FieldPath({0})));
+ StructFieldOptions extract0_field_ref_name(FieldRef("a"));
+ ASSERT_OK_AND_ASSIGN(auto field_ref, FieldRef::FromDotPath(".c.d"));
+ StructFieldOptions extract20_field_ref_nest(field_ref);
+
FieldVector fields = {field("a", int32()), field("b", utf8()),
field("c", struct_({
field("d", int64()),
@@ -278,16 +285,25 @@ TEST(TestScalarNested, StructField) {
&extract0);
CheckScalar("struct_field", {arr}, ArrayFromJSON(int64(), "[10, 11, 12, null]"),
&extract20);
+
+ CheckScalar("struct_field", {arr}, ArrayFromJSON(int32(), "[1, null, 3, null]"),
+ &extract0_field_ref_path);
+ CheckScalar("struct_field", {arr}, ArrayFromJSON(int32(), "[1, null, 3, null]"),
+ &extract0_field_ref_name);
+ CheckScalar("struct_field", {arr}, ArrayFromJSON(int64(), "[10, 11, 12, null]"),
+ &extract20_field_ref_nest);
+
EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
::testing::HasSubstr("out-of-bounds field reference"),
CallFunction("struct_field", {arr}, &invalid1));
EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
- ::testing::HasSubstr("out-of-bounds field reference"),
+ ::testing::HasSubstr("No match for FieldRef"),
CallFunction("struct_field", {arr}, &invalid2));
EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
::testing::HasSubstr("out-of-bounds field reference"),
CallFunction("struct_field", {arr}, &invalid3));
- EXPECT_RAISES_WITH_MESSAGE_THAT(TypeError, ::testing::HasSubstr("cannot subscript"),
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
+ ::testing::HasSubstr("No match for FieldRef"),
CallFunction("struct_field", {arr}, &invalid4));
}
{
@@ -303,16 +319,25 @@ TEST(TestScalarNested, StructField) {
&extract0);
CheckScalar("struct_field", {arr}, ArrayFromJSON(int64(), "[null, null, null, 10]"),
&extract20);
+
+ CheckScalar("struct_field", {arr}, ArrayFromJSON(int32(), "[1, null, null, null]"),
+ &extract0_field_ref_path);
+ CheckScalar("struct_field", {arr}, ArrayFromJSON(int32(), "[1, null, null, null]"),
+ &extract0_field_ref_name);
+ CheckScalar("struct_field", {arr}, ArrayFromJSON(int64(), "[null, null, null, 10]"),
+ &extract20_field_ref_nest);
+
EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
::testing::HasSubstr("out-of-bounds field reference"),
CallFunction("struct_field", {arr}, &invalid1));
EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
- ::testing::HasSubstr("out-of-bounds field reference"),
+ ::testing::HasSubstr("No match for FieldRef"),
CallFunction("struct_field", {arr}, &invalid2));
EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
::testing::HasSubstr("out-of-bounds field reference"),
CallFunction("struct_field", {arr}, &invalid3));
- EXPECT_RAISES_WITH_MESSAGE_THAT(TypeError, ::testing::HasSubstr("cannot subscript"),
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
+ ::testing::HasSubstr("No match for FieldRef"),
CallFunction("struct_field", {arr}, &invalid4));
// Test edge cases for union representation
@@ -352,16 +377,25 @@ TEST(TestScalarNested, StructField) {
&extract0);
CheckScalar("struct_field", {arr}, ArrayFromJSON(int64(), "[null, null, null, 10]"),
&extract20);
+
+ CheckScalar("struct_field", {arr}, ArrayFromJSON(int32(), "[1, null, null, null]"),
+ &extract0_field_ref_path);
+ CheckScalar("struct_field", {arr}, ArrayFromJSON(int32(), "[1, null, null, null]"),
+ &extract0_field_ref_name);
+ CheckScalar("struct_field", {arr}, ArrayFromJSON(int64(), "[null, null, null, 10]"),
+ &extract20_field_ref_nest);
+
EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
::testing::HasSubstr("out-of-bounds field reference"),
CallFunction("struct_field", {arr}, &invalid1));
EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
- ::testing::HasSubstr("out-of-bounds field reference"),
+ ::testing::HasSubstr("No match for FieldRef"),
CallFunction("struct_field", {arr}, &invalid2));
EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
::testing::HasSubstr("out-of-bounds field reference"),
CallFunction("struct_field", {arr}, &invalid3));
- EXPECT_RAISES_WITH_MESSAGE_THAT(TypeError, ::testing::HasSubstr("cannot subscript"),
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
+ ::testing::HasSubstr("No match for FieldRef"),
CallFunction("struct_field", {arr}, &invalid4));
}
{
diff --git a/cpp/src/arrow/engine/substrait/expression_internal.cc b/cpp/src/arrow/engine/substrait/expression_internal.cc
index 7495d1a34e..b988bf195a 100644
--- a/cpp/src/arrow/engine/substrait/expression_internal.cc
+++ b/cpp/src/arrow/engine/substrait/expression_internal.cc
@@ -170,9 +170,10 @@ Result<compute::Expression> FromProto(const substrait::Expression& expr,
out = compute::field_ref(FieldRef(*out_ref, index));
} else if (out->call() && out->call()->function_name == "struct_field") {
// Nested StructFields on top of an arbitrary expression
- std::static_pointer_cast<arrow::compute::StructFieldOptions>(
- out->call()->options)
- ->indices.push_back(index);
+ auto* field_options =
+ checked_cast<compute::StructFieldOptions*>(out->call()->options.get());
+ field_options->field_ref =
+ FieldRef(std::move(field_options->field_ref), index);
} else {
// First StructField on top of an arbitrary expression
out = compute::call("struct_field", {std::move(*out)},
@@ -1019,13 +1020,16 @@ Result<std::unique_ptr<substrait::Expression>> ToProto(
if (call->function_name == "struct_field") {
// catch the special case of calls convertible to a StructField
+ const auto& field_options =
+ checked_cast<const compute::StructFieldOptions&>(*call->options);
+ const DataType& struct_type = *call->arguments[0].type();
+ DCHECK_EQ(struct_type.id(), Type::STRUCT);
+
+ ARROW_ASSIGN_OR_RAISE(auto field_path, field_options.field_ref.FindOne(struct_type));
out = std::move(arguments[0]);
- for (int index :
- checked_cast<const arrow::compute::StructFieldOptions&>(*call->options)
- .indices) {
+ for (int index : field_path.indices()) {
ARROW_ASSIGN_OR_RAISE(out, MakeStructFieldReference(std::move(out), index));
}
-
return std::move(out);
}
diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc
index 6b63a1f8b7..4247ac2360 100644
--- a/cpp/src/arrow/type.cc
+++ b/cpp/src/arrow/type.cc
@@ -20,6 +20,7 @@
#include <algorithm>
#include <climits>
#include <cstddef>
+#include <iterator>
#include <limits>
#include <memory>
#include <mutex>
@@ -1161,36 +1162,72 @@ Result<std::shared_ptr<ArrayData>> FieldPath::Get(const ArrayData& data) const {
return FieldPathGetImpl::Get(this, data.child_data);
}
-FieldRef::FieldRef(FieldPath indices) : impl_(std::move(indices)) {
- DCHECK_GT(std::get<FieldPath>(impl_).indices().size(), 0);
-}
+FieldRef::FieldRef(FieldPath indices) : impl_(std::move(indices)) {}
void FieldRef::Flatten(std::vector<FieldRef> children) {
+ ARROW_CHECK(!children.empty());
+
// flatten children
struct Visitor {
- void operator()(std::string&& name) { out->push_back(FieldRef(std::move(name))); }
+ void operator()(std::string&& name, std::vector<FieldRef>* out) {
+ out->push_back(FieldRef(std::move(name)));
+ }
- void operator()(FieldPath&& indices) { out->push_back(FieldRef(std::move(indices))); }
+ void operator()(FieldPath&& path, std::vector<FieldRef>* out) {
+ if (path.indices().empty()) {
+ return;
+ }
+ out->push_back(FieldRef(std::move(path)));
+ }
- void operator()(std::vector<FieldRef>&& children) {
- out->reserve(out->size() + children.size());
+ void operator()(std::vector<FieldRef>&& children, std::vector<FieldRef>* out) {
+ if (children.empty()) {
+ return;
+ }
+ // First flatten children into temporary result
+ std::vector<FieldRef> flattened_children;
+ flattened_children.reserve(children.size());
for (auto&& child : children) {
- std::visit(*this, std::move(child.impl_));
+ std::visit(std::bind(*this, std::placeholders::_1, &flattened_children),
+ std::move(child.impl_));
+ }
+ // If all children are FieldPaths, concatenate them into a single FieldPath
+ int64_t n_indices = 0;
+ for (const auto& child : flattened_children) {
+ const FieldPath* path = child.field_path();
+ if (!path) {
+ n_indices = -1;
+ break;
+ }
+ n_indices += static_cast<int64_t>(path->indices().size());
+ }
+ if (n_indices == 0) {
+ return;
+ } else if (n_indices > 0) {
+ std::vector<int> indices(n_indices);
+ auto out_indices = indices.begin();
+ for (const auto& child : flattened_children) {
+ for (int index : *child.field_path()) {
+ *out_indices++ = index;
+ }
+ }
+ DCHECK_EQ(out_indices, indices.end());
+ out->push_back(FieldRef(std::move(indices)));
+ } else {
+ // ... otherwise, just transfer them to the final result
+ out->insert(out->end(), std::move_iterator(flattened_children.begin()),
+ std::move_iterator(flattened_children.end()));
}
}
-
- std::vector<FieldRef>* out;
};
std::vector<FieldRef> out;
- Visitor visitor{&out};
- visitor(std::move(children));
+ Visitor visitor;
+ visitor(std::move(children), &out);
- DCHECK(!out.empty());
- DCHECK(std::none_of(out.begin(), out.end(),
- [](const FieldRef& ref) { return ref.IsNested(); }));
-
- if (out.size() == 1) {
+ if (out.empty()) {
+ impl_ = std::vector<int>();
+ } else if (out.size() == 1) {
impl_ = std::move(out[0].impl_);
} else {
impl_ = std::move(out);
@@ -1199,7 +1236,7 @@ void FieldRef::Flatten(std::vector<FieldRef> children) {
Result<FieldRef> FieldRef::FromDotPath(const std::string& dot_path_arg) {
if (dot_path_arg.empty()) {
- return Status::Invalid("Dot path was empty");
+ return FieldRef();
}
std::vector<FieldRef> children;
@@ -1449,6 +1486,11 @@ std::vector<FieldPath> FieldRef::FindAll(const RecordBatch& batch) const {
void PrintTo(const FieldRef& ref, std::ostream* os) { *os << ref.ToString(); }
+std::ostream& operator<<(std::ostream& os, const FieldRef& ref) {
+ os << ref.ToString();
+ return os;
+}
+
// ----------------------------------------------------------------------
// Schema implementation
diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h
index 3bb92bf26f..415aaacf1c 100644
--- a/cpp/src/arrow/type.h
+++ b/cpp/src/arrow/type.h
@@ -1851,6 +1851,9 @@ class ARROW_EXPORT FieldRef : public util::EqualityComparable<FieldRef> {
ARROW_EXPORT void PrintTo(const FieldRef& ref, std::ostream* os);
+ARROW_EXPORT
+std::ostream& operator<<(std::ostream& os, const FieldRef&);
+
// ----------------------------------------------------------------------
// Schema
diff --git a/cpp/src/arrow/type_test.cc b/cpp/src/arrow/type_test.cc
index c6ce1887de..954ad63c8a 100644
--- a/cpp/src/arrow/type_test.cc
+++ b/cpp/src/arrow/type_test.cc
@@ -414,12 +414,26 @@ TEST(TestFieldRef, FromDotPath) {
ASSERT_OK_AND_EQ(FieldRef(R"([y]\tho.\)"), FieldRef::FromDotPath(R"(.\[y\]\\tho\.\)"));
- ASSERT_RAISES(Invalid, FieldRef::FromDotPath(R"()"));
+ ASSERT_OK_AND_EQ(FieldRef(), FieldRef::FromDotPath(R"()"));
+
ASSERT_RAISES(Invalid, FieldRef::FromDotPath(R"(alpha)"));
ASSERT_RAISES(Invalid, FieldRef::FromDotPath(R"([134234)"));
ASSERT_RAISES(Invalid, FieldRef::FromDotPath(R"([1stuf])"));
}
+TEST(TestFieldRef, DotPathRoundTrip) {
+ auto check_roundtrip = [](const FieldRef& ref) {
+ auto dot_path = ref.ToDotPath();
+ ASSERT_OK_AND_EQ(ref, FieldRef::FromDotPath(dot_path));
+ };
+
+ check_roundtrip(FieldRef());
+ check_roundtrip(FieldRef("foo"));
+ check_roundtrip(FieldRef("foo", 1, "bar", 2, 3));
+ check_roundtrip(FieldRef(1, 2, 3));
+ check_roundtrip(FieldRef("foo", 1, FieldRef("bar", 2, 3), FieldRef()));
+}
+
TEST(TestFieldPath, Nested) {
auto f0 = field("alpha", int32());
auto f1_0 = field("alpha", int32());
@@ -456,6 +470,42 @@ TEST(TestFieldRef, Nested) {
ElementsAre(FieldPath{2, 1, 0}, FieldPath{2, 1, 1}));
}
+TEST(TestFieldRef, Flatten) {
+ FieldRef ref;
+
+ auto assert_name = [](const FieldRef& ref, const std::string& expected) {
+ ASSERT_TRUE(ref.IsName());
+ ASSERT_EQ(*ref.name(), expected);
+ };
+
+ auto assert_path = [](const FieldRef& ref, const std::vector<int>& expected) {
+ ASSERT_TRUE(ref.IsFieldPath());
+ ASSERT_EQ(ref.field_path()->indices(), expected);
+ };
+
+ auto assert_nested = [](const FieldRef& ref, const std::vector<FieldRef>& expected) {
+ ASSERT_TRUE(ref.IsNested());
+ ASSERT_EQ(*ref.nested_refs(), expected);
+ };
+
+ assert_path(FieldRef(), {});
+ assert_path(FieldRef(1, 2, 3), {1, 2, 3});
+ // If all leaves are field paths, they are fully flattened
+ assert_path(FieldRef(1, FieldRef(2, 3)), {1, 2, 3});
+ assert_path(FieldRef(1, FieldRef(2, 3), FieldRef(), FieldRef(FieldRef(4), FieldRef(5))),
+ {1, 2, 3, 4, 5});
+ assert_path(FieldRef(FieldRef(), FieldRef(FieldRef(), FieldRef())), {});
+
+ assert_name(FieldRef("foo"), "foo");
+
+ // Nested empty field refs are optimized away
+ assert_nested(FieldRef("foo", 1, FieldRef(), FieldRef(FieldRef(), "bar")),
+ {FieldRef("foo"), FieldRef(1), FieldRef("bar")});
+ // For now, subsequences of indices are not concatenated
+ assert_nested(FieldRef("foo", FieldRef("bar"), FieldRef(1, 2), FieldRef(3)),
+ {FieldRef("foo"), FieldRef("bar"), FieldRef(1, 2), FieldRef(3)});
+}
+
using TestSchema = ::testing::Test;
TEST_F(TestSchema, Basics) {
diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx
index 659af0afba..c75c5bf189 100644
--- a/python/pyarrow/_compute.pyx
+++ b/python/pyarrow/_compute.pyx
@@ -1361,7 +1361,37 @@ class MakeStructOptions(_MakeStructOptions):
cdef class _StructFieldOptions(FunctionOptions):
def _set_options(self, indices):
- self.wrapped.reset(new CStructFieldOptions(indices))
+ cdef:
+ CFieldRef field_ref
+ const CFieldRef* field_ref_ptr
+
+ if isinstance(indices, (list, tuple)):
+ if len(indices):
+ indices = Expression._nested_field(tuple(indices))
+ else:
+ # Allow empty indices; effecitively return same array
+ self.wrapped.reset(
+ new CStructFieldOptions(<vector[int]>indices))
+ return
+
+ if isinstance(indices, Expression):
+ field_ref_ptr = (<Expression>indices).unwrap().field_ref()
+ if field_ref_ptr is NULL:
+ raise ValueError("Unable to get CFieldRef from Expression")
+ field_ref = <CFieldRef>deref(field_ref_ptr)
+ elif isinstance(indices, (bytes, str)):
+ if indices.startswith(b'.' if isinstance(indices, bytes) else '.'):
+ field_ref = GetResultValue(
+ CFieldRef.FromDotPath(<c_string>tobytes(indices)))
+ else:
+ field_ref = CFieldRef(<c_string>tobytes(indices))
+ elif isinstance(indices, int):
+ field_ref = CFieldRef(<int> indices)
+ else:
+ raise TypeError("Expected List[str], List[int], List[bytes], "
+ "Expression, bytes, str, or int. "
+ f"Got: {type(indices)}")
+ self.wrapped.reset(new CStructFieldOptions(field_ref))
class StructFieldOptions(_StructFieldOptions):
@@ -1370,7 +1400,7 @@ class StructFieldOptions(_StructFieldOptions):
Parameters
----------
- indices : sequence of int
+ indices : List[str], List[bytes], List[int], Expression, bytes, str, or int
List of indices for chained field lookup, for example `[4, 1]`
will look up the second nested field in the fifth outer field.
"""
@@ -2442,7 +2472,10 @@ cdef class Expression(_Weakrefable):
raise ValueError("nested field reference should be non-empty")
nested.reserve(len(names))
for name in names:
- nested.push_back(CFieldRef(<c_string> tobytes(name)))
+ if isinstance(name, int):
+ nested.push_back(CFieldRef(<int>name))
+ else:
+ nested.push_back(CFieldRef(<c_string> tobytes(name)))
return Expression.wrap(CMakeFieldExpression(CFieldRef(move(nested))))
@staticmethod
diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd
index bc82a42089..9cea340a30 100644
--- a/python/pyarrow/includes/libarrow.pxd
+++ b/python/pyarrow/includes/libarrow.pxd
@@ -434,6 +434,9 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil:
CFieldRef(c_string name)
CFieldRef(int index)
CFieldRef(vector[CFieldRef])
+
+ @staticmethod
+ CResult[CFieldRef] FromDotPath(c_string& dot_path)
const c_string* name() const
cdef cppclass CFieldRefHash" arrow::FieldRef::Hash":
@@ -2291,7 +2294,9 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil:
cdef cppclass CStructFieldOptions \
"arrow::compute::StructFieldOptions"(CFunctionOptions):
CStructFieldOptions(vector[int] indices)
+ CStructFieldOptions(CFieldRef field_ref)
vector[int] indices
+ CFieldRef field_ref
ctypedef enum CSortOrder" arrow::compute::SortOrder":
CSortOrder_Ascending \
@@ -2496,6 +2501,7 @@ cdef extern from "arrow/compute/exec/expression.h" \
c_bool Equals(const CExpression& other) const
c_string ToString() const
CResult[CExpression] Bind(const CSchema&)
+ const CFieldRef* field_ref() const
cdef CExpression CMakeScalarExpression \
"arrow::compute::literal"(shared_ptr[CScalar] value)
diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py
index 3d03c7d86a..68b3303fe7 100644
--- a/python/pyarrow/tests/test_compute.py
+++ b/python/pyarrow/tests/test_compute.py
@@ -2690,14 +2690,32 @@ def test_struct_fields_options():
c = pa.StructArray.from_arrays([a, b], ["a", "b"])
arr = pa.StructArray.from_arrays([a, c], ["a", "c"])
- assert pc.struct_field(arr,
- indices=[1, 1]) == pa.array(["bar", None, ""])
- assert pc.struct_field(arr, [1, 1]) == pa.array(["bar", None, ""])
- assert pc.struct_field(arr, [0]) == pa.array([4, 5, 6], type=pa.int64())
+ assert pc.struct_field(arr, '.c.b') == b
+ assert pc.struct_field(arr, b'.c.b') == b
+ assert pc.struct_field(arr, ['c', 'b']) == b
+ assert pc.struct_field(arr, [1, 'b']) == b
+ assert pc.struct_field(arr, (b'c', 'b')) == b
+ assert pc.struct_field(arr, pc.field(('c', 'b'))) == b
+
+ assert pc.struct_field(arr, '.a') == a
+ assert pc.struct_field(arr, ['a']) == a
+ assert pc.struct_field(arr, 'a') == a
+ assert pc.struct_field(arr, pc.field(('a',))) == a
+
+ assert pc.struct_field(arr, indices=[1, 1]) == b
+ assert pc.struct_field(arr, (1, 1)) == b
+ assert pc.struct_field(arr, [0]) == a
assert pc.struct_field(arr, []) == arr
- with pytest.raises(TypeError, match="an integer is required"):
- pc.struct_field(arr, indices=['a'])
+ with pytest.raises(pa.ArrowInvalid, match="No match for FieldRef"):
+ pc.struct_field(arr, 'foo')
+
+ with pytest.raises(pa.ArrowInvalid, match="No match for FieldRef"):
+ pc.struct_field(arr, '.c.foo')
+
+ # drill into a non-struct array and continue to ask for a field
+ with pytest.raises(pa.ArrowInvalid, match="No match for FieldRef"):
+ pc.struct_field(arr, '.a.foo')
# TODO: https://issues.apache.org/jira/browse/ARROW-14853
# assert pc.struct_field(arr) == arr
@@ -2863,6 +2881,7 @@ def test_expression_construction():
false = pc.scalar(False)
string = pc.scalar("string")
field = pc.field("field")
+ nested_mixed_types = pc.field(b"a", 1, "b")
nested_field = pc.field(("nested", "field"))
nested_field2 = pc.field("nested", "field")
@@ -2872,6 +2891,7 @@ def test_expression_construction():
field.cast(typ) == true
field.isin([1, 2])
+ nested_mixed_types.isin(["foo", "bar"])
nested_field.isin(["foo", "bar"])
nested_field2.isin(["foo", "bar"])