You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by tu...@apache.org on 2022/12/19 21:02:39 UTC

[arrow-rs] branch master updated: Use custom Any instead of prost_types (#3360)

This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 0f196b8da Use custom Any instead of prost_types (#3360)
0f196b8da is described below

commit 0f196b8dad7592ae139d17c4a8aa960b0e8731fa
Author: Raphael Taylor-Davies <17...@users.noreply.github.com>
AuthorDate: Mon Dec 19 21:02:34 2022 +0000

    Use custom Any instead of prost_types (#3360)
    
    * Use custom Any instead of prost_types
    
    * Remove unnecesary path prefix
---
 arrow-flight/Cargo.toml                    |  5 +--
 arrow-flight/examples/flight_sql_server.rs | 11 ++---
 arrow-flight/src/sql/client.rs             | 22 +++++-----
 arrow-flight/src/sql/mod.rs                | 67 ++++++++++++++++++------------
 arrow-flight/src/sql/server.rs             | 26 ++++++------
 5 files changed, 73 insertions(+), 58 deletions(-)

diff --git a/arrow-flight/Cargo.toml b/arrow-flight/Cargo.toml
index 238e03f3c..847d77ca5 100644
--- a/arrow-flight/Cargo.toml
+++ b/arrow-flight/Cargo.toml
@@ -35,14 +35,13 @@ base64 = { version = "0.20", default-features = false, features = ["std"] }
 tonic = { version = "0.8", default-features = false, features = ["transport", "codegen", "prost"] }
 bytes = { version = "1", default-features = false }
 prost = { version = "0.11", default-features = false }
-prost-types = { version = "0.11.0", default-features = false, optional = true }
 prost-derive = { version = "0.11", default-features = false }
 tokio = { version = "1.0", default-features = false, features = ["macros", "rt", "rt-multi-thread"] }
-futures = { version = "0.3", default-features = false, features = ["alloc"]}
+futures = { version = "0.3", default-features = false, features = ["alloc"] }
 
 [features]
 default = []
-flight-sql-experimental = ["prost-types"]
+flight-sql-experimental = []
 
 [dev-dependencies]
 arrow = { version = "29.0.0", path = "../arrow", features = ["prettyprint"] }
diff --git a/arrow-flight/examples/flight_sql_server.rs b/arrow-flight/examples/flight_sql_server.rs
index 29e6c2c37..5adb5d59a 100644
--- a/arrow-flight/examples/flight_sql_server.rs
+++ b/arrow-flight/examples/flight_sql_server.rs
@@ -17,13 +17,14 @@
 
 use arrow_array::builder::StringBuilder;
 use arrow_array::{ArrayRef, RecordBatch};
-use arrow_flight::sql::{ActionCreatePreparedStatementResult, ProstMessageExt, SqlInfo};
+use arrow_flight::sql::{
+    ActionCreatePreparedStatementResult, Any, ProstMessageExt, SqlInfo,
+};
 use arrow_flight::{
     Action, FlightData, FlightEndpoint, HandshakeRequest, HandshakeResponse, IpcMessage,
     Location, SchemaAsIpc, Ticket,
 };
 use futures::{stream, Stream};
-use prost_types::Any;
 use std::fs;
 use std::pin::Pin;
 use std::sync::Arc;
@@ -124,7 +125,7 @@ impl FlightSqlService for FlightSqlServiceImpl {
     async fn do_get_fallback(
         &self,
         _request: Request<Ticket>,
-        _message: prost_types::Any,
+        _message: Any,
     ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
         let batch =
             Self::fake_result().map_err(|e| status!("Could not fake a result", e))?;
@@ -474,9 +475,9 @@ impl ProstMessageExt for FetchResults {
     }
 
     fn as_any(&self) -> Any {
-        prost_types::Any {
+        Any {
             type_url: FetchResults::type_url().to_string(),
-            value: ::prost::Message::encode_to_vec(self),
+            value: ::prost::Message::encode_to_vec(self).into(),
         }
     }
 }
diff --git a/arrow-flight/src/sql/client.rs b/arrow-flight/src/sql/client.rs
index fa6691793..74039027e 100644
--- a/arrow-flight/src/sql/client.rs
+++ b/arrow-flight/src/sql/client.rs
@@ -23,11 +23,12 @@ use crate::flight_service_client::FlightServiceClient;
 use crate::sql::server::{CLOSE_PREPARED_STATEMENT, CREATE_PREPARED_STATEMENT};
 use crate::sql::{
     ActionClosePreparedStatementRequest, ActionCreatePreparedStatementRequest,
-    ActionCreatePreparedStatementResult, CommandGetCatalogs, CommandGetCrossReference,
-    CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys,
-    CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables,
-    CommandPreparedStatementQuery, CommandStatementQuery, CommandStatementUpdate,
-    DoPutUpdateResult, ProstAnyExt, ProstMessageExt, SqlInfo,
+    ActionCreatePreparedStatementResult, Any, CommandGetCatalogs,
+    CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys,
+    CommandGetImportedKeys, CommandGetPrimaryKeys, CommandGetSqlInfo,
+    CommandGetTableTypes, CommandGetTables, CommandPreparedStatementQuery,
+    CommandStatementQuery, CommandStatementUpdate, DoPutUpdateResult, ProstMessageExt,
+    SqlInfo,
 };
 use crate::{
     Action, FlightData, FlightDescriptor, FlightInfo, HandshakeRequest,
@@ -177,8 +178,8 @@ impl FlightSqlServiceClient {
             .await
             .map_err(status_to_arrow_error)?
             .unwrap();
-        let any: prost_types::Any = prost::Message::decode(&*result.app_metadata)
-            .map_err(decode_error_to_arrow_error)?;
+        let any =
+            Any::decode(&*result.app_metadata).map_err(decode_error_to_arrow_error)?;
         let result: DoPutUpdateResult = any.unpack()?.unwrap();
         Ok(result.record_count)
     }
@@ -298,8 +299,7 @@ impl FlightSqlServiceClient {
             .await
             .map_err(status_to_arrow_error)?
             .unwrap();
-        let any: prost_types::Any =
-            prost::Message::decode(&*result.body).map_err(decode_error_to_arrow_error)?;
+        let any = Any::decode(&*result.body).map_err(decode_error_to_arrow_error)?;
         let prepared_result: ActionCreatePreparedStatementResult = any.unpack()?.unwrap();
         let dataset_schema = match prepared_result.dataset_schema.len() {
             0 => Schema::empty(),
@@ -384,8 +384,8 @@ impl PreparedStatement<Channel> {
             .await
             .map_err(status_to_arrow_error)?
             .unwrap();
-        let any: prost_types::Any = Message::decode(&*result.app_metadata)
-            .map_err(decode_error_to_arrow_error)?;
+        let any =
+            Any::decode(&*result.app_metadata).map_err(decode_error_to_arrow_error)?;
         let result: DoPutUpdateResult = any.unpack()?.unwrap();
         Ok(result.record_count)
     }
diff --git a/arrow-flight/src/sql/mod.rs b/arrow-flight/src/sql/mod.rs
index 0ddc64c55..88dc6cde9 100644
--- a/arrow-flight/src/sql/mod.rs
+++ b/arrow-flight/src/sql/mod.rs
@@ -16,6 +16,7 @@
 // under the License.
 
 use arrow_schema::ArrowError;
+use bytes::Bytes;
 use prost::Message;
 
 mod gen {
@@ -66,8 +67,8 @@ pub trait ProstMessageExt: prost::Message + Default {
     /// type_url for this Message
     fn type_url() -> &'static str;
 
-    /// Convert this Message to prost_types::Any
-    fn as_any(&self) -> prost_types::Any;
+    /// Convert this Message to [`Any`]
+    fn as_any(&self) -> Any;
 }
 
 macro_rules! prost_message_ext {
@@ -78,10 +79,10 @@ macro_rules! prost_message_ext {
                     concat!("type.googleapis.com/arrow.flight.protocol.sql.", stringify!($name))
                 }
 
-                fn as_any(&self) -> prost_types::Any {
-                    prost_types::Any {
+                fn as_any(&self) -> Any {
+                    Any {
                         type_url: <$name>::type_url().to_string(),
-                        value: self.encode_to_vec(),
+                        value: self.encode_to_vec().into(),
                     }
                 }
             }
@@ -111,30 +112,44 @@ prost_message_ext!(
     TicketStatementQuery,
 );
 
-/// ProstAnyExt are useful utility methods for prost_types::Any
-/// The API design is inspired by [rust-protobuf](https://github.com/stepancheg/rust-protobuf/blob/master/protobuf/src/well_known_types_util/any.rs)
-pub trait ProstAnyExt {
-    /// Check if `Any` contains a message of given type.
-    fn is<M: ProstMessageExt>(&self) -> bool;
-
-    /// Extract a message from this `Any`.
-    ///
-    /// # Returns
-    ///
-    /// * `Ok(None)` when message type mismatch
-    /// * `Err` when parse failed
-    fn unpack<M: ProstMessageExt>(&self) -> Result<Option<M>, ArrowError>;
-
-    /// Pack any message into `prost_types::Any` value.
-    fn pack<M: ProstMessageExt>(message: &M) -> Result<prost_types::Any, ArrowError>;
+/// An implementation of the protobuf [`Any`] message type
+///
+/// Encoded protobuf messages are not self-describing, nor contain any information
+/// on the schema of the encoded payload. Consequently to decode a protobuf a client
+/// must know the exact schema of the message.
+///
+/// This presents a problem for loosely typed APIs, where the exact message payloads
+/// are not enumerable, and therefore cannot be enumerated as variants in a [oneof].
+///
+/// One solution is [`Any`] where the encoded payload is paired with a `type_url`
+/// identifying the type of encoded message, and the resulting combination encoded.
+///
+/// Clients can then decode the outer [`Any`], inspect the `type_url` and if it is
+/// a type they recognise, proceed to decode the embedded message `value`
+///
+/// [`Any`]: https://developers.google.com/protocol-buffers/docs/proto3#any
+/// [oneof]: https://developers.google.com/protocol-buffers/docs/proto3#oneof
+#[derive(Clone, PartialEq, ::prost::Message)]
+pub struct Any {
+    /// A URL/resource name that uniquely identifies the type of the serialized
+    /// protocol buffer message. This string must contain at least
+    /// one "/" character. The last segment of the URL's path must represent
+    /// the fully qualified name of the type (as in
+    /// `path/google.protobuf.Duration`). The name should be in a canonical form
+    /// (e.g., leading "." is not accepted).
+    #[prost(string, tag = "1")]
+    pub type_url: String,
+    /// Must be a valid serialized protocol buffer of the above specified type.
+    #[prost(bytes = "bytes", tag = "2")]
+    pub value: Bytes,
 }
 
-impl ProstAnyExt for prost_types::Any {
-    fn is<M: ProstMessageExt>(&self) -> bool {
+impl Any {
+    pub fn is<M: ProstMessageExt>(&self) -> bool {
         M::type_url() == self.type_url
     }
 
-    fn unpack<M: ProstMessageExt>(&self) -> Result<Option<M>, ArrowError> {
+    pub fn unpack<M: ProstMessageExt>(&self) -> Result<Option<M>, ArrowError> {
         if !self.is::<M>() {
             return Ok(None);
         }
@@ -144,7 +159,7 @@ impl ProstAnyExt for prost_types::Any {
         Ok(Some(m))
     }
 
-    fn pack<M: ProstMessageExt>(message: &M) -> Result<prost_types::Any, ArrowError> {
+    pub fn pack<M: ProstMessageExt>(message: &M) -> Result<Any, ArrowError> {
         Ok(message.as_any())
     }
 }
@@ -170,7 +185,7 @@ mod tests {
         let query = CommandStatementQuery {
             query: "select 1".to_string(),
         };
-        let any = prost_types::Any::pack(&query).unwrap();
+        let any = Any::pack(&query).unwrap();
         assert!(any.is::<CommandStatementQuery>());
         let unpack_query: CommandStatementQuery = any.unpack().unwrap().unwrap();
         assert_eq!(query, unpack_query);
diff --git a/arrow-flight/src/sql/server.rs b/arrow-flight/src/sql/server.rs
index ec48d7cfe..fdf9c9133 100644
--- a/arrow-flight/src/sql/server.rs
+++ b/arrow-flight/src/sql/server.rs
@@ -17,6 +17,7 @@
 
 use std::pin::Pin;
 
+use crate::sql::Any;
 use futures::Stream;
 use prost::Message;
 use tonic::{Request, Response, Status, Streaming};
@@ -32,7 +33,7 @@ use super::{
     CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys,
     CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables,
     CommandPreparedStatementQuery, CommandPreparedStatementUpdate, CommandStatementQuery,
-    CommandStatementUpdate, DoPutUpdateResult, ProstAnyExt, ProstMessageExt, SqlInfo,
+    CommandStatementUpdate, DoPutUpdateResult, ProstMessageExt, SqlInfo,
     TicketStatementQuery,
 };
 
@@ -63,7 +64,7 @@ pub trait FlightSqlService: Sync + Send + Sized + 'static {
     async fn do_get_fallback(
         &self,
         _request: Request<Ticket>,
-        message: prost_types::Any,
+        message: Any,
     ) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
         Err(Status::unimplemented(format!(
             "do_get: The defined request is invalid: {}",
@@ -311,8 +312,8 @@ where
         &self,
         request: Request<FlightDescriptor>,
     ) -> Result<Response<FlightInfo>, Status> {
-        let message: prost_types::Any =
-            Message::decode(&*request.get_ref().cmd).map_err(decode_error_to_status)?;
+        let message =
+            Any::decode(&*request.get_ref().cmd).map_err(decode_error_to_status)?;
 
         if message.is::<CommandStatementQuery>() {
             let token = message
@@ -411,10 +412,10 @@ where
         &self,
         request: Request<Ticket>,
     ) -> Result<Response<Self::DoGetStream>, Status> {
-        let msg: prost_types::Any = Message::decode(&*request.get_ref().ticket)
+        let msg: Any = Message::decode(&*request.get_ref().ticket)
             .map_err(decode_error_to_status)?;
 
-        fn unpack<T: ProstMessageExt>(msg: prost_types::Any) -> Result<T, Status> {
+        fn unpack<T: ProstMessageExt>(msg: Any) -> Result<T, Status> {
             msg.unpack()
                 .map_err(arrow_error_to_status)?
                 .ok_or_else(|| Status::internal("Expected a command, but found none."))
@@ -462,9 +463,8 @@ where
         mut request: Request<Streaming<FlightData>>,
     ) -> Result<Response<Self::DoPutStream>, Status> {
         let cmd = request.get_mut().message().await?.unwrap();
-        let message: prost_types::Any =
-            Message::decode(&*cmd.flight_descriptor.unwrap().cmd)
-                .map_err(decode_error_to_status)?;
+        let message = Any::decode(&*cmd.flight_descriptor.unwrap().cmd)
+            .map_err(decode_error_to_status)?;
         if message.is::<CommandStatementUpdate>() {
             let token = message
                 .unpack()
@@ -536,8 +536,8 @@ where
         request: Request<Action>,
     ) -> Result<Response<Self::DoActionStream>, Status> {
         if request.get_ref().r#type == CREATE_PREPARED_STATEMENT {
-            let any: prost_types::Any = Message::decode(&*request.get_ref().body)
-                .map_err(decode_error_to_status)?;
+            let any =
+                Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?;
 
             let cmd: ActionCreatePreparedStatementRequest = any
                 .unpack()
@@ -556,8 +556,8 @@ where
             return Ok(Response::new(Box::pin(output)));
         }
         if request.get_ref().r#type == CLOSE_PREPARED_STATEMENT {
-            let any: prost_types::Any = Message::decode(&*request.get_ref().body)
-                .map_err(decode_error_to_status)?;
+            let any =
+                Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?;
 
             let cmd: ActionClosePreparedStatementRequest = any
                 .unpack()