You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by tu...@apache.org on 2022/12/19 21:02:39 UTC
[arrow-rs] branch master updated: Use custom Any instead of prost_types (#3360)
This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 0f196b8da Use custom Any instead of prost_types (#3360)
0f196b8da is described below
commit 0f196b8dad7592ae139d17c4a8aa960b0e8731fa
Author: Raphael Taylor-Davies <17...@users.noreply.github.com>
AuthorDate: Mon Dec 19 21:02:34 2022 +0000
Use custom Any instead of prost_types (#3360)
* Use custom Any instead of prost_types
* Remove unnecesary path prefix
---
arrow-flight/Cargo.toml | 5 +--
arrow-flight/examples/flight_sql_server.rs | 11 ++---
arrow-flight/src/sql/client.rs | 22 +++++-----
arrow-flight/src/sql/mod.rs | 67 ++++++++++++++++++------------
arrow-flight/src/sql/server.rs | 26 ++++++------
5 files changed, 73 insertions(+), 58 deletions(-)
diff --git a/arrow-flight/Cargo.toml b/arrow-flight/Cargo.toml
index 238e03f3c..847d77ca5 100644
--- a/arrow-flight/Cargo.toml
+++ b/arrow-flight/Cargo.toml
@@ -35,14 +35,13 @@ base64 = { version = "0.20", default-features = false, features = ["std"] }
tonic = { version = "0.8", default-features = false, features = ["transport", "codegen", "prost"] }
bytes = { version = "1", default-features = false }
prost = { version = "0.11", default-features = false }
-prost-types = { version = "0.11.0", default-features = false, optional = true }
prost-derive = { version = "0.11", default-features = false }
tokio = { version = "1.0", default-features = false, features = ["macros", "rt", "rt-multi-thread"] }
-futures = { version = "0.3", default-features = false, features = ["alloc"]}
+futures = { version = "0.3", default-features = false, features = ["alloc"] }
[features]
default = []
-flight-sql-experimental = ["prost-types"]
+flight-sql-experimental = []
[dev-dependencies]
arrow = { version = "29.0.0", path = "../arrow", features = ["prettyprint"] }
diff --git a/arrow-flight/examples/flight_sql_server.rs b/arrow-flight/examples/flight_sql_server.rs
index 29e6c2c37..5adb5d59a 100644
--- a/arrow-flight/examples/flight_sql_server.rs
+++ b/arrow-flight/examples/flight_sql_server.rs
@@ -17,13 +17,14 @@
use arrow_array::builder::StringBuilder;
use arrow_array::{ArrayRef, RecordBatch};
-use arrow_flight::sql::{ActionCreatePreparedStatementResult, ProstMessageExt, SqlInfo};
+use arrow_flight::sql::{
+ ActionCreatePreparedStatementResult, Any, ProstMessageExt, SqlInfo,
+};
use arrow_flight::{
Action, FlightData, FlightEndpoint, HandshakeRequest, HandshakeResponse, IpcMessage,
Location, SchemaAsIpc, Ticket,
};
use futures::{stream, Stream};
-use prost_types::Any;
use std::fs;
use std::pin::Pin;
use std::sync::Arc;
@@ -124,7 +125,7 @@ impl FlightSqlService for FlightSqlServiceImpl {
async fn do_get_fallback(
&self,
_request: Request<Ticket>,
- _message: prost_types::Any,
+ _message: Any,
) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
let batch =
Self::fake_result().map_err(|e| status!("Could not fake a result", e))?;
@@ -474,9 +475,9 @@ impl ProstMessageExt for FetchResults {
}
fn as_any(&self) -> Any {
- prost_types::Any {
+ Any {
type_url: FetchResults::type_url().to_string(),
- value: ::prost::Message::encode_to_vec(self),
+ value: ::prost::Message::encode_to_vec(self).into(),
}
}
}
diff --git a/arrow-flight/src/sql/client.rs b/arrow-flight/src/sql/client.rs
index fa6691793..74039027e 100644
--- a/arrow-flight/src/sql/client.rs
+++ b/arrow-flight/src/sql/client.rs
@@ -23,11 +23,12 @@ use crate::flight_service_client::FlightServiceClient;
use crate::sql::server::{CLOSE_PREPARED_STATEMENT, CREATE_PREPARED_STATEMENT};
use crate::sql::{
ActionClosePreparedStatementRequest, ActionCreatePreparedStatementRequest,
- ActionCreatePreparedStatementResult, CommandGetCatalogs, CommandGetCrossReference,
- CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys,
- CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables,
- CommandPreparedStatementQuery, CommandStatementQuery, CommandStatementUpdate,
- DoPutUpdateResult, ProstAnyExt, ProstMessageExt, SqlInfo,
+ ActionCreatePreparedStatementResult, Any, CommandGetCatalogs,
+ CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys,
+ CommandGetImportedKeys, CommandGetPrimaryKeys, CommandGetSqlInfo,
+ CommandGetTableTypes, CommandGetTables, CommandPreparedStatementQuery,
+ CommandStatementQuery, CommandStatementUpdate, DoPutUpdateResult, ProstMessageExt,
+ SqlInfo,
};
use crate::{
Action, FlightData, FlightDescriptor, FlightInfo, HandshakeRequest,
@@ -177,8 +178,8 @@ impl FlightSqlServiceClient {
.await
.map_err(status_to_arrow_error)?
.unwrap();
- let any: prost_types::Any = prost::Message::decode(&*result.app_metadata)
- .map_err(decode_error_to_arrow_error)?;
+ let any =
+ Any::decode(&*result.app_metadata).map_err(decode_error_to_arrow_error)?;
let result: DoPutUpdateResult = any.unpack()?.unwrap();
Ok(result.record_count)
}
@@ -298,8 +299,7 @@ impl FlightSqlServiceClient {
.await
.map_err(status_to_arrow_error)?
.unwrap();
- let any: prost_types::Any =
- prost::Message::decode(&*result.body).map_err(decode_error_to_arrow_error)?;
+ let any = Any::decode(&*result.body).map_err(decode_error_to_arrow_error)?;
let prepared_result: ActionCreatePreparedStatementResult = any.unpack()?.unwrap();
let dataset_schema = match prepared_result.dataset_schema.len() {
0 => Schema::empty(),
@@ -384,8 +384,8 @@ impl PreparedStatement<Channel> {
.await
.map_err(status_to_arrow_error)?
.unwrap();
- let any: prost_types::Any = Message::decode(&*result.app_metadata)
- .map_err(decode_error_to_arrow_error)?;
+ let any =
+ Any::decode(&*result.app_metadata).map_err(decode_error_to_arrow_error)?;
let result: DoPutUpdateResult = any.unpack()?.unwrap();
Ok(result.record_count)
}
diff --git a/arrow-flight/src/sql/mod.rs b/arrow-flight/src/sql/mod.rs
index 0ddc64c55..88dc6cde9 100644
--- a/arrow-flight/src/sql/mod.rs
+++ b/arrow-flight/src/sql/mod.rs
@@ -16,6 +16,7 @@
// under the License.
use arrow_schema::ArrowError;
+use bytes::Bytes;
use prost::Message;
mod gen {
@@ -66,8 +67,8 @@ pub trait ProstMessageExt: prost::Message + Default {
/// type_url for this Message
fn type_url() -> &'static str;
- /// Convert this Message to prost_types::Any
- fn as_any(&self) -> prost_types::Any;
+ /// Convert this Message to [`Any`]
+ fn as_any(&self) -> Any;
}
macro_rules! prost_message_ext {
@@ -78,10 +79,10 @@ macro_rules! prost_message_ext {
concat!("type.googleapis.com/arrow.flight.protocol.sql.", stringify!($name))
}
- fn as_any(&self) -> prost_types::Any {
- prost_types::Any {
+ fn as_any(&self) -> Any {
+ Any {
type_url: <$name>::type_url().to_string(),
- value: self.encode_to_vec(),
+ value: self.encode_to_vec().into(),
}
}
}
@@ -111,30 +112,44 @@ prost_message_ext!(
TicketStatementQuery,
);
-/// ProstAnyExt are useful utility methods for prost_types::Any
-/// The API design is inspired by [rust-protobuf](https://github.com/stepancheg/rust-protobuf/blob/master/protobuf/src/well_known_types_util/any.rs)
-pub trait ProstAnyExt {
- /// Check if `Any` contains a message of given type.
- fn is<M: ProstMessageExt>(&self) -> bool;
-
- /// Extract a message from this `Any`.
- ///
- /// # Returns
- ///
- /// * `Ok(None)` when message type mismatch
- /// * `Err` when parse failed
- fn unpack<M: ProstMessageExt>(&self) -> Result<Option<M>, ArrowError>;
-
- /// Pack any message into `prost_types::Any` value.
- fn pack<M: ProstMessageExt>(message: &M) -> Result<prost_types::Any, ArrowError>;
+/// An implementation of the protobuf [`Any`] message type
+///
+/// Encoded protobuf messages are not self-describing, nor contain any information
+/// on the schema of the encoded payload. Consequently to decode a protobuf a client
+/// must know the exact schema of the message.
+///
+/// This presents a problem for loosely typed APIs, where the exact message payloads
+/// are not enumerable, and therefore cannot be enumerated as variants in a [oneof].
+///
+/// One solution is [`Any`] where the encoded payload is paired with a `type_url`
+/// identifying the type of encoded message, and the resulting combination encoded.
+///
+/// Clients can then decode the outer [`Any`], inspect the `type_url` and if it is
+/// a type they recognise, proceed to decode the embedded message `value`
+///
+/// [`Any`]: https://developers.google.com/protocol-buffers/docs/proto3#any
+/// [oneof]: https://developers.google.com/protocol-buffers/docs/proto3#oneof
+#[derive(Clone, PartialEq, ::prost::Message)]
+pub struct Any {
+ /// A URL/resource name that uniquely identifies the type of the serialized
+ /// protocol buffer message. This string must contain at least
+ /// one "/" character. The last segment of the URL's path must represent
+ /// the fully qualified name of the type (as in
+ /// `path/google.protobuf.Duration`). The name should be in a canonical form
+ /// (e.g., leading "." is not accepted).
+ #[prost(string, tag = "1")]
+ pub type_url: String,
+ /// Must be a valid serialized protocol buffer of the above specified type.
+ #[prost(bytes = "bytes", tag = "2")]
+ pub value: Bytes,
}
-impl ProstAnyExt for prost_types::Any {
- fn is<M: ProstMessageExt>(&self) -> bool {
+impl Any {
+ pub fn is<M: ProstMessageExt>(&self) -> bool {
M::type_url() == self.type_url
}
- fn unpack<M: ProstMessageExt>(&self) -> Result<Option<M>, ArrowError> {
+ pub fn unpack<M: ProstMessageExt>(&self) -> Result<Option<M>, ArrowError> {
if !self.is::<M>() {
return Ok(None);
}
@@ -144,7 +159,7 @@ impl ProstAnyExt for prost_types::Any {
Ok(Some(m))
}
- fn pack<M: ProstMessageExt>(message: &M) -> Result<prost_types::Any, ArrowError> {
+ pub fn pack<M: ProstMessageExt>(message: &M) -> Result<Any, ArrowError> {
Ok(message.as_any())
}
}
@@ -170,7 +185,7 @@ mod tests {
let query = CommandStatementQuery {
query: "select 1".to_string(),
};
- let any = prost_types::Any::pack(&query).unwrap();
+ let any = Any::pack(&query).unwrap();
assert!(any.is::<CommandStatementQuery>());
let unpack_query: CommandStatementQuery = any.unpack().unwrap().unwrap();
assert_eq!(query, unpack_query);
diff --git a/arrow-flight/src/sql/server.rs b/arrow-flight/src/sql/server.rs
index ec48d7cfe..fdf9c9133 100644
--- a/arrow-flight/src/sql/server.rs
+++ b/arrow-flight/src/sql/server.rs
@@ -17,6 +17,7 @@
use std::pin::Pin;
+use crate::sql::Any;
use futures::Stream;
use prost::Message;
use tonic::{Request, Response, Status, Streaming};
@@ -32,7 +33,7 @@ use super::{
CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys,
CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables,
CommandPreparedStatementQuery, CommandPreparedStatementUpdate, CommandStatementQuery,
- CommandStatementUpdate, DoPutUpdateResult, ProstAnyExt, ProstMessageExt, SqlInfo,
+ CommandStatementUpdate, DoPutUpdateResult, ProstMessageExt, SqlInfo,
TicketStatementQuery,
};
@@ -63,7 +64,7 @@ pub trait FlightSqlService: Sync + Send + Sized + 'static {
async fn do_get_fallback(
&self,
_request: Request<Ticket>,
- message: prost_types::Any,
+ message: Any,
) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
Err(Status::unimplemented(format!(
"do_get: The defined request is invalid: {}",
@@ -311,8 +312,8 @@ where
&self,
request: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>, Status> {
- let message: prost_types::Any =
- Message::decode(&*request.get_ref().cmd).map_err(decode_error_to_status)?;
+ let message =
+ Any::decode(&*request.get_ref().cmd).map_err(decode_error_to_status)?;
if message.is::<CommandStatementQuery>() {
let token = message
@@ -411,10 +412,10 @@ where
&self,
request: Request<Ticket>,
) -> Result<Response<Self::DoGetStream>, Status> {
- let msg: prost_types::Any = Message::decode(&*request.get_ref().ticket)
+ let msg: Any = Message::decode(&*request.get_ref().ticket)
.map_err(decode_error_to_status)?;
- fn unpack<T: ProstMessageExt>(msg: prost_types::Any) -> Result<T, Status> {
+ fn unpack<T: ProstMessageExt>(msg: Any) -> Result<T, Status> {
msg.unpack()
.map_err(arrow_error_to_status)?
.ok_or_else(|| Status::internal("Expected a command, but found none."))
@@ -462,9 +463,8 @@ where
mut request: Request<Streaming<FlightData>>,
) -> Result<Response<Self::DoPutStream>, Status> {
let cmd = request.get_mut().message().await?.unwrap();
- let message: prost_types::Any =
- Message::decode(&*cmd.flight_descriptor.unwrap().cmd)
- .map_err(decode_error_to_status)?;
+ let message = Any::decode(&*cmd.flight_descriptor.unwrap().cmd)
+ .map_err(decode_error_to_status)?;
if message.is::<CommandStatementUpdate>() {
let token = message
.unpack()
@@ -536,8 +536,8 @@ where
request: Request<Action>,
) -> Result<Response<Self::DoActionStream>, Status> {
if request.get_ref().r#type == CREATE_PREPARED_STATEMENT {
- let any: prost_types::Any = Message::decode(&*request.get_ref().body)
- .map_err(decode_error_to_status)?;
+ let any =
+ Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?;
let cmd: ActionCreatePreparedStatementRequest = any
.unpack()
@@ -556,8 +556,8 @@ where
return Ok(Response::new(Box::pin(output)));
}
if request.get_ref().r#type == CLOSE_PREPARED_STATEMENT {
- let any: prost_types::Any = Message::decode(&*request.get_ref().body)
- .map_err(decode_error_to_status)?;
+ let any =
+ Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?;
let cmd: ActionClosePreparedStatementRequest = any
.unpack()