You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by uw...@apache.org on 2018/04/25 16:39:58 UTC

[arrow] branch master updated: ARROW-2074: [Python] Infer lists of dicts as struct arrays

This is an automated email from the ASF dual-hosted git repository.

uwe pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 3d7a5a6  ARROW-2074: [Python] Infer lists of dicts as struct arrays
3d7a5a6 is described below

commit 3d7a5a64732a8a6457f544006ac1c49c141332c5
Author: Antoine Pitrou <an...@python.org>
AuthorDate: Wed Apr 25 18:39:35 2018 +0200

    ARROW-2074: [Python] Infer lists of dicts as struct arrays
    
    Also refactor the type inference visitor and remove the superfluous separate SeqVisitor; improve inference visitor performance by 30%; and add a struct type inference benchmark.
    
    Author: Antoine Pitrou <an...@python.org>
    
    Closes #1935 from pitrou/ARROW-2074-infer-dict-lists and squashes the following commits:
    
    13ed6c30 <Antoine Pitrou> Fix tests on 2.7
    3baa2eac <Antoine Pitrou> ARROW-2074:  Infer lists of dicts as struct arrays
---
 cpp/src/arrow/python/builtin_convert.cc      | 242 ++++++++++++++-------------
 python/benchmarks/convert_builtins.py        |   3 +-
 python/pyarrow/tests/test_convert_builtin.py |  77 ++++++++-
 python/pyarrow/tests/test_types.py           |  15 +-
 python/pyarrow/types.pxi                     |   3 +
 5 files changed, 213 insertions(+), 127 deletions(-)

diff --git a/cpp/src/arrow/python/builtin_convert.cc b/cpp/src/arrow/python/builtin_convert.cc
index a1c379d..740c896 100644
--- a/cpp/src/arrow/python/builtin_convert.cc
+++ b/cpp/src/arrow/python/builtin_convert.cc
@@ -21,6 +21,7 @@
 
 #include <algorithm>
 #include <limits>
+#include <map>
 #include <sstream>
 #include <string>
 #include <utility>
@@ -49,9 +50,11 @@ Status InvalidConversion(PyObject* obj, const std::string& expected_types,
   return Status::OK();
 }
 
