You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/12/09 14:41:08 UTC
[arrow-rs] branch master updated: FlightSQL Client & integration test (#3207)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new d18827d28 FlightSQL Client & integration test (#3207)
d18827d28 is described below
commit d18827d28e4149c81a7e3a3c86aae3fdedb87305
Author: Brent Gardner <br...@spaceandtime.io>
AuthorDate: Fri Dec 9 07:41:01 2022 -0700
FlightSQL Client & integration test (#3207)
* squash
* Undo nightly clippy advice
* PR feedback
* PR feedback
* PR feedback
* PR feedback
* Formatting
---
arrow-flight/Cargo.toml | 4 +
arrow-flight/examples/flight_sql_server.rs | 229 +++++++++++--
arrow-flight/examples/server.rs | 18 +-
arrow-flight/src/sql/client.rs | 531 +++++++++++++++++++++++++++++
arrow-flight/src/sql/mod.rs | 1 +
arrow-flight/src/sql/server.rs | 4 +-
arrow-flight/src/utils.rs | 52 ++-
7 files changed, 804 insertions(+), 35 deletions(-)
diff --git a/arrow-flight/Cargo.toml b/arrow-flight/Cargo.toml
index 77881a70f..35f70669c 100644
--- a/arrow-flight/Cargo.toml
+++ b/arrow-flight/Cargo.toml
@@ -45,6 +45,10 @@ default = []
flight-sql-experimental = ["prost-types"]
[dev-dependencies]
+arrow = { version = "28.0.0", path = "../arrow", features = ["prettyprint"] }
+tempfile = "3.3"
+tokio-stream = { version = "0.1", features = ["net"] }
+tower = "0.4.13"
[build-dependencies]
# Pin specific version of the tonic-build dependencies to avoid auto-generated
diff --git a/arrow-flight/examples/flight_sql_server.rs b/arrow-flight/examples/flight_sql_server.rs
index aa0d40711..29e6c2c37 100644
--- a/arrow-flight/examples/flight_sql_server.rs
+++ b/arrow-flight/examples/flight_sql_server.rs
@@ -15,13 +15,27 @@
// specific language governing permissions and limitations
// under the License.
-use arrow_flight::sql::{ActionCreatePreparedStatementResult, SqlInfo};
-use arrow_flight::{Action, FlightData, HandshakeRequest, HandshakeResponse, Ticket};
-use futures::Stream;
+use arrow_array::builder::StringBuilder;
+use arrow_array::{ArrayRef, RecordBatch};
+use arrow_flight::sql::{ActionCreatePreparedStatementResult, ProstMessageExt, SqlInfo};
+use arrow_flight::{
+ Action, FlightData, FlightEndpoint, HandshakeRequest, HandshakeResponse, IpcMessage,
+ Location, SchemaAsIpc, Ticket,
+};
+use futures::{stream, Stream};
+use prost_types::Any;
+use std::fs;
use std::pin::Pin;
-use tonic::transport::Server;
+use std::sync::Arc;
+use tempfile::NamedTempFile;
+use tokio::net::{UnixListener, UnixStream};
+use tokio_stream::wrappers::UnixListenerStream;
+use tonic::transport::{Endpoint, Server};
use tonic::{Request, Response, Status, Streaming};
+use arrow_flight::flight_descriptor::DescriptorType;
+use arrow_flight::sql::client::FlightSqlServiceClient;
+use arrow_flight::utils::batches_to_flight_data;
use arrow_flight::{
flight_service_server::FlightService,
flight_service_server::FlightServiceServer,
@@ -36,10 +50,28 @@ use arrow_flight::{
},
FlightDescriptor, FlightInfo,
};
+use arrow_ipc::writer::IpcWriteOptions;
+use arrow_schema::{ArrowError, DataType, Field, Schema};
+
+macro_rules! status {
+ ($desc:expr, $err:expr) => {
+ Status::internal(format!("{}: {} at {}:{}", $desc, $err, file!(), line!()))
+ };
+}
#[derive(Clone)]
pub struct FlightSqlServiceImpl {}
+impl FlightSqlServiceImpl {
+ fn fake_result() -> Result<RecordBatch, ArrowError> {
+ let schema = Schema::new(vec![Field::new("salutation", DataType::Utf8, false)]);
+ let mut builder = StringBuilder::new();
+ builder.append_value("Hello, FlightSQL!");
+ let cols = vec![Arc::new(builder.finish()) as ArrayRef];
+ RecordBatch::try_new(Arc::new(schema), cols)
+ }
+}
+
#[tonic::async_trait]
impl FlightSqlService for FlightSqlServiceImpl {
type FlightService = FlightSqlServiceImpl;
@@ -57,7 +89,7 @@ impl FlightSqlService for FlightSqlServiceImpl {
.get("authorization")
.ok_or(Status::invalid_argument("authorization field not present"))?
.to_str()
- .map_err(|_| Status::invalid_argument("authorization not parsable"))?;
+ .map_err(|e| status!("authorization not parsable", e))?;
if !authorization.starts_with(basic) {
Err(Status::invalid_argument(format!(
"Auth type not implemented: {}",
@@ -66,20 +98,20 @@ impl FlightSqlService for FlightSqlServiceImpl {
}
let base64 = &authorization[basic.len()..];
let bytes = base64::decode(base64)
- .map_err(|_| Status::invalid_argument("authorization not parsable"))?;
+ .map_err(|e| status!("authorization not decodable", e))?;
let str = String::from_utf8(bytes)
- .map_err(|_| Status::invalid_argument("authorization not parsable"))?;
+ .map_err(|e| status!("authorization not parsable", e))?;
let parts: Vec<_> = str.split(":").collect();
- if parts.len() != 2 {
- Err(Status::invalid_argument(format!(
- "Invalid authorization header"
- )))?;
- }
- let user = parts[0];
- let pass = parts[1];
- if user != "admin" || pass != "password" {
+ let (user, pass) = match parts.as_slice() {
+ [user, pass] => (user, pass),
+ _ => Err(Status::invalid_argument(
+ "Invalid authorization header".to_string(),
+ ))?,
+ };
+ if user != &"admin" || pass != &"password" {
Err(Status::unauthenticated("Invalid credentials!"))?
}
+
let result = HandshakeResponse {
protocol_version: 0,
payload: "random_uuid_token".as_bytes().to_vec(),
@@ -89,7 +121,26 @@ impl FlightSqlService for FlightSqlServiceImpl {
return Ok(Response::new(Box::pin(output)));
}
- // get_flight_info
+ async fn do_get_fallback(
+ &self,
+ _request: Request<Ticket>,
+ _message: prost_types::Any,
+ ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
+ let batch =
+ Self::fake_result().map_err(|e| status!("Could not fake a result", e))?;
+ let schema = (*batch.schema()).clone();
+ let batches = vec![batch];
+ let flight_data = batches_to_flight_data(schema, batches)
+ .map_err(|e| status!("Could not convert batches", e))?
+ .into_iter()
+ .map(Ok);
+
+ let stream: Pin<Box<dyn Stream<Item = Result<FlightData, Status>> + Send>> =
+ Box::pin(stream::iter(flight_data));
+ let resp = Response::new(stream);
+ Ok(resp)
+ }
+
async fn get_flight_info_statement(
&self,
_query: CommandStatementQuery,
@@ -102,12 +153,49 @@ impl FlightSqlService for FlightSqlServiceImpl {
async fn get_flight_info_prepared_statement(
&self,
- _query: CommandPreparedStatementQuery,
+ cmd: CommandPreparedStatementQuery,
_request: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>, Status> {
- Err(Status::unimplemented(
- "get_flight_info_prepared_statement not implemented",
- ))
+ let handle = String::from_utf8(cmd.prepared_statement_handle)
+ .map_err(|e| status!("Unable to parse handle", e))?;
+ let batch =
+ Self::fake_result().map_err(|e| status!("Could not fake a result", e))?;
+ let schema = (*batch.schema()).clone();
+ let num_rows = batch.num_rows();
+ let num_bytes = batch.get_array_memory_size();
+ let loc = Location {
+ uri: "grpc+tcp://127.0.0.1".to_string(),
+ };
+ let fetch = FetchResults {
+ handle: handle.to_string(),
+ };
+ let buf = ::prost::Message::encode_to_vec(&fetch.as_any());
+ let ticket = Ticket { ticket: buf };
+ let endpoint = FlightEndpoint {
+ ticket: Some(ticket),
+ location: vec![loc],
+ };
+ let endpoints = vec![endpoint];
+
+ let message = SchemaAsIpc::new(&schema, &IpcWriteOptions::default())
+ .try_into()
+ .map_err(|e| status!("Unable to serialize schema", e))?;
+ let IpcMessage(schema_bytes) = message;
+
+ let flight_desc = FlightDescriptor {
+ r#type: DescriptorType::Cmd.into(),
+ cmd: vec![],
+ path: vec![],
+ };
+ let info = FlightInfo {
+ schema: schema_bytes,
+ flight_descriptor: Some(flight_desc),
+ endpoint: endpoints,
+ total_records: num_rows as i64,
+ total_bytes: num_bytes as i64,
+ };
+ let resp = Response::new(info);
+ Ok(resp)
}
async fn get_flight_info_catalogs(
@@ -328,20 +416,33 @@ impl FlightSqlService for FlightSqlServiceImpl {
))
}
- // do_action
async fn do_action_create_prepared_statement(
&self,
_query: ActionCreatePreparedStatementRequest,
_request: Request<Action>,
) -> Result<ActionCreatePreparedStatementResult, Status> {
- Err(Status::unimplemented("Not yet implemented"))
+ let handle = "some_uuid";
+ let schema = Self::fake_result()
+ .map_err(|e| status!("Error getting result schema", e))?
+ .schema();
+ let message = SchemaAsIpc::new(&schema, &IpcWriteOptions::default())
+ .try_into()
+ .map_err(|e| status!("Unable to serialize schema", e))?;
+ let IpcMessage(schema_bytes) = message;
+ let res = ActionCreatePreparedStatementResult {
+ prepared_statement_handle: handle.as_bytes().to_vec(),
+ dataset_schema: schema_bytes,
+ parameter_schema: vec![], // TODO: parameters
+ };
+ Ok(res)
}
+
async fn do_action_close_prepared_statement(
&self,
_query: ActionClosePreparedStatementRequest,
_request: Request<Action>,
) {
- unimplemented!("Not yet implemented")
+ unimplemented!("Implement do_action_close_prepared_statement")
}
async fn register_sql_info(&self, _id: i32, _result: &SqlInfo) {}
@@ -360,3 +461,85 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
Ok(())
}
+
+#[derive(Clone, PartialEq, ::prost::Message)]
+pub struct FetchResults {
+ #[prost(string, tag = "1")]
+ pub handle: ::prost::alloc::string::String,
+}
+
+impl ProstMessageExt for FetchResults {
+ fn type_url() -> &'static str {
+ "type.googleapis.com/arrow.flight.protocol.sql.FetchResults"
+ }
+
+ fn as_any(&self) -> Any {
+ prost_types::Any {
+ type_url: FetchResults::type_url().to_string(),
+ value: ::prost::Message::encode_to_vec(self),
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use futures::TryStreamExt;
+
+ use arrow::util::pretty::pretty_format_batches;
+ use arrow_flight::utils::flight_data_to_batches;
+ use tower::service_fn;
+
+ async fn client_with_uds(path: String) -> FlightSqlServiceClient {
+ let connector = service_fn(move |_| UnixStream::connect(path.clone()));
+ let channel = Endpoint::try_from("https://example.com")
+ .unwrap()
+ .connect_with_connector(connector)
+ .await
+ .unwrap();
+ FlightSqlServiceClient::new(channel)
+ }
+
+ #[tokio::test]
+ async fn test_select_1() {
+ let file = NamedTempFile::new().unwrap();
+ let path = file.into_temp_path().to_str().unwrap().to_string();
+ let _ = fs::remove_file(path.clone());
+
+ let uds = UnixListener::bind(path.clone()).unwrap();
+ let stream = UnixListenerStream::new(uds);
+
+ // We would just listen on TCP, but it seems impossible to know when tonic is ready to serve
+ let service = FlightSqlServiceImpl {};
+ let serve_future = Server::builder()
+ .add_service(FlightServiceServer::new(service))
+ .serve_with_incoming(stream);
+
+ let request_future = async {
+ let mut client = client_with_uds(path).await;
+ let token = client.handshake("admin", "password").await.unwrap();
+ println!("Auth succeeded with token: {:?}", token);
+ let mut stmt = client.prepare("select 1;".to_string()).await.unwrap();
+ let flight_info = stmt.execute().await.unwrap();
+ let ticket = flight_info.endpoint[0].ticket.as_ref().unwrap().clone();
+ let flight_data = client.do_get(ticket).await.unwrap();
+ let flight_data: Vec<FlightData> = flight_data.try_collect().await.unwrap();
+ let batches = flight_data_to_batches(&flight_data).unwrap();
+ let res = pretty_format_batches(batches.as_slice()).unwrap();
+ let expected = r#"
++-------------------+
+| salutation |
++-------------------+
+| Hello, FlightSQL! |
++-------------------+"#
+ .trim()
+ .to_string();
+ assert_eq!(res.to_string(), expected);
+ };
+
+ tokio::select! {
+ _ = serve_future => panic!("server returned first"),
+ _ = request_future => println!("Client finished!"),
+ }
+ }
+}
diff --git a/arrow-flight/examples/server.rs b/arrow-flight/examples/server.rs
index 75d053787..1d473103a 100644
--- a/arrow-flight/examples/server.rs
+++ b/arrow-flight/examples/server.rs
@@ -58,63 +58,63 @@ impl FlightService for FlightServiceImpl {
&self,
_request: Request<Streaming<HandshakeRequest>>,
) -> Result<Response<Self::HandshakeStream>, Status> {
- Err(Status::unimplemented("Not yet implemented"))
+ Err(Status::unimplemented("Implement handshake"))
}
async fn list_flights(
&self,
_request: Request<Criteria>,
) -> Result<Response<Self::ListFlightsStream>, Status> {
- Err(Status::unimplemented("Not yet implemented"))
+ Err(Status::unimplemented("Implement list_flights"))
}
async fn get_flight_info(
&self,
_request: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>, Status> {
- Err(Status::unimplemented("Not yet implemented"))
+ Err(Status::unimplemented("Implement get_flight_info"))
}
async fn get_schema(
&self,
_request: Request<FlightDescriptor>,
) -> Result<Response<SchemaResult>, Status> {
- Err(Status::unimplemented("Not yet implemented"))
+ Err(Status::unimplemented("Implement get_schema"))
}
async fn do_get(
&self,
_request: Request<Ticket>,
) -> Result<Response<Self::DoGetStream>, Status> {
- Err(Status::unimplemented("Not yet implemented"))
+ Err(Status::unimplemented("Implement do_get"))
}
async fn do_put(
&self,
_request: Request<Streaming<FlightData>>,
) -> Result<Response<Self::DoPutStream>, Status> {
- Err(Status::unimplemented("Not yet implemented"))
+ Err(Status::unimplemented("Implement do_put"))
}
async fn do_action(
&self,
_request: Request<Action>,
) -> Result<Response<Self::DoActionStream>, Status> {
- Err(Status::unimplemented("Not yet implemented"))
+ Err(Status::unimplemented("Implement do_action"))
}
async fn list_actions(
&self,
_request: Request<Empty>,
) -> Result<Response<Self::ListActionsStream>, Status> {
- Err(Status::unimplemented("Not yet implemented"))
+ Err(Status::unimplemented("Implement list_actions"))
}
async fn do_exchange(
&self,
_request: Request<Streaming<FlightData>>,
) -> Result<Response<Self::DoExchangeStream>, Status> {
- Err(Status::unimplemented("Not yet implemented"))
+ Err(Status::unimplemented("Implement do_exchange"))
}
}
diff --git a/arrow-flight/src/sql/client.rs b/arrow-flight/src/sql/client.rs
new file mode 100644
index 000000000..fa6691793
--- /dev/null
+++ b/arrow-flight/src/sql/client.rs
@@ -0,0 +1,531 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::collections::HashMap;
+use std::sync::Arc;
+use std::time::Duration;
+
+use crate::flight_service_client::FlightServiceClient;
+use crate::sql::server::{CLOSE_PREPARED_STATEMENT, CREATE_PREPARED_STATEMENT};
+use crate::sql::{
+ ActionClosePreparedStatementRequest, ActionCreatePreparedStatementRequest,
+ ActionCreatePreparedStatementResult, CommandGetCatalogs, CommandGetCrossReference,
+ CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys,
+ CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables,
+ CommandPreparedStatementQuery, CommandStatementQuery, CommandStatementUpdate,
+ DoPutUpdateResult, ProstAnyExt, ProstMessageExt, SqlInfo,
+};
+use crate::{
+ Action, FlightData, FlightDescriptor, FlightInfo, HandshakeRequest,
+ HandshakeResponse, IpcMessage, Ticket,
+};
+use arrow_array::RecordBatch;
+use arrow_buffer::Buffer;
+use arrow_ipc::convert::fb_to_schema;
+use arrow_ipc::reader::read_record_batch;
+use arrow_ipc::{root_as_message, MessageHeader};
+use arrow_schema::{ArrowError, Schema, SchemaRef};
+use futures::{stream, TryStreamExt};
+use prost::Message;
+use tokio::sync::{Mutex, MutexGuard};
+use tonic::transport::{Channel, Endpoint};
+use tonic::Streaming;
+
+/// A FlightSQLServiceClient is an endpoint for retrieving or storing Arrow data
+/// by FlightSQL protocol.
+#[derive(Debug, Clone)]
+pub struct FlightSqlServiceClient {
+ token: Option<String>,
+ flight_client: Arc<Mutex<FlightServiceClient<Channel>>>,
+}
+
+/// A FlightSql protocol client that can run queries against FlightSql servers
+/// This client is in the "experimental" stage. It is not guaranteed to follow the spec in all instances.
+/// Github issues are welcomed.
+impl FlightSqlServiceClient {
+ /// Creates a new FlightSql Client that connects via TCP to a server
+ pub async fn new_with_endpoint(host: &str, port: u16) -> Result<Self, ArrowError> {
+ let addr = format!("http://{}:{}", host, port);
+ let endpoint = Endpoint::new(addr)
+ .map_err(|_| ArrowError::IoError("Cannot create endpoint".to_string()))?
+ .connect_timeout(Duration::from_secs(20))
+ .timeout(Duration::from_secs(20))
+ .tcp_nodelay(true) // Disable Nagle's Algorithm since we don't want packets to wait
+ .tcp_keepalive(Option::Some(Duration::from_secs(3600)))
+ .http2_keep_alive_interval(Duration::from_secs(300))
+ .keep_alive_timeout(Duration::from_secs(20))
+ .keep_alive_while_idle(true);
+ let channel = endpoint.connect().await.map_err(|e| {
+ ArrowError::IoError(format!("Cannot connect to endpoint: {}", e))
+ })?;
+ Ok(Self::new(channel))
+ }
+
+ /// Creates a new FlightSql client that connects to a server over an arbitrary tonic `Channel`
+ pub fn new(channel: Channel) -> Self {
+ let flight_client = FlightServiceClient::new(channel);
+ FlightSqlServiceClient {
+ token: None,
+ flight_client: Arc::new(Mutex::new(flight_client)),
+ }
+ }
+
+ fn mut_client(
+ &mut self,
+ ) -> Result<MutexGuard<FlightServiceClient<Channel>>, ArrowError> {
+ self.flight_client
+ .try_lock()
+ .map_err(|_| ArrowError::IoError("Unable to lock client".to_string()))
+ }
+
+ async fn get_flight_info_for_command<M: ProstMessageExt>(
+ &mut self,
+ cmd: M,
+ ) -> Result<FlightInfo, ArrowError> {
+ let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec());
+ let fi = self
+ .mut_client()?
+ .get_flight_info(descriptor)
+ .await
+ .map_err(status_to_arrow_error)?
+ .into_inner();
+ Ok(fi)
+ }
+
+ /// Execute a query on the server.
+ pub async fn execute(&mut self, query: String) -> Result<FlightInfo, ArrowError> {
+ let cmd = CommandStatementQuery { query };
+ self.get_flight_info_for_command(cmd).await
+ }
+
+ /// Perform a `handshake` with the server, passing credentials and establishing a session
+ /// Returns arbitrary auth/handshake info binary blob
+ pub async fn handshake(
+ &mut self,
+ username: &str,
+ password: &str,
+ ) -> Result<Vec<u8>, ArrowError> {
+ let cmd = HandshakeRequest {
+ protocol_version: 0,
+ payload: vec![],
+ };
+ let mut req = tonic::Request::new(stream::iter(vec![cmd]));
+ let val = base64::encode(format!("{}:{}", username, password));
+ let val = format!("Basic {}", val)
+ .parse()
+ .map_err(|_| ArrowError::ParseError("Cannot parse header".to_string()))?;
+ req.metadata_mut().insert("authorization", val);
+ let resp = self
+ .mut_client()?
+ .handshake(req)
+ .await
+ .map_err(|e| ArrowError::IoError(format!("Can't handshake {}", e)))?;
+ if let Some(auth) = resp.metadata().get("authorization") {
+ let auth = auth.to_str().map_err(|_| {
+ ArrowError::ParseError("Can't read auth header".to_string())
+ })?;
+ let bearer = "Bearer ";
+ if !auth.starts_with(bearer) {
+ Err(ArrowError::ParseError("Invalid auth header!".to_string()))?;
+ }
+ let auth = auth[bearer.len()..].to_string();
+ self.token = Some(auth);
+ }
+ let responses: Vec<HandshakeResponse> =
+ resp.into_inner().try_collect().await.map_err(|_| {
+ ArrowError::ParseError("Can't collect responses".to_string())
+ })?;
+ let resp = match responses.as_slice() {
+ [resp] => resp,
+ [] => Err(ArrowError::ParseError("No handshake response".to_string()))?,
+ _ => Err(ArrowError::ParseError(
+ "Multiple handshake responses".to_string(),
+ ))?,
+ };
+ Ok(resp.payload.clone())
+ }
+
+ /// Execute a update query on the server, and return the number of records affected
+ pub async fn execute_update(&mut self, query: String) -> Result<i64, ArrowError> {
+ let cmd = CommandStatementUpdate { query };
+ let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec());
+ let mut result = self
+ .mut_client()?
+ .do_put(stream::iter(vec![FlightData {
+ flight_descriptor: Some(descriptor),
+ ..Default::default()
+ }]))
+ .await
+ .map_err(status_to_arrow_error)?
+ .into_inner();
+ let result = result
+ .message()
+ .await
+ .map_err(status_to_arrow_error)?
+ .unwrap();
+ let any: prost_types::Any = prost::Message::decode(&*result.app_metadata)
+ .map_err(decode_error_to_arrow_error)?;
+ let result: DoPutUpdateResult = any.unpack()?.unwrap();
+ Ok(result.record_count)
+ }
+
+ /// Request a list of catalogs as tabular FlightInfo results
+ pub async fn get_catalogs(&mut self) -> Result<FlightInfo, ArrowError> {
+ self.get_flight_info_for_command(CommandGetCatalogs {})
+ .await
+ }
+
+ /// Request a list of database schemas as tabular FlightInfo results
+ pub async fn get_db_schemas(
+ &mut self,
+ request: CommandGetDbSchemas,
+ ) -> Result<FlightInfo, ArrowError> {
+ self.get_flight_info_for_command(request).await
+ }
+
+ /// Given a flight ticket, request to be sent the stream. Returns record batch stream reader
+ pub async fn do_get(
+ &mut self,
+ ticket: Ticket,
+ ) -> Result<Streaming<FlightData>, ArrowError> {
+ Ok(self
+ .mut_client()?
+ .do_get(ticket)
+ .await
+ .map_err(status_to_arrow_error)?
+ .into_inner())
+ }
+
+ /// Request a list of tables.
+ pub async fn get_tables(
+ &mut self,
+ request: CommandGetTables,
+ ) -> Result<FlightInfo, ArrowError> {
+ self.get_flight_info_for_command(request).await
+ }
+
+ /// Request the primary keys for a table.
+ pub async fn get_primary_keys(
+ &mut self,
+ request: CommandGetPrimaryKeys,
+ ) -> Result<FlightInfo, ArrowError> {
+ self.get_flight_info_for_command(request).await
+ }
+
+ /// Retrieves a description about the foreign key columns that reference the
+ /// primary key columns of the given table.
+ pub async fn get_exported_keys(
+ &mut self,
+ request: CommandGetExportedKeys,
+ ) -> Result<FlightInfo, ArrowError> {
+ self.get_flight_info_for_command(request).await
+ }
+
+ /// Retrieves the foreign key columns for the given table.
+ pub async fn get_imported_keys(
+ &mut self,
+ request: CommandGetImportedKeys,
+ ) -> Result<FlightInfo, ArrowError> {
+ self.get_flight_info_for_command(request).await
+ }
+
+ /// Retrieves a description of the foreign key columns in the given foreign key
+ /// table that reference the primary key or the columns representing a unique
+ /// constraint of the parent table (could be the same or a different table).
+ pub async fn get_cross_reference(
+ &mut self,
+ request: CommandGetCrossReference,
+ ) -> Result<FlightInfo, ArrowError> {
+ self.get_flight_info_for_command(request).await
+ }
+
+ /// Request a list of table types.
+ pub async fn get_table_types(&mut self) -> Result<FlightInfo, ArrowError> {
+ self.get_flight_info_for_command(CommandGetTableTypes {})
+ .await
+ }
+
+ /// Request a list of SQL information.
+ pub async fn get_sql_info(
+ &mut self,
+ sql_infos: Vec<SqlInfo>,
+ ) -> Result<FlightInfo, ArrowError> {
+ let request = CommandGetSqlInfo {
+ info: sql_infos.iter().map(|sql_info| *sql_info as u32).collect(),
+ };
+ self.get_flight_info_for_command(request).await
+ }
+
+ /// Create a prepared statement object.
+ pub async fn prepare(
+ &mut self,
+ query: String,
+ ) -> Result<PreparedStatement<Channel>, ArrowError> {
+ let cmd = ActionCreatePreparedStatementRequest { query };
+ let action = Action {
+ r#type: CREATE_PREPARED_STATEMENT.to_string(),
+ body: cmd.as_any().encode_to_vec(),
+ };
+ let mut req = tonic::Request::new(action);
+ if let Some(token) = &self.token {
+ let val = format!("Bearer {}", token).parse().map_err(|_| {
+ ArrowError::IoError("Statement already closed.".to_string())
+ })?;
+ req.metadata_mut().insert("authorization", val);
+ }
+ let mut result = self
+ .mut_client()?
+ .do_action(req)
+ .await
+ .map_err(status_to_arrow_error)?
+ .into_inner();
+ let result = result
+ .message()
+ .await
+ .map_err(status_to_arrow_error)?
+ .unwrap();
+ let any: prost_types::Any =
+ prost::Message::decode(&*result.body).map_err(decode_error_to_arrow_error)?;
+ let prepared_result: ActionCreatePreparedStatementResult = any.unpack()?.unwrap();
+ let dataset_schema = match prepared_result.dataset_schema.len() {
+ 0 => Schema::empty(),
+ _ => Schema::try_from(IpcMessage(prepared_result.dataset_schema))?,
+ };
+ let parameter_schema = match prepared_result.parameter_schema.len() {
+ 0 => Schema::empty(),
+ _ => Schema::try_from(IpcMessage(prepared_result.parameter_schema))?,
+ };
+ Ok(PreparedStatement::new(
+ self.flight_client.clone(),
+ prepared_result.prepared_statement_handle,
+ dataset_schema,
+ parameter_schema,
+ ))
+ }
+
+ /// Explicitly shut down and clean up the client.
+ pub async fn close(&mut self) -> Result<(), ArrowError> {
+ Ok(())
+ }
+}
+
+/// A PreparedStatement
+#[derive(Debug, Clone)]
+pub struct PreparedStatement<T> {
+ flight_client: Arc<Mutex<FlightServiceClient<T>>>,
+ parameter_binding: Option<RecordBatch>,
+ handle: Vec<u8>,
+ dataset_schema: Schema,
+ parameter_schema: Schema,
+}
+
+impl PreparedStatement<Channel> {
+ pub(crate) fn new(
+ client: Arc<Mutex<FlightServiceClient<Channel>>>,
+ handle: Vec<u8>,
+ dataset_schema: Schema,
+ parameter_schema: Schema,
+ ) -> Self {
+ PreparedStatement {
+ flight_client: client,
+ parameter_binding: None,
+ handle,
+ dataset_schema,
+ parameter_schema,
+ }
+ }
+
+ /// Executes the prepared statement query on the server.
+ pub async fn execute(&mut self) -> Result<FlightInfo, ArrowError> {
+ let cmd = CommandPreparedStatementQuery {
+ prepared_statement_handle: self.handle.clone(),
+ };
+ let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec());
+ let result = self
+ .mut_client()?
+ .get_flight_info(descriptor)
+ .await
+ .map_err(status_to_arrow_error)?
+ .into_inner();
+ Ok(result)
+ }
+
+ /// Executes the prepared statement update query on the server.
+ pub async fn execute_update(&mut self) -> Result<i64, ArrowError> {
+ let cmd = CommandPreparedStatementQuery {
+ prepared_statement_handle: self.handle.clone(),
+ };
+ let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec());
+ let mut result = self
+ .mut_client()?
+ .do_put(stream::iter(vec![FlightData {
+ flight_descriptor: Some(descriptor),
+ ..Default::default()
+ }]))
+ .await
+ .map_err(status_to_arrow_error)?
+ .into_inner();
+ let result = result
+ .message()
+ .await
+ .map_err(status_to_arrow_error)?
+ .unwrap();
+ let any: prost_types::Any = Message::decode(&*result.app_metadata)
+ .map_err(decode_error_to_arrow_error)?;
+ let result: DoPutUpdateResult = any.unpack()?.unwrap();
+ Ok(result.record_count)
+ }
+
+ /// Retrieve the parameter schema from the query.
+ pub fn parameter_schema(&self) -> Result<&Schema, ArrowError> {
+ Ok(&self.parameter_schema)
+ }
+
+ /// Retrieve the ResultSet schema from the query.
+ pub fn dataset_schema(&self) -> Result<&Schema, ArrowError> {
+ Ok(&self.dataset_schema)
+ }
+
+ /// Set a RecordBatch that contains the parameters that will be bind.
+ pub fn set_parameters(
+ &mut self,
+ parameter_binding: RecordBatch,
+ ) -> Result<(), ArrowError> {
+ self.parameter_binding = Some(parameter_binding);
+ Ok(())
+ }
+
+ /// Close the prepared statement, so that this PreparedStatement can not used
+ /// anymore and server can free up any resources.
+ pub async fn close(mut self) -> Result<(), ArrowError> {
+ let cmd = ActionClosePreparedStatementRequest {
+ prepared_statement_handle: self.handle.clone(),
+ };
+ let action = Action {
+ r#type: CLOSE_PREPARED_STATEMENT.to_string(),
+ body: cmd.as_any().encode_to_vec(),
+ };
+ let _ = self
+ .mut_client()?
+ .do_action(action)
+ .await
+ .map_err(status_to_arrow_error)?;
+ Ok(())
+ }
+
+ fn mut_client(
+ &mut self,
+ ) -> Result<MutexGuard<FlightServiceClient<Channel>>, ArrowError> {
+ self.flight_client
+ .try_lock()
+ .map_err(|_| ArrowError::IoError("Unable to lock client".to_string()))
+ }
+}
+
+fn decode_error_to_arrow_error(err: prost::DecodeError) -> ArrowError {
+ ArrowError::IoError(err.to_string())
+}
+
+fn status_to_arrow_error(status: tonic::Status) -> ArrowError {
+ ArrowError::IoError(format!("{:?}", status))
+}
+
+// A polymorphic structure to natively represent different types of data contained in `FlightData`
+pub enum ArrowFlightData {
+ RecordBatch(RecordBatch),
+ Schema(Schema),
+}
+
+/// Extract `Schema` or `RecordBatch`es from the `FlightData` wire representation
+pub fn arrow_data_from_flight_data(
+ flight_data: FlightData,
+ arrow_schema_ref: &SchemaRef,
+) -> Result<ArrowFlightData, ArrowError> {
+ let ipc_message = root_as_message(&flight_data.data_header[..]).map_err(|err| {
+ ArrowError::ParseError(format!("Unable to get root as message: {:?}", err))
+ })?;
+
+ match ipc_message.header_type() {
+ MessageHeader::RecordBatch => {
+ let ipc_record_batch =
+ ipc_message.header_as_record_batch().ok_or_else(|| {
+ ArrowError::ComputeError(
+ "Unable to convert flight data header to a record batch"
+ .to_string(),
+ )
+ })?;
+
+ let dictionaries_by_field = HashMap::new();
+ let record_batch = read_record_batch(
+ &Buffer::from(&flight_data.data_body),
+ ipc_record_batch,
+ arrow_schema_ref.clone(),
+ &dictionaries_by_field,
+ None,
+ &ipc_message.version(),
+ )?;
+ Ok(ArrowFlightData::RecordBatch(record_batch))
+ }
+ MessageHeader::Schema => {
+ let ipc_schema = ipc_message.header_as_schema().ok_or_else(|| {
+ ArrowError::ComputeError(
+ "Unable to convert flight data header to a schema".to_string(),
+ )
+ })?;
+
+ let arrow_schema = fb_to_schema(ipc_schema);
+ Ok(ArrowFlightData::Schema(arrow_schema))
+ }
+ MessageHeader::DictionaryBatch => {
+ let _ = ipc_message.header_as_dictionary_batch().ok_or_else(|| {
+ ArrowError::ComputeError(
+ "Unable to convert flight data header to a dictionary batch"
+ .to_string(),
+ )
+ })?;
+ Err(ArrowError::NotYetImplemented(
+ "no idea on how to convert an ipc dictionary batch to an arrow type"
+ .to_string(),
+ ))
+ }
+ MessageHeader::Tensor => {
+ let _ = ipc_message.header_as_tensor().ok_or_else(|| {
+ ArrowError::ComputeError(
+ "Unable to convert flight data header to a tensor".to_string(),
+ )
+ })?;
+ Err(ArrowError::NotYetImplemented(
+ "no idea on how to convert an ipc tensor to an arrow type".to_string(),
+ ))
+ }
+ MessageHeader::SparseTensor => {
+ let _ = ipc_message.header_as_sparse_tensor().ok_or_else(|| {
+ ArrowError::ComputeError(
+ "Unable to convert flight data header to a sparse tensor".to_string(),
+ )
+ })?;
+ Err(ArrowError::NotYetImplemented(
+ "no idea on how to convert an ipc sparse tensor to an arrow type"
+ .to_string(),
+ ))
+ }
+ _ => Err(ArrowError::ComputeError(format!(
+ "Unable to convert message with header_type: '{:?}' to arrow data",
+ ipc_message.header_type()
+ ))),
+ }
+}
diff --git a/arrow-flight/src/sql/mod.rs b/arrow-flight/src/sql/mod.rs
index a5d4c4c34..0ddc64c55 100644
--- a/arrow-flight/src/sql/mod.rs
+++ b/arrow-flight/src/sql/mod.rs
@@ -58,6 +58,7 @@ pub use gen::SupportedSqlGrammar;
pub use gen::TicketStatementQuery;
pub use gen::UpdateDeleteRules;
+pub mod client;
pub mod server;
/// ProstMessageExt are useful utility methods for prost::Message types
diff --git a/arrow-flight/src/sql/server.rs b/arrow-flight/src/sql/server.rs
index d78474849..ec48d7cfe 100644
--- a/arrow-flight/src/sql/server.rs
+++ b/arrow-flight/src/sql/server.rs
@@ -36,8 +36,8 @@ use super::{
TicketStatementQuery,
};
-static CREATE_PREPARED_STATEMENT: &str = "CreatePreparedStatement";
-static CLOSE_PREPARED_STATEMENT: &str = "ClosePreparedStatement";
+pub(crate) static CREATE_PREPARED_STATEMENT: &str = "CreatePreparedStatement";
+pub(crate) static CLOSE_PREPARED_STATEMENT: &str = "ClosePreparedStatement";
/// Implements FlightSqlService to handle the flight sql protocol
#[tonic::async_trait]
diff --git a/arrow-flight/src/utils.rs b/arrow-flight/src/utils.rs
index 49f9c47db..392d41c83 100644
--- a/arrow-flight/src/utils.rs
+++ b/arrow-flight/src/utils.rs
@@ -19,10 +19,12 @@
use crate::{FlightData, IpcMessage, SchemaAsIpc, SchemaResult};
use std::collections::HashMap;
+use std::sync::Arc;
use arrow_array::{ArrayRef, RecordBatch};
use arrow_buffer::Buffer;
-use arrow_ipc::{reader, writer, writer::IpcWriteOptions};
+use arrow_ipc::convert::fb_to_schema;
+use arrow_ipc::{reader, root_as_message, writer, writer::IpcWriteOptions};
use arrow_schema::{ArrowError, Schema, SchemaRef};
/// Convert a `RecordBatch` to a vector of `FlightData` representing the bytes of the dictionaries
@@ -44,6 +46,32 @@ pub fn flight_data_from_arrow_batch(
(flight_dictionaries, flight_batch)
}
+/// Convert a slice of wire protocol `FlightData`s into a vector of `RecordBatch`es
+pub fn flight_data_to_batches(
+ flight_data: &[FlightData],
+) -> Result<Vec<RecordBatch>, ArrowError> {
+ let schema = flight_data.get(0).ok_or_else(|| {
+ ArrowError::CastError("Need at least one FlightData for schema".to_string())
+ })?;
+ let message = root_as_message(&schema.data_header[..])
+ .map_err(|_| ArrowError::CastError("Cannot get root as message".to_string()))?;
+
+ let ipc_schema: arrow_ipc::Schema = message.header_as_schema().ok_or_else(|| {
+ ArrowError::CastError("Cannot get header as Schema".to_string())
+ })?;
+ let schema = fb_to_schema(ipc_schema);
+ let schema = Arc::new(schema);
+
+ let mut batches = vec![];
+ let dictionaries_by_id = HashMap::new();
+ for datum in flight_data[1..].iter() {
+ let batch =
+ flight_data_to_arrow_batch(datum, schema.clone(), &dictionaries_by_id)?;
+ batches.push(batch);
+ }
+ Ok(batches)
+}
+
/// Convert `FlightData` (with supplied schema and dictionaries) to an arrow `RecordBatch`.
pub fn flight_data_to_arrow_batch(
data: &FlightData,
@@ -111,3 +139,25 @@ pub fn ipc_message_from_arrow_schema(
let IpcMessage(vals) = message;
Ok(vals)
}
+
+/// Convert `RecordBatch`es to wire protocol `FlightData`s
+pub fn batches_to_flight_data(
+ schema: Schema,
+ batches: Vec<RecordBatch>,
+) -> Result<Vec<FlightData>, ArrowError> {
+ let options = IpcWriteOptions::default();
+ let schema_flight_data: FlightData = SchemaAsIpc::new(&schema, &options).into();
+ let mut dictionaries = vec![];
+ let mut flight_data = vec![];
+ for batch in batches.iter() {
+ let (flight_dictionaries, flight_datum) =
+ flight_data_from_arrow_batch(batch, &options);
+ dictionaries.extend(flight_dictionaries);
+ flight_data.push(flight_datum);
+ }
+ let mut stream = vec![schema_flight_data];
+ stream.extend(dictionaries.into_iter());
+ stream.extend(flight_data.into_iter());
+ let flight_data: Vec<_> = stream.into_iter().collect();
+ Ok(flight_data)
+}