You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2023/11/07 13:57:16 UTC
(arrow-datafusion) branch main updated: Fix incorrect results in COUNT(*) queries with LIMIT (#8049)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 06fd26b86d Fix incorrect results in COUNT(*) queries with LIMIT (#8049)
06fd26b86d is described below
commit 06fd26b86dd8e5269966be8658862e5a5a12f948
Author: Mark Sirek <mw...@gmail.com>
AuthorDate: Tue Nov 7 05:57:10 2023 -0800
Fix incorrect results in COUNT(*) queries with LIMIT (#8049)
Co-authored-by: Mark Sirek <si...@cockroachlabs.com>
---
datafusion/physical-plan/src/limit.rs | 159 +++++++++++++++++++++----
datafusion/sqllogictest/test_files/explain.slt | 2 +-
datafusion/sqllogictest/test_files/limit.slt | 85 +++++++++++++
datafusion/sqllogictest/test_files/window.slt | 12 +-
4 files changed, 232 insertions(+), 26 deletions(-)
diff --git a/datafusion/physical-plan/src/limit.rs b/datafusion/physical-plan/src/limit.rs
index 945dad16b7..c8427f9bc2 100644
--- a/datafusion/physical-plan/src/limit.rs
+++ b/datafusion/physical-plan/src/limit.rs
@@ -188,21 +188,11 @@ impl ExecutionPlan for GlobalLimitExec {
fn statistics(&self) -> Result<Statistics> {
let input_stats = self.input.statistics()?;
let skip = self.skip;
- // the maximum row number needs to be fetched
- let max_row_num = self
- .fetch
- .map(|fetch| {
- if fetch >= usize::MAX - skip {
- usize::MAX
- } else {
- fetch + skip
- }
- })
- .unwrap_or(usize::MAX);
let col_stats = Statistics::unknown_column(&self.schema());
+ let fetch = self.fetch.unwrap_or(usize::MAX);
- let fetched_row_number_stats = Statistics {
- num_rows: Precision::Exact(max_row_num),
+ let mut fetched_row_number_stats = Statistics {
+ num_rows: Precision::Exact(fetch),
column_statistics: col_stats.clone(),
total_byte_size: Precision::Absent,
};
@@ -218,23 +208,55 @@ impl ExecutionPlan for GlobalLimitExec {
} => {
if nr <= skip {
// if all input data will be skipped, return 0
- Statistics {
+ let mut skip_all_rows_stats = Statistics {
num_rows: Precision::Exact(0),
column_statistics: col_stats,
total_byte_size: Precision::Absent,
+ };
+ if !input_stats.num_rows.is_exact().unwrap_or(false) {
+ // The input stats are inexact, so the output stats must be too.
+ skip_all_rows_stats = skip_all_rows_stats.into_inexact();
}
- } else if nr <= max_row_num {
- // if the input does not reach the "fetch" globally, return input stats
+ skip_all_rows_stats
+ } else if nr <= fetch && self.skip == 0 {
+ // if the input does not reach the "fetch" globally, and "skip" is zero
+ // (meaning the input and output are identical), return input stats.
+ // Can input_stats still be used, but adjusted, in the "skip != 0" case?
input_stats
+ } else if nr - skip <= fetch {
+ // after "skip" input rows are skipped, the remaining rows are less than or equal to the
+ // "fetch" values, so `num_rows` must equal the remaining rows
+ let remaining_rows: usize = nr - skip;
+ let mut skip_some_rows_stats = Statistics {
+ num_rows: Precision::Exact(remaining_rows),
+ column_statistics: col_stats.clone(),
+ total_byte_size: Precision::Absent,
+ };
+ if !input_stats.num_rows.is_exact().unwrap_or(false) {
+ // The input stats are inexact, so the output stats must be too.
+ skip_some_rows_stats = skip_some_rows_stats.into_inexact();
+ }
+ skip_some_rows_stats
} else {
- // if the input is greater than the "fetch", the num_row will be the "fetch",
+ // if the input is greater than "fetch+skip", the num_rows will be the "fetch",
// but we won't be able to predict the other statistics
+ if !input_stats.num_rows.is_exact().unwrap_or(false)
+ || self.fetch.is_none()
+ {
+ // If the input stats are inexact, the output stats must be too.
+ // If the fetch value is `usize::MAX` because no LIMIT was specified,
+ // we also can't represent it as an exact value.
+ fetched_row_number_stats =
+ fetched_row_number_stats.into_inexact();
+ }
fetched_row_number_stats
}
}
_ => {
- // the result output row number will always be no greater than the limit number
- fetched_row_number_stats
+ // The result output `num_rows` will always be no greater than the limit number.
+ // Should `num_rows` be marked as `Absent` here when the `fetch` value is large,
+ // as the actual `num_rows` may be far away from the `fetch` value?
+ fetched_row_number_stats.into_inexact()
}
};
Ok(stats)
@@ -552,7 +574,10 @@ mod tests {
use crate::common::collect;
use crate::{common, test};
+ use crate::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy};
use arrow_schema::Schema;
+ use datafusion_physical_expr::expressions::col;
+ use datafusion_physical_expr::PhysicalExpr;
#[tokio::test]
async fn limit() -> Result<()> {
@@ -712,7 +737,7 @@ mod tests {
}
#[tokio::test]
- async fn skip_3_fetch_10() -> Result<()> {
+ async fn skip_3_fetch_10_stats() -> Result<()> {
// there are total of 100 rows, we skipped 3 rows (offset = 3)
let row_count = skip_and_fetch(3, Some(10)).await?;
assert_eq!(row_count, 10);
@@ -748,7 +773,58 @@ mod tests {
assert_eq!(row_count, Precision::Exact(10));
let row_count = row_number_statistics_for_global_limit(5, Some(10)).await?;
- assert_eq!(row_count, Precision::Exact(15));
+ assert_eq!(row_count, Precision::Exact(10));
+
+ let row_count = row_number_statistics_for_global_limit(400, Some(10)).await?;
+ assert_eq!(row_count, Precision::Exact(0));
+
+ let row_count = row_number_statistics_for_global_limit(398, Some(10)).await?;
+ assert_eq!(row_count, Precision::Exact(2));
+
+ let row_count = row_number_statistics_for_global_limit(398, Some(1)).await?;
+ assert_eq!(row_count, Precision::Exact(1));
+
+ let row_count = row_number_statistics_for_global_limit(398, None).await?;
+ assert_eq!(row_count, Precision::Exact(2));
+
+ let row_count =
+ row_number_statistics_for_global_limit(0, Some(usize::MAX)).await?;
+ assert_eq!(row_count, Precision::Exact(400));
+
+ let row_count =
+ row_number_statistics_for_global_limit(398, Some(usize::MAX)).await?;
+ assert_eq!(row_count, Precision::Exact(2));
+
+ let row_count =
+ row_number_inexact_statistics_for_global_limit(0, Some(10)).await?;
+ assert_eq!(row_count, Precision::Inexact(10));
+
+ let row_count =
+ row_number_inexact_statistics_for_global_limit(5, Some(10)).await?;
+ assert_eq!(row_count, Precision::Inexact(10));
+
+ let row_count =
+ row_number_inexact_statistics_for_global_limit(400, Some(10)).await?;
+ assert_eq!(row_count, Precision::Inexact(0));
+
+ let row_count =
+ row_number_inexact_statistics_for_global_limit(398, Some(10)).await?;
+ assert_eq!(row_count, Precision::Inexact(2));
+
+ let row_count =
+ row_number_inexact_statistics_for_global_limit(398, Some(1)).await?;
+ assert_eq!(row_count, Precision::Inexact(1));
+
+ let row_count = row_number_inexact_statistics_for_global_limit(398, None).await?;
+ assert_eq!(row_count, Precision::Inexact(2));
+
+ let row_count =
+ row_number_inexact_statistics_for_global_limit(0, Some(usize::MAX)).await?;
+ assert_eq!(row_count, Precision::Inexact(400));
+
+ let row_count =
+ row_number_inexact_statistics_for_global_limit(398, Some(usize::MAX)).await?;
+ assert_eq!(row_count, Precision::Inexact(2));
Ok(())
}
@@ -776,6 +852,47 @@ mod tests {
Ok(offset.statistics()?.num_rows)
}
+ pub fn build_group_by(
+ input_schema: &SchemaRef,
+ columns: Vec<String>,
+ ) -> PhysicalGroupBy {
+ let mut group_by_expr: Vec<(Arc<dyn PhysicalExpr>, String)> = vec![];
+ for column in columns.iter() {
+ group_by_expr.push((col(column, input_schema).unwrap(), column.to_string()));
+ }
+ PhysicalGroupBy::new_single(group_by_expr.clone())
+ }
+
+ async fn row_number_inexact_statistics_for_global_limit(
+ skip: usize,
+ fetch: Option<usize>,
+ ) -> Result<Precision<usize>> {
+ let num_partitions = 4;
+ let csv = test::scan_partitioned(num_partitions);
+
+ assert_eq!(csv.output_partitioning().partition_count(), num_partitions);
+
+ // Adding a "GROUP BY i" changes the input stats from Exact to Inexact.
+ let agg = AggregateExec::try_new(
+ AggregateMode::Final,
+ build_group_by(&csv.schema().clone(), vec!["i".to_string()]),
+ vec![],
+ vec![None],
+ vec![None],
+ csv.clone(),
+ csv.schema().clone(),
+ )?;
+ let agg_exec: Arc<dyn ExecutionPlan> = Arc::new(agg);
+
+ let offset = GlobalLimitExec::new(
+ Arc::new(CoalescePartitionsExec::new(agg_exec)),
+ skip,
+ fetch,
+ );
+
+ Ok(offset.statistics()?.num_rows)
+ }
+
async fn row_number_statistics_for_local_limit(
num_partitions: usize,
fetch: usize,
diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt
index 066a31590c..40a6d43574 100644
--- a/datafusion/sqllogictest/test_files/explain.slt
+++ b/datafusion/sqllogictest/test_files/explain.slt
@@ -273,7 +273,7 @@ query TT
EXPLAIN SELECT a, b, c FROM simple_explain_test limit 10;
----
physical_plan
-GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Exact(10), Bytes=Absent]
+GlobalLimitExec: skip=0, fetch=10, statistics=[Rows=Inexact(10), Bytes=Absent]
--CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], limit=10, has_header=true, statistics=[Rows=Absent, Bytes=Absent]
# Parquet scan with statistics collected
diff --git a/datafusion/sqllogictest/test_files/limit.slt b/datafusion/sqllogictest/test_files/limit.slt
index 253ca8f335..21248ddbd8 100644
--- a/datafusion/sqllogictest/test_files/limit.slt
+++ b/datafusion/sqllogictest/test_files/limit.slt
@@ -294,6 +294,91 @@ query T
SELECT c1 FROM aggregate_test_100 LIMIT 1 OFFSET 101
----
+#
+# global limit statistics test
+#
+
+statement ok
+CREATE TABLE IF NOT EXISTS t1 (a INT) AS VALUES(1),(2),(3),(4),(5),(6),(7),(8),(9),(10);
+
+# The aggregate does not need to be computed because the input statistics are exact and
+# the number of rows is less than the skip value (OFFSET).
+query TT
+EXPLAIN SELECT COUNT(*) FROM (SELECT a FROM t1 LIMIT 3 OFFSET 11);
+----
+logical_plan
+Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]]
+--Limit: skip=11, fetch=3
+----TableScan: t1 projection=[], fetch=14
+physical_plan
+ProjectionExec: expr=[0 as COUNT(*)]
+--EmptyExec: produce_one_row=true
+
+query I
+SELECT COUNT(*) FROM (SELECT a FROM t1 LIMIT 3 OFFSET 11);
+----
+0
+
+# The aggregate does not need to be computed because the input statistics are exact and
+# the number of rows is less than or equal to the the "fetch+skip" value (LIMIT+OFFSET).
+query TT
+EXPLAIN SELECT COUNT(*) FROM (SELECT a FROM t1 LIMIT 3 OFFSET 8);
+----
+logical_plan
+Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]]
+--Limit: skip=8, fetch=3
+----TableScan: t1 projection=[], fetch=11
+physical_plan
+ProjectionExec: expr=[2 as COUNT(*)]
+--EmptyExec: produce_one_row=true
+
+query I
+SELECT COUNT(*) FROM (SELECT a FROM t1 LIMIT 3 OFFSET 8);
+----
+2
+
+# The aggregate does not need to be computed because the input statistics are exact and
+# an OFFSET, but no LIMIT, is specified.
+query TT
+EXPLAIN SELECT COUNT(*) FROM (SELECT a FROM t1 OFFSET 8);
+----
+logical_plan
+Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]]
+--Limit: skip=8, fetch=None
+----TableScan: t1 projection=[]
+physical_plan
+ProjectionExec: expr=[2 as COUNT(*)]
+--EmptyExec: produce_one_row=true
+
+query I
+SELECT COUNT(*) FROM (SELECT a FROM t1 LIMIT 3 OFFSET 8);
+----
+2
+
+# The aggregate needs to be computed because the input statistics are inexact.
+query TT
+EXPLAIN SELECT COUNT(*) FROM (SELECT a FROM t1 WHERE a > 3 LIMIT 3 OFFSET 6);
+----
+logical_plan
+Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1)) AS COUNT(*)]]
+--Limit: skip=6, fetch=3
+----Filter: t1.a > Int32(3)
+------TableScan: t1 projection=[a]
+physical_plan
+AggregateExec: mode=Final, gby=[], aggr=[COUNT(*)]
+--CoalescePartitionsExec
+----AggregateExec: mode=Partial, gby=[], aggr=[COUNT(*)]
+------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
+--------GlobalLimitExec: skip=6, fetch=3
+----------CoalesceBatchesExec: target_batch_size=8192
+------------FilterExec: a@0 > 3
+--------------MemoryExec: partitions=1, partition_sizes=[1]
+
+query I
+SELECT COUNT(*) FROM (SELECT a FROM t1 WHERE a > 3 LIMIT 3 OFFSET 6);
+----
+1
+
########
# Clean up after the test
########
diff --git a/datafusion/sqllogictest/test_files/window.slt b/datafusion/sqllogictest/test_files/window.slt
index c7060433d9..2eb0576d55 100644
--- a/datafusion/sqllogictest/test_files/window.slt
+++ b/datafusion/sqllogictest/test_files/window.slt
@@ -2010,10 +2010,14 @@ Projection: ARRAY_AGG(aggregate_test_100.c13) AS array_agg1
--------TableScan: aggregate_test_100 projection=[c13]
physical_plan
ProjectionExec: expr=[ARRAY_AGG(aggregate_test_100.c13)@0 as array_agg1]
---AggregateExec: mode=Single, gby=[], aggr=[ARRAY_AGG(aggregate_test_100.c13)]
-----GlobalLimitExec: skip=0, fetch=1
-------SortExec: TopK(fetch=1), expr=[c13@0 ASC NULLS LAST]
---------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c13], has_header=true
+--AggregateExec: mode=Final, gby=[], aggr=[ARRAY_AGG(aggregate_test_100.c13)]
+----CoalescePartitionsExec
+------AggregateExec: mode=Partial, gby=[], aggr=[ARRAY_AGG(aggregate_test_100.c13)]
+--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
+----------GlobalLimitExec: skip=0, fetch=1
+------------SortExec: TopK(fetch=1), expr=[c13@0 ASC NULLS LAST]
+--------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c13], has_header=true
+
query ?
SELECT ARRAY_AGG(c13) as array_agg1 FROM (SELECT * FROM aggregate_test_100 ORDER BY c13 LIMIT 1)