You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ap...@apache.org on 2020/04/02 11:04:49 UTC

[arrow] branch master updated: ARROW-8304: [Flight][Python] Fix client example with TLS

This is an automated email from the ASF dual-hosted git repository.

apitrou pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 5ab4930  ARROW-8304: [Flight][Python] Fix client example with TLS
5ab4930 is described below

commit 5ab4930bb2642e8798c30ad02c315fa6330f18f0
Author: Antoine Pitrou <an...@python.org>
AuthorDate: Thu Apr 2 13:04:28 2020 +0200

    ARROW-8304: [Flight][Python] Fix client example with TLS
    
    The `get` command wouldn't use the adequate TLS root certs when fetching a Flight from its endpoints.
    
    Also fix style in the Python examples and configure `archery lint` to check them.
    
    Closes #6808 from pitrou/ARROW-8304-py-flight-client-tls
    
    Lead-authored-by: Antoine Pitrou <an...@python.org>
    Co-authored-by: Ravindra Wagh <ra...@cambridgesemantics.com>
    Signed-off-by: Antoine Pitrou <an...@python.org>
---
 dev/archery/archery/utils/lint.py |  5 +++--
 python/examples/flight/client.py  | 23 ++++++++++++-----------
 python/examples/flight/server.py  | 19 ++++++++++++-------
 3 files changed, 27 insertions(+), 20 deletions(-)

diff --git a/dev/archery/archery/utils/lint.py b/dev/archery/archery/utils/lint.py
index 4890b6e..03219e9 100644
--- a/dev/archery/archery/utils/lint.py
+++ b/dev/archery/archery/utils/lint.py
@@ -116,8 +116,9 @@ def python_linter(src):
         return
 
     setup_py = os.path.join(src.python, "setup.py")
-    yield LintResult.from_cmd(flake8(setup_py, src.pyarrow, src.dev,
-                                     check=False))
+    yield LintResult.from_cmd(flake8(setup_py, src.pyarrow,
+                                     os.path.join(src.python, "examples"),
+                                     src.dev, check=False))
     config = os.path.join(src.python, ".flake8.cython")
     yield LintResult.from_cmd(flake8("--config=" + config, src.pyarrow,
                                      check=False))
diff --git a/python/examples/flight/client.py b/python/examples/flight/client.py
index 8dd3efd..958017a 100644
--- a/python/examples/flight/client.py
+++ b/python/examples/flight/client.py
@@ -25,7 +25,7 @@ import pyarrow.flight
 import pyarrow.csv as csv
 
 
-def list_flights(args, client):
+def list_flights(args, client, connection_args={}):
     print('Flights\n=======')
     for flight in client.list_flights():
         descriptor = flight.descriptor
@@ -60,7 +60,7 @@ def list_flights(args, client):
         print('---')
 
 
-def do_action(args, client):
+def do_action(args, client, connection_args={}):
     try:
         buf = pyarrow.allocate_buffer(0)
         action = pyarrow.flight.Action(args.action_type, buf)
@@ -71,19 +71,19 @@ def do_action(args, client):
         print("Error calling action:", e)
 
 
-def push_data(args, client):
+def push_data(args, client, connection_args={}):
     print('File Name:', args.file)
     my_table = csv.read_csv(args.file)
-    print ('Table rows=', str(len(my_table)))
+    print('Table rows=', str(len(my_table)))
     df = my_table.to_pandas()
     print(df.head())
     writer, _ = client.do_put(
-        pyarrow.flight.FlightDescriptor.for_path(args.file), my_table.schema)    
+        pyarrow.flight.FlightDescriptor.for_path(args.file), my_table.schema)
     writer.write_table(my_table)
     writer.close()
 
 
-def get_flight(args, client):
+def get_flight(args, client, connection_args={}):
     if args.path:
         descriptor = pyarrow.flight.FlightDescriptor.for_path(*args.path)
     else:
@@ -94,7 +94,8 @@ def get_flight(args, client):
         print('Ticket:', endpoint.ticket)
         for location in endpoint.locations:
             print(location)
-            get_client = pyarrow.flight.FlightClient(location)
+            get_client = pyarrow.flight.FlightClient(location,
+                                                     **connection_args)
             reader = get_client.do_get(endpoint.ticket)
             df = reader.read_pandas()
             print(df)
@@ -129,8 +130,8 @@ def main():
     cmd_put.set_defaults(action='put')
     _add_common_arguments(cmd_put)
     cmd_put.add_argument('file', type=str,
-                        help="CSV file to upload.")
-                        
+                         help="CSV file to upload.")
+
     cmd_get = subcommands.add_parser('get')
     cmd_get.set_defaults(action='get')
     _add_common_arguments(cmd_get)
@@ -161,7 +162,7 @@ def main():
             with open(args.tls_roots, "rb") as root_certs:
                 connection_args["tls_root_certs"] = root_certs.read()
     client = pyarrow.flight.FlightClient(f"{scheme}://{host}:{port}",
-                                                 **connection_args)
+                                         **connection_args)
     while True:
         try:
             action = pyarrow.flight.Action("healthcheck", b"")
@@ -171,7 +172,7 @@ def main():
         except pyarrow.ArrowIOError as e:
             if "Deadline" in str(e):
                 print("Server is not ready, waiting...")
-    commands[args.action](args, client)
+    commands[args.action](args, client, connection_args)
 
 
 if __name__ == '__main__':
diff --git a/python/examples/flight/server.py b/python/examples/flight/server.py
index 93f0695..91a745c 100644
--- a/python/examples/flight/server.py
+++ b/python/examples/flight/server.py
@@ -27,8 +27,10 @@ import pyarrow.flight
 
 
 class FlightServer(pyarrow.flight.FlightServerBase):
-    def __init__(self, host="localhost", location=None, tls_certificates=None, auth_handler=None):
-        super(FlightServer, self).__init__(location, auth_handler, tls_certificates)
+    def __init__(self, host="localhost", location=None,
+                 tls_certificates=None, auth_handler=None):
+        super(FlightServer, self).__init__(
+            location, auth_handler, tls_certificates)
         self.flights = {}
         self.host = host
         self.tls_certificates = tls_certificates
@@ -40,13 +42,16 @@ class FlightServer(pyarrow.flight.FlightServerBase):
 
     def _make_flight_info(self, key, descriptor, table):
         if self.tls_certificates:
-            location = pyarrow.flight.Location.for_grpc_tls(self.host, self.port)
+            location = pyarrow.flight.Location.for_grpc_tls(
+                self.host, self.port)
         else:
-            location = pyarrow.flight.Location.for_grpc_tcp(self.host, self.port)
-        endpoints = [pyarrow.flight.FlightEndpoint(repr(key), [location]),]
+            location = pyarrow.flight.Location.for_grpc_tcp(
+                self.host, self.port)
+        endpoints = [pyarrow.flight.FlightEndpoint(repr(key), [location]), ]
 
         mock_sink = pyarrow.MockOutputStream()
-        stream_writer = pyarrow.RecordBatchStreamWriter(mock_sink, table.schema)
+        stream_writer = pyarrow.RecordBatchStreamWriter(
+            mock_sink, table.schema)
         stream_writer.write_table(table)
         stream_writer.close()
         data_size = mock_sink.size()
@@ -139,6 +144,6 @@ def main():
     print("Serving on", location)
     server.serve()
 
+
 if __name__ == '__main__':
     main()
-