You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/12/23 12:23:43 UTC

[arrow-rs] branch master updated: Initial Mid-level `FlightClient` (#3378)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 17b3210af Initial Mid-level `FlightClient` (#3378)
17b3210af is described below

commit 17b3210af2ccd190489de9c641fd10f009abd45b
Author: Andrew Lamb <an...@nerdnetworks.org>
AuthorDate: Fri Dec 23 07:23:37 2022 -0500

    Initial Mid-level `FlightClient` (#3378)
    
    * Mid-level FlightClient
    
    * cleanup
    
    * fixup for use of Bytes
    
    * clippy
    
    * Apply suggestions from code review
    
    Co-authored-by: Raphael Taylor-Davies <17...@users.noreply.github.com>
    
    * fixup
    
    * BoxStream
    
    Co-authored-by: Raphael Taylor-Davies <17...@users.noreply.github.com>
---
 arrow-flight/src/client.rs          | 567 ++++++++++++++++++++++++++++++++++++
 arrow-flight/src/error.rs           |  59 ++++
 arrow-flight/src/lib.rs             |   7 +
 arrow-flight/tests/client.rs        | 309 ++++++++++++++++++++
 arrow-flight/tests/common/server.rs | 212 ++++++++++++++
 5 files changed, 1154 insertions(+)

diff --git a/arrow-flight/src/client.rs b/arrow-flight/src/client.rs
new file mode 100644
index 000000000..0e75ac7c0
--- /dev/null
+++ b/arrow-flight/src/client.rs
@@ -0,0 +1,567 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::{
+    flight_service_client::FlightServiceClient, utils::flight_data_to_arrow_batch,
+    FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, Ticket,
+};
+use arrow_array::{ArrayRef, RecordBatch};
+use arrow_schema::Schema;
+use bytes::Bytes;
+use futures::{future::ready, ready, stream, StreamExt};
+use std::{collections::HashMap, convert::TryFrom, pin::Pin, sync::Arc, task::Poll};
+use tonic::{metadata::MetadataMap, transport::Channel, Streaming};
+
+use crate::error::{FlightError, Result};
+
+/// A "Mid level" [Apache Arrow Flight](https://arrow.apache.org/docs/format/Flight.html) client.
+///
+/// [`FlightClient`] is intended as a convenience for interactions
+/// with Arrow Flight servers. For more direct control, such as access
+/// to the response headers, use  [`FlightServiceClient`] directly
+/// via methods such as [`Self::inner`] or [`Self::into_inner`].
+///
+/// # Example:
+/// ```no_run
+/// # async fn run() {
+/// # use arrow_flight::FlightClient;
+/// # use bytes::Bytes;
+/// use tonic::transport::Channel;
+/// let channel = Channel::from_static("http://localhost:1234")
+///   .connect()
+///   .await
+///   .expect("error connecting");
+///
+/// let mut client = FlightClient::new(channel);
+///
+/// // Send 'Hi' bytes as the handshake request to the server
+/// let response = client
+///   .handshake(Bytes::from("Hi"))
+///   .await
+///   .expect("error handshaking");
+///
+/// // Expect the server responded with 'Ho'
+/// assert_eq!(response, Bytes::from("Ho"));
+/// # }
+/// ```
+#[derive(Debug)]
+pub struct FlightClient {
+    /// Optional grpc header metadata to include with each request
+    metadata: MetadataMap,
+
+    /// The inner client
+    inner: FlightServiceClient<Channel>,
+}
+
+impl FlightClient {
+    /// Creates a client client with the provided [`Channel`](tonic::transport::Channel)
+    pub fn new(channel: Channel) -> Self {
+        Self::new_from_inner(FlightServiceClient::new(channel))
+    }
+
+    /// Creates a new higher level client with the provided lower level client
+    pub fn new_from_inner(inner: FlightServiceClient<Channel>) -> Self {
+        Self {
+            metadata: MetadataMap::new(),
+            inner,
+        }
+    }
+
+    /// Return a reference to gRPC metadata included with each request
+    pub fn metadata(&self) -> &MetadataMap {
+        &self.metadata
+    }
+
+    /// Return a reference to gRPC metadata included with each request
+    ///
+    /// These headers can be used, for example, to include
+    /// authorization or other application specific headers.
+    pub fn metadata_mut(&mut self) -> &mut MetadataMap {
+        &mut self.metadata
+    }
+
+    /// Add the specified header with value to all subsequent
+    /// requests. See [`Self::metadata_mut`] for fine grained control.
+    pub fn add_header(&mut self, key: &str, value: &str) -> Result<()> {
+        let key = tonic::metadata::MetadataKey::<_>::from_bytes(key.as_bytes())
+            .map_err(|e| FlightError::ExternalError(Box::new(e)))?;
+
+        let value = value
+            .parse()
+            .map_err(|e| FlightError::ExternalError(Box::new(e)))?;
+
+        // ignore previous value
+        self.metadata.insert(key, value);
+
+        Ok(())
+    }
+
+    /// Return a reference to the underlying tonic
+    /// [`FlightServiceClient`]
+    pub fn inner(&self) -> &FlightServiceClient<Channel> {
+        &self.inner
+    }
+
+    /// Return a mutable reference to the underlying tonic
+    /// [`FlightServiceClient`]
+    pub fn inner_mut(&mut self) -> &mut FlightServiceClient<Channel> {
+        &mut self.inner
+    }
+
+    /// Consume this client and return the underlying tonic
+    /// [`FlightServiceClient`]
+    pub fn into_inner(self) -> FlightServiceClient<Channel> {
+        self.inner
+    }
+
+    /// Perform an Arrow Flight handshake with the server, sending
+    /// `payload` as the [`HandshakeRequest`] payload and returning
+    /// the [`HandshakeResponse`](crate::HandshakeResponse)
+    /// bytes returned from the server
+    ///
+    /// See [`FlightClient`] docs for an example.
+    pub async fn handshake(&mut self, payload: impl Into<Bytes>) -> Result<Bytes> {
+        let request = HandshakeRequest {
+            protocol_version: 0,
+            payload: payload.into(),
+        };
+
+        // apply headers, etc
+        let request = self.make_request(stream::once(ready(request)));
+
+        let mut response_stream = self.inner.handshake(request).await?.into_inner();
+
+        if let Some(response) = response_stream.next().await.transpose()? {
+            // check if there is another response
+            if response_stream.next().await.is_some() {
+                return Err(FlightError::protocol(
+                    "Got unexpected second response from handshake",
+                ));
+            }
+
+            Ok(response.payload)
+        } else {
+            Err(FlightError::protocol("No response from handshake"))
+        }
+    }
+
+    /// Make a `DoGet` call to the server with the provided ticket,
+    /// returning a [`FlightRecordBatchStream`] for reading
+    /// [`RecordBatch`]es.
+    ///
+    /// # Example:
+    /// ```no_run
+    /// # async fn run() {
+    /// # use bytes::Bytes;
+    /// # use arrow_flight::FlightClient;
+    /// # use arrow_flight::Ticket;
+    /// # use arrow_array::RecordBatch;
+    /// # use tonic::transport::Channel;
+    /// # use futures::stream::TryStreamExt;
+    /// # let channel = Channel::from_static("http://localhost:1234")
+    /// #  .connect()
+    /// #  .await
+    /// #  .expect("error connecting");
+    /// # let ticket = Ticket { ticket: Bytes::from("foo") };
+    /// let mut client = FlightClient::new(channel);
+    ///
+    /// // Invoke a do_get request on the server with a previously
+    /// // received Ticket
+    ///
+    /// let response = client
+    ///    .do_get(ticket)
+    ///    .await
+    ///    .expect("error invoking do_get");
+    ///
+    /// // Use try_collect to get the RecordBatches from the server
+    /// let batches: Vec<RecordBatch> = response
+    ///    .try_collect()
+    ///    .await
+    ///    .expect("no stream errors");
+    /// # }
+    /// ```
+    pub async fn do_get(&mut self, ticket: Ticket) -> Result<FlightRecordBatchStream> {
+        let request = self.make_request(ticket);
+
+        let response = self.inner.do_get(request).await?.into_inner();
+
+        let flight_data_stream = FlightDataStream::new(response);
+        Ok(FlightRecordBatchStream::new(flight_data_stream))
+    }
+
+    /// Make a `GetFlightInfo` call to the server with the provided
+    /// [`FlightDescriptor`] and return the [`FlightInfo`] from the
+    /// server. The [`FlightInfo`] can be used with [`Self::do_get`]
+    /// to retrieve the requested batches.
+    ///
+    /// # Example:
+    /// ```no_run
+    /// # async fn run() {
+    /// # use arrow_flight::FlightClient;
+    /// # use arrow_flight::FlightDescriptor;
+    /// # use tonic::transport::Channel;
+    /// # let channel = Channel::from_static("http://localhost:1234")
+    /// #   .connect()
+    /// #   .await
+    /// #   .expect("error connecting");
+    /// let mut client = FlightClient::new(channel);
+    ///
+    /// // Send a 'CMD' request to the server
+    /// let request = FlightDescriptor::new_cmd(b"MOAR DATA".to_vec());
+    /// let flight_info = client
+    ///   .get_flight_info(request)
+    ///   .await
+    ///   .expect("error handshaking");
+    ///
+    /// // retrieve the first endpoint from the returned flight info
+    /// let ticket = flight_info
+    ///   .endpoint[0]
+    ///   // Extract the ticket
+    ///   .ticket
+    ///   .clone()
+    ///   .expect("expected ticket");
+    ///
+    /// // Retrieve the corresponding RecordBatch stream with do_get
+    /// let data = client
+    ///   .do_get(ticket)
+    ///   .await
+    ///   .expect("error fetching data");
+    /// # }
+    /// ```
+    pub async fn get_flight_info(
+        &mut self,
+        descriptor: FlightDescriptor,
+    ) -> Result<FlightInfo> {
+        let request = self.make_request(descriptor);
+
+        let response = self.inner.get_flight_info(request).await?.into_inner();
+        Ok(response)
+    }
+
+    // TODO other methods
+    // list_flights
+    // get_schema
+    // do_put
+    // do_action
+    // list_actions
+    // do_exchange
+
+    /// return a Request, adding any configured metadata
+    fn make_request<T>(&self, t: T) -> tonic::Request<T> {
+        // Pass along metadata
+        let mut request = tonic::Request::new(t);
+        *request.metadata_mut() = self.metadata.clone();
+        request
+    }
+}
+
+/// A stream of [`RecordBatch`]es from from an Arrow Flight server.
+///
+/// To access the lower level Flight messages directly, consider
+/// calling [`Self::into_inner`] and using the [`FlightDataStream`]
+/// directly.
+#[derive(Debug)]
+pub struct FlightRecordBatchStream {
+    inner: FlightDataStream,
+    got_schema: bool,
+}
+
+impl FlightRecordBatchStream {
+    pub fn new(inner: FlightDataStream) -> Self {
+        Self {
+            inner,
+            got_schema: false,
+        }
+    }
+
+    /// Has a message defining the schema been received yet?
+    pub fn got_schema(&self) -> bool {
+        self.got_schema
+    }
+
+    /// Consume self and return the wrapped [`FlightDataStream`]
+    pub fn into_inner(self) -> FlightDataStream {
+        self.inner
+    }
+}
+impl futures::Stream for FlightRecordBatchStream {
+    type Item = Result<RecordBatch>;
+
+    /// Returns the next [`RecordBatch`] available in this stream, or `None` if
+    /// there are no further results available.
+    fn poll_next(
+        mut self: Pin<&mut Self>,
+        cx: &mut std::task::Context<'_>,
+    ) -> Poll<Option<Result<RecordBatch>>> {
+        loop {
+            let res = ready!(self.inner.poll_next_unpin(cx));
+            match res {
+                // Inner exhausted
+                None => {
+                    return Poll::Ready(None);
+                }
+                Some(Err(e)) => {
+                    return Poll::Ready(Some(Err(e)));
+                }
+                // translate data
+                Some(Ok(data)) => match data.payload {
+                    DecodedPayload::Schema(_) if self.got_schema => {
+                        return Poll::Ready(Some(Err(FlightError::protocol(
+                            "Unexpectedly saw multiple Schema messages in FlightData stream",
+                        ))));
+                    }
+                    DecodedPayload::Schema(_) => {
+                        self.got_schema = true;
+                        // Need next message, poll inner again
+                    }
+                    DecodedPayload::RecordBatch(batch) => {
+                        return Poll::Ready(Some(Ok(batch)));
+                    }
+                    DecodedPayload::None => {
+                        // Need next message
+                    }
+                },
+            }
+        }
+    }
+}
+
+/// Wrapper around a stream of [`FlightData`] that handles the details
+/// of decoding low level Flight messages into [`Schema`] and
+/// [`RecordBatch`]es, including details such as dictionaries.
+///
+/// # Protocol Details
+///
+/// The client handles flight messages as followes:
+///
+/// - **None:** This message has no effect. This is useful to
+///   transmit metadata without any actual payload.
+///
+/// - **Schema:** The schema is (re-)set. Dictionaries are cleared and
+///   the decoded schema is returned.
+///
+/// - **Dictionary Batch:** A new dictionary for a given column is registered. An existing
+///   dictionary for the same column will be overwritten. This
+///   message is NOT visible.
+///
+/// - **Record Batch:** Record batch is created based on the current
+///   schema and dictionaries. This fails if no schema was transmitted
+///   yet.
+///
+/// All other message types (at the time of writing: e.g. tensor and
+/// sparse tensor) lead to an error.
+///
+/// Example usecases
+///
+/// 1. Using this low level stream it is possible to receive a steam
+/// of RecordBatches in FlightData that have different schemas by
+/// handling multiple schema messages separately.
+#[derive(Debug)]
+pub struct FlightDataStream {
+    /// Underlying data stream
+    response: Streaming<FlightData>,
+    /// Decoding state
+    state: Option<FlightStreamState>,
+    /// seen the end of the inner stream?
+    done: bool,
+}
+
+impl FlightDataStream {
+    /// Create a new wrapper around the stream of FlightData
+    pub fn new(response: Streaming<FlightData>) -> Self {
+        Self {
+            state: None,
+            response,
+            done: false,
+        }
+    }
+
+    /// Extracts flight data from the next message, updating decoding
+    /// state as necessary.
+    fn extract_message(&mut self, data: FlightData) -> Result<Option<DecodedFlightData>> {
+        use arrow_ipc::MessageHeader;
+        let message = arrow_ipc::root_as_message(&data.data_header[..]).map_err(|e| {
+            FlightError::DecodeError(format!("Error decoding root message: {e}"))
+        })?;
+
+        match message.header_type() {
+            MessageHeader::NONE => Ok(Some(DecodedFlightData::new_none(data))),
+            MessageHeader::Schema => {
+                let schema = Schema::try_from(&data).map_err(|e| {
+                    FlightError::DecodeError(format!("Error decoding schema: {e}"))
+                })?;
+
+                let schema = Arc::new(schema);
+                let dictionaries_by_field = HashMap::new();
+
+                self.state = Some(FlightStreamState {
+                    schema: Arc::clone(&schema),
+                    dictionaries_by_field,
+                });
+                Ok(Some(DecodedFlightData::new_schema(data, schema)))
+            }
+            MessageHeader::DictionaryBatch => {
+                let state = if let Some(state) = self.state.as_mut() {
+                    state
+                } else {
+                    return Err(FlightError::protocol(
+                        "Received DictionaryBatch prior to Schema",
+                    ));
+                };
+
+                let buffer: arrow_buffer::Buffer = data.data_body.into();
+                let dictionary_batch =
+                    message.header_as_dictionary_batch().ok_or_else(|| {
+                        FlightError::protocol(
+                            "Could not get dictionary batch from DictionaryBatch message",
+                        )
+                    })?;
+
+                arrow_ipc::reader::read_dictionary(
+                    &buffer,
+                    dictionary_batch,
+                    &state.schema,
+                    &mut state.dictionaries_by_field,
+                    &message.version(),
+                )
+                .map_err(|e| {
+                    FlightError::DecodeError(format!(
+                        "Error decoding ipc dictionary: {e}"
+                    ))
+                })?;
+
+                // Updated internal state, but no decoded message
+                Ok(None)
+            }
+            MessageHeader::RecordBatch => {
+                let state = if let Some(state) = self.state.as_ref() {
+                    state
+                } else {
+                    return Err(FlightError::protocol(
+                        "Received RecordBatch prior to Schema",
+                    ));
+                };
+
+                let batch = flight_data_to_arrow_batch(
+                    &data,
+                    Arc::clone(&state.schema),
+                    &state.dictionaries_by_field,
+                )
+                .map_err(|e| {
+                    FlightError::DecodeError(format!(
+                        "Error decoding ipc RecordBatch: {e}"
+                    ))
+                })?;
+
+                Ok(Some(DecodedFlightData::new_record_batch(data, batch)))
+            }
+            other => {
+                let name = other.variant_name().unwrap_or("UNKNOWN");
+                Err(FlightError::protocol(format!("Unexpected message: {name}")))
+            }
+        }
+    }
+}
+
+impl futures::Stream for FlightDataStream {
+    type Item = Result<DecodedFlightData>;
+    /// Returns the result of decoding the next [`FlightData`] message
+    /// from the server, or `None` if there are no further results
+    /// available.
+    fn poll_next(
+        mut self: Pin<&mut Self>,
+        cx: &mut std::task::Context<'_>,
+    ) -> Poll<Option<Self::Item>> {
+        if self.done {
+            return Poll::Ready(None);
+        }
+        loop {
+            let res = ready!(self.response.poll_next_unpin(cx));
+
+            return Poll::Ready(match res {
+                None => {
+                    self.done = true;
+                    None // inner is exhausted
+                }
+                Some(data) => Some(match data {
+                    Err(e) => Err(FlightError::Tonic(e)),
+                    Ok(data) => match self.extract_message(data) {
+                        Ok(Some(extracted)) => Ok(extracted),
+                        Ok(None) => continue, // Need next input message
+                        Err(e) => Err(e),
+                    },
+                }),
+            });
+        }
+    }
+}
+
+/// tracks the state needed to reconstruct [`RecordBatch`]es from a
+/// streaming flight response.
+#[derive(Debug)]
+struct FlightStreamState {
+    schema: Arc<Schema>,
+    dictionaries_by_field: HashMap<i64, ArrayRef>,
+}
+
+/// FlightData and the decoded payload (Schema, RecordBatch), if any
+#[derive(Debug)]
+pub struct DecodedFlightData {
+    pub inner: FlightData,
+    pub payload: DecodedPayload,
+}
+
+impl DecodedFlightData {
+    pub fn new_none(inner: FlightData) -> Self {
+        Self {
+            inner,
+            payload: DecodedPayload::None,
+        }
+    }
+
+    pub fn new_schema(inner: FlightData, schema: Arc<Schema>) -> Self {
+        Self {
+            inner,
+            payload: DecodedPayload::Schema(schema),
+        }
+    }
+
+    pub fn new_record_batch(inner: FlightData, batch: RecordBatch) -> Self {
+        Self {
+            inner,
+            payload: DecodedPayload::RecordBatch(batch),
+        }
+    }
+
+    /// return the metadata field of the inner flight data
+    pub fn app_metadata(&self) -> &[u8] {
+        &self.inner.app_metadata
+    }
+}
+
+/// The result of decoding [`FlightData`]
+#[derive(Debug)]
+pub enum DecodedPayload {
+    /// None (no data was sent in the corresponding FlightData)
+    None,
+
+    /// A decoded Schema message
+    Schema(Arc<Schema>),
+
+    /// A decoded Record batch.
+    RecordBatch(RecordBatch),
+}
diff --git a/arrow-flight/src/error.rs b/arrow-flight/src/error.rs
new file mode 100644
index 000000000..fbb9efa44
--- /dev/null
+++ b/arrow-flight/src/error.rs
@@ -0,0 +1,59 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+/// Errors for the Apache Arrow Flight crate
+#[derive(Debug)]
+pub enum FlightError {
+    /// Returned when functionality is not yet available.
+    NotYetImplemented(String),
+    /// Error from the underlying tonic library
+    Tonic(tonic::Status),
+    /// Some unexpected message was received
+    ProtocolError(String),
+    /// An error occured during decoding
+    DecodeError(String),
+    /// Some other (opaque) error
+    ExternalError(Box<dyn std::error::Error + Send + Sync>),
+}
+
+impl FlightError {
+    pub fn protocol(message: impl Into<String>) -> Self {
+        Self::ProtocolError(message.into())
+    }
+
+    /// Wraps an external error in an `ArrowError`.
+    pub fn from_external_error(error: Box<dyn std::error::Error + Send + Sync>) -> Self {
+        Self::ExternalError(error)
+    }
+}
+
+impl std::fmt::Display for FlightError {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        // TODO better format / error
+        write!(f, "{:?}", self)
+    }
+}
+
+impl std::error::Error for FlightError {}
+
+impl From<tonic::Status> for FlightError {
+    fn from(status: tonic::Status) -> Self {
+        Self::Tonic(status)
+    }
+}
+
+pub type Result<T> = std::result::Result<T, FlightError>;
diff --git a/arrow-flight/src/lib.rs b/arrow-flight/src/lib.rs
index 051509fb1..f30cb5484 100644
--- a/arrow-flight/src/lib.rs
+++ b/arrow-flight/src/lib.rs
@@ -71,6 +71,13 @@ pub mod flight_service_server {
     pub use gen::flight_service_server::FlightServiceServer;
 }
 
+/// Mid Level [`FlightClient`] for
+pub mod client;
+pub use client::FlightClient;
+
+/// Common error types
+pub mod error;
+
 pub use gen::Action;
 pub use gen::ActionType;
 pub use gen::BasicAuth;
diff --git a/arrow-flight/tests/client.rs b/arrow-flight/tests/client.rs
new file mode 100644
index 000000000..5bc1062f0
--- /dev/null
+++ b/arrow-flight/tests/client.rs
@@ -0,0 +1,309 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Integration test for "mid level" Client
+
+mod common {
+    pub mod server;
+}
+use arrow_flight::{
+    error::FlightError, FlightClient, FlightDescriptor, FlightInfo, HandshakeRequest,
+    HandshakeResponse,
+};
+use bytes::Bytes;
+use common::server::TestFlightServer;
+use futures::Future;
+use tokio::{net::TcpListener, task::JoinHandle};
+use tonic::{
+    transport::{Channel, Uri},
+    Status,
+};
+
+use std::{net::SocketAddr, time::Duration};
+
+const DEFAULT_TIMEOUT_SECONDS: u64 = 30;
+
+#[tokio::test]
+async fn test_handshake() {
+    do_test(|test_server, mut client| async move {
+        let request_payload = Bytes::from("foo");
+        let response_payload = Bytes::from("Bar");
+
+        let request = HandshakeRequest {
+            payload: request_payload.clone(),
+            protocol_version: 0,
+        };
+
+        let response = HandshakeResponse {
+            payload: response_payload.clone(),
+            protocol_version: 0,
+        };
+
+        test_server.set_handshake_response(Ok(response));
+        let response = client.handshake(request_payload).await.unwrap();
+        assert_eq!(response, response_payload);
+        assert_eq!(test_server.take_handshake_request(), Some(request));
+    })
+    .await;
+}
+
+#[tokio::test]
+async fn test_handshake_error() {
+    do_test(|test_server, mut client| async move {
+        let request_payload = "foo".to_string().into_bytes();
+        let e = Status::unauthenticated("DENIED");
+        test_server.set_handshake_response(Err(e));
+
+        let response = client.handshake(request_payload).await.unwrap_err();
+        let e = Status::unauthenticated("DENIED");
+        expect_status(response, e);
+    })
+    .await;
+}
+
+#[tokio::test]
+async fn test_handshake_metadata() {
+    do_test(|test_server, mut client| async move {
+        client.add_header("foo", "bar").unwrap();
+
+        let request_payload = Bytes::from("Blarg");
+        let response_payload = Bytes::from("Bazz");
+
+        let response = HandshakeResponse {
+            payload: response_payload.clone(),
+            protocol_version: 0,
+        };
+
+        test_server.set_handshake_response(Ok(response));
+        client.handshake(request_payload).await.unwrap();
+        ensure_metadata(&client, &test_server);
+    })
+    .await;
+}
+
+/// Verifies that all headers sent from the the client are in the request_metadata
+fn ensure_metadata(client: &FlightClient, test_server: &TestFlightServer) {
+    let client_metadata = client.metadata().clone().into_headers();
+    assert!(!client_metadata.is_empty());
+    let metadata = test_server
+        .take_last_request_metadata()
+        .expect("No headers in server")
+        .into_headers();
+
+    for (k, v) in &client_metadata {
+        assert_eq!(
+            metadata.get(k).as_ref(),
+            Some(&v),
+            "Missing / Mismatched metadata {:?} sent {:?} got {:?}",
+            k,
+            client_metadata,
+            metadata
+        );
+    }
+}
+
+fn test_flight_info(request: &FlightDescriptor) -> FlightInfo {
+    FlightInfo {
+        schema: Bytes::new(),
+        endpoint: vec![],
+        flight_descriptor: Some(request.clone()),
+        total_bytes: 123,
+        total_records: 456,
+    }
+}
+
+#[tokio::test]
+async fn test_get_flight_info() {
+    do_test(|test_server, mut client| async move {
+        let request = FlightDescriptor::new_cmd(b"My Command".to_vec());
+
+        let expected_response = test_flight_info(&request);
+        test_server.set_get_flight_info_response(Ok(expected_response.clone()));
+
+        let response = client.get_flight_info(request.clone()).await.unwrap();
+
+        assert_eq!(response, expected_response);
+        assert_eq!(test_server.take_get_flight_info_request(), Some(request));
+    })
+    .await;
+}
+
+#[tokio::test]
+async fn test_get_flight_info_error() {
+    do_test(|test_server, mut client| async move {
+        let request = FlightDescriptor::new_cmd(b"My Command".to_vec());
+
+        let e = Status::unauthenticated("DENIED");
+        test_server.set_get_flight_info_response(Err(e));
+
+        let response = client.get_flight_info(request.clone()).await.unwrap_err();
+        let e = Status::unauthenticated("DENIED");
+        expect_status(response, e);
+    })
+    .await;
+}
+
+#[tokio::test]
+async fn test_get_flight_info_metadata() {
+    do_test(|test_server, mut client| async move {
+        client.add_header("foo", "bar").unwrap();
+        let request = FlightDescriptor::new_cmd(b"My Command".to_vec());
+
+        let expected_response = test_flight_info(&request);
+        test_server.set_get_flight_info_response(Ok(expected_response));
+        client.get_flight_info(request.clone()).await.unwrap();
+        ensure_metadata(&client, &test_server);
+    })
+    .await;
+}
+
+// TODO more negative  tests (like if there are endpoints defined, etc)
+
+// TODO test for do_get
+
+/// Runs the future returned by the function,  passing it a test server and client
+async fn do_test<F, Fut>(f: F)
+where
+    F: Fn(TestFlightServer, FlightClient) -> Fut,
+    Fut: Future<Output = ()>,
+{
+    let test_server = TestFlightServer::new();
+    let fixture = TestFixture::new(&test_server).await;
+    let client = FlightClient::new(fixture.channel().await);
+
+    // run the test function
+    f(test_server, client).await;
+
+    // cleanly shutdown the test fixture
+    fixture.shutdown_and_wait().await
+}
+
+fn expect_status(error: FlightError, expected: Status) {
+    let status = if let FlightError::Tonic(status) = error {
+        status
+    } else {
+        panic!("Expected FlightError::Tonic, got: {:?}", error);
+    };
+
+    assert_eq!(
+        status.code(),
+        expected.code(),
+        "Got {:?} want {:?}",
+        status,
+        expected
+    );
+    assert_eq!(
+        status.message(),
+        expected.message(),
+        "Got {:?} want {:?}",
+        status,
+        expected
+    );
+    assert_eq!(
+        status.details(),
+        expected.details(),
+        "Got {:?} want {:?}",
+        status,
+        expected
+    );
+}
+
+/// Creates and manages a running TestServer with a background task
+struct TestFixture {
+    /// channel to send shutdown command
+    shutdown: Option<tokio::sync::oneshot::Sender<()>>,
+
+    /// Address the server is listening on
+    addr: SocketAddr,
+
+    // handle for the server task
+    handle: Option<JoinHandle<Result<(), tonic::transport::Error>>>,
+}
+
+impl TestFixture {
+    /// create a new test fixture from the server
+    pub async fn new(test_server: &TestFlightServer) -> Self {
+        // let OS choose a a free port
+        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
+        let addr = listener.local_addr().unwrap();
+
+        println!("Listening on {addr}");
+
+        // prepare the shutdown channel
+        let (tx, rx) = tokio::sync::oneshot::channel();
+
+        let server_timeout = Duration::from_secs(DEFAULT_TIMEOUT_SECONDS);
+
+        let shutdown_future = async move {
+            rx.await.ok();
+        };
+
+        let serve_future = tonic::transport::Server::builder()
+            .timeout(server_timeout)
+            .add_service(test_server.service())
+            .serve_with_incoming_shutdown(
+                tokio_stream::wrappers::TcpListenerStream::new(listener),
+                shutdown_future,
+            );
+
+        // Run the server in its own background task
+        let handle = tokio::task::spawn(serve_future);
+
+        Self {
+            shutdown: Some(tx),
+            addr,
+            handle: Some(handle),
+        }
+    }
+
+    /// Return a [`Channel`] connected to the TestServer
+    pub async fn channel(&self) -> Channel {
+        let url = format!("http://{}", self.addr);
+        let uri: Uri = url.parse().expect("Valid URI");
+        Channel::builder(uri)
+            .timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECONDS))
+            .connect()
+            .await
+            .expect("error connecting to server")
+    }
+
+    /// Stops the test server and waits for the server to shutdown
+    pub async fn shutdown_and_wait(mut self) {
+        if let Some(shutdown) = self.shutdown.take() {
+            shutdown.send(()).expect("server quit early");
+        }
+        if let Some(handle) = self.handle.take() {
+            println!("Waiting on server to finish");
+            handle
+                .await
+                .expect("task join error (panic?)")
+                .expect("Server Error found at shutdown");
+        }
+    }
+}
+
+impl Drop for TestFixture {
+    fn drop(&mut self) {
+        if let Some(shutdown) = self.shutdown.take() {
+            shutdown.send(()).ok();
+        }
+        if self.handle.is_some() {
+            // tests should properly clean up TestFixture
+            println!("TestFixture::Drop called prior to `shutdown_and_wait`");
+        }
+    }
+}
diff --git a/arrow-flight/tests/common/server.rs b/arrow-flight/tests/common/server.rs
new file mode 100644
index 000000000..f1cb140b6
--- /dev/null
+++ b/arrow-flight/tests/common/server.rs
@@ -0,0 +1,212 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::sync::{Arc, Mutex};
+
+use futures::stream::BoxStream;
+use tonic::{metadata::MetadataMap, Request, Response, Status, Streaming};
+
+use arrow_flight::{
+    flight_service_server::{FlightService, FlightServiceServer},
+    Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
+    HandshakeRequest, HandshakeResponse, PutResult, SchemaResult, Ticket,
+};
+
+#[derive(Debug, Clone)]
+/// Flight server for testing, with configurable responses
+pub struct TestFlightServer {
+    /// Shared state to configure responses
+    state: Arc<Mutex<State>>,
+}
+
+impl TestFlightServer {
+    /// Create a `TestFlightServer`
+    pub fn new() -> Self {
+        Self {
+            state: Arc::new(Mutex::new(State::new())),
+        }
+    }
+
+    /// Return an [`FlightServiceServer`] that can be used with a
+    /// [`Server`](tonic::transport::Server)
+    pub fn service(&self) -> FlightServiceServer<TestFlightServer> {
+        // wrap up tonic goop
+        FlightServiceServer::new(self.clone())
+    }
+
+    /// Specify the response returned from the next call to handshake
+    pub fn set_handshake_response(&self, response: Result<HandshakeResponse, Status>) {
+        let mut state = self.state.lock().expect("mutex not poisoned");
+
+        state.handshake_response.replace(response);
+    }
+
+    /// Take and return last handshake request send to the server,
+    pub fn take_handshake_request(&self) -> Option<HandshakeRequest> {
+        self.state
+            .lock()
+            .expect("mutex not poisoned")
+            .handshake_request
+            .take()
+    }
+
+    /// Specify the response returned from the next call to handshake
+    pub fn set_get_flight_info_response(&self, response: Result<FlightInfo, Status>) {
+        let mut state = self.state.lock().expect("mutex not poisoned");
+
+        state.get_flight_info_response.replace(response);
+    }
+
+    /// Take and return last get_flight_info request send to the server,
+    pub fn take_get_flight_info_request(&self) -> Option<FlightDescriptor> {
+        self.state
+            .lock()
+            .expect("mutex not poisoned")
+            .get_flight_info_request
+            .take()
+    }
+
+    /// Returns the last metadata from a request received by the server
+    pub fn take_last_request_metadata(&self) -> Option<MetadataMap> {
+        self.state
+            .lock()
+            .expect("mutex not poisoned")
+            .last_request_metadata
+            .take()
+    }
+
+    /// Save the last request's metadatacom
+    fn save_metadata<T>(&self, request: &Request<T>) {
+        let metadata = request.metadata().clone();
+        let mut state = self.state.lock().expect("mutex not poisoned");
+        state.last_request_metadata = Some(metadata);
+    }
+}
+
+/// mutable state for the TestFlightSwrver
+#[derive(Debug, Default)]
+struct State {
+    /// The last handshake request that was received
+    pub handshake_request: Option<HandshakeRequest>,
+    /// The next response to return from `handshake()`
+    pub handshake_response: Option<Result<HandshakeResponse, Status>>,
+    /// The last `get_flight_info` request received
+    pub get_flight_info_request: Option<FlightDescriptor>,
+    /// the next response  to return from `get_flight_info`
+    pub get_flight_info_response: Option<Result<FlightInfo, Status>>,
+    /// The last request headers received
+    pub last_request_metadata: Option<MetadataMap>,
+}
+
+impl State {
+    fn new() -> Self {
+        Default::default()
+    }
+}
+
+/// Implement the FlightService trait
+#[tonic::async_trait]
+impl FlightService for TestFlightServer {
+    type HandshakeStream = BoxStream<'static, Result<HandshakeResponse, Status>>;
+    type ListFlightsStream = BoxStream<'static, Result<FlightInfo, Status>>;
+    type DoGetStream = BoxStream<'static, Result<FlightData, Status>>;
+    type DoPutStream = BoxStream<'static, Result<PutResult, Status>>;
+    type DoActionStream = BoxStream<'static, Result<arrow_flight::Result, Status>>;
+    type ListActionsStream = BoxStream<'static, Result<ActionType, Status>>;
+    type DoExchangeStream = BoxStream<'static, Result<FlightData, Status>>;
+
+    async fn handshake(
+        &self,
+        request: Request<Streaming<HandshakeRequest>>,
+    ) -> Result<Response<Self::HandshakeStream>, Status> {
+        self.save_metadata(&request);
+        let handshake_request = request.into_inner().message().await?.unwrap();
+
+        let mut state = self.state.lock().expect("mutex not poisoned");
+        state.handshake_request = Some(handshake_request);
+
+        let response = state.handshake_response.take().unwrap_or_else(|| {
+            Err(Status::internal("No handshake response configured"))
+        })?;
+
+        // turn into a streaming response
+        let output = futures::stream::iter(std::iter::once(Ok(response)));
+        Ok(Response::new(Box::pin(output) as Self::HandshakeStream))
+    }
+
+    async fn list_flights(
+        &self,
+        _request: Request<Criteria>,
+    ) -> Result<Response<Self::ListFlightsStream>, Status> {
+        Err(Status::unimplemented("Implement list_flights"))
+    }
+
+    async fn get_flight_info(
+        &self,
+        request: Request<FlightDescriptor>,
+    ) -> Result<Response<FlightInfo>, Status> {
+        self.save_metadata(&request);
+        let mut state = self.state.lock().expect("mutex not poisoned");
+        state.get_flight_info_request = Some(request.into_inner());
+        let response = state.get_flight_info_response.take().unwrap_or_else(|| {
+            Err(Status::internal("No get_flight_info response configured"))
+        })?;
+        Ok(Response::new(response))
+    }
+
+    async fn get_schema(
+        &self,
+        _request: Request<FlightDescriptor>,
+    ) -> Result<Response<SchemaResult>, Status> {
+        Err(Status::unimplemented("Implement get_schema"))
+    }
+
+    async fn do_get(
+        &self,
+        _request: Request<Ticket>,
+    ) -> Result<Response<Self::DoGetStream>, Status> {
+        Err(Status::unimplemented("Implement do_get"))
+    }
+
+    async fn do_put(
+        &self,
+        _request: Request<Streaming<FlightData>>,
+    ) -> Result<Response<Self::DoPutStream>, Status> {
+        Err(Status::unimplemented("Implement do_put"))
+    }
+
+    async fn do_action(
+        &self,
+        _request: Request<Action>,
+    ) -> Result<Response<Self::DoActionStream>, Status> {
+        Err(Status::unimplemented("Implement do_action"))
+    }
+
+    async fn list_actions(
+        &self,
+        _request: Request<Empty>,
+    ) -> Result<Response<Self::ListActionsStream>, Status> {
+        Err(Status::unimplemented("Implement list_actions"))
+    }
+
+    async fn do_exchange(
+        &self,
+        _request: Request<Streaming<FlightData>>,
+    ) -> Result<Response<Self::DoExchangeStream>, Status> {
+        Err(Status::unimplemented("Implement do_exchange"))
+    }
+}