You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by uw...@apache.org on 2019/04/22 20:17:21 UTC

[arrow] branch master updated: ARROW-4824: [Python] Fix error checking in read_csv()

This is an automated email from the ASF dual-hosted git repository.

uwe pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new e441291  ARROW-4824: [Python] Fix error checking in read_csv()
e441291 is described below

commit e441291468a0d2082d72bd52406391ccc2172f3c
Author: Antoine Pitrou <an...@python.org>
AuthorDate: Mon Apr 22 22:17:08 2019 +0200

    ARROW-4824: [Python] Fix error checking in read_csv()
    
    Raise a TypeError when a text file is given to read_csv(), as only binary files are allowed.
    
    Also fix a systemic issue where Python exceptions could be swallowed by C++ destructors that could call back into Python and clear the current error status.  The solution is to define a "safe" facility to call into Python code from C++ without clobbering the current error status.
    
    Author: Antoine Pitrou <an...@python.org>
    
    Closes #4183 from pitrou/ARROW-4824-read-csv-error-message and squashes the following commits:
    
    ed3f931f <Antoine Pitrou> ARROW-4824:  Fix error checking in read_csv()
---
 cpp/src/arrow/python/common.h    |  19 +++++++
 cpp/src/arrow/python/flight.cc   |  84 +++++++++++++++++-------------
 cpp/src/arrow/python/io.cc       | 108 ++++++++++++++++++++++-----------------
 python/pyarrow/io.pxi            |   6 ++-
 python/pyarrow/tests/test_csv.py |  11 ++++
 python/pyarrow/tests/test_io.py  |   6 +++
 6 files changed, 149 insertions(+), 85 deletions(-)

diff --git a/cpp/src/arrow/python/common.h b/cpp/src/arrow/python/common.h
index 6e41bed..27661a9 100644
--- a/cpp/src/arrow/python/common.h
+++ b/cpp/src/arrow/python/common.h
@@ -54,6 +54,7 @@ ARROW_PYTHON_EXPORT Status PassPyError();
 
 #define PY_RETURN_IF_ERROR(CODE) ARROW_RETURN_NOT_OK(CheckPyError(CODE));
 