-class ScalarVisitor {
+class TypeInferrer {
+  // A type inference visitor for Python values
+
  public:
-  ScalarVisitor()
+  TypeInferrer()
       : total_count_(0),
         none_count_(0),
         bool_count_(0),
@@ -62,14 +65,46 @@ class ScalarVisitor {
         binary_count_(0),
         unicode_count_(0),
         decimal_count_(0),
+        list_count_(0),
+        struct_count_(0),
         max_decimal_metadata_(std::numeric_limits<int32_t>::min(),
                               std::numeric_limits<int32_t>::min()),
         decimal_type_() {
-    PyAcquireGIL lock;
     Status status = internal::ImportDecimalType(&decimal_type_);
     DCHECK_OK(status);
   }
 
+  // Infer value type from a sequence of values
+  Status VisitSequence(PyObject* obj) {
+    // Loop through a sequence
+    if (PyArray_Check(obj)) {
+      Py_ssize_t size = PySequence_Size(obj);
+      OwnedRef value_ref;
+
+      for (Py_ssize_t i = 0; i < size; ++i) {
+        auto array = reinterpret_cast<PyArrayObject*>(obj);
+        auto ptr = reinterpret_cast<const char*>(PyArray_GETPTR1(array, i));
+
+        value_ref.reset(PyArray_GETITEM(array, ptr));
+        RETURN_IF_PYERROR();
+        RETURN_NOT_OK(Visit(value_ref.obj()));
+      }
+    } else if (PySequence_Check(obj)) {
+      OwnedRef seq_ref(PySequence_Fast(obj, "Object is not a sequence or iterable"));
+      RETURN_IF_PYERROR();
+      PyObject* seq = seq_ref.obj();
+
+      Py_ssize_t size = PySequence_Fast_GET_SIZE(seq);
+      for (Py_ssize_t i = 0; i < size; ++i) {
+        PyObject* value = PySequence_Fast_GET_ITEM(seq, i);
+        RETURN_NOT_OK(Visit(value));
+      }
+    } else {
+      return Status::TypeError("Object is not a sequence or iterable");
+    }
+    return Status::OK();
+  }
+
   Status Visit(PyObject* obj) {
     ++total_count_;
     if (obj == Py_None || internal::PyFloat_IsNaN(obj)) {
@@ -103,6 +138,10 @@ class ScalarVisitor {
         ss << type->ToString();
         return Status::Invalid(ss.str());
       }
+    } else if (PyList_Check(obj) || PyArray_Check(obj)) {
+      return VisitList(obj);
+    } else if (PyDict_Check(obj)) {
+      return VisitDict(obj);
     } else if (PyObject_IsInstance(obj, decimal_type_.obj())) {
       RETURN_NOT_OK(max_decimal_metadata_.Update(obj));
       ++decimal_count_;
@@ -118,14 +157,36 @@ class ScalarVisitor {
     return Status::OK();
   }
 
-  std::shared_ptr<DataType> GetType() {
+  Status Validate() const {
+    if (list_count_ > 0) {
+      if (list_count_ + none_count_ != total_count_) {
+        return Status::Invalid("cannot mix list and non-list, non-null values");
+      }
+      RETURN_NOT_OK(list_inferrer_->Validate());
+    } else if (struct_count_ > 0) {
+      if (struct_count_ + none_count_ != total_count_) {
+        return Status::Invalid("cannot mix struct and non-struct, non-null values");
+      }
+      for (const auto& it : struct_inferrers_) {
+        RETURN_NOT_OK(it.second.Validate());
+      }
+    }
+    return Status::OK();
+  }
+
+  std::shared_ptr<DataType> GetType() const {
     // TODO(wesm): handling mixed-type cases
-    if (decimal_count_) {
+    if (list_count_) {
+      auto value_type = list_inferrer_->GetType();
+      DCHECK(value_type != nullptr);
+      return list(value_type);
+    } else if (struct_count_) {
+      return GetStructType();
+    } else if (decimal_count_) {
       return decimal(max_decimal_metadata_.precision(), max_decimal_metadata_.scale());
     } else if (float_count_) {
       return float64();
     } else if (int_count_) {
-      // TODO(wesm): tighter type later
       return int64();
     } else if (date_count_) {
       return date64();
@@ -144,6 +205,53 @@ class ScalarVisitor {
 
   int64_t total_count() const { return total_count_; }
 
+ protected:
+  Status VisitList(PyObject* obj) {
+    if (!list_inferrer_) {
+      list_inferrer_.reset(new TypeInferrer);
+    }
+    ++list_count_;
+    return list_inferrer_->VisitSequence(obj);
+  }
+
+  Status VisitDict(PyObject* obj) {
+    PyObject* key_obj;
+    PyObject* value_obj;
+    Py_ssize_t pos = 0;
+
+    while (PyDict_Next(obj, &pos, &key_obj, &value_obj)) {
+      std::string key;
+      if (PyUnicode_Check(key_obj)) {
+        RETURN_NOT_OK(internal::PyUnicode_AsStdString(key_obj, &key));
+      } else if (PyBytes_Check(key_obj)) {
+        key = internal::PyBytes_AsStdString(key_obj);
+      } else {
+        std::stringstream ss;
+        ss << "Expected dict key of type str or bytes, got '" << Py_TYPE(key_obj)->tp_name
+           << "'";
+        return Status::TypeError(ss.str());
+      }
+      // Get or create visitor for this key
+      auto it = struct_inferrers_.find(key);
+      if (it == struct_inferrers_.end()) {
+        it = struct_inferrers_.insert(std::make_pair(key, TypeInferrer())).first;
+      }
+      TypeInferrer* visitor = &it->second;
+      RETURN_NOT_OK(visitor->Visit(value_obj));
+    }
+    ++struct_count_;
+    return Status::OK();
+  }
+
+  std::shared_ptr<DataType> GetStructType() const {
+    std::vector<std::shared_ptr<Field>> fields;
+    for (const auto& it : struct_inferrers_) {
+      const auto struct_field = field(it.first, it.second.GetType());
+      fields.emplace_back(struct_field);
+    }
+    return struct_(fields);
+  }
+
  private:
   int64_t total_count_;
   int64_t none_count_;
@@ -155,6 +263,10 @@ class ScalarVisitor {
   int64_t binary_count_;
   int64_t unicode_count_;
   int64_t decimal_count_;
+  int64_t list_count_;
+  std::unique_ptr<TypeInferrer> list_inferrer_;
+  int64_t struct_count_;
+  std::map<std::string, TypeInferrer> struct_inferrers_;
 
   internal::DecimalMetadata max_decimal_metadata_;
 
@@ -163,116 +275,6 @@ class ScalarVisitor {
   OwnedRefNoGIL decimal_type_;
 };
 
-static constexpr int MAX_NESTING_LEVELS = 32;
-
-// SeqVisitor is used to infer the type.
-class SeqVisitor {
- public:
-  SeqVisitor() : max_nesting_level_(0), max_observed_level_(0), nesting_histogram_() {
-    std::fill(nesting_histogram_, nesting_histogram_ + MAX_NESTING_LEVELS, 0);
-  }
-
-  // co-recursive with VisitElem
-  Status Visit(PyObject* obj, int level = 0) {
-    max_nesting_level_ = std::max(max_nesting_level_, level);
-
-    // Loop through a sequence
-    if (!PySequence_Check(obj))
-      return Status::TypeError("Object is not a sequence or iterable");
-
-    Py_ssize_t size = PySequence_Size(obj);
-    for (int64_t i = 0; i < size; ++i) {
-      OwnedRef ref;
-      if (PyArray_Check(obj)) {
-        auto array = reinterpret_cast<PyArrayObject*>(obj);
-        auto ptr = reinterpret_cast<const char*>(PyArray_GETPTR1(array, i));
-
-        ref.reset(PyArray_GETITEM(array, ptr));
-        RETURN_IF_PYERROR();
-
-        RETURN_NOT_OK(VisitElem(ref, level));
-      } else {
-        ref.reset(PySequence_GetItem(obj, i));
-        RETURN_IF_PYERROR();
-        RETURN_NOT_OK(VisitElem(ref, level));
-      }
-    }
-    return Status::OK();
-  }
-
-  std::shared_ptr<DataType> GetType() {
-    // If all the non-list inputs were null (or there were no inputs)
-    std::shared_ptr<DataType> result;
-    if (scalars_.total_count() == 0) {
-      // Lists of Lists of NULL
-      result = null();
-    } else {
-      // Lists of Lists of [X]
-      result = scalars_.GetType();
-    }
-    for (int i = 0; i < max_nesting_level_; ++i) {
-      result = std::make_shared<ListType>(result);
-    }
-    return result;
-  }
-
-  Status Validate() const {
-    if (scalars_.total_count() > 0) {
-      if (num_nesting_levels() > 1) {
-        return Status::Invalid("Mixed nesting levels not supported");
-        // If the nesting goes deeper than the deepest scalar
-      } else if (max_observed_level_ < max_nesting_level_) {
-        return Status::Invalid("Mixed nesting levels not supported");
-      }
-    }
-    return Status::OK();
-  }
-
-  // Returns the number of nesting levels which have scalar elements.
-  int num_nesting_levels() const {
-    int result = 0;
-    for (int i = 0; i < MAX_NESTING_LEVELS; ++i) {
-      if (nesting_histogram_[i] > 0) {
-        ++result;
-      }
-    }
-    return result;
-  }
-
- private:
-  ScalarVisitor scalars_;
-
-  // Track observed
-  // Deapest nesting level (irregardless of scalars)
-  int max_nesting_level_;
-  int max_observed_level_;
-
-  // Number of scalar elements at each nesting level.
-  // (TOOD: We really only need to know if a scalar is present, not the count).
-  int nesting_histogram_[MAX_NESTING_LEVELS];
-
-  // Visits a specific element (inner part of the loop).
-  Status VisitElem(const OwnedRef& item_ref, int level) {
-    DCHECK_NE(item_ref.obj(), NULLPTR);
-    if (PyList_Check(item_ref.obj()) || PyArray_Check(item_ref.obj())) {
-      RETURN_NOT_OK(Visit(item_ref.obj(), level + 1));
-    } else if (PyDict_Check(item_ref.obj())) {
-      return Status::NotImplemented("No type inference for dicts");
-    } else {
-      // We permit nulls at any level of nesting, but they aren't treated like
-      // other scalar values as far as the checking for mixed nesting structure
-      if (item_ref.obj() != Py_None) {
-        ++nesting_histogram_[level];
-      }
-      if (level > max_observed_level_) {
-        max_observed_level_ = level;
-      }
-      return scalars_.Visit(item_ref.obj());
-    }
-    return Status::OK();
-  }
-};
-
 // Convert *obj* to a sequence if necessary
 // Fill *size* to its length.  If >= 0 on entry, *size* is an upper size
 // bound that may lead to truncation.
@@ -319,11 +321,11 @@ Status ConvertToSequenceAndInferSize(PyObject* obj, PyObject** seq, int64_t* siz
 // Non-exhaustive type inference
 Status InferArrowType(PyObject* obj, std::shared_ptr<DataType>* out_type) {
   PyDateTime_IMPORT;
-  SeqVisitor seq_visitor;
-  RETURN_NOT_OK(seq_visitor.Visit(obj));
-  RETURN_NOT_OK(seq_visitor.Validate());
+  TypeInferrer inferrer;
+  RETURN_NOT_OK(inferrer.VisitSequence(obj));
+  RETURN_NOT_OK(inferrer.Validate());
 
-  *out_type = seq_visitor.GetType();
+  *out_type = inferrer.GetType();
   if (*out_type == nullptr) {
     return Status::TypeError("Unable to determine data type");
   }
diff --git a/python/benchmarks/convert_builtins.py b/python/benchmarks/convert_builtins.py
index 91b15ec..48a38fa 100644
--- a/python/benchmarks/convert_builtins.py
+++ b/python/benchmarks/convert_builtins.py
@@ -51,8 +51,7 @@ class InferPyListToArray(object):
     """
     size = 10 ** 5
     types = ('int64', 'float64', 'bool', 'decimal', 'binary', 'ascii',
-             'unicode', 'int64 list')
-    # TODO add 'struct' when supported
+             'unicode', 'int64 list', 'struct')
 
     param_names = ['type']
     params = [types]
diff --git a/python/pyarrow/tests/test_convert_builtin.py b/python/pyarrow/tests/test_convert_builtin.py
index a18d183..7fb4301 100644
--- a/python/pyarrow/tests/test_convert_builtin.py
+++ b/python/pyarrow/tests/test_convert_builtin.py
@@ -51,6 +51,14 @@ class StrangeIterable:
         return self.lst.__iter__()
 
 
+def check_struct_type(ty, expected):
+    """
+    Check a struct type is as expected, but not taking order into account.
+    """
+    assert pa.types.is_struct(ty)
+    assert set(ty) == set(expected)
+
+
 def test_iterable_types():
     arr1 = pa.array(StrangeIterable([0, 1, 2, 3]))
     arr2 = pa.array((0, 1, 2, 3))
@@ -479,11 +487,28 @@ def test_sequence_timestamp_from_int_with_unit():
         pa.array([1, CustomClass()], type=pa.date64())
 
 
-def test_sequence_mixed_nesting_levels():
-    pa.array([1, 2, None])
-    pa.array([[1], [2], None])
-    pa.array([[1], [2], [None]])
+def test_sequence_nesting_levels():
+    data = [1, 2, None]
+    arr = pa.array(data)
+    assert arr.type == pa.int64()
+    assert arr.to_pylist() == data
+
+    data = [[1], [2], None]
+    arr = pa.array(data)
+    assert arr.type == pa.list_(pa.int64())
+    assert arr.to_pylist() == data
 
+    data = [[1], [2, 3, 4], [None]]
+    arr = pa.array(data)
+    assert arr.type == pa.list_(pa.int64())
+    assert arr.to_pylist() == data
+
+    data = [None, [[None, 1]], [[2, 3, 4], None], [None]]
+    arr = pa.array(data)
+    assert arr.type == pa.list_(pa.list_(pa.int64()))
+    assert arr.to_pylist() == data
+
+    # Mixed nesting levels are rejected
     with pytest.raises(pa.ArrowInvalid):
         pa.array([1, 2, [1]])
 
@@ -649,6 +674,50 @@ def test_struct_from_mixed_sequence():
         pa.array(data, type=ty)
 
 
+def test_struct_from_dicts_inference():
+    expected_type = pa.struct([pa.field('a', pa.int64()),
+                               pa.field('b', pa.string()),
+                               pa.field('c', pa.bool_())])
+    data = [{'a': 5, 'b': u'foo', 'c': True},
+            {'a': 6, 'b': u'bar', 'c': False}]
+    arr = pa.array(data)
+    check_struct_type(arr.type, expected_type)
+    assert arr.to_pylist() == data
+
+    # With omitted values
+    data = [{'a': 5, 'c': True},
+            None,
+            {},
+            {'a': None, 'b': u'bar'}]
+    expected = [{'a': 5, 'b': None, 'c': True},
+                None,
+                {'a': None, 'b': None, 'c': None},
+                {'a': None, 'b': u'bar', 'c': None}]
+    arr = pa.array(data)
+    check_struct_type(arr.type, expected_type)
+    assert arr.to_pylist() == expected
+
+    # Nested
+    expected_type = pa.struct([
+        pa.field('a', pa.struct([pa.field('aa', pa.list_(pa.int64())),
+                                 pa.field('ab', pa.bool_())])),
+        pa.field('b', pa.string())])
+    data = [{'a': {'aa': [5, 6], 'ab': True}, 'b': 'foo'},
+            {'a': {'aa': None, 'ab': False}, 'b': None},
+            {'a': None, 'b': 'bar'}]
+    arr = pa.array(data)
+    assert arr.to_pylist() == data
+
+    # Edge cases
+    arr = pa.array([{}])
+    assert arr.type == pa.struct([])
+    assert arr.to_pylist() == [{}]
+
+    # Mixing structs and scalars is rejected
+    with pytest.raises(pa.ArrowInvalid):
+        pa.array([1, {'a': 2}])
+
+
 def test_structarray_from_arrays_coerce():
     # ARROW-1706
     ints = [None, 2, 3]
diff --git a/python/pyarrow/tests/test_types.py b/python/pyarrow/tests/test_types.py
index 9d0f383..bb2a986 100644
--- a/python/pyarrow/tests/test_types.py
+++ b/python/pyarrow/tests/test_types.py
@@ -199,8 +199,9 @@ def test_types_hashable():
     for i, type_ in enumerate(MANY_TYPES):
         assert hash(type_) == hash(type_)
         in_dict[type_] = i
-        assert in_dict[type_] == i
     assert len(in_dict) == len(MANY_TYPES)
+    for i, type_ in enumerate(MANY_TYPES):
+        assert in_dict[type_] == i
 
 
 def test_types_picklable():
@@ -209,6 +210,18 @@ def test_types_picklable():
         assert pickle.loads(data) == ty
 
 
+def test_fields_hashable():
+    in_dict = {}
+    fields = [pa.field('a', pa.int64()),
+              pa.field('a', pa.int32()),
+              pa.field('b', pa.int32())]
+    for i, field in enumerate(fields):
+        in_dict[field] = i
+    assert len(in_dict) == len(fields)
+    for i, field in enumerate(fields):
+        assert in_dict[field] == i
+
+
 @pytest.mark.parametrize('t,check_func', [
     (pa.date32(), types.is_date32),
     (pa.date64(), types.is_date64),
diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi
index 4ed1443..850be23 100644
--- a/python/pyarrow/types.pxi
+++ b/python/pyarrow/types.pxi
@@ -414,6 +414,9 @@ cdef class Field:
     def __repr__(self):
         return self.__str__()
 
+    def __hash__(self):
+        return hash((self.field.name(), self.type.id))
+
     property nullable:
 
         def __get__(self):

-- 
To stop receiving notification emails like this one, please contact
uwe@apache.org.