You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by we...@apache.org on 2019/03/12 14:15:25 UTC

[arrow] branch master updated: ARROW-4796: [Flight/Python] Keep underlying Python object alive in FlightServerBase.do_get

This is an automated email from the ASF dual-hosted git repository.

wesm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new f05369e  ARROW-4796: [Flight/Python] Keep underlying Python object alive in FlightServerBase.do_get
f05369e is described below

commit f05369e6dc7cb5a36048cebb78d0d69b32a27b6f
Author: David Li <Da...@twosigma.com>
AuthorDate: Tue Mar 12 09:15:09 2019 -0500

    ARROW-4796: [Flight/Python] Keep underlying Python object alive in FlightServerBase.do_get
    
    Author: David Li <Da...@twosigma.com>
    
    Closes #3834 from lihalite/arrow-4796 and squashes the following commits:
    
    942b9a708 <David Li> Keep underlying Python object alive in FlightServerBase.do_get
---
 cpp/src/arrow/python/flight.cc              | 13 +++++
 cpp/src/arrow/python/flight.h               | 14 +++++
 python/pyarrow/_flight.pyx                  |  7 ++-
 python/pyarrow/includes/libarrow_flight.pxd |  8 +++
 python/pyarrow/tests/test_flight.py         | 80 +++++++++++++++++++++++++++++
 5 files changed, 121 insertions(+), 1 deletion(-)

diff --git a/cpp/src/arrow/python/flight.cc b/cpp/src/arrow/python/flight.cc
index a8bae63..ec25d32 100644
--- a/cpp/src/arrow/python/flight.cc
+++ b/cpp/src/arrow/python/flight.cc
@@ -117,6 +117,19 @@ Status PyFlightResultStream::Next(std::unique_ptr<arrow::flight::Result>* result
   return CheckPyError();
 }
 
+PyFlightDataStream::PyFlightDataStream(
+    PyObject* data_source, std::unique_ptr<arrow::flight::FlightDataStream> stream)
+    : stream_(std::move(stream)) {
+  Py_INCREF(data_source);
+  data_source_.reset(data_source);
+}
+
+std::shared_ptr<arrow::Schema> PyFlightDataStream::schema() { return stream_->schema(); }
+
+Status PyFlightDataStream::Next(arrow::flight::FlightPayload* payload) {
+  return stream_->Next(payload);
+}
+
 Status CreateFlightInfo(const std::shared_ptr<arrow::Schema>& schema,
                         const arrow::flight::FlightDescriptor& descriptor,
                         const std::vector<arrow::flight::FlightEndpoint>& endpoints,
diff --git a/cpp/src/arrow/python/flight.h b/cpp/src/arrow/python/flight.h
index effd1a8..128784f 100644
--- a/cpp/src/arrow/python/flight.h
+++ b/cpp/src/arrow/python/flight.h
@@ -92,6 +92,20 @@ class ARROW_PYTHON_EXPORT PyFlightResultStream : public arrow::flight::ResultStr
   PyFlightResultStreamCallback callback_;
 };
 
