You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by tu...@apache.org on 2023/01/30 17:03:57 UTC

[arrow-rs] branch master updated: Add limit to ArrowReaderBuilder to push limit down to parquet reader (#3633)

This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new a76ea1c8e Add limit to ArrowReaderBuilder to push limit down to parquet reader (#3633)
a76ea1c8e is described below

commit a76ea1c8e2a40902300185ce8122a19f85ff550b
Author: Dan Harris <13...@users.noreply.github.com>
AuthorDate: Mon Jan 30 19:03:50 2023 +0200

    Add limit to ArrowReaderBuilder to push limit down to parquet reader (#3633)
    
    * Add limit to ArrowReaderBuilder to push limit down to parquet reader
    
    * Update parquet/src/arrow/arrow_reader/mod.rs
    
    Co-authored-by: Raphael Taylor-Davies <17...@users.noreply.github.com>
    
    * pr comments
    
    * Apply limit to entire file instead of each row group
    
    ---------
    
    Co-authored-by: Raphael Taylor-Davies <17...@users.noreply.github.com>
---
 parquet/src/arrow/arrow_reader/mod.rs       |  58 ++++++++++-
 parquet/src/arrow/arrow_reader/selection.rs |  82 +++++++++++++++-
 parquet/src/arrow/async_reader/mod.rs       | 147 +++++++++++++++++++++++++++-
 3 files changed, 282 insertions(+), 5 deletions(-)

diff --git a/parquet/src/arrow/arrow_reader/mod.rs b/parquet/src/arrow/arrow_reader/mod.rs
index 87165ef8e..c4b645da7 100644
--- a/parquet/src/arrow/arrow_reader/mod.rs
+++ b/parquet/src/arrow/arrow_reader/mod.rs
@@ -69,6 +69,8 @@ pub struct ArrowReaderBuilder<T> {
     pub(crate) filter: Option<RowFilter>,
 
     pub(crate) selection: Option<RowSelection>,
+
+    pub(crate) limit: Option<usize>,
 }
 
 impl<T> ArrowReaderBuilder<T> {
@@ -98,6 +100,7 @@ impl<T> ArrowReaderBuilder<T> {
             projection: ProjectionMask::all(),
             filter: None,
             selection: None,
+            limit: None,
         })
     }
 
@@ -167,6 +170,17 @@ impl<T> ArrowReaderBuilder<T> {
             ..self
         }
     }
+
+    /// Provide a limit to the number of rows to be read
+    ///
+    /// The limit will be applied after any [`Self::with_row_selection`] and [`Self::with_row_filter`]
+    /// allowing it to limit the final set of rows decoded after any pushed down predicates
+    pub fn with_limit(self, limit: usize) -> Self {
+        Self {
+            limit: Some(limit),
+            ..self
+        }
+    }
 }
 
 /// Arrow reader api.
@@ -453,6 +467,19 @@ impl<T: ChunkReader + 'static> ArrowReaderBuilder<SyncReader<T>> {
             selection = Some(RowSelection::from(vec![]));
         }
 
+        // If a limit is defined, apply it to the final `RowSelection`
+        if let Some(limit) = self.limit {
+            selection = Some(
+                selection
+                    .map(|selection| selection.limit(limit))
+                    .unwrap_or_else(|| {
+                        RowSelection::from(vec![RowSelector::select(
+                            limit.min(reader.num_rows()),
+                        )])
+                    }),
+            );
+        }
+
         Ok(ParquetRecordBatchReader::new(
             batch_size,
             array_reader,
@@ -1215,6 +1242,8 @@ mod tests {
         row_selections: Option<(RowSelection, usize)>,
         /// row filter
         row_filter: Option<Vec<bool>>,
+        /// limit
+        limit: Option<usize>,
     }
 
     /// Manually implement this to avoid printing entire contents of row_selections and row_filter
@@ -1233,6 +1262,7 @@ mod tests {
                 .field("encoding", &self.encoding)
                 .field("row_selections", &self.row_selections.is_some())
                 .field("row_filter", &self.row_filter.is_some())
+                .field("limit", &self.limit)
                 .finish()
         }
     }
@@ -1252,6 +1282,7 @@ mod tests {
                 encoding: Encoding::PLAIN,
                 row_selections: None,
                 row_filter: None,
+                limit: None,
             }
         }
     }
@@ -1323,6 +1354,13 @@ mod tests {
             }
         }
 
