You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by uw...@apache.org on 2018/01/28 15:47:51 UTC

[arrow] branch master updated: ARROW-1999: [Python] Type checking in `from_numpy_dtype`

This is an automated email from the ASF dual-hosted git repository.

uwe pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new edde5c1  ARROW-1999: [Python] Type checking in `from_numpy_dtype`
edde5c1 is described below

commit edde5c19ad5c441429ae80cce32db80bd52ed364
Author: Jim Crist <ji...@gmail.com>
AuthorDate: Sun Jan 28 16:47:41 2018 +0100

    ARROW-1999: [Python] Type checking in `from_numpy_dtype`
    
    - Adds type checking to the C++ `NumPyDtypeToArrow` and `GetTensorType` to
      ensure `dtype` is actually a dtype object.
    - Add conversion of non-dtype objects in `pa.from_numpy_dtype`.
    - Adds tests to check a wider variety of inputs to
      `pa.from_numpy_dtype`, and ensure proper errors.
    
    Author: Jim Crist <ji...@gmail.com>
    
    Closes #1523 from jcrist/from_numpy_dtype and squashes the following commits:
    
    e9de101 [Jim Crist] Type checking in `from_numpy_dtype`
---
 cpp/src/arrow/python/numpy_convert.cc |  6 ++++++
 python/pyarrow/tests/test_schema.py   | 27 ++++++++++++++++++++++++++-
 python/pyarrow/types.pxi              |  1 +
 3 files changed, 33 insertions(+), 1 deletion(-)

diff --git a/cpp/src/arrow/python/numpy_convert.cc b/cpp/src/arrow/python/numpy_convert.cc
index 9ed2d73..124745e 100644
--- a/cpp/src/arrow/python/numpy_convert.cc
+++ b/cpp/src/arrow/python/numpy_convert.cc
@@ -84,6 +84,9 @@ NumPyBuffer::~NumPyBuffer() { Py_XDECREF(arr_); }
     break;
 
 Status GetTensorType(PyObject* dtype, std::shared_ptr<DataType>* out) {
+  if (!PyArray_DescrCheck(dtype)) {
+    return Status::TypeError("Did not pass numpy.dtype object");
+  }
   PyArray_Descr* descr = reinterpret_cast<PyArray_Descr*>(dtype);
   int type_num = cast_npy_type_compat(descr->type_num);
 
@@ -145,6 +148,9 @@ Status GetNumPyType(const DataType& type, int* type_num) {
 }
 
 Status NumPyDtypeToArrow(PyObject* dtype, std::shared_ptr<DataType>* out) {
+  if (!PyArray_DescrCheck(dtype)) {
+    return Status::TypeError("Did not pass numpy.dtype object");
+  }
   PyArray_Descr* descr = reinterpret_cast<PyArray_Descr*>(dtype);
 
   int type_num = cast_npy_type_compat(descr->type_num);
diff --git a/python/pyarrow/tests/test_schema.py b/python/pyarrow/tests/test_schema.py
index dbca139..90efe3f 100644
--- a/python/pyarrow/tests/test_schema.py
+++ b/python/pyarrow/tests/test_schema.py
@@ -154,8 +154,21 @@ def test_time_types():
         pa.time64('s')
 
 
-def test_type_from_numpy_dtype_timestamps():
+def test_from_numpy_dtype():
     cases = [
+        (np.dtype('bool'), pa.bool_()),
+        (np.dtype('int8'), pa.int8()),
+        (np.dtype('int16'), pa.int16()),
+        (np.dtype('int32'), pa.int32()),
+        (np.dtype('int64'), pa.int64()),
+        (np.dtype('uint8'), pa.uint8()),
+        (np.dtype('uint16'), pa.uint16()),
+        (np.dtype('uint32'), pa.uint32()),
+        (np.dtype('float16'), pa.float16()),
+        (np.dtype('float32'), pa.float32()),
+        (np.dtype('float64'), pa.float64()),
+        (np.dtype('U'), pa.string()),
+        (np.dtype('S'), pa.binary()),
         (np.dtype('datetime64[s]'), pa.timestamp('s')),
         (np.dtype('datetime64[ms]'), pa.timestamp('ms')),
         (np.dtype('datetime64[us]'), pa.timestamp('us')),
@@ -166,6 +179,18 @@ def test_type_from_numpy_dtype_timestamps():
         result = pa.from_numpy_dtype(dt)
         assert result == pt
 
+    # Things convertible to numpy dtypes work
+    assert pa.from_numpy_dtype('U') == pa.string()
+    assert pa.from_numpy_dtype(np.unicode) == pa.string()
+    assert pa.from_numpy_dtype('int32') == pa.int32()
+    assert pa.from_numpy_dtype(bool) == pa.bool_()
+
+    with pytest.raises(NotImplementedError):
+        pa.from_numpy_dtype(np.dtype('O'))
+
+    with pytest.raises(TypeError):
+        pa.from_numpy_dtype('not_convertible_to_dtype')
+
 
 def test_field():
     t = pa.string()
diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi
index 1563b57..a3cbeef 100644
--- a/python/pyarrow/types.pxi
+++ b/python/pyarrow/types.pxi
@@ -1207,6 +1207,7 @@ def from_numpy_dtype(object dtype):
     Convert NumPy dtype to pyarrow.DataType
     """
     cdef shared_ptr[CDataType] c_type
+    dtype = np.dtype(dtype)
     with nogil:
         check_status(NumPyDtypeToArrow(dtype, &c_type))
 

-- 
To stop receiving notification emails like this one, please contact
uwe@apache.org.