You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by tu...@apache.org on 2023/04/27 12:33:00 UTC

[arrow-rs] branch master updated: Don't hardcode port in FlightSQL tests (#4145)

This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 547512172 Don't hardcode port in FlightSQL tests (#4145)
547512172 is described below

commit 547512172737004321ff5a02145882e15a52df0d
Author: Raphael Taylor-Davies <17...@users.noreply.github.com>
AuthorDate: Thu Apr 27 08:32:53 2023 -0400

    Don't hardcode port in FlightSQL tests (#4145)
    
    * Don't hardcode port in FlightSQL tests
    
    * Remove sleep
---
 arrow-flight/examples/flight_sql_server.rs | 47 ++++++++++++++++++------------
 1 file changed, 28 insertions(+), 19 deletions(-)

diff --git a/arrow-flight/examples/flight_sql_server.rs b/arrow-flight/examples/flight_sql_server.rs
index 675692aba..43154420d 100644
--- a/arrow-flight/examples/flight_sql_server.rs
+++ b/arrow-flight/examples/flight_sql_server.rs
@@ -546,22 +546,31 @@ impl ProstMessageExt for FetchResults {
 #[cfg(test)]
 mod tests {
     use super::*;
-    use futures::TryStreamExt;
+    use futures::future::BoxFuture;
+    use futures::{FutureExt, TryStreamExt};
     use std::fs;
     use std::future::Future;
+    use std::net::SocketAddr;
     use std::time::Duration;
     use tempfile::NamedTempFile;
-    use tokio::net::{UnixListener, UnixStream};
-    use tokio::time::sleep;
+    use tokio::net::{TcpListener, UnixListener, UnixStream};
     use tokio_stream::wrappers::UnixListenerStream;
     use tonic::transport::{Channel, ClientTlsConfig};
 
     use arrow_cast::pretty::pretty_format_batches;
     use arrow_flight::sql::client::FlightSqlServiceClient;
     use arrow_flight::utils::flight_data_to_batches;
+    use tonic::transport::server::TcpIncoming;
     use tonic::transport::{Certificate, Endpoint};
     use tower::service_fn;
 
+    async fn bind_tcp() -> (TcpIncoming, SocketAddr) {
+        let listener = TcpListener::bind("0.0.0.0:0").await.unwrap();
+        let addr = listener.local_addr().unwrap();
+        let incoming = TcpIncoming::from_listener(listener, true, None).unwrap();
+        (incoming, addr)
+    }
+
     async fn client_with_uds(path: String) -> FlightSqlServiceClient<Channel> {
         let connector = service_fn(move |_| UnixStream::connect(path.clone()));
         let channel = Endpoint::try_from("http://example.com")
@@ -572,7 +581,10 @@ mod tests {
         FlightSqlServiceClient::new(channel)
     }
 
-    async fn create_https_server() -> Result<(), tonic::transport::Error> {
+    type ServeFut = BoxFuture<'static, Result<(), tonic::transport::Error>>;
+
+    async fn create_https_server(
+    ) -> Result<(ServeFut, SocketAddr), tonic::transport::Error> {
         let cert = std::fs::read_to_string("examples/data/server.pem").unwrap();
         let key = std::fs::read_to_string("examples/data/server.key").unwrap();
         let client_ca = std::fs::read_to_string("examples/data/client_ca.pem").unwrap();
@@ -581,20 +593,22 @@ mod tests {
             .identity(Identity::from_pem(&cert, &key))
             .client_ca_root(Certificate::from_pem(&client_ca));
 
-        let addr = "0.0.0.0:50051".parse().unwrap();
+        let (incoming, addr) = bind_tcp().await;
 
         let svc = FlightServiceServer::new(FlightSqlServiceImpl {});
 
-        Server::builder()
+        let serve = Server::builder()
             .tls_config(tls_config)
             .unwrap()
             .add_service(svc)
-            .serve(addr)
-            .await
+            .serve_with_incoming(incoming)
+            .boxed();
+
+        Ok((serve, addr))
     }
 
-    fn endpoint(addr: String) -> Result<Endpoint, ArrowError> {
-        let endpoint = Endpoint::new(addr)
+    fn endpoint(uri: String) -> Result<Endpoint, ArrowError> {
+        let endpoint = Endpoint::new(uri)
             .map_err(|_| ArrowError::IoError("Cannot create endpoint".to_string()))?
             .connect_timeout(Duration::from_secs(20))
             .timeout(Duration::from_secs(20))
@@ -609,11 +623,8 @@ mod tests {
 
     #[tokio::test]
     async fn test_select_https() {
-        tokio::spawn(async {
-            create_https_server().await.unwrap();
-        });
-
-        sleep(Duration::from_millis(2000)).await;
+        let (serve, addr) = create_https_server().await.unwrap();
+        let uri = format!("https://{}:{}", addr.ip(), addr.port());
 
         let request_future = async {
             let cert = std::fs::read_to_string("examples/data/client1.pem").unwrap();
@@ -624,10 +635,7 @@ mod tests {
                 .domain_name("localhost")
                 .ca_certificate(Certificate::from_pem(&server_ca))
                 .identity(Identity::from_pem(cert, key));
-            let endpoint = endpoint(String::from("https://127.0.0.1:50051"))
-                .unwrap()
-                .tls_config(tls_config)
-                .unwrap();
+            let endpoint = endpoint(uri).unwrap().tls_config(tls_config).unwrap();
             let channel = endpoint.connect().await.unwrap();
             let mut client = FlightSqlServiceClient::new(channel);
             let token = client.handshake("admin", "password").await.unwrap();
@@ -652,6 +660,7 @@ mod tests {
         };
 
         tokio::select! {
+            _ = serve => panic!("server finished"),
             _ = request_future => println!("Client finished!"),
         }
     }