+// A RAII-style helper that ensures the GIL is acquired inside a lexical block.
 class ARROW_PYTHON_EXPORT PyAcquireGIL {
  public:
   PyAcquireGIL() : acquired_gil_(false) { acquire(); }
@@ -81,6 +82,24 @@ class ARROW_PYTHON_EXPORT PyAcquireGIL {
   ARROW_DISALLOW_COPY_AND_ASSIGN(PyAcquireGIL);
 };
 
+// A helper to call safely into the Python interpreter from arbitrary C++ code.
+// The GIL is acquired, and the current thread's error status is preserved.
+template <typename Function>
+Status SafeCallIntoPython(Function&& func) {
+  PyAcquireGIL lock;
+  PyObject* exc_type;
+  PyObject* exc_value;
+  PyObject* exc_traceback;
+  PyErr_Fetch(&exc_type, &exc_value, &exc_traceback);
+  Status st = std::forward<Function>(func)();
+  // If the return Status is a "Python error", the current Python error status
+  // describes the error and shouldn't be clobbered.
+  if (!st.IsPythonError() && exc_type != NULLPTR) {
+    PyErr_Restore(exc_type, exc_value, exc_traceback);
+  }
+  return st;
+}
+
 #define PYARROW_IS_PY2 PY_MAJOR_VERSION <= 2
 
 // A RAII primitive that DECREFs the underlying PyObject* when it
diff --git a/cpp/src/arrow/python/flight.cc b/cpp/src/arrow/python/flight.cc
index 33a063a..861aa8d 100644
--- a/cpp/src/arrow/python/flight.cc
+++ b/cpp/src/arrow/python/flight.cc
@@ -35,16 +35,18 @@ PyServerAuthHandler::PyServerAuthHandler(PyObject* handler,
 
 Status PyServerAuthHandler::Authenticate(arrow::flight::ServerAuthSender* outgoing,
                                          arrow::flight::ServerAuthReader* incoming) {
-  PyAcquireGIL lock;
-  vtable_.authenticate(handler_.obj(), outgoing, incoming);
-  return CheckPyError();
+  return SafeCallIntoPython([=] {
+    vtable_.authenticate(handler_.obj(), outgoing, incoming);
+    return CheckPyError();
+  });
 }
 
 Status PyServerAuthHandler::IsValid(const std::string& token,
                                     std::string* peer_identity) {
-  PyAcquireGIL lock;
-  vtable_.is_valid(handler_.obj(), token, peer_identity);
-  return CheckPyError();
+  return SafeCallIntoPython([=] {
+    vtable_.is_valid(handler_.obj(), token, peer_identity);
+    return CheckPyError();
+  });
 }
 
 PyClientAuthHandler::PyClientAuthHandler(PyObject* handler,
@@ -56,15 +58,17 @@ PyClientAuthHandler::PyClientAuthHandler(PyObject* handler,
 
 Status PyClientAuthHandler::Authenticate(arrow::flight::ClientAuthSender* outgoing,
                                          arrow::flight::ClientAuthReader* incoming) {
-  PyAcquireGIL lock;
-  vtable_.authenticate(handler_.obj(), outgoing, incoming);
-  return CheckPyError();
+  return SafeCallIntoPython([=] {
+    vtable_.authenticate(handler_.obj(), outgoing, incoming);
+    return CheckPyError();
+  });
 }
 
 Status PyClientAuthHandler::GetToken(std::string* token) {
-  PyAcquireGIL lock;
-  vtable_.get_token(handler_.obj(), token);
-  return CheckPyError();
+  return SafeCallIntoPython([=] {
+    vtable_.get_token(handler_.obj(), token);
+    return CheckPyError();
+  });
 }
 
 PyFlightServer::PyFlightServer(PyObject* server, PyFlightServerVtable vtable)
@@ -77,47 +81,53 @@ Status PyFlightServer::ListFlights(
     const arrow::flight::ServerCallContext& context,
     const arrow::flight::Criteria* criteria,
     std::unique_ptr<arrow::flight::FlightListing>* listings) {
-  PyAcquireGIL lock;
-  vtable_.list_flights(server_.obj(), context, criteria, listings);
-  return CheckPyError();
+  return SafeCallIntoPython([&] {
+    vtable_.list_flights(server_.obj(), context, criteria, listings);
+    return CheckPyError();
+  });
 }
 
 Status PyFlightServer::GetFlightInfo(const arrow::flight::ServerCallContext& context,
                                      const arrow::flight::FlightDescriptor& request,
                                      std::unique_ptr<arrow::flight::FlightInfo>* info) {
-  PyAcquireGIL lock;
-  vtable_.get_flight_info(server_.obj(), context, request, info);
-  return CheckPyError();
+  return SafeCallIntoPython([&] {
+    vtable_.get_flight_info(server_.obj(), context, request, info);
+    return CheckPyError();
+  });
 }
 
 Status PyFlightServer::DoGet(const arrow::flight::ServerCallContext& context,
                              const arrow::flight::Ticket& request,
                              std::unique_ptr<arrow::flight::FlightDataStream>* stream) {
-  PyAcquireGIL lock;
-  vtable_.do_get(server_.obj(), context, request, stream);
-  return CheckPyError();
+  return SafeCallIntoPython([&] {
+    vtable_.do_get(server_.obj(), context, request, stream);
+    return CheckPyError();
+  });
 }
 
 Status PyFlightServer::DoPut(const arrow::flight::ServerCallContext& context,
                              std::unique_ptr<arrow::flight::FlightMessageReader> reader) {
-  PyAcquireGIL lock;
-  vtable_.do_put(server_.obj(), context, std::move(reader));
-  return CheckPyError();
+  return SafeCallIntoPython([&] {
+    vtable_.do_put(server_.obj(), context, std::move(reader));
+    return CheckPyError();
+  });
 }
 
 Status PyFlightServer::DoAction(const arrow::flight::ServerCallContext& context,
                                 const arrow::flight::Action& action,
                                 std::unique_ptr<arrow::flight::ResultStream>* result) {
-  PyAcquireGIL lock;
-  vtable_.do_action(server_.obj(), context, action, result);
-  return CheckPyError();
+  return SafeCallIntoPython([&] {
+    vtable_.do_action(server_.obj(), context, action, result);
+    return CheckPyError();
+  });
 }
 
 Status PyFlightServer::ListActions(const arrow::flight::ServerCallContext& context,
                                    std::vector<arrow::flight::ActionType>* actions) {
-  PyAcquireGIL lock;
-  vtable_.list_actions(server_.obj(), context, actions);
-  return CheckPyError();
+  return SafeCallIntoPython([&] {
+    vtable_.list_actions(server_.obj(), context, actions);
+    return CheckPyError();
+  });
 }
 
 Status PyFlightServer::ServeWithSignals() {
@@ -159,9 +169,10 @@ PyFlightResultStream::PyFlightResultStream(PyObject* generator,
 }
 
 Status PyFlightResultStream::Next(std::unique_ptr<arrow::flight::Result>* result) {
-  PyAcquireGIL lock;
-  callback_(generator_.obj(), result);
-  return CheckPyError();
+  return SafeCallIntoPython([=] {
+    callback_(generator_.obj(), result);
+    return CheckPyError();
+  });
 }
 
 PyFlightDataStream::PyFlightDataStream(
@@ -188,9 +199,10 @@ PyGeneratorFlightDataStream::PyGeneratorFlightDataStream(
 std::shared_ptr<arrow::Schema> PyGeneratorFlightDataStream::schema() { return schema_; }
 
 Status PyGeneratorFlightDataStream::Next(arrow::flight::FlightPayload* payload) {
-  PyAcquireGIL lock;
-  callback_(generator_.obj(), payload);
-  return CheckPyError();
+  return SafeCallIntoPython([=] {
+    callback_(generator_.obj(), payload);
+    return CheckPyError();
+  });
 }
 
 Status CreateFlightInfo(const std::shared_ptr<arrow::Schema>& schema,
diff --git a/cpp/src/arrow/python/io.cc b/cpp/src/arrow/python/io.cc
index 2e8ebfe..fd16f67 100644
--- a/cpp/src/arrow/python/io.cc
+++ b/cpp/src/arrow/python/io.cc
@@ -136,80 +136,88 @@ PyReadableFile::PyReadableFile(PyObject* file) { file_.reset(new PythonFile(file
 PyReadableFile::~PyReadableFile() {}
 
 Status PyReadableFile::Close() {
-  PyAcquireGIL lock;
-  return file_->Close();
+  return SafeCallIntoPython([this]() { return file_->Close(); });
 }
 
 bool PyReadableFile::closed() const {
-  PyAcquireGIL lock;
-  return file_->closed();
+  bool res;
+  Status st = SafeCallIntoPython([this, &res]() {
+    res = file_->closed();
+    return Status::OK();
+  });
+  return res;
 }
 
 Status PyReadableFile::Seek(int64_t position) {
-  PyAcquireGIL lock;
-  return file_->Seek(position, 0);
+  return SafeCallIntoPython([=] { return file_->Seek(position, 0); });
 }
 
 Status PyReadableFile::Tell(int64_t* position) const {
-  PyAcquireGIL lock;
-  return file_->Tell(position);
+  return SafeCallIntoPython([=]() { return file_->Tell(position); });
 }
 
 Status PyReadableFile::Read(int64_t nbytes, int64_t* bytes_read, void* out) {
-  PyAcquireGIL lock;
-
-  PyObject* bytes_obj = NULL;
-  RETURN_NOT_OK(file_->Read(nbytes, &bytes_obj));
-  DCHECK(bytes_obj != NULL);
-
-  *bytes_read = PyBytes_GET_SIZE(bytes_obj);
-  std::memcpy(out, PyBytes_AS_STRING(bytes_obj), *bytes_read);
-  Py_XDECREF(bytes_obj);
+  return SafeCallIntoPython([=]() {
+    OwnedRef bytes;
+    RETURN_NOT_OK(file_->Read(nbytes, bytes.ref()));
+    PyObject* bytes_obj = bytes.obj();
+    DCHECK(bytes_obj != NULL);
+
+    if (!PyBytes_Check(bytes_obj)) {
+      return Status::TypeError(
+          "Python file read() should have returned a bytes object, got '",
+          Py_TYPE(bytes_obj)->tp_name, "' (did you open the file in binary mode?)");
+    }
 
-  return Status::OK();
+    *bytes_read = PyBytes_GET_SIZE(bytes_obj);
+    std::memcpy(out, PyBytes_AS_STRING(bytes_obj), *bytes_read);
+    return Status::OK();
+  });
 }
 
 Status PyReadableFile::Read(int64_t nbytes, std::shared_ptr<Buffer>* out) {
-  PyAcquireGIL lock;
-
-  OwnedRef bytes_obj;
-  RETURN_NOT_OK(file_->Read(nbytes, bytes_obj.ref()));
-  DCHECK(bytes_obj.obj() != NULL);
+  return SafeCallIntoPython([=]() {
+    OwnedRef bytes_obj;
+    RETURN_NOT_OK(file_->Read(nbytes, bytes_obj.ref()));
+    DCHECK(bytes_obj.obj() != NULL);
 
-  return PyBuffer::FromPyObject(bytes_obj.obj(), out);
+    return PyBuffer::FromPyObject(bytes_obj.obj(), out);
+  });
 }
 
 Status PyReadableFile::ReadAt(int64_t position, int64_t nbytes, int64_t* bytes_read,
                               void* out) {
   std::lock_guard<std::mutex> guard(file_->lock());
-  RETURN_NOT_OK(Seek(position));
-  return Read(nbytes, bytes_read, out);
+  return SafeCallIntoPython([=]() {
+    RETURN_NOT_OK(Seek(position));
+    return Read(nbytes, bytes_read, out);
+  });
 }
 
 Status PyReadableFile::ReadAt(int64_t position, int64_t nbytes,
                               std::shared_ptr<Buffer>* out) {
   std::lock_guard<std::mutex> guard(file_->lock());
-  RETURN_NOT_OK(Seek(position));
-  return Read(nbytes, out);
+  return SafeCallIntoPython([=]() {
+    RETURN_NOT_OK(Seek(position));
+    return Read(nbytes, out);
+  });
 }
 
 Status PyReadableFile::GetSize(int64_t* size) {
-  PyAcquireGIL lock;
+  return SafeCallIntoPython([=]() {
+    int64_t current_position = -1;
 
-  int64_t current_position = -1;
+    RETURN_NOT_OK(file_->Tell(&current_position));
+    RETURN_NOT_OK(file_->Seek(0, 2));
 
-  RETURN_NOT_OK(file_->Tell(&current_position));
+    int64_t file_size = -1;
+    RETURN_NOT_OK(file_->Tell(&file_size));
+    // Restore previous file position
+    RETURN_NOT_OK(file_->Seek(current_position, 0));
 
-  RETURN_NOT_OK(file_->Seek(0, 2));
-
-  int64_t file_size = -1;
-  RETURN_NOT_OK(file_->Tell(&file_size));
-
-  // Restore previous file position
-  RETURN_NOT_OK(file_->Seek(current_position, 0));
-
-  *size = file_size;
-  return Status::OK();
+    *size = file_size;
+    return Status::OK();
+  });
 }
 
 // ----------------------------------------------------------------------
@@ -222,13 +230,16 @@ PyOutputStream::PyOutputStream(PyObject* file) : position_(0) {
 PyOutputStream::~PyOutputStream() {}
 
 Status PyOutputStream::Close() {
-  PyAcquireGIL lock;
-  return file_->Close();
+  return SafeCallIntoPython([=]() { return file_->Close(); });
 }
 
 bool PyOutputStream::closed() const {
-  PyAcquireGIL lock;
-  return file_->closed();
+  bool res;
+  Status st = SafeCallIntoPython([this, &res]() {
+    res = file_->closed();
+    return Status::OK();
+  });
+  return res;
 }
 
 Status PyOutputStream::Tell(int64_t* position) const {
@@ -237,9 +248,10 @@ Status PyOutputStream::Tell(int64_t* position) const {
 }
 
 Status PyOutputStream::Write(const void* data, int64_t nbytes) {
-  PyAcquireGIL lock;
-  position_ += nbytes;
-  return file_->Write(data, nbytes);
+  return SafeCallIntoPython([=]() {
+    position_ += nbytes;
+    return file_->Write(data, nbytes);
+  });
 }
 
 // ----------------------------------------------------------------------
diff --git a/python/pyarrow/io.pxi b/python/pyarrow/io.pxi
index 8edffbe..f542b24 100644
--- a/python/pyarrow/io.pxi
+++ b/python/pyarrow/io.pxi
@@ -26,7 +26,7 @@ import sys
 import threading
 import time
 import warnings
-from io import BufferedIOBase, IOBase, UnsupportedOperation
+from io import BufferedIOBase, IOBase, TextIOBase, UnsupportedOperation
 
 from pyarrow.util import _stringify_path
 from pyarrow.compat import (
@@ -627,6 +627,10 @@ cdef class PythonFile(NativeFile):
                         raise TypeError("writable file expected")
             # (other duck-typed file-like objects are possible)
 
+        # If possible, check the file is a binary file
+        if isinstance(handle, TextIOBase):
+            raise TypeError("binary file expected, got text file")
+
         if kind == 'r':
             self.set_random_access_file(
                 shared_ptr[RandomAccessFile](new PyReadableFile(handle)))
diff --git a/python/pyarrow/tests/test_csv.py b/python/pyarrow/tests/test_csv.py
index 6beebf4..40a059f 100644
--- a/python/pyarrow/tests/test_csv.py
+++ b/python/pyarrow/tests/test_csv.py
@@ -184,6 +184,17 @@ class BaseTestCSVRead:
         assert table.num_columns == len(names)
         assert [c.name for c in table.columns] == names
 
+    def test_file_object(self):
+        data = b"a,b\n1,2\n"
+        expected_data = {'a': [1], 'b': [2]}
+        bio = io.BytesIO(data)
+        table = self.read_csv(bio)
+        assert table.to_pydict() == expected_data
+        # Text files not allowed
+        sio = io.StringIO(data.decode())
+        with pytest.raises(TypeError):
+            self.read_csv(sio)
+
     def test_header(self):
         rows = b"abc,def,gh\n"
         table = self.read_bytes(rows)
diff --git a/python/pyarrow/tests/test_io.py b/python/pyarrow/tests/test_io.py
index 8eb76ed..5f6ff44 100644
--- a/python/pyarrow/tests/test_io.py
+++ b/python/pyarrow/tests/test_io.py
@@ -83,6 +83,9 @@ def test_python_file_write():
     f.close()
     assert f.closed
 
+    with pytest.raises(TypeError, match="binary file expected"):
+        pa.PythonFile(StringIO())
+
 
 def test_python_file_read():
     data = b'some sample data'
@@ -113,6 +116,9 @@ def test_python_file_read():
     f.close()
     assert f.closed
 
+    with pytest.raises(TypeError, match="binary file expected"):
+        pa.PythonFile(StringIO(), mode='r')
+
 
 def test_python_file_readall():
     data = b'some sample data'