You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2023/11/07 13:57:16 UTC

(arrow-datafusion) branch main updated: Fix incorrect results in COUNT(*) queries with LIMIT (#8049)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 06fd26b86d Fix incorrect results in COUNT(*) queries with LIMIT (#8049)
06fd26b86d is described below

commit 06fd26b86dd8e5269966be8658862e5a5a12f948
Author: Mark Sirek <mw...@gmail.com>
AuthorDate: Tue Nov 7 05:57:10 2023 -0800

    Fix incorrect results in COUNT(*) queries with LIMIT (#8049)
    
    Co-authored-by: Mark Sirek <si...@cockroachlabs.com>
---
 datafusion/physical-plan/src/limit.rs          | 159 +++++++++++++++++++++----
 datafusion/sqllogictest/test_files/explain.slt |   2 +-
 datafusion/sqllogictest/test_files/limit.slt   |  85 +++++++++++++
 datafusion/sqllogictest/test_files/window.slt  |  12 +-
 4 files changed, 232 insertions(+), 26 deletions(-)

diff --git a/datafusion/physical-plan/src/limit.rs b/datafusion/physical-plan/src/limit.rs
index 945dad16b7..c8427f9bc2 100644
--- a/datafusion/physical-plan/src/limit.rs
+++ b/datafusion/physical-plan/src/limit.rs
@@ -188,21 +188,11 @@ impl ExecutionPlan for GlobalLimitExec {
     fn statistics(&self) -> Result<Statistics> {
         let input_stats = self.input.statistics()?;
         let skip = self.skip;
-        // the maximum row number needs to be fetched
-        let max_row_num = self
-            .fetch
-            .map(|fetch| {
-                if fetch >= usize::MAX - skip {
-                    usize::MAX
-                } else {
-                    fetch + skip
-                }
-            })
-            .unwrap_or(usize::MAX);
         let col_stats = Statistics::unknown_column(&self.schema());
+        let fetch = self.fetch.unwrap_or(usize::MAX);
 
-        let fetched_row_number_stats = Statistics {
-            num_rows: Precision::Exact(max_row_num),
+        let mut fetched_row_number_stats = Statistics {
+            num_rows: Precision::Exact(fetch),
             column_statistics: col_stats.clone(),
             total_byte_size: Precision::Absent,
         };
@@ -218,23 +208,55 @@ impl ExecutionPlan for GlobalLimitExec {
             } => {
                 if nr <= skip {
                     // if all input data will be skipped, return 0
-                    Statistics {
+                    let mut skip_all_rows_stats = Statistics {
                         num_rows: Precision::Exact(0),
                         column_statistics: col_stats,
                         total_byte_size: Precision::Absent,
+                    };
+                    if !input_stats.num_rows.is_exact().unwrap_or(false) {
+                        // The input stats are inexact, so the output stats must be too.
+                        skip_all_rows_stats = skip_all_rows_stats.into_inexact();
                     }
-                } else if nr <= max_row_num {
-                    // if the input does not reach the "fetch" globally, return input stats
+                    skip_all_rows_stats
+                } else if nr <= fetch && self.skip == 0 {
+                    // if the input does not reach the "fetch" globally, and "skip" is zero
+                    // (meaning the input and output are identical), return input stats.
+                    // Can input_stats still be used, but adjusted, in the "skip != 0" case?
                     input_stats
+                } else if nr - skip <= fetch {
+                    // after "skip" input rows are skipped, the remaining rows are less than or equal to the
+                    // "fetch" values, so `num_rows` must equal the remaining rows
+                    let remaining_rows: usize = nr - skip;
+                    let mut skip_some_rows_stats = Statistics {
+                        num_rows: Precision::Exact(remaining_rows),
+                        column_statistics: col_stats.clone(),
+                        total_byte_size: Precision::Absent,
+                    };
+                    if !input_stats.num_rows.is_exact().unwrap_or(false) {
+                        // The input stats are inexact, so the output stats must be too.
+                        skip_some_rows_stats = skip_some_rows_stats.into_inexact();
+                    }
+                    skip_some_rows_stats
                 } else {
-                    // if the input is greater than the "fetch", the num_row will be the "fetch",
+                    // if the input is greater than "fetch+skip", the num_rows will be the "fetch",
                     // but we won't be able to predict the other statistics
+                    if !input_stats.num_rows.is_exact().unwrap_or(false)
+                        || self.fetch.is_none()
+                    {
+                        // If the input stats are inexact, the output stats must be too.
+                        // If the fetch value is `usize::MAX` because no LIMIT was specified,
+                        // we also can't represent it as an exact value.
+                        fetched_row_number_stats =
+                            fetched_row_number_stats.into_inexact();
+                    }
                     fetched_row_number_stats
                 }
             }
             _ => {
-                // the result output row number will always be no greater than the limit number
-                fetched_row_number_stats
+                // The result output `num_rows` will always be no greater than the limit number.
+                // Should `num_rows` be marked as `Absent` here when the `fetch` value is large,
+                // as the actual `num_rows` may be far away from the `fetch` value?
+                fetched_row_number_stats.into_inexact()
             }
         };
         Ok(stats)
@@ -552,7 +574,10 @@ mod tests {
     use crate::common::collect;
     use crate::{common, test};
 
+    use crate::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy};
     use arrow_schema::Schema;
+    use datafusion_physical_expr::expressions::col;
+    use datafusion_physical_expr::PhysicalExpr;
 
     #[tokio::test]
     async fn limit() -> Result<()> {
@@ -712,7 +737,7 @@ mod tests {
     }
 
     #[tokio::test]
-    async fn skip_3_fetch_10() -> Result<()> {
+    async fn skip_3_fetch_10_stats() -> Result<()> {
         // there are total of 100 rows, we skipped 3 rows (offset = 3)
         let row_count = skip_and_fetch(3, Some(10)).await?;
         assert_eq!(row_count, 10);
@@ -748,7 +773,58 @@ mod tests {
         assert_eq!(row_count, Precision::Exact(10));
 
         let row_count = row_number_statistics_for_global_limit(5, Some(10)).await?;
-        assert_eq!(row_count, Precision::Exact(15));
+        assert_eq!(row_count, Precision::Exact(10));
+
+        let row_count = row_number_statistics_for_global_limit(400, Some(10)).await?;
+        assert_eq!(row_count, Precision::Exact(0));
+
+        let row_count = row_number_statistics_for_global_limit(398, Some(10)).await?;
+        assert_eq!(row_count, Precision::Exact(2));
+
+        let row_count = row_number_statistics_for_global_limit(398, Some(1)).await?;
+        assert_eq!(row_count, Precision::Exact(1));
+
+        let row_count = row_number_statistics_for_global_limit(398, None).await?;
+        assert_eq!(row_count, Precision::Exact(2));
+
+        let row_count =
+            row_number_statistics_for_global_limit(0, Some(usize::MAX)).await?;
+        assert_eq!(row_count, Precision::Exact(400));
+
+        let row_count =
+            row_number_statistics_for_global_limit(398, Some(usize::MAX)).await?;
+        assert_eq!(row_count, Precision::Exact(2));
+
+        let row_count =
+            row_number_inexact_statistics_for_global_limit(0, Some(10)).await?;
+        assert_eq!(row_count, Precision::Inexact(10));
+
+        let row_count =
+            row_number_inexact_statistics_for_global_limit(5, Some(10)).await?;
+        assert_eq!(row_count, Precision::Inexact(10));
+
+        let row_count =
+            row_number_inexact_statistics_for_global_limit(400, Some(10)).await?;
+        assert_eq!(row_count, Precision::Inexact(0));
+
+        let row_count =
+            row_number_inexact_statistics_for_global_limit(398, Some(10)).await?;
+        assert_eq!(row_count, Precision::Inexact(2));
+
+        let row_count =
+            row_number_inexact_statistics_for_global_limit(398, Some(1)).await?;
+        assert_eq!(row_count, Precision::Inexact(1));
+
+        let row_count = row_number_inexact_statistics_for_global_limit(398, None).await?;
+        assert_eq!(row_count, Precision::Inexact(2));
+
+        let row_count =
+            row_number_inexact_statistics_for_global_limit(0, Some(usize::MAX)).await?;
+        assert_eq!(row_count, Precision::Inexact(400));
+
+        let row_count =
+            row_number_inexact_statistics_for_global_limit(398, Some(usize::MAX)).await?;
+        assert_eq!(row_count, Precision::Inexact(2));
 
         Ok(())
     }
@@ -776,6 +852,47 @@ mod tests {
         Ok(offset.statistics()?.num_rows)
     }
 
+    pub fn build_group_by(
+        input_schema: &SchemaRef,
+        columns: Vec<String>,
+    ) -> PhysicalGroupBy {
+        let mut group_by_expr: Vec<(Arc<dyn PhysicalExpr>, String)> = vec![];
+        for column in columns.iter() {
+            group_by_expr.push((col(column, input_schema).unwrap(), column.to_string()));
+        }
+        PhysicalGroupBy::new_single(group_by_expr.clone())
+    }
+
+    async fn row_number_inexact_statistics_for_global_limit(
+        skip: usize,
+        fetch: Option<usize>,
+    ) -> Result<Precision<usize>> {
+        let num_partitions = 4;
+        let csv = test::scan_partitioned(num_partitions);
+
+        assert_eq!(csv.output_partitioning().partition_count(), num_partitions);
+
+        // Adding a "GROUP BY i" changes the input stats from Exact to Inexact.
+        let agg = AggregateExec::try_new(
+            AggregateMode::Final,
+            build_group_by(&csv.schema().clone(), vec!["i".to_string()]),
+            vec![],
+            vec![None],
+            vec![None],
+            csv.clone(),
+            csv.schema().clone(),
+        )?;
+        let agg_exec: Arc<dyn ExecutionPlan> = Arc::new(agg);
+
+        let offset = GlobalLimitExec::new(
+            Arc::new(CoalescePartitionsExec::new(agg_exec)),
+            skip,
+            fetch,
+        );
+
+        Ok(offset.statistics()?.num_rows)
+    }
+
     async fn row_number_statistics_for_local_limit(
         num_partitions: usize,
         fetch: usize,
diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt
index 066a31590c..40a6d43574 100644
--- a/datafusion/sqllogictest/test_files/explain.slt
+++ b/datafusion/sqllogictest/test_files/explain.slt
@@ -273,7 +273,7 @@ query TT
 EXPLAIN SELECT a, b, c FROM simple_explain_test limit 10;
 ----
 physical_plan
-GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(10), Bytes=Absent]
+GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Inexact(10), Bytes=Absent]
 --CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], limit=10, has_header=true, statistics=[Rows=Absent, Bytes=Absent]
 
 # Parquet scan with statistics collected
diff --git a/datafusion/sqllogictest/test_files/limit.slt b/datafusion/sqllogictest/test_files/limit.slt
index 253ca8f335..21248ddbd8 100644
--- a/datafusion/sqllogictest/test_files/limit.slt
+++ b/datafusion/sqllogictest/test_files/limit.slt
@@ -294,6 +294,91 @@ query T
 SELECT c1 FROM aggregate_test_100 LIMIT 1 OFFSET 101
 ----
 
+#
+# global limit statistics test
+#
+
+statement ok
+CREATE TABLE IF NOT EXISTS t1 (a INT) AS VALUES(1),(2),(3),(4),(5),(6),(7),(8),(9),(10);
+
+# The aggregate does not need to be computed because the input statistics are exact and
+# the number of rows is less than the skip value (OFFSET).
+query TT
+EXPLAIN SELECT COUNT(*) FROM (SELECT a FROM t1 LIMIT 3 OFFSET 11);
+----
+logical_plan
+Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]]
+--Limit: skip=11, fetch=3
+----TableScan: t1 projection=[], fetch=14
+physical_plan
+ProjectionExec: expr=[0 as COUNT(*)]
+--EmptyExec: produce_one_row=true
+
+query I
+SELECT COUNT(*) FROM (SELECT a FROM t1 LIMIT 3 OFFSET 11);
+----
+0
+
+# The aggregate does not need to be computed because the input statistics are exact and
+# the number of rows is less than or equal to the the "fetch+skip" value (LIMIT+OFFSET).
+query TT
+EXPLAIN SELECT COUNT(*) FROM (SELECT a FROM t1 LIMIT 3 OFFSET 8);
+----
+logical_plan
+Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]]
+--Limit: skip=8, fetch=3
+----TableScan: t1 projection=[], fetch=11
+physical_plan
+ProjectionExec: expr=[2 as COUNT(*)]
+--EmptyExec: produce_one_row=true
+
+query I
+SELECT COUNT(*) FROM (SELECT a FROM t1 LIMIT 3 OFFSET 8);
+----
+2
+
+# The aggregate does not need to be computed because the input statistics are exact and
+# an OFFSET, but no LIMIT, is specified.
+query TT
+EXPLAIN SELECT COUNT(*) FROM (SELECT a FROM t1 OFFSET 8);
+----
+logical_plan
+Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]]
+--Limit: skip=8, fetch=None
+----TableScan: t1 projection=[]
+physical_plan
+ProjectionExec: expr=[2 as COUNT(*)]
+--EmptyExec: produce_one_row=true
+
+query I
+SELECT COUNT(*) FROM (SELECT a FROM t1 LIMIT 3 OFFSET 8);
+----
+2
+
+# The aggregate needs to be computed because the input statistics are inexact.
+query TT
+EXPLAIN SELECT COUNT(*) FROM (SELECT a FROM t1 WHERE a > 3 LIMIT 3 OFFSET 6);
+----
+logical_plan
+Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]]
+--Limit: skip=6, fetch=3
+----Filter: t1.a > Int32(3)
+------TableScan: t1 projection=[a]
+physical_plan
+AggregateExec: mode=Final, gby=[], aggr=[COUNT(*)]
+--CoalescePartitionsExec
+----AggregateExec: mode=Partial, gby=[], aggr=[COUNT(*)]
+------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
+--------GlobalLimitExec: skip=6, fetch=3
+----------CoalesceBatchesExec: target_batch_size=8192
+------------FilterExec: a@0 > 3
+--------------MemoryExec: partitions=1, partition_sizes=[1]
+
+query I
+SELECT COUNT(*) FROM (SELECT a FROM t1 WHERE a > 3 LIMIT 3 OFFSET 6);
+----
+1
+
 ########
 # Clean up after the test
 ########
diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt
index c7060433d9..2eb0576d55 100644
--- a/datafusion/sqllogictest/test_files/window.slt
+++ b/datafusion/sqllogictest/test_files/window.slt
@@ -2010,10 +2010,14 @@ Projection: ARRAY_AGG(aggregate_test_100.c13) AS array_agg1
 --------TableScan: aggregate_test_100 projection=[c13]
 physical_plan
 ProjectionExec: expr=[ARRAY_AGG(aggregate_test_100.c13)@0 as array_agg1]
---AggregateExec: mode=Single, gby=[], aggr=[ARRAY_AGG(aggregate_test_100.c13)]
-----GlobalLimitExec: skip=0, fetch=1
-------SortExec: TopK(fetch=1), expr=[c13@0 ASC NULLS LAST]
---------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c13], has_header=true
+--AggregateExec: mode=Final, gby=[], aggr=[ARRAY_AGG(aggregate_test_100.c13)]
+----CoalescePartitionsExec
+------AggregateExec: mode=Partial, gby=[], aggr=[ARRAY_AGG(aggregate_test_100.c13)]
+--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
+----------GlobalLimitExec: skip=0, fetch=1
+------------SortExec: TopK(fetch=1), expr=[c13@0 ASC NULLS LAST]
+--------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c13], has_header=true
+
 
 query ?
 SELECT ARRAY_AGG(c13) as array_agg1 FROM (SELECT * FROM aggregate_test_100 ORDER BY c13 LIMIT 1)