You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by av...@apache.org on 2023/09/20 20:37:39 UTC

[arrow-datafusion] branch main updated: Fix panic in TopK (#7609)

This is an automated email from the ASF dual-hosted git repository.

avantgardner pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 6ca1c44c3e Fix panic in TopK (#7609)
6ca1c44c3e is described below

commit 6ca1c44c3e7eb846c371534b6775a246bb086893
Author: Brent Gardner <bg...@squarelabs.net>
AuthorDate: Wed Sep 20 14:37:32 2023 -0600

    Fix panic in TopK (#7609)
    
    Fix panic in TopK (#7609)
---
 .../src/physical_optimizer/topk_aggregation.rs     |  7 ++--
 datafusion/physical-plan/src/aggregates/mod.rs     |  3 +-
 datafusion/sqllogictest/test_files/aggregate.slt   | 48 ++++++++++++++++++----
 3 files changed, 46 insertions(+), 12 deletions(-)

diff --git a/datafusion/core/src/physical_optimizer/topk_aggregation.rs b/datafusion/core/src/physical_optimizer/topk_aggregation.rs
index 4789226d7a..572e796a8b 100644
--- a/datafusion/core/src/physical_optimizer/topk_aggregation.rs
+++ b/datafusion/core/src/physical_optimizer/topk_aggregation.rs
@@ -30,6 +30,7 @@ use datafusion_common::tree_node::{Transformed, TreeNode};
 use datafusion_common::Result;
 use datafusion_physical_expr::expressions::Column;
 use datafusion_physical_expr::PhysicalSortExpr;
+use itertools::Itertools;
 use std::sync::Arc;
 
 /// An optimizer rule that passes a `limit` hint to aggregations if the whole result is not needed
@@ -51,7 +52,7 @@ impl TopKAggregation {
         if desc != order.options.descending {
             return None;
         }
-        let group_key = aggr.group_expr().expr().first()?;
+        let group_key = aggr.group_expr().expr().iter().exactly_one().ok()?;
         let kt = group_key.0.data_type(&aggr.input().schema()).ok()?;
         if !kt.is_primitive() && kt != DataType::Utf8 {
             return None;
@@ -85,9 +86,9 @@ impl TopKAggregation {
         let sort = plan.as_any().downcast_ref::<SortExec>()?;
 
         let children = sort.children();
-        let child = children.first()?;
+        let child = children.iter().exactly_one().ok()?;
         let order = sort.output_ordering()?;
-        let order = order.first()?;
+        let order = order.iter().exactly_one().ok()?;
         let limit = sort.fetch()?;
 
         let is_cardinality_preserving = |plan: Arc<dyn ExecutionPlan>| {
diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs
index 7c7a593c48..58dbb252fe 100644
--- a/datafusion/physical-plan/src/aggregates/mod.rs
+++ b/datafusion/physical-plan/src/aggregates/mod.rs
@@ -42,6 +42,7 @@ use datafusion_physical_expr::{
     PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement,
 };
 
+use itertools::Itertools;
 use std::any::Any;
 use std::collections::HashMap;
 use std::sync::Arc;
@@ -785,7 +786,7 @@ impl AggregateExec {
 
     /// Finds the DataType and SortDirection for this Aggregate, if there is one
     pub fn get_minmax_desc(&self) -> Option<(Field, bool)> {
-        let agg_expr = self.aggr_expr.as_slice().first()?;
+        let agg_expr = self.aggr_expr.iter().exactly_one().ok()?;
         if let Some(max) = agg_expr.as_any().downcast_ref::<Max>() {
             Some((max.field().ok()?, true))
         } else if let Some(min) = agg_expr.as_any().downcast_ref::<Min>() {
diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt
index 0680aa6323..d0e41b12b8 100644
--- a/datafusion/sqllogictest/test_files/aggregate.slt
+++ b/datafusion/sqllogictest/test_files/aggregate.slt
@@ -2314,14 +2314,14 @@ NULL
 
 # TopK aggregation
 statement ok
-CREATE TABLE traces(trace_id varchar, timestamp bigint) AS VALUES
-(NULL, 0),
-('a', NULL),
-('a', 1),
-('b', 0),
-('c', 1),
-('c', 2),
-('b', 3);
+CREATE TABLE traces(trace_id varchar, timestamp bigint, other bigint) AS VALUES
+(NULL, 0, 0),
+('a', NULL, NULL),
+('a', 1, 1),
+('b', 0, 0),
+('c', 1, 1),
+('c', 2, 2),
+('b', 3, 3);
 
 statement ok
 set datafusion.optimizer.enable_topk_aggregation = false;
@@ -2362,6 +2362,22 @@ b 0
 c 1
 a 1
 
+query TII
+select trace_id, other, MIN(timestamp) from traces group by trace_id, other order by MIN(timestamp) asc limit 4;
+----
+b 0 0
+NULL 0 0
+c 1 1
+a 1 1
+
+query TII
+select trace_id, MIN(other), MIN(timestamp) from traces group by trace_id order by MIN(timestamp), MIN(other) limit 4;
+----
+NULL 0 0
+b 0 0
+c 1 1
+a 1 1
+
 statement ok
 set datafusion.optimizer.enable_topk_aggregation = true;
 
@@ -2471,6 +2487,22 @@ NULL 0
 b 0
 c 1
 
+query TII
+select trace_id, other, MIN(timestamp) from traces group by trace_id, other order by MIN(timestamp) asc limit 4;
+----
+b 0 0
+NULL 0 0
+c 1 1
+a 1 1
+
+query TII
+select trace_id, MIN(other), MIN(timestamp) from traces group by trace_id order by MIN(timestamp), MIN(other) limit 4;
+----
+NULL 0 0
+b 0 0
+c 1 1
+a 1 1
+
 #
 # regr_*() tests
 #