You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by tu...@apache.org on 2023/04/28 09:47:23 UTC

[arrow-rs] branch master updated: Better flight SQL example codes (#4144)

This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new b717b3939 Better flight SQL example codes (#4144)
b717b3939 is described below

commit b717b39393367d1de7577078c13b91c59a62d581
Author: sundyli <54...@qq.com>
AuthorDate: Fri Apr 28 02:47:17 2023 -0700

    Better flight SQL example codes (#4144)
    
    * Better flight sql example codes
    
    * Better flight sql example codes
    
    * feat: flight sql server enable tcp no deplay
    
    * Remove unnecessary doc
    
    ---------
    
    Co-authored-by: Raphael Taylor-Davies <r....@googlemail.com>
---
 arrow-flight/examples/flight_sql_server.rs | 196 ++++++++++++++++-------------
 1 file changed, 107 insertions(+), 89 deletions(-)

diff --git a/arrow-flight/examples/flight_sql_server.rs b/arrow-flight/examples/flight_sql_server.rs
index 43154420d..23d71090a 100644
--- a/arrow-flight/examples/flight_sql_server.rs
+++ b/arrow-flight/examples/flight_sql_server.rs
@@ -546,8 +546,7 @@ impl ProstMessageExt for FetchResults {
 #[cfg(test)]
 mod tests {
     use super::*;
-    use futures::future::BoxFuture;
-    use futures::{FutureExt, TryStreamExt};
+    use futures::TryStreamExt;
     use std::fs;
     use std::future::Future;
     use std::net::SocketAddr;
@@ -571,42 +570,6 @@ mod tests {
         (incoming, addr)
     }
 
-    async fn client_with_uds(path: String) -> FlightSqlServiceClient<Channel> {
-        let connector = service_fn(move |_| UnixStream::connect(path.clone()));
-        let channel = Endpoint::try_from("http://example.com")
-            .unwrap()
-            .connect_with_connector(connector)
-            .await
-            .unwrap();
-        FlightSqlServiceClient::new(channel)
-    }
-
-    type ServeFut = BoxFuture<'static, Result<(), tonic::transport::Error>>;
-
-    async fn create_https_server(
-    ) -> Result<(ServeFut, SocketAddr), tonic::transport::Error> {
-        let cert = std::fs::read_to_string("examples/data/server.pem").unwrap();
-        let key = std::fs::read_to_string("examples/data/server.key").unwrap();
-        let client_ca = std::fs::read_to_string("examples/data/client_ca.pem").unwrap();
-
-        let tls_config = ServerTlsConfig::new()
-            .identity(Identity::from_pem(&cert, &key))
-            .client_ca_root(Certificate::from_pem(&client_ca));
-
-        let (incoming, addr) = bind_tcp().await;
-
-        let svc = FlightServiceServer::new(FlightSqlServiceImpl {});
-
-        let serve = Server::builder()
-            .tls_config(tls_config)
-            .unwrap()
-            .add_service(svc)
-            .serve_with_incoming(incoming)
-            .boxed();
-
-        Ok((serve, addr))
-    }
-
     fn endpoint(uri: String) -> Result<Endpoint, ArrowError> {
         let endpoint = Endpoint::new(uri)
             .map_err(|_| ArrowError::IoError("Cannot create endpoint".to_string()))?
@@ -621,56 +584,12 @@ mod tests {
         Ok(endpoint)
     }
 
-    #[tokio::test]
-    async fn test_select_https() {
-        let (serve, addr) = create_https_server().await.unwrap();
-        let uri = format!("https://{}:{}", addr.ip(), addr.port());
-
-        let request_future = async {
-            let cert = std::fs::read_to_string("examples/data/client1.pem").unwrap();
-            let key = std::fs::read_to_string("examples/data/client1.key").unwrap();
-            let server_ca = std::fs::read_to_string("examples/data/ca.pem").unwrap();
-
-            let tls_config = ClientTlsConfig::new()
-                .domain_name("localhost")
-                .ca_certificate(Certificate::from_pem(&server_ca))
-                .identity(Identity::from_pem(cert, key));
-            let endpoint = endpoint(uri).unwrap().tls_config(tls_config).unwrap();
-            let channel = endpoint.connect().await.unwrap();
-            let mut client = FlightSqlServiceClient::new(channel);
-            let token = client.handshake("admin", "password").await.unwrap();
-            client.set_token(String::from_utf8(token.to_vec()).unwrap());
-            println!("Auth succeeded with token: {:?}", token);
-            let mut stmt = client.prepare("select 1;".to_string()).await.unwrap();
-            let flight_info = stmt.execute().await.unwrap();
-            let ticket = flight_info.endpoint[0].ticket.as_ref().unwrap().clone();
-            let flight_data = client.do_get(ticket).await.unwrap();
-            let flight_data: Vec<FlightData> = flight_data.try_collect().await.unwrap();
-            let batches = flight_data_to_batches(&flight_data).unwrap();
-            let res = pretty_format_batches(batches.as_slice()).unwrap();
-            let expected = r#"
-+-------------------+
-| salutation        |
-+-------------------+
-| Hello, FlightSQL! |
-+-------------------+"#
-                .trim()
-                .to_string();
-            assert_eq!(res.to_string(), expected);
-        };
-
-        tokio::select! {
-            _ = serve => panic!("server finished"),
-            _ = request_future => println!("Client finished!"),
-        }
-    }
-
     async fn auth_client(client: &mut FlightSqlServiceClient<Channel>) {
         let token = client.handshake("admin", "password").await.unwrap();
         client.set_token(String::from_utf8(token.to_vec()).unwrap());
     }
 
-    async fn test_client<F, C>(f: F)
+    async fn test_uds_client<F, C>(f: F)
     where
         F: FnOnce(FlightSqlServiceClient<Channel>) -> C,
         C: Future<Output = ()>,
@@ -682,14 +601,91 @@ mod tests {
         let uds = UnixListener::bind(path.clone()).unwrap();
         let stream = UnixListenerStream::new(uds);
 
-        // We would just listen on TCP, but it seems impossible to know when tonic is ready to serve
         let service = FlightSqlServiceImpl {};
         let serve_future = Server::builder()
             .add_service(FlightServiceServer::new(service))
             .serve_with_incoming(stream);
 
         let request_future = async {
-            let client = client_with_uds(path).await;
+            let connector = service_fn(move |_| UnixStream::connect(path.clone()));
+            let channel = Endpoint::try_from("http://example.com")
+                .unwrap()
+                .connect_with_connector(connector)
+                .await
+                .unwrap();
+            let client = FlightSqlServiceClient::new(channel);
+            f(client).await
+        };
+
+        tokio::select! {
+            _ = serve_future => panic!("server returned first"),
+            _ = request_future => println!("Client finished!"),
+        }
+    }
+
+    async fn test_http_client<F, C>(f: F)
+    where
+        F: FnOnce(FlightSqlServiceClient<Channel>) -> C,
+        C: Future<Output = ()>,
+    {
+        let (incoming, addr) = bind_tcp().await;
+        let uri = format!("http://{}:{}", addr.ip(), addr.port());
+
+        let service = FlightSqlServiceImpl {};
+        let serve_future = Server::builder()
+            .add_service(FlightServiceServer::new(service))
+            .serve_with_incoming(incoming);
+
+        let request_future = async {
+            let endpoint = endpoint(uri).unwrap();
+            let channel = endpoint.connect().await.unwrap();
+            let client = FlightSqlServiceClient::new(channel);
+            f(client).await
+        };
+
+        tokio::select! {
+            _ = serve_future => panic!("server returned first"),
+            _ = request_future => println!("Client finished!"),
+        }
+    }
+
+    async fn test_https_client<F, C>(f: F)
+    where
+        F: FnOnce(FlightSqlServiceClient<Channel>) -> C,
+        C: Future<Output = ()>,
+    {
+        let cert = std::fs::read_to_string("examples/data/server.pem").unwrap();
+        let key = std::fs::read_to_string("examples/data/server.key").unwrap();
+        let client_ca = std::fs::read_to_string("examples/data/client_ca.pem").unwrap();
+
+        let tls_config = ServerTlsConfig::new()
+            .identity(Identity::from_pem(&cert, &key))
+            .client_ca_root(Certificate::from_pem(&client_ca));
+
+        let (incoming, addr) = bind_tcp().await;
+        let uri = format!("https://{}:{}", addr.ip(), addr.port());
+
+        let svc = FlightServiceServer::new(FlightSqlServiceImpl {});
+
+        let serve_future = Server::builder()
+            .tls_config(tls_config)
+            .unwrap()
+            .add_service(svc)
+            .serve_with_incoming(incoming);
+
+        let request_future = async {
+            let cert = std::fs::read_to_string("examples/data/client1.pem").unwrap();
+            let key = std::fs::read_to_string("examples/data/client1.key").unwrap();
+            let server_ca = std::fs::read_to_string("examples/data/ca.pem").unwrap();
+
+            let tls_config = ClientTlsConfig::new()
+                .domain_name("localhost")
+                .ca_certificate(Certificate::from_pem(&server_ca))
+                .identity(Identity::from_pem(cert, key));
+
+            let endpoint = endpoint(uri).unwrap().tls_config(tls_config).unwrap();
+            let channel = endpoint.connect().await.unwrap();
+            let client = FlightSqlServiceClient::new(channel);
             f(client).await
         };
 
@@ -699,16 +695,38 @@ mod tests {
         }
     }
 
+    async fn test_all_clients<F, C>(task: F)
+    where
+        F: FnOnce(FlightSqlServiceClient<Channel>) -> C + Copy,
+        C: Future<Output = ()>,
+    {
+        println!("testing uds client");
+        test_uds_client(task).await;
+        println!("=======");
+
+        println!("testing http client");
+        test_http_client(task).await;
+        println!("=======");
+
+        println!("testing https client");
+        test_https_client(task).await;
+        println!("=======");
+    }
+
     #[tokio::test]
-    async fn test_select_1() {
-        test_client(|mut client| async move {
+    async fn test_select() {
+        test_all_clients(|mut client| async move {
             auth_client(&mut client).await;
+
             let mut stmt = client.prepare("select 1;".to_string()).await.unwrap();
+
             let flight_info = stmt.execute().await.unwrap();
+
             let ticket = flight_info.endpoint[0].ticket.as_ref().unwrap().clone();
             let flight_data = client.do_get(ticket).await.unwrap();
             let flight_data: Vec<FlightData> = flight_data.try_collect().await.unwrap();
             let batches = flight_data_to_batches(&flight_data).unwrap();
+
             let res = pretty_format_batches(batches.as_slice()).unwrap();
             let expected = r#"
 +-------------------+
@@ -725,7 +743,7 @@ mod tests {
 
     #[tokio::test]
     async fn test_execute_update() {
-        test_client(|mut client| async move {
+        test_all_clients(|mut client| async move {
             auth_client(&mut client).await;
             let res = client
                 .execute_update("creat table test(a int);".to_string())
@@ -738,7 +756,7 @@ mod tests {
 
     #[tokio::test]
     async fn test_auth() {
-        test_client(|mut client| async move {
+        test_all_clients(|mut client| async move {
             // no handshake
             assert!(client
                 .prepare("select 1;".to_string())