You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by tu...@apache.org on 2023/04/28 09:47:23 UTC
[arrow-rs] branch master updated: Better flight SQL example codes (#4144)
This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new b717b3939 Better flight SQL example codes (#4144)
b717b3939 is described below
commit b717b39393367d1de7577078c13b91c59a62d581
Author: sundyli <54...@qq.com>
AuthorDate: Fri Apr 28 02:47:17 2023 -0700
Better flight SQL example codes (#4144)
* Better flight sql example codes
* Better flight sql example codes
* feat: flight sql server enable tcp no deplay
* Remove unnecessary doc
---------
Co-authored-by: Raphael Taylor-Davies <r....@googlemail.com>
---
arrow-flight/examples/flight_sql_server.rs | 196 ++++++++++++++++-------------
1 file changed, 107 insertions(+), 89 deletions(-)
diff --git a/arrow-flight/examples/flight_sql_server.rs b/arrow-flight/examples/flight_sql_server.rs
index 43154420d..23d71090a 100644
--- a/arrow-flight/examples/flight_sql_server.rs
+++ b/arrow-flight/examples/flight_sql_server.rs
@@ -546,8 +546,7 @@ impl ProstMessageExt for FetchResults {
#[cfg(test)]
mod tests {
use super::*;
- use futures::future::BoxFuture;
- use futures::{FutureExt, TryStreamExt};
+ use futures::TryStreamExt;
use std::fs;
use std::future::Future;
use std::net::SocketAddr;
@@ -571,42 +570,6 @@ mod tests {
(incoming, addr)
}
- async fn client_with_uds(path: String) -> FlightSqlServiceClient<Channel> {
- let connector = service_fn(move |_| UnixStream::connect(path.clone()));
- let channel = Endpoint::try_from("http://example.com")
- .unwrap()
- .connect_with_connector(connector)
- .await
- .unwrap();
- FlightSqlServiceClient::new(channel)
- }
-
- type ServeFut = BoxFuture<'static, Result<(), tonic::transport::Error>>;
-
- async fn create_https_server(
- ) -> Result<(ServeFut, SocketAddr), tonic::transport::Error> {
- let cert = std::fs::read_to_string("examples/data/server.pem").unwrap();
- let key = std::fs::read_to_string("examples/data/server.key").unwrap();
- let client_ca = std::fs::read_to_string("examples/data/client_ca.pem").unwrap();
-
- let tls_config = ServerTlsConfig::new()
- .identity(Identity::from_pem(&cert, &key))
- .client_ca_root(Certificate::from_pem(&client_ca));
-
- let (incoming, addr) = bind_tcp().await;
-
- let svc = FlightServiceServer::new(FlightSqlServiceImpl {});
-
- let serve = Server::builder()
- .tls_config(tls_config)
- .unwrap()
- .add_service(svc)
- .serve_with_incoming(incoming)
- .boxed();
-
- Ok((serve, addr))
- }
-
fn endpoint(uri: String) -> Result<Endpoint, ArrowError> {
let endpoint = Endpoint::new(uri)
.map_err(|_| ArrowError::IoError("Cannot create endpoint".to_string()))?
@@ -621,56 +584,12 @@ mod tests {
Ok(endpoint)
}
- #[tokio::test]
- async fn test_select_https() {
- let (serve, addr) = create_https_server().await.unwrap();
- let uri = format!("https://{}:{}", addr.ip(), addr.port());
-
- let request_future = async {
- let cert = std::fs::read_to_string("examples/data/client1.pem").unwrap();
- let key = std::fs::read_to_string("examples/data/client1.key").unwrap();
- let server_ca = std::fs::read_to_string("examples/data/ca.pem").unwrap();
-
- let tls_config = ClientTlsConfig::new()
- .domain_name("localhost")
- .ca_certificate(Certificate::from_pem(&server_ca))
- .identity(Identity::from_pem(cert, key));
- let endpoint = endpoint(uri).unwrap().tls_config(tls_config).unwrap();
- let channel = endpoint.connect().await.unwrap();
- let mut client = FlightSqlServiceClient::new(channel);
- let token = client.handshake("admin", "password").await.unwrap();
- client.set_token(String::from_utf8(token.to_vec()).unwrap());
- println!("Auth succeeded with token: {:?}", token);
- let mut stmt = client.prepare("select 1;".to_string()).await.unwrap();
- let flight_info = stmt.execute().await.unwrap();
- let ticket = flight_info.endpoint[0].ticket.as_ref().unwrap().clone();
- let flight_data = client.do_get(ticket).await.unwrap();
- let flight_data: Vec<FlightData> = flight_data.try_collect().await.unwrap();
- let batches = flight_data_to_batches(&flight_data).unwrap();
- let res = pretty_format_batches(batches.as_slice()).unwrap();
- let expected = r#"
-+-------------------+
-| salutation |
-+-------------------+
-| Hello, FlightSQL! |
-+-------------------+"#
- .trim()
- .to_string();
- assert_eq!(res.to_string(), expected);
- };
-
- tokio::select! {
- _ = serve => panic!("server finished"),
- _ = request_future => println!("Client finished!"),
- }
- }
-
async fn auth_client(client: &mut FlightSqlServiceClient<Channel>) {
let token = client.handshake("admin", "password").await.unwrap();
client.set_token(String::from_utf8(token.to_vec()).unwrap());
}
- async fn test_client<F, C>(f: F)
+ async fn test_uds_client<F, C>(f: F)
where
F: FnOnce(FlightSqlServiceClient<Channel>) -> C,
C: Future<Output = ()>,
@@ -682,14 +601,91 @@ mod tests {
let uds = UnixListener::bind(path.clone()).unwrap();
let stream = UnixListenerStream::new(uds);
- // We would just listen on TCP, but it seems impossible to know when tonic is ready to serve
let service = FlightSqlServiceImpl {};
let serve_future = Server::builder()
.add_service(FlightServiceServer::new(service))
.serve_with_incoming(stream);
let request_future = async {
- let client = client_with_uds(path).await;
+ let connector = service_fn(move |_| UnixStream::connect(path.clone()));
+ let channel = Endpoint::try_from("http://example.com")
+ .unwrap()
+ .connect_with_connector(connector)
+ .await
+ .unwrap();
+ let client = FlightSqlServiceClient::new(channel);
+ f(client).await
+ };
+
+ tokio::select! {
+ _ = serve_future => panic!("server returned first"),
+ _ = request_future => println!("Client finished!"),
+ }
+ }
+
+ async fn test_http_client<F, C>(f: F)
+ where
+ F: FnOnce(FlightSqlServiceClient<Channel>) -> C,
+ C: Future<Output = ()>,
+ {
+ let (incoming, addr) = bind_tcp().await;
+ let uri = format!("http://{}:{}", addr.ip(), addr.port());
+
+ let service = FlightSqlServiceImpl {};
+ let serve_future = Server::builder()
+ .add_service(FlightServiceServer::new(service))
+ .serve_with_incoming(incoming);
+
+ let request_future = async {
+ let endpoint = endpoint(uri).unwrap();
+ let channel = endpoint.connect().await.unwrap();
+ let client = FlightSqlServiceClient::new(channel);
+ f(client).await
+ };
+
+ tokio::select! {
+ _ = serve_future => panic!("server returned first"),
+ _ = request_future => println!("Client finished!"),
+ }
+ }
+
+ async fn test_https_client<F, C>(f: F)
+ where
+ F: FnOnce(FlightSqlServiceClient<Channel>) -> C,
+ C: Future<Output = ()>,
+ {
+ let cert = std::fs::read_to_string("examples/data/server.pem").unwrap();
+ let key = std::fs::read_to_string("examples/data/server.key").unwrap();
+ let client_ca = std::fs::read_to_string("examples/data/client_ca.pem").unwrap();
+
+ let tls_config = ServerTlsConfig::new()
+ .identity(Identity::from_pem(&cert, &key))
+ .client_ca_root(Certificate::from_pem(&client_ca));
+
+ let (incoming, addr) = bind_tcp().await;
+ let uri = format!("https://{}:{}", addr.ip(), addr.port());
+
+ let svc = FlightServiceServer::new(FlightSqlServiceImpl {});
+
+ let serve_future = Server::builder()
+ .tls_config(tls_config)
+ .unwrap()
+ .add_service(svc)
+ .serve_with_incoming(incoming);
+
+ let request_future = async {
+ let cert = std::fs::read_to_string("examples/data/client1.pem").unwrap();
+ let key = std::fs::read_to_string("examples/data/client1.key").unwrap();
+ let server_ca = std::fs::read_to_string("examples/data/ca.pem").unwrap();
+
+ let tls_config = ClientTlsConfig::new()
+ .domain_name("localhost")
+ .ca_certificate(Certificate::from_pem(&server_ca))
+ .identity(Identity::from_pem(cert, key));
+
+ let endpoint = endpoint(uri).unwrap().tls_config(tls_config).unwrap();
+ let channel = endpoint.connect().await.unwrap();
+ let client = FlightSqlServiceClient::new(channel);
f(client).await
};
@@ -699,16 +695,38 @@ mod tests {
}
}
+ async fn test_all_clients<F, C>(task: F)
+ where
+ F: FnOnce(FlightSqlServiceClient<Channel>) -> C + Copy,
+ C: Future<Output = ()>,
+ {
+ println!("testing uds client");
+ test_uds_client(task).await;
+ println!("=======");
+
+ println!("testing http client");
+ test_http_client(task).await;
+ println!("=======");
+
+ println!("testing https client");
+ test_https_client(task).await;
+ println!("=======");
+ }
+
#[tokio::test]
- async fn test_select_1() {
- test_client(|mut client| async move {
+ async fn test_select() {
+ test_all_clients(|mut client| async move {
auth_client(&mut client).await;
+
let mut stmt = client.prepare("select 1;".to_string()).await.unwrap();
+
let flight_info = stmt.execute().await.unwrap();
+
let ticket = flight_info.endpoint[0].ticket.as_ref().unwrap().clone();
let flight_data = client.do_get(ticket).await.unwrap();
let flight_data: Vec<FlightData> = flight_data.try_collect().await.unwrap();
let batches = flight_data_to_batches(&flight_data).unwrap();
+
let res = pretty_format_batches(batches.as_slice()).unwrap();
let expected = r#"
+-------------------+
@@ -725,7 +743,7 @@ mod tests {
#[tokio::test]
async fn test_execute_update() {
- test_client(|mut client| async move {
+ test_all_clients(|mut client| async move {
auth_client(&mut client).await;
let res = client
.execute_update("creat table test(a int);".to_string())
@@ -738,7 +756,7 @@ mod tests {
#[tokio::test]
async fn test_auth() {
- test_client(|mut client| async move {
+ test_all_clients(|mut client| async move {
// no handshake
assert!(client
.prepare("select 1;".to_string())