You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by tu...@apache.org on 2023/01/26 10:12:48 UTC

[arrow-rs] branch master updated: Support sending schemas for empty streams (#3594)

This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 902a17d7d Support sending schemas for empty streams (#3594)
902a17d7d is described below

commit 902a17d7d3817ef9030adeb535fd5951b9f72590
Author: Andrew Lamb <an...@nerdnetworks.org>
AuthorDate: Thu Jan 26 11:12:42 2023 +0100

    Support sending schemas for empty streams (#3594)
    
    * Support sending schemas for empty streams
    
    * comments
    
    * clippy
    
    * Restore got_schema, return references
    
    * Review comments
    
    * revert unecessary change
    
    * Update arrow-flight/src/decode.rs
    
    Co-authored-by: Raphael Taylor-Davies <17...@users.noreply.github.com>
    
    Co-authored-by: Raphael Taylor-Davies <17...@users.noreply.github.com>
---
 arrow-flight/src/decode.rs          | 32 ++++++++++--------
 arrow-flight/src/encode.rs          | 66 +++++++++++++++++++++++++++----------
 arrow-flight/tests/encode_decode.rs | 27 +++++++++++++++
 3 files changed, 94 insertions(+), 31 deletions(-)

diff --git a/arrow-flight/src/decode.rs b/arrow-flight/src/decode.rs
index cab52a434..fe132e3e8 100644
--- a/arrow-flight/src/decode.rs
+++ b/arrow-flight/src/decode.rs
@@ -17,7 +17,7 @@
 
 use crate::{utils::flight_data_to_arrow_batch, FlightData};
 use arrow_array::{ArrayRef, RecordBatch};
-use arrow_schema::Schema;
+use arrow_schema::{Schema, SchemaRef};
 use bytes::Bytes;
 use futures::{ready, stream::BoxStream, Stream, StreamExt};
 use std::{
@@ -82,16 +82,12 @@ use crate::error::{FlightError, Result};
 #[derive(Debug)]
 pub struct FlightRecordBatchStream {
     inner: FlightDataDecoder,
-    got_schema: bool,
 }
 
 impl FlightRecordBatchStream {
     /// Create a new [`FlightRecordBatchStream`] from a decoded stream
     pub fn new(inner: FlightDataDecoder) -> Self {
-        Self {
-            inner,
-            got_schema: false,
-        }
+        Self { inner }
     }
 
     /// Create a new [`FlightRecordBatchStream`] from a stream of [`FlightData`]
@@ -101,13 +97,18 @@ impl FlightRecordBatchStream {
     {
         Self {
             inner: FlightDataDecoder::new(inner),
-            got_schema: false,
         }
     }
 
     /// Has a message defining the schema been received yet?
+    #[deprecated = "use schema().is_some() instead"]
     pub fn got_schema(&self) -> bool {
-        self.got_schema
+        self.schema().is_some()
+    }
+
+    /// Return schema for the stream, if it has been received
+    pub fn schema(&self) -> Option<&SchemaRef> {
+        self.inner.schema()
     }
 
     /// Consume self and return the wrapped [`FlightDataDecoder`]
@@ -125,6 +126,7 @@ impl futures::Stream for FlightRecordBatchStream {
         cx: &mut std::task::Context<'_>,
     ) -> Poll<Option<Result<RecordBatch>>> {
         loop {
+            let had_schema = self.schema().is_some();
             let res = ready!(self.inner.poll_next_unpin(cx));
             match res {
                 // Inner exhausted
@@ -136,13 +138,12 @@ impl futures::Stream for FlightRecordBatchStream {
                 }
                 // translate data
                 Some(Ok(data)) => match data.payload {
-                    DecodedPayload::Schema(_) if self.got_schema => {
+                    DecodedPayload::Schema(_) if had_schema => {
                         return Poll::Ready(Some(Err(FlightError::protocol(
                             "Unexpectedly saw multiple Schema messages in FlightData stream",
                         ))));
                     }
                     DecodedPayload::Schema(_) => {
-                        self.got_schema = true;
                         // Need next message, poll inner again
                     }
                     DecodedPayload::RecordBatch(batch) => {
@@ -219,6 +220,11 @@ impl FlightDataDecoder {
         }
     }
 
+    /// Returns the current schema for this stream
+    pub fn schema(&self) -> Option<&SchemaRef> {
+        self.state.as_ref().map(|state| &state.schema)
+    }
+
     /// Extracts flight data from the next message, updating decoding
     /// state as necessary.
     fn extract_message(&mut self, data: FlightData) -> Result<Option<DecodedFlightData>> {
@@ -343,7 +349,7 @@ impl futures::Stream for FlightDataDecoder {
 /// streaming flight response.
 #[derive(Debug)]
 struct FlightStreamState {
-    schema: Arc<Schema>,
+    schema: SchemaRef,
     dictionaries_by_field: HashMap<i64, ArrayRef>,
 }
 
@@ -362,7 +368,7 @@ impl DecodedFlightData {
         }
     }
 
-    pub fn new_schema(inner: FlightData, schema: Arc<Schema>) -> Self {
+    pub fn new_schema(inner: FlightData, schema: SchemaRef) -> Self {
         Self {
             inner,
             payload: DecodedPayload::Schema(schema),
@@ -389,7 +395,7 @@ pub enum DecodedPayload {
     None,
 
     /// A decoded Schema message
-    Schema(Arc<Schema>),
+    Schema(SchemaRef),
 
     /// A decoded Record batch.
     RecordBatch(RecordBatch),
diff --git a/arrow-flight/src/encode.rs b/arrow-flight/src/encode.rs
index c130a2d7e..2f06ee58f 100644
--- a/arrow-flight/src/encode.rs
+++ b/arrow-flight/src/encode.rs
@@ -70,6 +70,8 @@ pub struct FlightDataEncoderBuilder {
     options: IpcWriteOptions,
     /// Metadata to add to the schema message
     app_metadata: Bytes,
+    /// Optional schema, if known before data.
+    schema: Option<SchemaRef>,
 }
 
 /// Default target size for encoded [`FlightData`].
@@ -84,6 +86,7 @@ impl Default for FlightDataEncoderBuilder {
             max_flight_data_size: GRPC_TARGET_MAX_FLIGHT_SIZE_BYTES,
             options: IpcWriteOptions::default(),
             app_metadata: Bytes::new(),
+            schema: None,
         }
     }
 }
@@ -122,6 +125,15 @@ impl FlightDataEncoderBuilder {
         self
     }
 
+    /// Specify a schema for the RecordBatches being sent. If a schema
+    /// is not specified, an encoded Schema message will be sent when
+    /// the first [`RecordBatch`], if any, is encoded. Some clients
+    /// expect a Schema message even if there is no data sent.
+    pub fn with_schema(mut self, schema: SchemaRef) -> Self {
+        self.schema = Some(schema);
+        self
+    }
+
     /// Return a [`Stream`](futures::Stream) of [`FlightData`],
     /// consuming self. More details on [`FlightDataEncoder`]
     pub fn build<S>(self, input: S) -> FlightDataEncoder
@@ -132,9 +144,16 @@ impl FlightDataEncoderBuilder {
             max_flight_data_size,
             options,
             app_metadata,
+            schema,
         } = self;
 
-        FlightDataEncoder::new(input.boxed(), max_flight_data_size, options, app_metadata)
+        FlightDataEncoder::new(
+            input.boxed(),
+            schema,
+            max_flight_data_size,
+            options,
+            app_metadata,
+        )
     }
 }
 
@@ -162,11 +181,12 @@ pub struct FlightDataEncoder {
 impl FlightDataEncoder {
     fn new(
         inner: BoxStream<'static, Result<RecordBatch>>,
+        schema: Option<SchemaRef>,
         max_flight_data_size: usize,
         options: IpcWriteOptions,
         app_metadata: Bytes,
     ) -> Self {
-        Self {
+        let mut encoder = Self {
             inner,
             schema: None,
             max_flight_data_size,
@@ -174,7 +194,13 @@ impl FlightDataEncoder {
             app_metadata: Some(app_metadata),
             queue: VecDeque::new(),
             done: false,
+        };
+
+        // If schema is known up front, enqueue it immediately
+        if let Some(schema) = schema {
+            encoder.encode_schema(&schema);
         }
+        encoder
     }
 
     /// Place the `FlightData` in the queue to send
@@ -189,26 +215,30 @@ impl FlightDataEncoder {
         }
     }
 
+    /// Encodes schema as a [`FlightData`] in self.queue.
+    /// Updates `self.schema` and returns the new schema
+    fn encode_schema(&mut self, schema: &SchemaRef) -> SchemaRef {
+        // The first message is the schema message, and all
+        // batches have the same schema
+        let schema = Arc::new(prepare_schema_for_flight(schema));
+        let mut schema_flight_data = self.encoder.encode_schema(&schema);
+
+        // attach any metadata requested
+        if let Some(app_metadata) = self.app_metadata.take() {
+            schema_flight_data.app_metadata = app_metadata;
+        }
+        self.queue_message(schema_flight_data);
+        // remember schema
+        self.schema = Some(schema.clone());
+        schema
+    }
+
     /// Encodes batch into one or more `FlightData` messages in self.queue
     fn encode_batch(&mut self, batch: RecordBatch) -> Result<()> {
         let schema = match &self.schema {
             Some(schema) => schema.clone(),
-            None => {
-                let batch_schema = batch.schema();
-                // The first message is the schema message, and all
-                // batches have the same schema
-                let schema = Arc::new(prepare_schema_for_flight(&batch_schema));
-                let mut schema_flight_data = self.encoder.encode_schema(&schema);
-
-                // attach any metadata requested
-                if let Some(app_metadata) = self.app_metadata.take() {
-                    schema_flight_data.app_metadata = app_metadata;
-                }
-                self.queue_message(schema_flight_data);
-                // remember schema
-                self.schema = Some(schema.clone());
-                schema
-            }
+            // encode the schema if this is the first time we have seen it
+            None => self.encode_schema(&batch.schema()),
         };
 
         // encode the batch
diff --git a/arrow-flight/tests/encode_decode.rs b/arrow-flight/tests/encode_decode.rs
index 0aa987687..1990e5b0c 100644
--- a/arrow-flight/tests/encode_decode.rs
+++ b/arrow-flight/tests/encode_decode.rs
@@ -96,6 +96,33 @@ async fn test_dictionary_many() {
     .await;
 }
 
+#[tokio::test]
+async fn test_zero_batches_no_schema() {
+    let stream = FlightDataEncoderBuilder::default().build(futures::stream::iter(vec![]));
+
+    let mut decoder = FlightRecordBatchStream::new_from_flight_data(stream);
+    assert!(decoder.schema().is_none());
+    // No batches come out
+    assert!(decoder.next().await.is_none());
+    // schema has not been received
+    assert!(decoder.schema().is_none());
+}
+
+#[tokio::test]
+async fn test_zero_batches_schema_specified() {
+    let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)]));
+    let stream = FlightDataEncoderBuilder::default()
+        .with_schema(schema.clone())
+        .build(futures::stream::iter(vec![]));
+
+    let mut decoder = FlightRecordBatchStream::new_from_flight_data(stream);
+    assert!(decoder.schema().is_none());
+    // No batches come out
+    assert!(decoder.next().await.is_none());
+    // But schema has been received correctly
+    assert_eq!(decoder.schema(), Some(&schema));
+}
+
 #[tokio::test]
 async fn test_app_metadata() {
     let input_batch_stream = futures::stream::iter(vec![Ok(make_primative_batch(78))]);