You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by uw...@apache.org on 2018/04/25 16:39:58 UTC
[arrow] branch master updated: ARROW-2074: [Python] Infer lists of
dicts as struct arrays
This is an automated email from the ASF dual-hosted git repository.
uwe pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 3d7a5a6 ARROW-2074: [Python] Infer lists of dicts as struct arrays
3d7a5a6 is described below
commit 3d7a5a64732a8a6457f544006ac1c49c141332c5
Author: Antoine Pitrou <an...@python.org>
AuthorDate: Wed Apr 25 18:39:35 2018 +0200
ARROW-2074: [Python] Infer lists of dicts as struct arrays
Also refactor the type inference visitor and remove the superfluous separate SeqVisitor; improve inference visitor performance by 30%; and add a struct type inference benchmark.
Author: Antoine Pitrou <an...@python.org>
Closes #1935 from pitrou/ARROW-2074-infer-dict-lists and squashes the following commits:
13ed6c30 <Antoine Pitrou> Fix tests on 2.7
3baa2eac <Antoine Pitrou> ARROW-2074: Infer lists of dicts as struct arrays
---
cpp/src/arrow/python/builtin_convert.cc | 242 ++++++++++++++-------------
python/benchmarks/convert_builtins.py | 3 +-
python/pyarrow/tests/test_convert_builtin.py | 77 ++++++++-
python/pyarrow/tests/test_types.py | 15 +-
python/pyarrow/types.pxi | 3 +
5 files changed, 213 insertions(+), 127 deletions(-)
diff --git a/cpp/src/arrow/python/builtin_convert.cc b/cpp/src/arrow/python/builtin_convert.cc
index a1c379d..740c896 100644
--- a/cpp/src/arrow/python/builtin_convert.cc
+++ b/cpp/src/arrow/python/builtin_convert.cc
@@ -21,6 +21,7 @@
#include <algorithm>
#include <limits>
+#include <map>
#include <sstream>
#include <string>
#include <utility>
@@ -49,9 +50,11 @@ Status InvalidConversion(PyObject* obj, const std::string& expected_types,
return Status::OK();
}
-class ScalarVisitor {
+class TypeInferrer {
+ // A type inference visitor for Python values
+
public:
- ScalarVisitor()
+ TypeInferrer()
: total_count_(0),
none_count_(0),
bool_count_(0),
@@ -62,14 +65,46 @@ class ScalarVisitor {
binary_count_(0),
unicode_count_(0),
decimal_count_(0),
+ list_count_(0),
+ struct_count_(0),
max_decimal_metadata_(std::numeric_limits<int32_t>::min(),
std::numeric_limits<int32_t>::min()),
decimal_type_() {
- PyAcquireGIL lock;
Status status = internal::ImportDecimalType(&decimal_type_);
DCHECK_OK(status);
}
+ // Infer value type from a sequence of values
+ Status VisitSequence(PyObject* obj) {
+ // Loop through a sequence
+ if (PyArray_Check(obj)) {
+ Py_ssize_t size = PySequence_Size(obj);
+ OwnedRef value_ref;
+
+ for (Py_ssize_t i = 0; i < size; ++i) {
+ auto array = reinterpret_cast<PyArrayObject*>(obj);
+ auto ptr = reinterpret_cast<const char*>(PyArray_GETPTR1(array, i));
+
+ value_ref.reset(PyArray_GETITEM(array, ptr));
+ RETURN_IF_PYERROR();
+ RETURN_NOT_OK(Visit(value_ref.obj()));
+ }
+ } else if (PySequence_Check(obj)) {
+ OwnedRef seq_ref(PySequence_Fast(obj, "Object is not a sequence or iterable"));
+ RETURN_IF_PYERROR();
+ PyObject* seq = seq_ref.obj();
+
+ Py_ssize_t size = PySequence_Fast_GET_SIZE(seq);
+ for (Py_ssize_t i = 0; i < size; ++i) {
+ PyObject* value = PySequence_Fast_GET_ITEM(seq, i);
+ RETURN_NOT_OK(Visit(value));
+ }
+ } else {
+ return Status::TypeError("Object is not a sequence or iterable");
+ }
+ return Status::OK();
+ }
+
Status Visit(PyObject* obj) {
++total_count_;
if (obj == Py_None || internal::PyFloat_IsNaN(obj)) {
@@ -103,6 +138,10 @@ class ScalarVisitor {
ss << type->ToString();
return Status::Invalid(ss.str());
}
+ } else if (PyList_Check(obj) || PyArray_Check(obj)) {
+ return VisitList(obj);
+ } else if (PyDict_Check(obj)) {
+ return VisitDict(obj);
} else if (PyObject_IsInstance(obj, decimal_type_.obj())) {
RETURN_NOT_OK(max_decimal_metadata_.Update(obj));
++decimal_count_;
@@ -118,14 +157,36 @@ class ScalarVisitor {
return Status::OK();
}
- std::shared_ptr<DataType> GetType() {
+ Status Validate() const {
+ if (list_count_ > 0) {
+ if (list_count_ + none_count_ != total_count_) {
+ return Status::Invalid("cannot mix list and non-list, non-null values");
+ }
+ RETURN_NOT_OK(list_inferrer_->Validate());
+ } else if (struct_count_ > 0) {
+ if (struct_count_ + none_count_ != total_count_) {
+ return Status::Invalid("cannot mix struct and non-struct, non-null values");
+ }
+ for (const auto& it : struct_inferrers_) {
+ RETURN_NOT_OK(it.second.Validate());
+ }
+ }
+ return Status::OK();
+ }
+
+ std::shared_ptr<DataType> GetType() const {
// TODO(wesm): handling mixed-type cases
- if (decimal_count_) {
+ if (list_count_) {
+ auto value_type = list_inferrer_->GetType();
+ DCHECK(value_type != nullptr);
+ return list(value_type);
+ } else if (struct_count_) {
+ return GetStructType();
+ } else if (decimal_count_) {
return decimal(max_decimal_metadata_.precision(), max_decimal_metadata_.scale());
} else if (float_count_) {
return float64();
} else if (int_count_) {
- // TODO(wesm): tighter type later
return int64();
} else if (date_count_) {
return date64();
@@ -144,6 +205,53 @@ class ScalarVisitor {
int64_t total_count() const { return total_count_; }
+ protected:
+ Status VisitList(PyObject* obj) {
+ if (!list_inferrer_) {
+ list_inferrer_.reset(new TypeInferrer);
+ }
+ ++list_count_;
+ return list_inferrer_->VisitSequence(obj);
+ }
+
+ Status VisitDict(PyObject* obj) {
+ PyObject* key_obj;
+ PyObject* value_obj;
+ Py_ssize_t pos = 0;
+
+ while (PyDict_Next(obj, &pos, &key_obj, &value_obj)) {
+ std::string key;
+ if (PyUnicode_Check(key_obj)) {
+ RETURN_NOT_OK(internal::PyUnicode_AsStdString(key_obj, &key));
+ } else if (PyBytes_Check(key_obj)) {
+ key = internal::PyBytes_AsStdString(key_obj);
+ } else {
+ std::stringstream ss;
+ ss << "Expected dict key of type str or bytes, got '" << Py_TYPE(key_obj)->tp_name
+ << "'";
+ return Status::TypeError(ss.str());
+ }
+ // Get or create visitor for this key
+ auto it = struct_inferrers_.find(key);
+ if (it == struct_inferrers_.end()) {
+ it = struct_inferrers_.insert(std::make_pair(key, TypeInferrer())).first;
+ }
+ TypeInferrer* visitor = &it->second;
+ RETURN_NOT_OK(visitor->Visit(value_obj));
+ }
+ ++struct_count_;
+ return Status::OK();
+ }
+
+ std::shared_ptr<DataType> GetStructType() const {
+ std::vector<std::shared_ptr<Field>> fields;
+ for (const auto& it : struct_inferrers_) {
+ const auto struct_field = field(it.first, it.second.GetType());
+ fields.emplace_back(struct_field);
+ }
+ return struct_(fields);
+ }
+
private:
int64_t total_count_;
int64_t none_count_;
@@ -155,6 +263,10 @@ class ScalarVisitor {
int64_t binary_count_;
int64_t unicode_count_;
int64_t decimal_count_;
+ int64_t list_count_;
+ std::unique_ptr<TypeInferrer> list_inferrer_;
+ int64_t struct_count_;
+ std::map<std::string, TypeInferrer> struct_inferrers_;
internal::DecimalMetadata max_decimal_metadata_;
@@ -163,116 +275,6 @@ class ScalarVisitor {
OwnedRefNoGIL decimal_type_;
};
-static constexpr int MAX_NESTING_LEVELS = 32;
-
-// SeqVisitor is used to infer the type.
-class SeqVisitor {
- public:
- SeqVisitor() : max_nesting_level_(0), max_observed_level_(0), nesting_histogram_() {
- std::fill(nesting_histogram_, nesting_histogram_ + MAX_NESTING_LEVELS, 0);
- }
-
- // co-recursive with VisitElem
- Status Visit(PyObject* obj, int level = 0) {
- max_nesting_level_ = std::max(max_nesting_level_, level);
-
- // Loop through a sequence
- if (!PySequence_Check(obj))
- return Status::TypeError("Object is not a sequence or iterable");
-
- Py_ssize_t size = PySequence_Size(obj);
- for (int64_t i = 0; i < size; ++i) {
- OwnedRef ref;
- if (PyArray_Check(obj)) {
- auto array = reinterpret_cast<PyArrayObject*>(obj);
- auto ptr = reinterpret_cast<const char*>(PyArray_GETPTR1(array, i));
-
- ref.reset(PyArray_GETITEM(array, ptr));
- RETURN_IF_PYERROR();
-
- RETURN_NOT_OK(VisitElem(ref, level));
- } else {
- ref.reset(PySequence_GetItem(obj, i));
- RETURN_IF_PYERROR();
- RETURN_NOT_OK(VisitElem(ref, level));
- }
- }
- return Status::OK();
- }
-
- std::shared_ptr<DataType> GetType() {
- // If all the non-list inputs were null (or there were no inputs)
- std::shared_ptr<DataType> result;
- if (scalars_.total_count() == 0) {
- // Lists of Lists of NULL
- result = null();
- } else {
- // Lists of Lists of [X]
- result = scalars_.GetType();
- }
- for (int i = 0; i < max_nesting_level_; ++i) {
- result = std::make_shared<ListType>(result);
- }
- return result;
- }
-
- Status Validate() const {
- if (scalars_.total_count() > 0) {
- if (num_nesting_levels() > 1) {
- return Status::Invalid("Mixed nesting levels not supported");
- // If the nesting goes deeper than the deepest scalar
- } else if (max_observed_level_ < max_nesting_level_) {
- return Status::Invalid("Mixed nesting levels not supported");
- }
- }
- return Status::OK();
- }
-
- // Returns the number of nesting levels which have scalar elements.
- int num_nesting_levels() const {
- int result = 0;
- for (int i = 0; i < MAX_NESTING_LEVELS; ++i) {
- if (nesting_histogram_[i] > 0) {
- ++result;
- }
- }
- return result;
- }
-
- private:
- ScalarVisitor scalars_;
-
- // Track observed
- // Deapest nesting level (irregardless of scalars)
- int max_nesting_level_;
- int max_observed_level_;
-
- // Number of scalar elements at each nesting level.
- // (TOOD: We really only need to know if a scalar is present, not the count).
- int nesting_histogram_[MAX_NESTING_LEVELS];
-
- // Visits a specific element (inner part of the loop).
- Status VisitElem(const OwnedRef& item_ref, int level) {
- DCHECK_NE(item_ref.obj(), NULLPTR);
- if (PyList_Check(item_ref.obj()) || PyArray_Check(item_ref.obj())) {
- RETURN_NOT_OK(Visit(item_ref.obj(), level + 1));
- } else if (PyDict_Check(item_ref.obj())) {
- return Status::NotImplemented("No type inference for dicts");
- } else {
- // We permit nulls at any level of nesting, but they aren't treated like
- // other scalar values as far as the checking for mixed nesting structure
- if (item_ref.obj() != Py_None) {
- ++nesting_histogram_[level];
- }
- if (level > max_observed_level_) {
- max_observed_level_ = level;
- }
- return scalars_.Visit(item_ref.obj());
- }
- return Status::OK();
- }
-};
-
// Convert *obj* to a sequence if necessary
// Fill *size* to its length. If >= 0 on entry, *size* is an upper size
// bound that may lead to truncation.
@@ -319,11 +321,11 @@ Status ConvertToSequenceAndInferSize(PyObject* obj, PyObject** seq, int64_t* siz
// Non-exhaustive type inference
Status InferArrowType(PyObject* obj, std::shared_ptr<DataType>* out_type) {
PyDateTime_IMPORT;
- SeqVisitor seq_visitor;
- RETURN_NOT_OK(seq_visitor.Visit(obj));
- RETURN_NOT_OK(seq_visitor.Validate());
+ TypeInferrer inferrer;
+ RETURN_NOT_OK(inferrer.VisitSequence(obj));
+ RETURN_NOT_OK(inferrer.Validate());
- *out_type = seq_visitor.GetType();
+ *out_type = inferrer.GetType();
if (*out_type == nullptr) {
return Status::TypeError("Unable to determine data type");
}
diff --git a/python/benchmarks/convert_builtins.py b/python/benchmarks/convert_builtins.py
index 91b15ec..48a38fa 100644
--- a/python/benchmarks/convert_builtins.py
+++ b/python/benchmarks/convert_builtins.py
@@ -51,8 +51,7 @@ class InferPyListToArray(object):
"""
size = 10 ** 5
types = ('int64', 'float64', 'bool', 'decimal', 'binary', 'ascii',
- 'unicode', 'int64 list')
- # TODO add 'struct' when supported
+ 'unicode', 'int64 list', 'struct')
param_names = ['type']
params = [types]
diff --git a/python/pyarrow/tests/test_convert_builtin.py b/python/pyarrow/tests/test_convert_builtin.py
index a18d183..7fb4301 100644
--- a/python/pyarrow/tests/test_convert_builtin.py
+++ b/python/pyarrow/tests/test_convert_builtin.py
@@ -51,6 +51,14 @@ class StrangeIterable:
return self.lst.__iter__()
+def check_struct_type(ty, expected):
+ """
+ Check a struct type is as expected, but not taking order into account.
+ """
+ assert pa.types.is_struct(ty)
+ assert set(ty) == set(expected)
+
+
def test_iterable_types():
arr1 = pa.array(StrangeIterable([0, 1, 2, 3]))
arr2 = pa.array((0, 1, 2, 3))
@@ -479,11 +487,28 @@ def test_sequence_timestamp_from_int_with_unit():
pa.array([1, CustomClass()], type=pa.date64())
-def test_sequence_mixed_nesting_levels():
- pa.array([1, 2, None])
- pa.array([[1], [2], None])
- pa.array([[1], [2], [None]])
+def test_sequence_nesting_levels():
+ data = [1, 2, None]
+ arr = pa.array(data)
+ assert arr.type == pa.int64()
+ assert arr.to_pylist() == data
+
+ data = [[1], [2], None]
+ arr = pa.array(data)
+ assert arr.type == pa.list_(pa.int64())
+ assert arr.to_pylist() == data
+ data = [[1], [2, 3, 4], [None]]
+ arr = pa.array(data)
+ assert arr.type == pa.list_(pa.int64())
+ assert arr.to_pylist() == data
+
+ data = [None, [[None, 1]], [[2, 3, 4], None], [None]]
+ arr = pa.array(data)
+ assert arr.type == pa.list_(pa.list_(pa.int64()))
+ assert arr.to_pylist() == data
+
+ # Mixed nesting levels are rejected
with pytest.raises(pa.ArrowInvalid):
pa.array([1, 2, [1]])
@@ -649,6 +674,50 @@ def test_struct_from_mixed_sequence():
pa.array(data, type=ty)
+def test_struct_from_dicts_inference():
+ expected_type = pa.struct([pa.field('a', pa.int64()),
+ pa.field('b', pa.string()),
+ pa.field('c', pa.bool_())])
+ data = [{'a': 5, 'b': u'foo', 'c': True},
+ {'a': 6, 'b': u'bar', 'c': False}]
+ arr = pa.array(data)
+ check_struct_type(arr.type, expected_type)
+ assert arr.to_pylist() == data
+
+ # With omitted values
+ data = [{'a': 5, 'c': True},
+ None,
+ {},
+ {'a': None, 'b': u'bar'}]
+ expected = [{'a': 5, 'b': None, 'c': True},
+ None,
+ {'a': None, 'b': None, 'c': None},
+ {'a': None, 'b': u'bar', 'c': None}]
+ arr = pa.array(data)
+ check_struct_type(arr.type, expected_type)
+ assert arr.to_pylist() == expected
+
+ # Nested
+ expected_type = pa.struct([
+ pa.field('a', pa.struct([pa.field('aa', pa.list_(pa.int64())),
+ pa.field('ab', pa.bool_())])),
+ pa.field('b', pa.string())])
+ data = [{'a': {'aa': [5, 6], 'ab': True}, 'b': 'foo'},
+ {'a': {'aa': None, 'ab': False}, 'b': None},
+ {'a': None, 'b': 'bar'}]
+ arr = pa.array(data)
+ assert arr.to_pylist() == data
+
+ # Edge cases
+ arr = pa.array([{}])
+ assert arr.type == pa.struct([])
+ assert arr.to_pylist() == [{}]
+
+ # Mixing structs and scalars is rejected
+ with pytest.raises(pa.ArrowInvalid):
+ pa.array([1, {'a': 2}])
+
+
def test_structarray_from_arrays_coerce():
# ARROW-1706
ints = [None, 2, 3]
diff --git a/python/pyarrow/tests/test_types.py b/python/pyarrow/tests/test_types.py
index 9d0f383..bb2a986 100644
--- a/python/pyarrow/tests/test_types.py
+++ b/python/pyarrow/tests/test_types.py
@@ -199,8 +199,9 @@ def test_types_hashable():
for i, type_ in enumerate(MANY_TYPES):
assert hash(type_) == hash(type_)
in_dict[type_] = i
- assert in_dict[type_] == i
assert len(in_dict) == len(MANY_TYPES)
+ for i, type_ in enumerate(MANY_TYPES):
+ assert in_dict[type_] == i
def test_types_picklable():
@@ -209,6 +210,18 @@ def test_types_picklable():
assert pickle.loads(data) == ty
+def test_fields_hashable():
+ in_dict = {}
+ fields = [pa.field('a', pa.int64()),
+ pa.field('a', pa.int32()),
+ pa.field('b', pa.int32())]
+ for i, field in enumerate(fields):
+ in_dict[field] = i
+ assert len(in_dict) == len(fields)
+ for i, field in enumerate(fields):
+ assert in_dict[field] == i
+
+
@pytest.mark.parametrize('t,check_func', [
(pa.date32(), types.is_date32),
(pa.date64(), types.is_date64),
diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi
index 4ed1443..850be23 100644
--- a/python/pyarrow/types.pxi
+++ b/python/pyarrow/types.pxi
@@ -414,6 +414,9 @@ cdef class Field:
def __repr__(self):
return self.__str__()
+ def __hash__(self):
+ return hash((self.field.name(), self.type.id))
+
property nullable:
def __get__(self):
--
To stop receiving notification emails like this one, please contact
uwe@apache.org.