+/// \brief A wrapper around a FlightDataStream that keeps alive a
+/// Python object backing it.
+class ARROW_PYTHON_EXPORT PyFlightDataStream : public arrow::flight::FlightDataStream {
+ public:
+  explicit PyFlightDataStream(PyObject* data_source,
+                              std::unique_ptr<arrow::flight::FlightDataStream> stream);
+  std::shared_ptr<arrow::Schema> schema() override;
+  Status Next(arrow::flight::FlightPayload* payload) override;
+
+ private:
+  OwnedRefNoGIL data_source_;
+  std::unique_ptr<arrow::flight::FlightDataStream> stream_;
+};
+
 ARROW_PYTHON_EXPORT
 Status CreateFlightInfo(const std::shared_ptr<arrow::Schema>& schema,
                         const arrow::flight::FlightDescriptor& descriptor,
diff --git a/python/pyarrow/_flight.pyx b/python/pyarrow/_flight.pyx
index 532cf54..695513a 100644
--- a/python/pyarrow/_flight.pyx
+++ b/python/pyarrow/_flight.pyx
@@ -491,13 +491,18 @@ cdef void _do_put(void* self,
 cdef void _do_get(void* self, CTicket ticket,
                   unique_ptr[CFlightDataStream]* stream) except *:
     """Callback for implementing Flight servers in Python."""
+    cdef:
+        unique_ptr[CFlightDataStream] data_stream
+
     py_ticket = Ticket(ticket.ticket)
     result = (<object> self).do_get(py_ticket)
     if not isinstance(result, FlightDataStream):
         raise TypeError("FlightServerBase.do_get must return "
                         "a FlightDataStream")
-    stream[0] = unique_ptr[CFlightDataStream](
+    data_stream = unique_ptr[CFlightDataStream](
         (<FlightDataStream> result).to_stream())
+    stream[0] = unique_ptr[CFlightDataStream](
+        new CPyFlightDataStream(result, move(data_stream)))
 
 
 cdef void _do_action_result_next(void* self,
diff --git a/python/pyarrow/includes/libarrow_flight.pxd b/python/pyarrow/includes/libarrow_flight.pxd
index 0271f33..153f725 100644
--- a/python/pyarrow/includes/libarrow_flight.pxd
+++ b/python/pyarrow/includes/libarrow_flight.pxd
@@ -156,6 +156,11 @@ cdef extern from "arrow/python/flight.h" namespace "arrow::py::flight" nogil:
         CPyFlightResultStream(object generator,
                               function[cb_result_next] callback)
 
+    cdef cppclass CPyFlightDataStream\
+            " arrow::py::flight::PyFlightDataStream"(CFlightDataStream):
+        CPyFlightDataStream(object data_source,
+                            unique_ptr[CFlightDataStream] stream)
+
     cdef CStatus CreateFlightInfo" arrow::py::flight::CreateFlightInfo"(
         shared_ptr[CSchema] schema,
         CFlightDescriptor& descriptor,
@@ -163,3 +168,6 @@ cdef extern from "arrow/python/flight.h" namespace "arrow::py::flight" nogil:
         uint64_t total_records,
         uint64_t total_bytes,
         unique_ptr[CFlightInfo]* out)
+
+cdef extern from "<utility>" namespace "std":
+    unique_ptr[CFlightDataStream] move(unique_ptr[CFlightDataStream])
diff --git a/python/pyarrow/tests/test_flight.py b/python/pyarrow/tests/test_flight.py
new file mode 100644
index 0000000..d225f77
--- /dev/null
+++ b/python/pyarrow/tests/test_flight.py
@@ -0,0 +1,80 @@
+# -*- coding: utf-8 -*-
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import contextlib
+import socket
+import threading
+
+import pytest
+
+import pyarrow as pa
+
+
+flight = pytest.importorskip("pyarrow.flight")
+
+
+class ConstantFlightServer(flight.FlightServerBase):
+    """A Flight server that always returns the same data.
+
+    See ARROW-4796: this server implementation will segfault if Flight
+    does not properly hold a reference to the Table object.
+    """
+
+    def do_get(self, ticket):
+        data = [
+            pa.array([-10, -5, 0, 5, 10])
+        ]
+        table = pa.Table.from_arrays(data, names=['a'])
+        return flight.RecordBatchStream(table)
+
+
+@contextlib.contextmanager
+def flight_server(server_base, *args, **kwargs):
+    """Spawn a Flight server on a free port, shutting it down when done."""
+    # Find a free port
+    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+    with contextlib.closing(sock) as sock:
+        sock.bind(('', 0))
+        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+        port = sock.getsockname()[1]
+
+    server_instance = server_base(*args, **kwargs)
+
+    def _server_thread():
+        server_instance.run(port)
+
+    thread = threading.Thread(target=_server_thread, daemon=True)
+    thread.start()
+
+    yield port
+
+    server_instance.shutdown()
+    thread.join()
+
+
+def test_flight_do_get():
+    """Try a simple do_get call."""
+    data = [
+        pa.array([-10, -5, 0, 5, 10])
+    ]
+    table = pa.Table.from_arrays(data, names=['a'])
+
+    with flight_server(ConstantFlightServer) as server_port:
+        client = flight.FlightClient.connect('localhost', server_port)
+        data = client.do_get(flight.Ticket(b''), table.schema).read_all()
+        assert data.equals(table)