+        fn with_limit(self, limit: usize) -> Self {
+            Self {
+                limit: Some(limit),
+                ..self
+            }
+        }
+
         fn writer_props(&self) -> WriterProperties {
             let builder = WriterProperties::builder()
                 .set_data_pagesize_limit(self.max_data_page_size)
@@ -1381,6 +1419,14 @@ mod tests {
             TestOptions::new(2, 256, 127).with_null_percent(0),
             // Test optional with nulls
             TestOptions::new(2, 256, 93).with_null_percent(25),
+            // Test with limit of 0
+            TestOptions::new(4, 100, 25).with_limit(0),
+            // Test with limit of 50
+            TestOptions::new(4, 100, 25).with_limit(50),
+            // Test with limit equal to number of rows
+            TestOptions::new(4, 100, 25).with_limit(10),
+            // Test with limit larger than number of rows
+            TestOptions::new(4, 100, 25).with_limit(101),
             // Test with no page-level statistics
             TestOptions::new(2, 256, 91)
                 .with_null_percent(25)
@@ -1423,6 +1469,11 @@ mod tests {
             TestOptions::new(2, 256, 93)
                 .with_null_percent(25)
                 .with_row_selections(),
+            // Test optional with nulls
+            TestOptions::new(2, 256, 93)
+                .with_null_percent(25)
+                .with_row_selections()
+                .with_limit(10),
             // Test filter
 
             // Test with row filter
@@ -1592,7 +1643,7 @@ mod tests {
             }
         };
 
-        let expected_data = match opts.row_filter {
+        let mut expected_data = match opts.row_filter {
             Some(filter) => {
                 let expected_data = expected_data
                     .into_iter()
@@ -1622,6 +1673,11 @@ mod tests {
             None => expected_data,
         };
 
+        if let Some(limit) = opts.limit {
+            builder = builder.with_limit(limit);
+            expected_data = expected_data.into_iter().take(limit).collect();
+        }
+
         let mut record_reader = builder
             .with_batch_size(opts.record_batch_size)
             .build()
diff --git a/parquet/src/arrow/arrow_reader/selection.rs b/parquet/src/arrow/arrow_reader/selection.rs
index 03c7e01e0..d2af4516d 100644
--- a/parquet/src/arrow/arrow_reader/selection.rs
+++ b/parquet/src/arrow/arrow_reader/selection.rs
@@ -19,6 +19,7 @@ use arrow_array::{Array, BooleanArray};
 use arrow_select::filter::SlicesIterator;
 use std::cmp::Ordering;
 use std::collections::VecDeque;
+use std::mem;
 use std::ops::Range;
 
 /// [`RowSelection`] is a collection of [`RowSelector`] used to skip rows when
@@ -111,7 +112,7 @@ impl RowSelection {
     }
 
     /// Creates a [`RowSelection`] from an iterator of consecutive ranges to keep
-    fn from_consecutive_ranges<I: Iterator<Item = Range<usize>>>(
+    pub(crate) fn from_consecutive_ranges<I: Iterator<Item = Range<usize>>>(
         ranges: I,
         total_rows: usize,
     ) -> Self {
@@ -371,6 +372,32 @@ impl RowSelection {
         self
     }
 
+    /// Limit this [`RowSelection`] to only select `limit` rows
+    pub(crate) fn limit(mut self, mut limit: usize) -> Self {
+        let mut new_selectors = Vec::with_capacity(self.selectors.len());
+        for mut selection in mem::take(&mut self.selectors) {
+            if limit == 0 {
+                break;
+            }
+
+            if !selection.skip {
+                if selection.row_count >= limit {
+                    selection.row_count = limit;
+                    new_selectors.push(selection);
+                    break;
+                } else {
+                    limit -= selection.row_count;
+                    new_selectors.push(selection);
+                }
+            } else {
+                new_selectors.push(selection);
+            }
+        }
+
+        self.selectors = new_selectors;
+        self
+    }
+
     /// Returns an iterator over the [`RowSelector`]s for this
     /// [`RowSelection`].
     pub fn iter(&self) -> impl Iterator<Item = &RowSelector> {
@@ -841,6 +868,59 @@ mod tests {
         assert_eq!(selectors, round_tripped);
     }
 
+    #[test]
+    fn test_limit() {
+        // Limit to existing limit should no-op
+        let selection =
+            RowSelection::from(vec![RowSelector::select(10), RowSelector::skip(90)]);
+        let limited = selection.limit(10);
+        assert_eq!(RowSelection::from(vec![RowSelector::select(10)]), limited);
+
+        let selection = RowSelection::from(vec![
+            RowSelector::select(10),
+            RowSelector::skip(10),
+            RowSelector::select(10),
+            RowSelector::skip(10),
+            RowSelector::select(10),
+        ]);
+
+        let limited = selection.clone().limit(5);
+        let expected = vec![RowSelector::select(5)];
+        assert_eq!(limited.selectors, expected);
+
+        let limited = selection.clone().limit(15);
+        let expected = vec![
+            RowSelector::select(10),
+            RowSelector::skip(10),
+            RowSelector::select(5),
+        ];
+        assert_eq!(limited.selectors, expected);
+
+        let limited = selection.clone().limit(0);
+        let expected = vec![];
+        assert_eq!(limited.selectors, expected);
+
+        let limited = selection.clone().limit(30);
+        let expected = vec![
+            RowSelector::select(10),
+            RowSelector::skip(10),
+            RowSelector::select(10),
+            RowSelector::skip(10),
+            RowSelector::select(10),
+        ];
+        assert_eq!(limited.selectors, expected);
+
+        let limited = selection.limit(100);
+        let expected = vec![
+            RowSelector::select(10),
+            RowSelector::skip(10),
+            RowSelector::select(10),
+            RowSelector::skip(10),
+            RowSelector::select(10),
+        ];
+        assert_eq!(limited.selectors, expected);
+    }
+
     #[test]
     fn test_scan_ranges() {
         let index = vec![
diff --git a/parquet/src/arrow/async_reader/mod.rs b/parquet/src/arrow/async_reader/mod.rs
index 0397df206..71f95e07a 100644
--- a/parquet/src/arrow/async_reader/mod.rs
+++ b/parquet/src/arrow/async_reader/mod.rs
@@ -99,7 +99,7 @@ use arrow_schema::SchemaRef;
 use crate::arrow::array_reader::{build_array_reader, RowGroupCollection};
 use crate::arrow::arrow_reader::{
     evaluate_predicate, selects_any, ArrowReaderBuilder, ArrowReaderOptions,
-    ParquetRecordBatchReader, RowFilter, RowSelection,
+    ParquetRecordBatchReader, RowFilter, RowSelection, RowSelector,
 };
 use crate::arrow::schema::ParquetField;
 use crate::arrow::ProjectionMask;
@@ -352,6 +352,7 @@ impl<T: AsyncFileReader + Send + 'static> ArrowReaderBuilder<AsyncReader<T>> {
         Ok(ParquetRecordBatchStream {
             metadata: self.metadata,
             batch_size,
+            limit: self.limit,
             row_groups,
             projection: self.projection,
             selection: self.selection,
@@ -389,6 +390,7 @@ where
         mut selection: Option<RowSelection>,
         projection: ProjectionMask,
         batch_size: usize,
+        limit: Option<usize>,
     ) -> ReadResult<T> {
         // TODO: calling build_array multiple times is wasteful
 
@@ -430,6 +432,17 @@ where
             return Ok((self, None));
         }
 
+        // If a limit is defined, apply it to the final `RowSelection`
+        if let Some(limit) = limit {
+            selection = Some(
+                selection
+                    .map(|selection| selection.limit(limit))
+                    .unwrap_or_else(|| {
+                        RowSelection::from(vec![RowSelector::select(limit)])
+                    }),
+            );
+        }
+
         row_group
             .fetch(&mut self.input, &projection, selection.as_ref())
             .await?;
@@ -479,6 +492,8 @@ pub struct ParquetRecordBatchStream<T> {
 
     batch_size: usize,
 
+    limit: Option<usize>,
+
     selection: Option<RowSelection>,
 
     /// This is an option so it can be moved into a future
@@ -519,7 +534,12 @@ where
         loop {
             match &mut self.state {
                 StreamState::Decoding(batch_reader) => match batch_reader.next() {
-                    Some(Ok(batch)) => return Poll::Ready(Some(Ok(batch))),
+                    Some(Ok(batch)) => {
+                        if let Some(limit) = self.limit.as_mut() {
+                            *limit -= batch.num_rows();
+                        }
+                        return Poll::Ready(Some(Ok(batch)));
+                    }
                     Some(Err(e)) => {
                         self.state = StreamState::Error;
                         return Poll::Ready(Some(Err(ParquetError::ArrowError(
@@ -548,6 +568,7 @@ where
                             selection,
                             self.projection.clone(),
                             self.batch_size,
+                            self.limit,
                         )
                         .boxed();
 
@@ -803,6 +824,7 @@ mod tests {
     use crate::arrow::ArrowWriter;
     use crate::file::footer::parse_metadata;
     use crate::file::page_index::index_reader;
+    use crate::file::properties::WriterProperties;
     use arrow::error::Result as ArrowResult;
     use arrow_array::{Array, ArrayRef, Int32Array, StringArray};
     use futures::TryStreamExt;
@@ -943,6 +965,70 @@ mod tests {
         assert_eq!(async_batches, sync_batches);
     }
 
+    #[tokio::test]
+    async fn test_async_reader_with_limit() {
+        let testdata = arrow::util::test_util::parquet_test_data();
+        let path = format!("{testdata}/alltypes_tiny_pages_plain.parquet");
+        let data = Bytes::from(std::fs::read(path).unwrap());
+
+        let metadata = parse_metadata(&data).unwrap();
+        let metadata = Arc::new(metadata);
+
+        assert_eq!(metadata.num_row_groups(), 1);
+
+        let async_reader = TestReader {
+            data: data.clone(),
+            metadata: metadata.clone(),
+            requests: Default::default(),
+        };
+
+        let options = ArrowReaderOptions::new().with_page_index(true);
+        let builder =
+            ParquetRecordBatchStreamBuilder::new_with_options(async_reader, options)
+                .await
+                .unwrap();
+
+        // The builder should have page and offset indexes loaded now
+        let metadata_with_index = builder.metadata();
+
+        // Check offset indexes are present for all columns
+        for rg in metadata_with_index.row_groups() {
+            let page_locations =
+                rg.page_offset_index().expect("expected page offset index");
+            assert_eq!(page_locations.len(), rg.columns().len())
+        }
+
+        // Check page indexes are present for all columns
+        let page_indexes = metadata_with_index
+            .page_indexes()
+            .expect("expected page indexes");
+        for (idx, rg) in metadata_with_index.row_groups().iter().enumerate() {
+            assert_eq!(page_indexes[idx].len(), rg.columns().len())
+        }
+
+        let mask = ProjectionMask::leaves(builder.parquet_schema(), vec![1, 2]);
+        let stream = builder
+            .with_projection(mask.clone())
+            .with_batch_size(1024)
+            .with_limit(1)
+            .build()
+            .unwrap();
+
+        let async_batches: Vec<_> = stream.try_collect().await.unwrap();
+
+        let sync_batches = ParquetRecordBatchReaderBuilder::try_new(data)
+            .unwrap()
+            .with_projection(mask)
+            .with_batch_size(1024)
+            .with_limit(1)
+            .build()
+            .unwrap()
+            .collect::<ArrowResult<Vec<_>>>()
+            .unwrap();
+
+        assert_eq!(async_batches, sync_batches);
+    }
+
     #[tokio::test]
     async fn test_async_reader_skip_pages() {
         let testdata = arrow::util::test_util::parquet_test_data();
@@ -1204,6 +1290,61 @@ mod tests {
         assert_eq!(requests.lock().unwrap().len(), 3);
     }
 
+    #[tokio::test]
+    async fn test_limit_multiple_row_groups() {
+        let a = StringArray::from_iter_values(["a", "b", "b", "b", "c", "c"]);
+        let b = StringArray::from_iter_values(["1", "2", "3", "4", "5", "6"]);
+        let c = Int32Array::from_iter(0..6);
+        let data = RecordBatch::try_from_iter([
+            ("a", Arc::new(a) as ArrayRef),
+            ("b", Arc::new(b) as ArrayRef),
+            ("c", Arc::new(c) as ArrayRef),
+        ])
+        .unwrap();
+
+        let mut buf = Vec::with_capacity(1024);
+        let props = WriterProperties::builder()
+            .set_max_row_group_size(3)
+            .build();
+        let mut writer =
+            ArrowWriter::try_new(&mut buf, data.schema(), Some(props)).unwrap();
+        writer.write(&data).unwrap();
+        writer.close().unwrap();
+
+        let data: Bytes = buf.into();
+        let metadata = parse_metadata(&data).unwrap();
+
+        assert_eq!(metadata.num_row_groups(), 2);
+
+        let test = TestReader {
+            data,
+            metadata: Arc::new(metadata),
+            requests: Default::default(),
+        };
+
+        let stream = ParquetRecordBatchStreamBuilder::new(test)
+            .await
+            .unwrap()
+            .with_batch_size(1024)
+            .with_limit(4)
+            .build()
+            .unwrap();
+
+        let batches: Vec<_> = stream.try_collect().await.unwrap();
+        // Expect one batch for each row group
+        assert_eq!(batches.len(), 2);
+
+        let batch = &batches[0];
+        // First batch should contain all rows
+        assert_eq!(batch.num_rows(), 3);
+        assert_eq!(batch.num_columns(), 3);
+
+        let batch = &batches[1];
+        // Second batch should trigger the limit and only have one row
+        assert_eq!(batch.num_rows(), 1);
+        assert_eq!(batch.num_columns(), 3);
+    }
+
     #[tokio::test]
     async fn test_row_filter_with_index() {
         let testdata = arrow::util::test_util::parquet_test_data();
@@ -1330,7 +1471,7 @@ mod tests {
         let selection = RowSelection::from(selectors);
 
         let (_factory, _reader) = reader_factory
-            .read_row_group(0, Some(selection), projection.clone(), 48)
+            .read_row_group(0, Some(selection), projection.clone(), 48, None)
             .await
             .expect("reading row group");