You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/12/23 12:23:43 UTC
[arrow-rs] branch master updated: Initial Mid-level `FlightClient` (#3378)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 17b3210af Initial Mid-level `FlightClient` (#3378)
17b3210af is described below
commit 17b3210af2ccd190489de9c641fd10f009abd45b
Author: Andrew Lamb <an...@nerdnetworks.org>
AuthorDate: Fri Dec 23 07:23:37 2022 -0500
Initial Mid-level `FlightClient` (#3378)
* Mid-level FlightClient
* cleanup
* fixup for use of Bytes
* clippy
* Apply suggestions from code review
Co-authored-by: Raphael Taylor-Davies <17...@users.noreply.github.com>
* fixup
* BoxStream
Co-authored-by: Raphael Taylor-Davies <17...@users.noreply.github.com>
---
arrow-flight/src/client.rs | 567 ++++++++++++++++++++++++++++++++++++
arrow-flight/src/error.rs | 59 ++++
arrow-flight/src/lib.rs | 7 +
arrow-flight/tests/client.rs | 309 ++++++++++++++++++++
arrow-flight/tests/common/server.rs | 212 ++++++++++++++
5 files changed, 1154 insertions(+)
diff --git a/arrow-flight/src/client.rs b/arrow-flight/src/client.rs
new file mode 100644
index 000000000..0e75ac7c0
--- /dev/null
+++ b/arrow-flight/src/client.rs
@@ -0,0 +1,567 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::{
+ flight_service_client::FlightServiceClient, utils::flight_data_to_arrow_batch,
+ FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, Ticket,
+};
+use arrow_array::{ArrayRef, RecordBatch};
+use arrow_schema::Schema;
+use bytes::Bytes;
+use futures::{future::ready, ready, stream, StreamExt};
+use std::{collections::HashMap, convert::TryFrom, pin::Pin, sync::Arc, task::Poll};
+use tonic::{metadata::MetadataMap, transport::Channel, Streaming};
+
+use crate::error::{FlightError, Result};
+
+/// A "Mid level" [Apache Arrow Flight](https://arrow.apache.org/docs/format/Flight.html) client.
+///
+/// [`FlightClient`] is intended as a convenience for interactions
+/// with Arrow Flight servers. For more direct control, such as access
+/// to the response headers, use [`FlightServiceClient`] directly
+/// via methods such as [`Self::inner`] or [`Self::into_inner`].
+///
+/// # Example:
+/// ```no_run
+/// # async fn run() {
+/// # use arrow_flight::FlightClient;
+/// # use bytes::Bytes;
+/// use tonic::transport::Channel;
+/// let channel = Channel::from_static("http://localhost:1234")
+/// .connect()
+/// .await
+/// .expect("error connecting");
+///
+/// let mut client = FlightClient::new(channel);
+///
+/// // Send 'Hi' bytes as the handshake request to the server
+/// let response = client
+/// .handshake(Bytes::from("Hi"))
+/// .await
+/// .expect("error handshaking");
+///
+/// // Expect the server responded with 'Ho'
+/// assert_eq!(response, Bytes::from("Ho"));
+/// # }
+/// ```
+#[derive(Debug)]
+pub struct FlightClient {
+ /// Optional grpc header metadata to include with each request
+ metadata: MetadataMap,
+
+ /// The inner client
+ inner: FlightServiceClient<Channel>,
+}
+
+impl FlightClient {
+ /// Creates a client client with the provided [`Channel`](tonic::transport::Channel)
+ pub fn new(channel: Channel) -> Self {
+ Self::new_from_inner(FlightServiceClient::new(channel))
+ }
+
+ /// Creates a new higher level client with the provided lower level client
+ pub fn new_from_inner(inner: FlightServiceClient<Channel>) -> Self {
+ Self {
+ metadata: MetadataMap::new(),
+ inner,
+ }
+ }
+
+ /// Return a reference to gRPC metadata included with each request
+ pub fn metadata(&self) -> &MetadataMap {
+ &self.metadata
+ }
+
+ /// Return a reference to gRPC metadata included with each request
+ ///
+ /// These headers can be used, for example, to include
+ /// authorization or other application specific headers.
+ pub fn metadata_mut(&mut self) -> &mut MetadataMap {
+ &mut self.metadata
+ }
+
+ /// Add the specified header with value to all subsequent
+ /// requests. See [`Self::metadata_mut`] for fine grained control.
+ pub fn add_header(&mut self, key: &str, value: &str) -> Result<()> {
+ let key = tonic::metadata::MetadataKey::<_>::from_bytes(key.as_bytes())
+ .map_err(|e| FlightError::ExternalError(Box::new(e)))?;
+
+ let value = value
+ .parse()
+ .map_err(|e| FlightError::ExternalError(Box::new(e)))?;
+
+ // ignore previous value
+ self.metadata.insert(key, value);
+
+ Ok(())
+ }
+
+ /// Return a reference to the underlying tonic
+ /// [`FlightServiceClient`]
+ pub fn inner(&self) -> &FlightServiceClient<Channel> {
+ &self.inner
+ }
+
+ /// Return a mutable reference to the underlying tonic
+ /// [`FlightServiceClient`]
+ pub fn inner_mut(&mut self) -> &mut FlightServiceClient<Channel> {
+ &mut self.inner
+ }
+
+ /// Consume this client and return the underlying tonic
+ /// [`FlightServiceClient`]
+ pub fn into_inner(self) -> FlightServiceClient<Channel> {
+ self.inner
+ }
+
+ /// Perform an Arrow Flight handshake with the server, sending
+ /// `payload` as the [`HandshakeRequest`] payload and returning
+ /// the [`HandshakeResponse`](crate::HandshakeResponse)
+ /// bytes returned from the server
+ ///
+ /// See [`FlightClient`] docs for an example.
+ pub async fn handshake(&mut self, payload: impl Into<Bytes>) -> Result<Bytes> {
+ let request = HandshakeRequest {
+ protocol_version: 0,
+ payload: payload.into(),
+ };
+
+ // apply headers, etc
+ let request = self.make_request(stream::once(ready(request)));
+
+ let mut response_stream = self.inner.handshake(request).await?.into_inner();
+
+ if let Some(response) = response_stream.next().await.transpose()? {
+ // check if there is another response
+ if response_stream.next().await.is_some() {
+ return Err(FlightError::protocol(
+ "Got unexpected second response from handshake",
+ ));
+ }
+
+ Ok(response.payload)
+ } else {
+ Err(FlightError::protocol("No response from handshake"))
+ }
+ }
+
+ /// Make a `DoGet` call to the server with the provided ticket,
+ /// returning a [`FlightRecordBatchStream`] for reading
+ /// [`RecordBatch`]es.
+ ///
+ /// # Example:
+ /// ```no_run
+ /// # async fn run() {
+ /// # use bytes::Bytes;
+ /// # use arrow_flight::FlightClient;
+ /// # use arrow_flight::Ticket;
+ /// # use arrow_array::RecordBatch;
+ /// # use tonic::transport::Channel;
+ /// # use futures::stream::TryStreamExt;
+ /// # let channel = Channel::from_static("http://localhost:1234")
+ /// # .connect()
+ /// # .await
+ /// # .expect("error connecting");
+ /// # let ticket = Ticket { ticket: Bytes::from("foo") };
+ /// let mut client = FlightClient::new(channel);
+ ///
+ /// // Invoke a do_get request on the server with a previously
+ /// // received Ticket
+ ///
+ /// let response = client
+ /// .do_get(ticket)
+ /// .await
+ /// .expect("error invoking do_get");
+ ///
+ /// // Use try_collect to get the RecordBatches from the server
+ /// let batches: Vec<RecordBatch> = response
+ /// .try_collect()
+ /// .await
+ /// .expect("no stream errors");
+ /// # }
+ /// ```
+ pub async fn do_get(&mut self, ticket: Ticket) -> Result<FlightRecordBatchStream> {
+ let request = self.make_request(ticket);
+
+ let response = self.inner.do_get(request).await?.into_inner();
+
+ let flight_data_stream = FlightDataStream::new(response);
+ Ok(FlightRecordBatchStream::new(flight_data_stream))
+ }
+
+ /// Make a `GetFlightInfo` call to the server with the provided
+ /// [`FlightDescriptor`] and return the [`FlightInfo`] from the
+ /// server. The [`FlightInfo`] can be used with [`Self::do_get`]
+ /// to retrieve the requested batches.
+ ///
+ /// # Example:
+ /// ```no_run
+ /// # async fn run() {
+ /// # use arrow_flight::FlightClient;
+ /// # use arrow_flight::FlightDescriptor;
+ /// # use tonic::transport::Channel;
+ /// # let channel = Channel::from_static("http://localhost:1234")
+ /// # .connect()
+ /// # .await
+ /// # .expect("error connecting");
+ /// let mut client = FlightClient::new(channel);
+ ///
+ /// // Send a 'CMD' request to the server
+ /// let request = FlightDescriptor::new_cmd(b"MOAR DATA".to_vec());
+ /// let flight_info = client
+ /// .get_flight_info(request)
+ /// .await
+ /// .expect("error handshaking");
+ ///
+ /// // retrieve the first endpoint from the returned flight info
+ /// let ticket = flight_info
+ /// .endpoint[0]
+ /// // Extract the ticket
+ /// .ticket
+ /// .clone()
+ /// .expect("expected ticket");
+ ///
+ /// // Retrieve the corresponding RecordBatch stream with do_get
+ /// let data = client
+ /// .do_get(ticket)
+ /// .await
+ /// .expect("error fetching data");
+ /// # }
+ /// ```
+ pub async fn get_flight_info(
+ &mut self,
+ descriptor: FlightDescriptor,
+ ) -> Result<FlightInfo> {
+ let request = self.make_request(descriptor);
+
+ let response = self.inner.get_flight_info(request).await?.into_inner();
+ Ok(response)
+ }
+
+ // TODO other methods
+ // list_flights
+ // get_schema
+ // do_put
+ // do_action
+ // list_actions
+ // do_exchange
+
+ /// return a Request, adding any configured metadata
+ fn make_request<T>(&self, t: T) -> tonic::Request<T> {
+ // Pass along metadata
+ let mut request = tonic::Request::new(t);
+ *request.metadata_mut() = self.metadata.clone();
+ request
+ }
+}
+
+/// A stream of [`RecordBatch`]es from from an Arrow Flight server.
+///
+/// To access the lower level Flight messages directly, consider
+/// calling [`Self::into_inner`] and using the [`FlightDataStream`]
+/// directly.
+#[derive(Debug)]
+pub struct FlightRecordBatchStream {
+ inner: FlightDataStream,
+ got_schema: bool,
+}
+
+impl FlightRecordBatchStream {
+ pub fn new(inner: FlightDataStream) -> Self {
+ Self {
+ inner,
+ got_schema: false,
+ }
+ }
+
+ /// Has a message defining the schema been received yet?
+ pub fn got_schema(&self) -> bool {
+ self.got_schema
+ }
+
+ /// Consume self and return the wrapped [`FlightDataStream`]
+ pub fn into_inner(self) -> FlightDataStream {
+ self.inner
+ }
+}
+impl futures::Stream for FlightRecordBatchStream {
+ type Item = Result<RecordBatch>;
+
+ /// Returns the next [`RecordBatch`] available in this stream, or `None` if
+ /// there are no further results available.
+ fn poll_next(
+ mut self: Pin<&mut Self>,
+ cx: &mut std::task::Context<'_>,
+ ) -> Poll<Option<Result<RecordBatch>>> {
+ loop {
+ let res = ready!(self.inner.poll_next_unpin(cx));
+ match res {
+ // Inner exhausted
+ None => {
+ return Poll::Ready(None);
+ }
+ Some(Err(e)) => {
+ return Poll::Ready(Some(Err(e)));
+ }
+ // translate data
+ Some(Ok(data)) => match data.payload {
+ DecodedPayload::Schema(_) if self.got_schema => {
+ return Poll::Ready(Some(Err(FlightError::protocol(
+ "Unexpectedly saw multiple Schema messages in FlightData stream",
+ ))));
+ }
+ DecodedPayload::Schema(_) => {
+ self.got_schema = true;
+ // Need next message, poll inner again
+ }
+ DecodedPayload::RecordBatch(batch) => {
+ return Poll::Ready(Some(Ok(batch)));
+ }
+ DecodedPayload::None => {
+ // Need next message
+ }
+ },
+ }
+ }
+ }
+}
+
+/// Wrapper around a stream of [`FlightData`] that handles the details
+/// of decoding low level Flight messages into [`Schema`] and
+/// [`RecordBatch`]es, including details such as dictionaries.
+///
+/// # Protocol Details
+///
+/// The client handles flight messages as followes:
+///
+/// - **None:** This message has no effect. This is useful to
+/// transmit metadata without any actual payload.
+///
+/// - **Schema:** The schema is (re-)set. Dictionaries are cleared and
+/// the decoded schema is returned.
+///
+/// - **Dictionary Batch:** A new dictionary for a given column is registered. An existing
+/// dictionary for the same column will be overwritten. This
+/// message is NOT visible.
+///
+/// - **Record Batch:** Record batch is created based on the current
+/// schema and dictionaries. This fails if no schema was transmitted
+/// yet.
+///
+/// All other message types (at the time of writing: e.g. tensor and
+/// sparse tensor) lead to an error.
+///
+/// Example usecases
+///
+/// 1. Using this low level stream it is possible to receive a steam
+/// of RecordBatches in FlightData that have different schemas by
+/// handling multiple schema messages separately.
+#[derive(Debug)]
+pub struct FlightDataStream {
+ /// Underlying data stream
+ response: Streaming<FlightData>,
+ /// Decoding state
+ state: Option<FlightStreamState>,
+ /// seen the end of the inner stream?
+ done: bool,
+}
+
+impl FlightDataStream {
+ /// Create a new wrapper around the stream of FlightData
+ pub fn new(response: Streaming<FlightData>) -> Self {
+ Self {
+ state: None,
+ response,
+ done: false,
+ }
+ }
+
+ /// Extracts flight data from the next message, updating decoding
+ /// state as necessary.
+ fn extract_message(&mut self, data: FlightData) -> Result<Option<DecodedFlightData>> {
+ use arrow_ipc::MessageHeader;
+ let message = arrow_ipc::root_as_message(&data.data_header[..]).map_err(|e| {
+ FlightError::DecodeError(format!("Error decoding root message: {e}"))
+ })?;
+
+ match message.header_type() {
+ MessageHeader::NONE => Ok(Some(DecodedFlightData::new_none(data))),
+ MessageHeader::Schema => {
+ let schema = Schema::try_from(&data).map_err(|e| {
+ FlightError::DecodeError(format!("Error decoding schema: {e}"))
+ })?;
+
+ let schema = Arc::new(schema);
+ let dictionaries_by_field = HashMap::new();
+
+ self.state = Some(FlightStreamState {
+ schema: Arc::clone(&schema),
+ dictionaries_by_field,
+ });
+ Ok(Some(DecodedFlightData::new_schema(data, schema)))
+ }
+ MessageHeader::DictionaryBatch => {
+ let state = if let Some(state) = self.state.as_mut() {
+ state
+ } else {
+ return Err(FlightError::protocol(
+ "Received DictionaryBatch prior to Schema",
+ ));
+ };
+
+ let buffer: arrow_buffer::Buffer = data.data_body.into();
+ let dictionary_batch =
+ message.header_as_dictionary_batch().ok_or_else(|| {
+ FlightError::protocol(
+ "Could not get dictionary batch from DictionaryBatch message",
+ )
+ })?;
+
+ arrow_ipc::reader::read_dictionary(
+ &buffer,
+ dictionary_batch,
+ &state.schema,
+ &mut state.dictionaries_by_field,
+ &message.version(),
+ )
+ .map_err(|e| {
+ FlightError::DecodeError(format!(
+ "Error decoding ipc dictionary: {e}"
+ ))
+ })?;
+
+ // Updated internal state, but no decoded message
+ Ok(None)
+ }
+ MessageHeader::RecordBatch => {
+ let state = if let Some(state) = self.state.as_ref() {
+ state
+ } else {
+ return Err(FlightError::protocol(
+ "Received RecordBatch prior to Schema",
+ ));
+ };
+
+ let batch = flight_data_to_arrow_batch(
+ &data,
+ Arc::clone(&state.schema),
+ &state.dictionaries_by_field,
+ )
+ .map_err(|e| {
+ FlightError::DecodeError(format!(
+ "Error decoding ipc RecordBatch: {e}"
+ ))
+ })?;
+
+ Ok(Some(DecodedFlightData::new_record_batch(data, batch)))
+ }
+ other => {
+ let name = other.variant_name().unwrap_or("UNKNOWN");
+ Err(FlightError::protocol(format!("Unexpected message: {name}")))
+ }
+ }
+ }
+}
+
+impl futures::Stream for FlightDataStream {
+ type Item = Result<DecodedFlightData>;
+ /// Returns the result of decoding the next [`FlightData`] message
+ /// from the server, or `None` if there are no further results
+ /// available.
+ fn poll_next(
+ mut self: Pin<&mut Self>,
+ cx: &mut std::task::Context<'_>,
+ ) -> Poll<Option<Self::Item>> {
+ if self.done {
+ return Poll::Ready(None);
+ }
+ loop {
+ let res = ready!(self.response.poll_next_unpin(cx));
+
+ return Poll::Ready(match res {
+ None => {
+ self.done = true;
+ None // inner is exhausted
+ }
+ Some(data) => Some(match data {
+ Err(e) => Err(FlightError::Tonic(e)),
+ Ok(data) => match self.extract_message(data) {
+ Ok(Some(extracted)) => Ok(extracted),
+ Ok(None) => continue, // Need next input message
+ Err(e) => Err(e),
+ },
+ }),
+ });
+ }
+ }
+}
+
+/// tracks the state needed to reconstruct [`RecordBatch`]es from a
+/// streaming flight response.
+#[derive(Debug)]
+struct FlightStreamState {
+ schema: Arc<Schema>,
+ dictionaries_by_field: HashMap<i64, ArrayRef>,
+}
+
+/// FlightData and the decoded payload (Schema, RecordBatch), if any
+#[derive(Debug)]
+pub struct DecodedFlightData {
+ pub inner: FlightData,
+ pub payload: DecodedPayload,
+}
+
+impl DecodedFlightData {
+ pub fn new_none(inner: FlightData) -> Self {
+ Self {
+ inner,
+ payload: DecodedPayload::None,
+ }
+ }
+
+ pub fn new_schema(inner: FlightData, schema: Arc<Schema>) -> Self {
+ Self {
+ inner,
+ payload: DecodedPayload::Schema(schema),
+ }
+ }
+
+ pub fn new_record_batch(inner: FlightData, batch: RecordBatch) -> Self {
+ Self {
+ inner,
+ payload: DecodedPayload::RecordBatch(batch),
+ }
+ }
+
+ /// return the metadata field of the inner flight data
+ pub fn app_metadata(&self) -> &[u8] {
+ &self.inner.app_metadata
+ }
+}
+
+/// The result of decoding [`FlightData`]
+#[derive(Debug)]
+pub enum DecodedPayload {
+ /// None (no data was sent in the corresponding FlightData)
+ None,
+
+ /// A decoded Schema message
+ Schema(Arc<Schema>),
+
+ /// A decoded Record batch.
+ RecordBatch(RecordBatch),
+}
diff --git a/arrow-flight/src/error.rs b/arrow-flight/src/error.rs
new file mode 100644
index 000000000..fbb9efa44
--- /dev/null
+++ b/arrow-flight/src/error.rs
@@ -0,0 +1,59 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+/// Errors for the Apache Arrow Flight crate
+#[derive(Debug)]
+pub enum FlightError {
+ /// Returned when functionality is not yet available.
+ NotYetImplemented(String),
+ /// Error from the underlying tonic library
+ Tonic(tonic::Status),
+ /// Some unexpected message was received
+ ProtocolError(String),
+ /// An error occured during decoding
+ DecodeError(String),
+ /// Some other (opaque) error
+ ExternalError(Box<dyn std::error::Error + Send + Sync>),
+}
+
+impl FlightError {
+ pub fn protocol(message: impl Into<String>) -> Self {
+ Self::ProtocolError(message.into())
+ }
+
+ /// Wraps an external error in an `ArrowError`.
+ pub fn from_external_error(error: Box<dyn std::error::Error + Send + Sync>) -> Self {
+ Self::ExternalError(error)
+ }
+}
+
+impl std::fmt::Display for FlightError {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ // TODO better format / error
+ write!(f, "{:?}", self)
+ }
+}
+
+impl std::error::Error for FlightError {}
+
+impl From<tonic::Status> for FlightError {
+ fn from(status: tonic::Status) -> Self {
+ Self::Tonic(status)
+ }
+}
+
+pub type Result<T> = std::result::Result<T, FlightError>;
diff --git a/arrow-flight/src/lib.rs b/arrow-flight/src/lib.rs
index 051509fb1..f30cb5484 100644
--- a/arrow-flight/src/lib.rs
+++ b/arrow-flight/src/lib.rs
@@ -71,6 +71,13 @@ pub mod flight_service_server {
pub use gen::flight_service_server::FlightServiceServer;
}
+/// Mid Level [`FlightClient`] for
+pub mod client;
+pub use client::FlightClient;
+
+/// Common error types
+pub mod error;
+
pub use gen::Action;
pub use gen::ActionType;
pub use gen::BasicAuth;
diff --git a/arrow-flight/tests/client.rs b/arrow-flight/tests/client.rs
new file mode 100644
index 000000000..5bc1062f0
--- /dev/null
+++ b/arrow-flight/tests/client.rs
@@ -0,0 +1,309 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Integration test for "mid level" Client
+
+mod common {
+ pub mod server;
+}
+use arrow_flight::{
+ error::FlightError, FlightClient, FlightDescriptor, FlightInfo, HandshakeRequest,
+ HandshakeResponse,
+};
+use bytes::Bytes;
+use common::server::TestFlightServer;
+use futures::Future;
+use tokio::{net::TcpListener, task::JoinHandle};
+use tonic::{
+ transport::{Channel, Uri},
+ Status,
+};
+
+use std::{net::SocketAddr, time::Duration};
+
+const DEFAULT_TIMEOUT_SECONDS: u64 = 30;
+
+#[tokio::test]
+async fn test_handshake() {
+ do_test(|test_server, mut client| async move {
+ let request_payload = Bytes::from("foo");
+ let response_payload = Bytes::from("Bar");
+
+ let request = HandshakeRequest {
+ payload: request_payload.clone(),
+ protocol_version: 0,
+ };
+
+ let response = HandshakeResponse {
+ payload: response_payload.clone(),
+ protocol_version: 0,
+ };
+
+ test_server.set_handshake_response(Ok(response));
+ let response = client.handshake(request_payload).await.unwrap();
+ assert_eq!(response, response_payload);
+ assert_eq!(test_server.take_handshake_request(), Some(request));
+ })
+ .await;
+}
+
+#[tokio::test]
+async fn test_handshake_error() {
+ do_test(|test_server, mut client| async move {
+ let request_payload = "foo".to_string().into_bytes();
+ let e = Status::unauthenticated("DENIED");
+ test_server.set_handshake_response(Err(e));
+
+ let response = client.handshake(request_payload).await.unwrap_err();
+ let e = Status::unauthenticated("DENIED");
+ expect_status(response, e);
+ })
+ .await;
+}
+
+#[tokio::test]
+async fn test_handshake_metadata() {
+ do_test(|test_server, mut client| async move {
+ client.add_header("foo", "bar").unwrap();
+
+ let request_payload = Bytes::from("Blarg");
+ let response_payload = Bytes::from("Bazz");
+
+ let response = HandshakeResponse {
+ payload: response_payload.clone(),
+ protocol_version: 0,
+ };
+
+ test_server.set_handshake_response(Ok(response));
+ client.handshake(request_payload).await.unwrap();
+ ensure_metadata(&client, &test_server);
+ })
+ .await;
+}
+
+/// Verifies that all headers sent from the the client are in the request_metadata
+fn ensure_metadata(client: &FlightClient, test_server: &TestFlightServer) {
+ let client_metadata = client.metadata().clone().into_headers();
+ assert!(!client_metadata.is_empty());
+ let metadata = test_server
+ .take_last_request_metadata()
+ .expect("No headers in server")
+ .into_headers();
+
+ for (k, v) in &client_metadata {
+ assert_eq!(
+ metadata.get(k).as_ref(),
+ Some(&v),
+ "Missing / Mismatched metadata {:?} sent {:?} got {:?}",
+ k,
+ client_metadata,
+ metadata
+ );
+ }
+}
+
+fn test_flight_info(request: &FlightDescriptor) -> FlightInfo {
+ FlightInfo {
+ schema: Bytes::new(),
+ endpoint: vec![],
+ flight_descriptor: Some(request.clone()),
+ total_bytes: 123,
+ total_records: 456,
+ }
+}
+
+#[tokio::test]
+async fn test_get_flight_info() {
+ do_test(|test_server, mut client| async move {
+ let request = FlightDescriptor::new_cmd(b"My Command".to_vec());
+
+ let expected_response = test_flight_info(&request);
+ test_server.set_get_flight_info_response(Ok(expected_response.clone()));
+
+ let response = client.get_flight_info(request.clone()).await.unwrap();
+
+ assert_eq!(response, expected_response);
+ assert_eq!(test_server.take_get_flight_info_request(), Some(request));
+ })
+ .await;
+}
+
+#[tokio::test]
+async fn test_get_flight_info_error() {
+ do_test(|test_server, mut client| async move {
+ let request = FlightDescriptor::new_cmd(b"My Command".to_vec());
+
+ let e = Status::unauthenticated("DENIED");
+ test_server.set_get_flight_info_response(Err(e));
+
+ let response = client.get_flight_info(request.clone()).await.unwrap_err();
+ let e = Status::unauthenticated("DENIED");
+ expect_status(response, e);
+ })
+ .await;
+}
+
+#[tokio::test]
+async fn test_get_flight_info_metadata() {
+ do_test(|test_server, mut client| async move {
+ client.add_header("foo", "bar").unwrap();
+ let request = FlightDescriptor::new_cmd(b"My Command".to_vec());
+
+ let expected_response = test_flight_info(&request);
+ test_server.set_get_flight_info_response(Ok(expected_response));
+ client.get_flight_info(request.clone()).await.unwrap();
+ ensure_metadata(&client, &test_server);
+ })
+ .await;
+}
+
+// TODO more negative tests (like if there are endpoints defined, etc)
+
+// TODO test for do_get
+
+/// Runs the future returned by the function, passing it a test server and client
+async fn do_test<F, Fut>(f: F)
+where
+ F: Fn(TestFlightServer, FlightClient) -> Fut,
+ Fut: Future<Output = ()>,
+{
+ let test_server = TestFlightServer::new();
+ let fixture = TestFixture::new(&test_server).await;
+ let client = FlightClient::new(fixture.channel().await);
+
+ // run the test function
+ f(test_server, client).await;
+
+ // cleanly shutdown the test fixture
+ fixture.shutdown_and_wait().await
+}
+
+fn expect_status(error: FlightError, expected: Status) {
+ let status = if let FlightError::Tonic(status) = error {
+ status
+ } else {
+ panic!("Expected FlightError::Tonic, got: {:?}", error);
+ };
+
+ assert_eq!(
+ status.code(),
+ expected.code(),
+ "Got {:?} want {:?}",
+ status,
+ expected
+ );
+ assert_eq!(
+ status.message(),
+ expected.message(),
+ "Got {:?} want {:?}",
+ status,
+ expected
+ );
+ assert_eq!(
+ status.details(),
+ expected.details(),
+ "Got {:?} want {:?}",
+ status,
+ expected
+ );
+}
+
+/// Creates and manages a running TestServer with a background task
+struct TestFixture {
+ /// channel to send shutdown command
+ shutdown: Option<tokio::sync::oneshot::Sender<()>>,
+
+ /// Address the server is listening on
+ addr: SocketAddr,
+
+ // handle for the server task
+ handle: Option<JoinHandle<Result<(), tonic::transport::Error>>>,
+}
+
+impl TestFixture {
+ /// create a new test fixture from the server
+ pub async fn new(test_server: &TestFlightServer) -> Self {
+ // let OS choose a a free port
+ let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
+ let addr = listener.local_addr().unwrap();
+
+ println!("Listening on {addr}");
+
+ // prepare the shutdown channel
+ let (tx, rx) = tokio::sync::oneshot::channel();
+
+ let server_timeout = Duration::from_secs(DEFAULT_TIMEOUT_SECONDS);
+
+ let shutdown_future = async move {
+ rx.await.ok();
+ };
+
+ let serve_future = tonic::transport::Server::builder()
+ .timeout(server_timeout)
+ .add_service(test_server.service())
+ .serve_with_incoming_shutdown(
+ tokio_stream::wrappers::TcpListenerStream::new(listener),
+ shutdown_future,
+ );
+
+ // Run the server in its own background task
+ let handle = tokio::task::spawn(serve_future);
+
+ Self {
+ shutdown: Some(tx),
+ addr,
+ handle: Some(handle),
+ }
+ }
+
+ /// Return a [`Channel`] connected to the TestServer
+ pub async fn channel(&self) -> Channel {
+ let url = format!("http://{}", self.addr);
+ let uri: Uri = url.parse().expect("Valid URI");
+ Channel::builder(uri)
+ .timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECONDS))
+ .connect()
+ .await
+ .expect("error connecting to server")
+ }
+
+ /// Stops the test server and waits for the server to shutdown
+ pub async fn shutdown_and_wait(mut self) {
+ if let Some(shutdown) = self.shutdown.take() {
+ shutdown.send(()).expect("server quit early");
+ }
+ if let Some(handle) = self.handle.take() {
+ println!("Waiting on server to finish");
+ handle
+ .await
+ .expect("task join error (panic?)")
+ .expect("Server Error found at shutdown");
+ }
+ }
+}
+
+impl Drop for TestFixture {
+ fn drop(&mut self) {
+ if let Some(shutdown) = self.shutdown.take() {
+ shutdown.send(()).ok();
+ }
+ if self.handle.is_some() {
+ // tests should properly clean up TestFixture
+ println!("TestFixture::Drop called prior to `shutdown_and_wait`");
+ }
+ }
+}
diff --git a/arrow-flight/tests/common/server.rs b/arrow-flight/tests/common/server.rs
new file mode 100644
index 000000000..f1cb140b6
--- /dev/null
+++ b/arrow-flight/tests/common/server.rs
@@ -0,0 +1,212 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::sync::{Arc, Mutex};
+
+use futures::stream::BoxStream;
+use tonic::{metadata::MetadataMap, Request, Response, Status, Streaming};
+
+use arrow_flight::{
+ flight_service_server::{FlightService, FlightServiceServer},
+ Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
+ HandshakeRequest, HandshakeResponse, PutResult, SchemaResult, Ticket,
+};
+
+#[derive(Debug, Clone)]
+/// Flight server for testing, with configurable responses
+pub struct TestFlightServer {
+ /// Shared state to configure responses
+ state: Arc<Mutex<State>>,
+}
+
+impl TestFlightServer {
+ /// Create a `TestFlightServer`
+ pub fn new() -> Self {
+ Self {
+ state: Arc::new(Mutex::new(State::new())),
+ }
+ }
+
+ /// Return an [`FlightServiceServer`] that can be used with a
+ /// [`Server`](tonic::transport::Server)
+ pub fn service(&self) -> FlightServiceServer<TestFlightServer> {
+ // wrap up tonic goop
+ FlightServiceServer::new(self.clone())
+ }
+
+ /// Specify the response returned from the next call to handshake
+ pub fn set_handshake_response(&self, response: Result<HandshakeResponse, Status>) {
+ let mut state = self.state.lock().expect("mutex not poisoned");
+
+ state.handshake_response.replace(response);
+ }
+
+ /// Take and return last handshake request send to the server,
+ pub fn take_handshake_request(&self) -> Option<HandshakeRequest> {
+ self.state
+ .lock()
+ .expect("mutex not poisoned")
+ .handshake_request
+ .take()
+ }
+
+ /// Specify the response returned from the next call to handshake
+ pub fn set_get_flight_info_response(&self, response: Result<FlightInfo, Status>) {
+ let mut state = self.state.lock().expect("mutex not poisoned");
+
+ state.get_flight_info_response.replace(response);
+ }
+
+ /// Take and return last get_flight_info request send to the server,
+ pub fn take_get_flight_info_request(&self) -> Option<FlightDescriptor> {
+ self.state
+ .lock()
+ .expect("mutex not poisoned")
+ .get_flight_info_request
+ .take()
+ }
+
+ /// Returns the last metadata from a request received by the server
+ pub fn take_last_request_metadata(&self) -> Option<MetadataMap> {
+ self.state
+ .lock()
+ .expect("mutex not poisoned")
+ .last_request_metadata
+ .take()
+ }
+
+ /// Save the last request's metadatacom
+ fn save_metadata<T>(&self, request: &Request<T>) {
+ let metadata = request.metadata().clone();
+ let mut state = self.state.lock().expect("mutex not poisoned");
+ state.last_request_metadata = Some(metadata);
+ }
+}
+
+/// mutable state for the TestFlightSwrver
+#[derive(Debug, Default)]
+struct State {
+ /// The last handshake request that was received
+ pub handshake_request: Option<HandshakeRequest>,
+ /// The next response to return from `handshake()`
+ pub handshake_response: Option<Result<HandshakeResponse, Status>>,
+ /// The last `get_flight_info` request received
+ pub get_flight_info_request: Option<FlightDescriptor>,
+ /// the next response to return from `get_flight_info`
+ pub get_flight_info_response: Option<Result<FlightInfo, Status>>,
+ /// The last request headers received
+ pub last_request_metadata: Option<MetadataMap>,
+}
+
+impl State {
+ fn new() -> Self {
+ Default::default()
+ }
+}
+
+/// Implement the FlightService trait
+#[tonic::async_trait]
+impl FlightService for TestFlightServer {
+ type HandshakeStream = BoxStream<'static, Result<HandshakeResponse, Status>>;
+ type ListFlightsStream = BoxStream<'static, Result<FlightInfo, Status>>;
+ type DoGetStream = BoxStream<'static, Result<FlightData, Status>>;
+ type DoPutStream = BoxStream<'static, Result<PutResult, Status>>;
+ type DoActionStream = BoxStream<'static, Result<arrow_flight::Result, Status>>;
+ type ListActionsStream = BoxStream<'static, Result<ActionType, Status>>;
+ type DoExchangeStream = BoxStream<'static, Result<FlightData, Status>>;
+
+ async fn handshake(
+ &self,
+ request: Request<Streaming<HandshakeRequest>>,
+ ) -> Result<Response<Self::HandshakeStream>, Status> {
+ self.save_metadata(&request);
+ let handshake_request = request.into_inner().message().await?.unwrap();
+
+ let mut state = self.state.lock().expect("mutex not poisoned");
+ state.handshake_request = Some(handshake_request);
+
+ let response = state.handshake_response.take().unwrap_or_else(|| {
+ Err(Status::internal("No handshake response configured"))
+ })?;
+
+ // turn into a streaming response
+ let output = futures::stream::iter(std::iter::once(Ok(response)));
+ Ok(Response::new(Box::pin(output) as Self::HandshakeStream))
+ }
+
+ async fn list_flights(
+ &self,
+ _request: Request<Criteria>,
+ ) -> Result<Response<Self::ListFlightsStream>, Status> {
+ Err(Status::unimplemented("Implement list_flights"))
+ }
+
+ async fn get_flight_info(
+ &self,
+ request: Request<FlightDescriptor>,
+ ) -> Result<Response<FlightInfo>, Status> {
+ self.save_metadata(&request);
+ let mut state = self.state.lock().expect("mutex not poisoned");
+ state.get_flight_info_request = Some(request.into_inner());
+ let response = state.get_flight_info_response.take().unwrap_or_else(|| {
+ Err(Status::internal("No get_flight_info response configured"))
+ })?;
+ Ok(Response::new(response))
+ }
+
+ async fn get_schema(
+ &self,
+ _request: Request<FlightDescriptor>,
+ ) -> Result<Response<SchemaResult>, Status> {
+ Err(Status::unimplemented("Implement get_schema"))
+ }
+
+ async fn do_get(
+ &self,
+ _request: Request<Ticket>,
+ ) -> Result<Response<Self::DoGetStream>, Status> {
+ Err(Status::unimplemented("Implement do_get"))
+ }
+
+ async fn do_put(
+ &self,
+ _request: Request<Streaming<FlightData>>,
+ ) -> Result<Response<Self::DoPutStream>, Status> {
+ Err(Status::unimplemented("Implement do_put"))
+ }
+
+ async fn do_action(
+ &self,
+ _request: Request<Action>,
+ ) -> Result<Response<Self::DoActionStream>, Status> {
+ Err(Status::unimplemented("Implement do_action"))
+ }
+
+ async fn list_actions(
+ &self,
+ _request: Request<Empty>,
+ ) -> Result<Response<Self::ListActionsStream>, Status> {
+ Err(Status::unimplemented("Implement list_actions"))
+ }
+
+ async fn do_exchange(
+ &self,
+ _request: Request<Streaming<FlightData>>,
+ ) -> Result<Response<Self::DoExchangeStream>, Status> {
+ Err(Status::unimplemented("Implement do_exchange"))
+ }
+}