You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by we...@apache.org on 2019/03/12 14:15:25 UTC
[arrow] branch master updated: ARROW-4796: [Flight/Python] Keep
underlying Python object alive in FlightServerBase.do_get
This is an automated email from the ASF dual-hosted git repository.
wesm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new f05369e ARROW-4796: [Flight/Python] Keep underlying Python object alive in FlightServerBase.do_get
f05369e is described below
commit f05369e6dc7cb5a36048cebb78d0d69b32a27b6f
Author: David Li <Da...@twosigma.com>
AuthorDate: Tue Mar 12 09:15:09 2019 -0500
ARROW-4796: [Flight/Python] Keep underlying Python object alive in FlightServerBase.do_get
Author: David Li <Da...@twosigma.com>
Closes #3834 from lihalite/arrow-4796 and squashes the following commits:
942b9a708 <David Li> Keep underlying Python object alive in FlightServerBase.do_get
---
cpp/src/arrow/python/flight.cc | 13 +++++
cpp/src/arrow/python/flight.h | 14 +++++
python/pyarrow/_flight.pyx | 7 ++-
python/pyarrow/includes/libarrow_flight.pxd | 8 +++
python/pyarrow/tests/test_flight.py | 80 +++++++++++++++++++++++++++++
5 files changed, 121 insertions(+), 1 deletion(-)
diff --git a/cpp/src/arrow/python/flight.cc b/cpp/src/arrow/python/flight.cc
index a8bae63..ec25d32 100644
--- a/cpp/src/arrow/python/flight.cc
+++ b/cpp/src/arrow/python/flight.cc
@@ -117,6 +117,19 @@ Status PyFlightResultStream::Next(std::unique_ptr<arrow::flight::Result>* result
return CheckPyError();
}
+PyFlightDataStream::PyFlightDataStream(
+ PyObject* data_source, std::unique_ptr<arrow::flight::FlightDataStream> stream)
+ : stream_(std::move(stream)) {
+ Py_INCREF(data_source);
+ data_source_.reset(data_source);
+}
+
+std::shared_ptr<arrow::Schema> PyFlightDataStream::schema() { return stream_->schema(); }
+
+Status PyFlightDataStream::Next(arrow::flight::FlightPayload* payload) {
+ return stream_->Next(payload);
+}
+
Status CreateFlightInfo(const std::shared_ptr<arrow::Schema>& schema,
const arrow::flight::FlightDescriptor& descriptor,
const std::vector<arrow::flight::FlightEndpoint>& endpoints,
diff --git a/cpp/src/arrow/python/flight.h b/cpp/src/arrow/python/flight.h
index effd1a8..128784f 100644
--- a/cpp/src/arrow/python/flight.h
+++ b/cpp/src/arrow/python/flight.h
@@ -92,6 +92,20 @@ class ARROW_PYTHON_EXPORT PyFlightResultStream : public arrow::flight::ResultStr
PyFlightResultStreamCallback callback_;
};
+/// \brief A wrapper around a FlightDataStream that keeps alive a
+/// Python object backing it.
+class ARROW_PYTHON_EXPORT PyFlightDataStream : public arrow::flight::FlightDataStream {
+ public:
+ explicit PyFlightDataStream(PyObject* data_source,
+ std::unique_ptr<arrow::flight::FlightDataStream> stream);
+ std::shared_ptr<arrow::Schema> schema() override;
+ Status Next(arrow::flight::FlightPayload* payload) override;
+
+ private:
+ OwnedRefNoGIL data_source_;
+ std::unique_ptr<arrow::flight::FlightDataStream> stream_;
+};
+
ARROW_PYTHON_EXPORT
Status CreateFlightInfo(const std::shared_ptr<arrow::Schema>& schema,
const arrow::flight::FlightDescriptor& descriptor,
diff --git a/python/pyarrow/_flight.pyx b/python/pyarrow/_flight.pyx
index 532cf54..695513a 100644
--- a/python/pyarrow/_flight.pyx
+++ b/python/pyarrow/_flight.pyx
@@ -491,13 +491,18 @@ cdef void _do_put(void* self,
cdef void _do_get(void* self, CTicket ticket,
unique_ptr[CFlightDataStream]* stream) except *:
"""Callback for implementing Flight servers in Python."""
+ cdef:
+ unique_ptr[CFlightDataStream] data_stream
+
py_ticket = Ticket(ticket.ticket)
result = (<object> self).do_get(py_ticket)
if not isinstance(result, FlightDataStream):
raise TypeError("FlightServerBase.do_get must return "
"a FlightDataStream")
- stream[0] = unique_ptr[CFlightDataStream](
+ data_stream = unique_ptr[CFlightDataStream](
(<FlightDataStream> result).to_stream())
+ stream[0] = unique_ptr[CFlightDataStream](
+ new CPyFlightDataStream(result, move(data_stream)))
cdef void _do_action_result_next(void* self,
diff --git a/python/pyarrow/includes/libarrow_flight.pxd b/python/pyarrow/includes/libarrow_flight.pxd
index 0271f33..153f725 100644
--- a/python/pyarrow/includes/libarrow_flight.pxd
+++ b/python/pyarrow/includes/libarrow_flight.pxd
@@ -156,6 +156,11 @@ cdef extern from "arrow/python/flight.h" namespace "arrow::py::flight" nogil:
CPyFlightResultStream(object generator,
function[cb_result_next] callback)
+ cdef cppclass CPyFlightDataStream\
+ " arrow::py::flight::PyFlightDataStream"(CFlightDataStream):
+ CPyFlightDataStream(object data_source,
+ unique_ptr[CFlightDataStream] stream)
+
cdef CStatus CreateFlightInfo" arrow::py::flight::CreateFlightInfo"(
shared_ptr[CSchema] schema,
CFlightDescriptor& descriptor,
@@ -163,3 +168,6 @@ cdef extern from "arrow/python/flight.h" namespace "arrow::py::flight" nogil:
uint64_t total_records,
uint64_t total_bytes,
unique_ptr[CFlightInfo]* out)
+
+cdef extern from "<utility>" namespace "std":
+ unique_ptr[CFlightDataStream] move(unique_ptr[CFlightDataStream])
diff --git a/python/pyarrow/tests/test_flight.py b/python/pyarrow/tests/test_flight.py
new file mode 100644
index 0000000..d225f77
--- /dev/null
+++ b/python/pyarrow/tests/test_flight.py
@@ -0,0 +1,80 @@
+# -*- coding: utf-8 -*-
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import contextlib
+import socket
+import threading
+
+import pytest
+
+import pyarrow as pa
+
+
+flight = pytest.importorskip("pyarrow.flight")
+
+
+class ConstantFlightServer(flight.FlightServerBase):
+ """A Flight server that always returns the same data.
+
+ See ARROW-4796: this server implementation will segfault if Flight
+ does not properly hold a reference to the Table object.
+ """
+
+ def do_get(self, ticket):
+ data = [
+ pa.array([-10, -5, 0, 5, 10])
+ ]
+ table = pa.Table.from_arrays(data, names=['a'])
+ return flight.RecordBatchStream(table)
+
+
+@contextlib.contextmanager
+def flight_server(server_base, *args, **kwargs):
+ """Spawn a Flight server on a free port, shutting it down when done."""
+ # Find a free port
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ with contextlib.closing(sock) as sock:
+ sock.bind(('', 0))
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ port = sock.getsockname()[1]
+
+ server_instance = server_base(*args, **kwargs)
+
+ def _server_thread():
+ server_instance.run(port)
+
+ thread = threading.Thread(target=_server_thread, daemon=True)
+ thread.start()
+
+ yield port
+
+ server_instance.shutdown()
+ thread.join()
+
+
+def test_flight_do_get():
+ """Try a simple do_get call."""
+ data = [
+ pa.array([-10, -5, 0, 5, 10])
+ ]
+ table = pa.Table.from_arrays(data, names=['a'])
+
+ with flight_server(ConstantFlightServer) as server_port:
+ client = flight.FlightClient.connect('localhost', server_port)
+ data = client.do_get(flight.Ticket(b''), table.schema).read_all()
+ assert data.equals(table)