You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ap...@apache.org on 2022/05/31 16:27:10 UTC

[arrow] branch master updated: ARROW-16597: [Python][FlightRPC] Force server shutdown at interpreter exit

This is an automated email from the ASF dual-hosted git repository.

apitrou pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new c9f7e28286 ARROW-16597: [Python][FlightRPC] Force server shutdown at interpreter exit
c9f7e28286 is described below

commit c9f7e28286cecf501bd0375535d9ad035145a0ff
Author: David Li <li...@gmail.com>
AuthorDate: Tue May 31 18:27:03 2022 +0200

    ARROW-16597: [Python][FlightRPC] Force server shutdown at interpreter exit
    
    Closes #13176 from lidavidm/arrow-16597
    
    Authored-by: David Li <li...@gmail.com>
    Signed-off-by: Antoine Pitrou <an...@python.org>
---
 python/pyarrow/_flight.pyx          | 39 ++++++++++++++++++++++++++++++++++---
 python/pyarrow/tests/arrow_16597.py | 37 +++++++++++++++++++++++++++++++++++
 python/pyarrow/tests/test_flight.py | 14 +++++++++++++
 3 files changed, 87 insertions(+), 3 deletions(-)

diff --git a/python/pyarrow/_flight.pyx b/python/pyarrow/_flight.pyx
index bd2f2e35a9..b2a3b2dadb 100644
--- a/python/pyarrow/_flight.pyx
+++ b/python/pyarrow/_flight.pyx
@@ -25,6 +25,7 @@ import socket
 import time
 import threading
 import warnings
+import weakref
 
 from cython.operator cimport dereference as deref
 from cython.operator cimport postincrement
@@ -2516,6 +2517,33 @@ cdef class _ServerMiddlewareWrapper(ServerMiddleware):
             instance.call_completed(exception)
 
 
+cdef class _FlightServerFinalizer(_Weakrefable):
+    """
+    A finalizer that shuts down the server on destruction.
+
+    See ARROW-16597. If the server is still active at interpreter
+    exit, the process may segfault.
+    """
+
+    cdef:
+        shared_ptr[PyFlightServer] server
+
+    def finalize(self):
+        cdef:
+            PyFlightServer* server = self.server.get()
+            CStatus status
+        if server == NULL:
+            return
+        try:
+            with nogil:
+                status = server.Shutdown()
+                if status.ok():
+                    status = server.Wait()
+            check_flight_status(status)
+        finally:
+            self.server.reset()
+
+
 cdef class FlightServerBase(_Weakrefable):
     """A Flight service definition.
 
@@ -2550,11 +2578,13 @@ cdef class FlightServerBase(_Weakrefable):
     """
 
     cdef:
-        unique_ptr[PyFlightServer] server
+        shared_ptr[PyFlightServer] server
+        object finalizer
 
     def __init__(self, location=None, auth_handler=None,
                  tls_certificates=None, verify_client=None,
                  root_certificates=None, middleware=None):
+        self.finalizer = None
         if isinstance(location, (bytes, str)):
             location = Location(location)
         elif isinstance(location, (tuple, type(None))):
@@ -2622,6 +2652,9 @@ cdef class FlightServerBase(_Weakrefable):
         self.server.reset(c_server)
         with nogil:
             check_flight_status(c_server.Init(deref(c_options)))
+        cdef _FlightServerFinalizer finalizer = _FlightServerFinalizer()
+        finalizer.server = self.server
+        self.finalizer = weakref.finalize(self, finalizer.finalize)
 
     @property
     def port(self):
@@ -2843,8 +2876,8 @@ cdef class FlightServerBase(_Weakrefable):
         return self
 
     def __exit__(self, exc_type, exc_value, traceback):
-        self.shutdown()
-        self.wait()
+        if self.finalizer:
+            self.finalizer()
 
 
 def connect(location, **kwargs):
diff --git a/python/pyarrow/tests/arrow_16597.py b/python/pyarrow/tests/arrow_16597.py
new file mode 100644
index 0000000000..7ab9d6cc94
--- /dev/null
+++ b/python/pyarrow/tests/arrow_16597.py
@@ -0,0 +1,37 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# This file is called from a test in test_flight.py.
+import time
+
+import pyarrow as pa
+import pyarrow.flight as flight
+
+
+class Server(flight.FlightServerBase):
+    def do_put(self, context, descriptor, reader, writer):
+        time.sleep(1)
+        raise flight.FlightCancelledError("")
+
+
+if __name__ == "__main__":
+    server = Server("grpc://localhost:0")
+    client = flight.connect(f"grpc://localhost:{server.port}")
+    schema = pa.schema([])
+    writer, reader = client.do_put(
+        flight.FlightDescriptor.for_command(b""), schema)
+    writer.done_writing()
diff --git a/python/pyarrow/tests/test_flight.py b/python/pyarrow/tests/test_flight.py
index 1805453e8a..afc9ce17e8 100644
--- a/python/pyarrow/tests/test_flight.py
+++ b/python/pyarrow/tests/test_flight.py
@@ -1994,6 +1994,11 @@ def test_interrupt():
         descriptor = flight.FlightDescriptor.for_command(b"echo")
         writer, reader = client.do_exchange(descriptor)
         test(reader.read_all)
+        try:
+            writer.close()
+        except (KeyboardInterrupt, flight.FlightCancelledError):
+            # Silence the Cancelled/Interrupt exception
+            pass
 
 
 def test_never_sends_data():
@@ -2157,3 +2162,12 @@ def test_write_error_propagation():
             writer.close()
         assert exc_info.value.extra_info == expected_info
         thread.join()
+
+
+def test_interpreter_shutdown():
+    """
+    Ensure that the gRPC server is stopped at interpreter shutdown.
+
+    See https://issues.apache.org/jira/browse/ARROW-16597.
+    """
+    util.invoke_script("arrow_16597.py")