You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2023/01/05 19:41:34 UTC

[arrow-rs] branch master updated: Add tests for `FlightClient::{list_flights, list_actions, do_action, get_schema}` (#3463)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 2d2d0a3ba Add tests for `FlightClient::{list_flights, list_actions, do_action, get_schema}` (#3463)
2d2d0a3ba is described below

commit 2d2d0a3ba72efb5ee82324064f7c7678c2dd8336
Author: Andrew Lamb <an...@nerdnetworks.org>
AuthorDate: Thu Jan 5 14:41:28 2023 -0500

    Add tests for `FlightClient::{list_flights, list_actions, do_action, get_schema}` (#3463)
---
 arrow-flight/src/lib.rs             |   7 +
 arrow-flight/tests/client.rs        | 326 ++++++++++++++++++++++++++++++++++--
 arrow-flight/tests/common/server.rs | 165 +++++++++++++++---
 3 files changed, 466 insertions(+), 32 deletions(-)

diff --git a/arrow-flight/src/lib.rs b/arrow-flight/src/lib.rs
index 87aeba1c1..3057735a6 100644
--- a/arrow-flight/src/lib.rs
+++ b/arrow-flight/src/lib.rs
@@ -454,6 +454,13 @@ impl Action {
     }
 }
 
+impl Result {
+    /// Create a new Result with the specified body
+    pub fn new(body: impl Into<Bytes>) -> Self {
+        Self { body: body.into() }
+    }
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
diff --git a/arrow-flight/tests/client.rs b/arrow-flight/tests/client.rs
index 7537e46db..032dad049 100644
--- a/arrow-flight/tests/client.rs
+++ b/arrow-flight/tests/client.rs
@@ -23,9 +23,10 @@ mod common {
 use arrow_array::{RecordBatch, UInt64Array};
 use arrow_flight::{
     decode::FlightRecordBatchStream, encode::FlightDataEncoderBuilder,
-    error::FlightError, FlightClient, FlightData, FlightDescriptor, FlightInfo,
-    HandshakeRequest, HandshakeResponse, PutResult, Ticket,
+    error::FlightError, Action, ActionType, Criteria, Empty, FlightClient, FlightData,
+    FlightDescriptor, FlightInfo, HandshakeRequest, HandshakeResponse, PutResult, Ticket,
 };
+use arrow_schema::{DataType, Field, Schema};
 use bytes::Bytes;
 use common::server::TestFlightServer;
 use futures::{Future, StreamExt, TryStreamExt};
@@ -70,10 +71,9 @@ async fn test_handshake_error() {
     do_test(|test_server, mut client| async move {
         let request_payload = "foo-request-payload".to_string().into_bytes();
         let e = Status::unauthenticated("DENIED");
-        test_server.set_handshake_response(Err(e));
+        test_server.set_handshake_response(Err(e.clone()));
 
         let response = client.handshake(request_payload).await.unwrap_err();
-        let e = Status::unauthenticated("DENIED");
         expect_status(response, e);
     })
     .await;
@@ -134,10 +134,9 @@ async fn test_get_flight_info_error() {
         let request = FlightDescriptor::new_cmd(b"My Command".to_vec());
 
         let e = Status::unauthenticated("DENIED");
-        test_server.set_get_flight_info_response(Err(e));
+        test_server.set_get_flight_info_response(Err(e.clone()));
 
         let response = client.get_flight_info(request.clone()).await.unwrap_err();
-        let e = Status::unauthenticated("DENIED");
         expect_status(response, e);
     })
     .await;
@@ -213,7 +212,7 @@ async fn test_do_get_error_in_record_batch_stream() {
 
         let e = Status::data_loss("she's dead jim");
 
-        let expected_response = vec![Ok(batch), Err(FlightError::Tonic(e.clone()))];
+        let expected_response = vec![Ok(batch), Err(e.clone())];
 
         test_server.set_do_get_response(expected_response);
 
@@ -300,11 +299,13 @@ async fn test_do_put_error_stream() {
 
         let input_flight_data = test_flight_data().await;
 
+        let e = Status::invalid_argument("bad arg");
+
         let response = vec![
             Ok(PutResult {
                 app_metadata: Bytes::from("foo-metadata"),
             }),
-            Err(FlightError::Tonic(Status::invalid_argument("bad arg"))),
+            Err(e.clone()),
         ];
 
         test_server.set_do_put_response(response);
@@ -320,7 +321,6 @@ async fn test_do_put_error_stream() {
             Err(e) => e,
         };
 
-        let e = Status::invalid_argument("bad arg");
         expect_status(response, e);
         // server still got the request
         assert_eq!(test_server.take_do_put_request(), Some(input_flight_data));
@@ -404,6 +404,7 @@ async fn test_do_exchange_error_stream() {
 
         let input_flight_data = test_flight_data().await;
 
+        let e = Status::invalid_argument("the error");
         let response = test_flight_data2()
             .await
             .into_iter()
@@ -413,8 +414,7 @@ async fn test_do_exchange_error_stream() {
                     Ok(m)
                 } else {
                     // make all messages after the first an error
-                    let e = tonic::Status::invalid_argument("the error");
-                    Err(FlightError::Tonic(e))
+                    Err(e.clone())
                 }
             })
             .collect();
@@ -432,7 +432,6 @@ async fn test_do_exchange_error_stream() {
             Err(e) => e,
         };
 
-        let e = tonic::Status::invalid_argument("the error");
         expect_status(response, e);
         // server still got the request
         assert_eq!(
@@ -444,6 +443,309 @@ async fn test_do_exchange_error_stream() {
     .await;
 }
 
+#[tokio::test]
+async fn test_get_schema() {
+    do_test(|test_server, mut client| async move {
+        client.add_header("foo-header", "bar-header-value").unwrap();
+
+        let schema = Schema::new(vec![Field::new("foo", DataType::Int64, true)]);
+
+        let request = FlightDescriptor::new_cmd("my command");
+        test_server.set_get_schema_response(Ok(schema.clone()));
+
+        let response = client
+            .get_schema(request.clone())
+            .await
+            .expect("error making request");
+
+        let expected_schema = schema;
+        let expected_request = request;
+
+        assert_eq!(response, expected_schema);
+        assert_eq!(
+            test_server.take_get_schema_request(),
+            Some(expected_request)
+        );
+
+        ensure_metadata(&client, &test_server);
+    })
+    .await;
+}
+
+#[tokio::test]
+async fn test_get_schema_error() {
+    do_test(|test_server, mut client| async move {
+        client.add_header("foo-header", "bar-header-value").unwrap();
+        let request = FlightDescriptor::new_cmd("my command");
+
+        let e = Status::unauthenticated("DENIED");
+        test_server.set_get_schema_response(Err(e.clone()));
+
+        let response = client.get_schema(request).await.unwrap_err();
+        expect_status(response, e);
+        ensure_metadata(&client, &test_server);
+    })
+    .await;
+}
+
+#[tokio::test]
+async fn test_list_flights() {
+    do_test(|test_server, mut client| async move {
+        client.add_header("foo-header", "bar-header-value").unwrap();
+
+        let infos = vec![
+            test_flight_info(&FlightDescriptor::new_cmd("foo")),
+            test_flight_info(&FlightDescriptor::new_cmd("bar")),
+        ];
+
+        let response = infos.iter().map(|i| Ok(i.clone())).collect();
+        test_server.set_list_flights_response(response);
+
+        let response_stream = client
+            .list_flights("query")
+            .await
+            .expect("error making request");
+
+        let expected_response = infos;
+        let response: Vec<_> = response_stream
+            .try_collect()
+            .await
+            .expect("Error streaming data");
+
+        let expected_request = Some(Criteria {
+            expression: "query".into(),
+        });
+
+        assert_eq!(response, expected_response);
+        assert_eq!(test_server.take_list_flights_request(), expected_request);
+        ensure_metadata(&client, &test_server);
+    })
+    .await;
+}
+
+#[tokio::test]
+async fn test_list_flights_error() {
+    do_test(|test_server, mut client| async move {
+        client.add_header("foo-header", "bar-header-value").unwrap();
+
+        let response = client.list_flights("query").await;
+        let response = match response {
+            Ok(_) => panic!("unexpected success"),
+            Err(e) => e,
+        };
+
+        let e = Status::internal("No list_flights response configured");
+        expect_status(response, e);
+        // server still got the request
+        let expected_request = Some(Criteria {
+            expression: "query".into(),
+        });
+        assert_eq!(test_server.take_list_flights_request(), expected_request);
+        ensure_metadata(&client, &test_server);
+    })
+    .await;
+}
+
+#[tokio::test]
+async fn test_list_flights_error_in_stream() {
+    do_test(|test_server, mut client| async move {
+        client.add_header("foo-header", "bar-header-value").unwrap();
+
+        let e = Status::data_loss("she's dead jim");
+
+        let response = vec![
+            Ok(test_flight_info(&FlightDescriptor::new_cmd("foo"))),
+            Err(e.clone()),
+        ];
+        test_server.set_list_flights_response(response);
+
+        let response_stream = client
+            .list_flights("other query")
+            .await
+            .expect("error making request");
+
+        let response: Result<Vec<_>, FlightError> = response_stream.try_collect().await;
+
+        let response = response.unwrap_err();
+        expect_status(response, e);
+        // server still got the request
+        let expected_request = Some(Criteria {
+            expression: "other query".into(),
+        });
+        assert_eq!(test_server.take_list_flights_request(), expected_request);
+        ensure_metadata(&client, &test_server);
+    })
+    .await;
+}
+
+#[tokio::test]
+async fn test_list_actions() {
+    do_test(|test_server, mut client| async move {
+        client.add_header("foo-header", "bar-header-value").unwrap();
+
+        let actions = vec![
+            ActionType {
+                r#type: "type 1".into(),
+                description: "awesomeness".into(),
+            },
+            ActionType {
+                r#type: "type 2".into(),
+                description: "more awesomeness".into(),
+            },
+        ];
+
+        let response = actions.iter().map(|i| Ok(i.clone())).collect();
+        test_server.set_list_actions_response(response);
+
+        let response_stream = client.list_actions().await.expect("error making request");
+
+        let expected_response = actions;
+        let response: Vec<_> = response_stream
+            .try_collect()
+            .await
+            .expect("Error streaming data");
+
+        assert_eq!(response, expected_response);
+        assert_eq!(test_server.take_list_actions_request(), Some(Empty {}));
+        ensure_metadata(&client, &test_server);
+    })
+    .await;
+}
+
+#[tokio::test]
+async fn test_list_actions_error() {
+    do_test(|test_server, mut client| async move {
+        client.add_header("foo-header", "bar-header-value").unwrap();
+
+        let response = client.list_actions().await;
+        let response = match response {
+            Ok(_) => panic!("unexpected success"),
+            Err(e) => e,
+        };
+
+        let e = Status::internal("No list_actions response configured");
+        expect_status(response, e);
+        // server still got the request
+        assert_eq!(test_server.take_list_actions_request(), Some(Empty {}));
+        ensure_metadata(&client, &test_server);
+    })
+    .await;
+}
+
+#[tokio::test]
+async fn test_list_actions_error_in_stream() {
+    do_test(|test_server, mut client| async move {
+        client.add_header("foo-header", "bar-header-value").unwrap();
+
+        let e = Status::data_loss("she's dead jim");
+
+        let response = vec![
+            Ok(ActionType {
+                r#type: "type 1".into(),
+                description: "awesomeness".into(),
+            }),
+            Err(e.clone()),
+        ];
+        test_server.set_list_actions_response(response);
+
+        let response_stream = client.list_actions().await.expect("error making request");
+
+        let response: Result<Vec<_>, FlightError> = response_stream.try_collect().await;
+
+        let response = response.unwrap_err();
+        expect_status(response, e);
+        // server still got the request
+        assert_eq!(test_server.take_list_actions_request(), Some(Empty {}));
+        ensure_metadata(&client, &test_server);
+    })
+    .await;
+}
+
+#[tokio::test]
+async fn test_do_action() {
+    do_test(|test_server, mut client| async move {
+        client.add_header("foo-header", "bar-header-value").unwrap();
+
+        let bytes = vec![Bytes::from("foo"), Bytes::from("blarg")];
+
+        let response = bytes
+            .iter()
+            .cloned()
+            .map(arrow_flight::Result::new)
+            .map(Ok)
+            .collect();
+        test_server.set_do_action_response(response);
+
+        let request = Action::new("action type", "action body");
+
+        let response_stream = client
+            .do_action(request.clone())
+            .await
+            .expect("error making request");
+
+        let expected_response = bytes;
+        let response: Vec<_> = response_stream
+            .try_collect()
+            .await
+            .expect("Error streaming data");
+
+        assert_eq!(response, expected_response);
+        assert_eq!(test_server.take_do_action_request(), Some(request));
+        ensure_metadata(&client, &test_server);
+    })
+    .await;
+}
+
+#[tokio::test]
+async fn test_do_action_error() {
+    do_test(|test_server, mut client| async move {
+        client.add_header("foo-header", "bar-header-value").unwrap();
+
+        let request = Action::new("action type", "action body");
+
+        let response = client.do_action(request.clone()).await;
+        let response = match response {
+            Ok(_) => panic!("unexpected success"),
+            Err(e) => e,
+        };
+
+        let e = Status::internal("No do_action response configured");
+        expect_status(response, e);
+        // server still got the request
+        assert_eq!(test_server.take_do_action_request(), Some(request));
+        ensure_metadata(&client, &test_server);
+    })
+    .await;
+}
+
+#[tokio::test]
+async fn test_do_action_error_in_stream() {
+    do_test(|test_server, mut client| async move {
+        client.add_header("foo-header", "bar-header-value").unwrap();
+
+        let e = Status::data_loss("she's dead jim");
+
+        let request = Action::new("action type", "action body");
+
+        let response = vec![Ok(arrow_flight::Result::new("foo")), Err(e.clone())];
+        test_server.set_do_action_response(response);
+
+        let response_stream = client
+            .do_action(request.clone())
+            .await
+            .expect("error making request");
+
+        let response: Result<Vec<_>, FlightError> = response_stream.try_collect().await;
+
+        let response = response.unwrap_err();
+        expect_status(response, e);
+        // server still got the request
+        assert_eq!(test_server.take_do_action_request(), Some(request));
+        ensure_metadata(&client, &test_server);
+    })
+    .await;
+}
+
 async fn test_flight_data() -> Vec<FlightData> {
     let batch = RecordBatch::try_from_iter(vec![(
         "col",
diff --git a/arrow-flight/tests/common/server.rs b/arrow-flight/tests/common/server.rs
index 5060d9d0c..b87019d63 100644
--- a/arrow-flight/tests/common/server.rs
+++ b/arrow-flight/tests/common/server.rs
@@ -18,15 +18,15 @@
 use std::sync::{Arc, Mutex};
 
 use arrow_array::RecordBatch;
+use arrow_schema::Schema;
 use futures::{stream::BoxStream, StreamExt, TryStreamExt};
 use tonic::{metadata::MetadataMap, Request, Response, Status, Streaming};
 
 use arrow_flight::{
     encode::FlightDataEncoderBuilder,
-    error::FlightError,
     flight_service_server::{FlightService, FlightServiceServer},
     Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
-    HandshakeRequest, HandshakeResponse, PutResult, SchemaResult, Ticket,
+    HandshakeRequest, HandshakeResponse, PutResult, SchemaAsIpc, SchemaResult, Ticket,
 };
 
 #[derive(Debug, Clone)]
@@ -84,7 +84,7 @@ impl TestFlightServer {
     }
 
     /// Specify the response returned from the next call to `do_get`
-    pub fn set_do_get_response(&self, response: Vec<Result<RecordBatch, FlightError>>) {
+    pub fn set_do_get_response(&self, response: Vec<Result<RecordBatch, Status>>) {
         let mut state = self.state.lock().expect("mutex not poisoned");
         state.do_get_response.replace(response);
     }
@@ -99,7 +99,7 @@ impl TestFlightServer {
     }
 
     /// Specify the response returned from the next call to `do_put`
-    pub fn set_do_put_response(&self, response: Vec<Result<PutResult, FlightError>>) {
+    pub fn set_do_put_response(&self, response: Vec<Result<PutResult, Status>>) {
         let mut state = self.state.lock().expect("mutex not poisoned");
         state.do_put_response.replace(response);
     }
@@ -114,10 +114,7 @@ impl TestFlightServer {
     }
 
     /// Specify the response returned from the next call to `do_exchange`
-    pub fn set_do_exchange_response(
-        &self,
-        response: Vec<Result<FlightData, FlightError>>,
-    ) {
+    pub fn set_do_exchange_response(&self, response: Vec<Result<FlightData, Status>>) {
         let mut state = self.state.lock().expect("mutex not poisoned");
         state.do_exchange_response.replace(response);
     }
@@ -131,6 +128,69 @@ impl TestFlightServer {
             .take()
     }
 
+    /// Specify the response returned from the next call to `list_flights`
+    pub fn set_list_flights_response(&self, response: Vec<Result<FlightInfo, Status>>) {
+        let mut state = self.state.lock().expect("mutex not poisoned");
+        state.list_flights_response.replace(response);
+    }
+
+    /// Take and return last list_flights request send to the server,
+    pub fn take_list_flights_request(&self) -> Option<Criteria> {
+        self.state
+            .lock()
+            .expect("mutex not poisoned")
+            .list_flights_request
+            .take()
+    }
+
+    /// Specify the response returned from the next call to `get_schema`
+    pub fn set_get_schema_response(&self, response: Result<Schema, Status>) {
+        let mut state = self.state.lock().expect("mutex not poisoned");
+        state.get_schema_response.replace(response);
+    }
+
+    /// Take and return last get_schema request send to the server,
+    pub fn take_get_schema_request(&self) -> Option<FlightDescriptor> {
+        self.state
+            .lock()
+            .expect("mutex not poisoned")
+            .get_schema_request
+            .take()
+    }
+
+    /// Specify the response returned from the next call to `list_actions`
+    pub fn set_list_actions_response(&self, response: Vec<Result<ActionType, Status>>) {
+        let mut state = self.state.lock().expect("mutex not poisoned");
+        state.list_actions_response.replace(response);
+    }
+
+    /// Take and return last list_actions request send to the server,
+    pub fn take_list_actions_request(&self) -> Option<Empty> {
+        self.state
+            .lock()
+            .expect("mutex not poisoned")
+            .list_actions_request
+            .take()
+    }
+
+    /// Specify the response returned from the next call to `do_action`
+    pub fn set_do_action_response(
+        &self,
+        response: Vec<Result<arrow_flight::Result, Status>>,
+    ) {
+        let mut state = self.state.lock().expect("mutex not poisoned");
+        state.do_action_response.replace(response);
+    }
+
+    /// Take and return last do_action request send to the server,
+    pub fn take_do_action_request(&self) -> Option<Action> {
+        self.state
+            .lock()
+            .expect("mutex not poisoned")
+            .do_action_request
+            .take()
+    }
+
     /// Returns the last metadata from a request received by the server
     pub fn take_last_request_metadata(&self) -> Option<MetadataMap> {
         self.state
@@ -162,15 +222,31 @@ struct State {
     /// The last do_get request received
     pub do_get_request: Option<Ticket>,
     /// The next response returned from `do_get`
-    pub do_get_response: Option<Vec<Result<RecordBatch, FlightError>>>,
+    pub do_get_response: Option<Vec<Result<RecordBatch, Status>>>,
     /// The last do_put request received
     pub do_put_request: Option<Vec<FlightData>>,
     /// The next response returned from `do_put`
-    pub do_put_response: Option<Vec<Result<PutResult, FlightError>>>,
+    pub do_put_response: Option<Vec<Result<PutResult, Status>>>,
     /// The last do_exchange request received
     pub do_exchange_request: Option<Vec<FlightData>>,
     /// The next response returned from `do_exchange`
-    pub do_exchange_response: Option<Vec<Result<FlightData, FlightError>>>,
+    pub do_exchange_response: Option<Vec<Result<FlightData, Status>>>,
+    /// The last list_flights request received
+    pub list_flights_request: Option<Criteria>,
+    /// The next response returned from `list_flights`
+    pub list_flights_response: Option<Vec<Result<FlightInfo, Status>>>,
+    /// The last get_schema request received
+    pub get_schema_request: Option<FlightDescriptor>,
+    /// The next response returned from `get_schema`
+    pub get_schema_response: Option<Result<Schema, Status>>,
+    /// The last list_actions request received
+    pub list_actions_request: Option<Empty>,
+    /// The next response returned from `list_actions`
+    pub list_actions_response: Option<Vec<Result<ActionType, Status>>>,
+    /// The last do_action request received
+    pub do_action_request: Option<Action>,
+    /// The next response returned from `do_action`
+    pub do_action_response: Option<Vec<Result<arrow_flight::Result, Status>>>,
     /// The last request headers received
     pub last_request_metadata: Option<MetadataMap>,
 }
@@ -213,9 +289,21 @@ impl FlightService for TestFlightServer {
 
     async fn list_flights(
         &self,
-        _request: Request<Criteria>,
+        request: Request<Criteria>,
     ) -> Result<Response<Self::ListFlightsStream>, Status> {
-        Err(Status::unimplemented("Implement list_flights"))
+        self.save_metadata(&request);
+        let mut state = self.state.lock().expect("mutex not poisoned");
+
+        state.list_flights_request = Some(request.into_inner());
+
+        let flights: Vec<_> = state
+            .list_flights_response
+            .take()
+            .ok_or_else(|| Status::internal("No list_flights response configured"))?;
+
+        let flights_stream = futures::stream::iter(flights);
+
+        Ok(Response::new(flights_stream.boxed()))
     }
 
     async fn get_flight_info(
@@ -233,9 +321,22 @@ impl FlightService for TestFlightServer {
 
     async fn get_schema(
         &self,
-        _request: Request<FlightDescriptor>,
+        request: Request<FlightDescriptor>,
     ) -> Result<Response<SchemaResult>, Status> {
-        Err(Status::unimplemented("Implement get_schema"))
+        self.save_metadata(&request);
+        let mut state = self.state.lock().expect("mutex not poisoned");
+        state.get_schema_request = Some(request.into_inner());
+        let schema = state.get_schema_response.take().unwrap_or_else(|| {
+            Err(Status::internal("No get_schema response configured"))
+        })?;
+
+        // encode the schema
+        let options = arrow_ipc::writer::IpcWriteOptions::default();
+        let response: SchemaResult = SchemaAsIpc::new(&schema, &options)
+            .try_into()
+            .expect("Error encoding schema");
+
+        Ok(Response::new(response))
     }
 
     async fn do_get(
@@ -252,7 +353,7 @@ impl FlightService for TestFlightServer {
             .take()
             .ok_or_else(|| Status::internal("No do_get response configured"))?;
 
-        let batch_stream = futures::stream::iter(batches);
+        let batch_stream = futures::stream::iter(batches).map_err(Into::into);
 
         let stream = FlightDataEncoderBuilder::new()
             .build(batch_stream)
@@ -284,16 +385,40 @@ impl FlightService for TestFlightServer {
 
     async fn do_action(
         &self,
-        _request: Request<Action>,
+        request: Request<Action>,
     ) -> Result<Response<Self::DoActionStream>, Status> {
-        Err(Status::unimplemented("Implement do_action"))
+        self.save_metadata(&request);
+        let mut state = self.state.lock().expect("mutex not poisoned");
+
+        state.do_action_request = Some(request.into_inner());
+
+        let results: Vec<_> = state
+            .do_action_response
+            .take()
+            .ok_or_else(|| Status::internal("No do_action response configured"))?;
+
+        let results_stream = futures::stream::iter(results);
+
+        Ok(Response::new(results_stream.boxed()))
     }
 
     async fn list_actions(
         &self,
-        _request: Request<Empty>,
+        request: Request<Empty>,
     ) -> Result<Response<Self::ListActionsStream>, Status> {
-        Err(Status::unimplemented("Implement list_actions"))
+        self.save_metadata(&request);
+        let mut state = self.state.lock().expect("mutex not poisoned");
+
+        state.list_actions_request = Some(request.into_inner());
+
+        let actions: Vec<_> = state
+            .list_actions_response
+            .take()
+            .ok_or_else(|| Status::internal("No list_actions response configured"))?;
+
+        let action_stream = futures::stream::iter(actions);
+
+        Ok(Response::new(action_stream.boxed()))
     }
 
     async fn do_exchange(