You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2021/06/28 22:09:24 UTC

[arrow-datafusion] branch master updated: Optimize count(*) with table statistics (#620)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 16a3db6  Optimize count(*) with table statistics (#620)
16a3db6 is described below

commit 16a3db64cb50a5f6e27a032c270d9de40dd2d5a5
Author: Daniƫl Heres <da...@gmail.com>
AuthorDate: Tue Jun 29 00:09:15 2021 +0200

    Optimize count(*) with table statistics (#620)
    
    * Optimize count(*) with table statistics
    
    * Optimize count(*) with table statistics
    
    * Fixes, simplification
    
    * Alias fix
    
    * Add member to table provider to return whether statistics are exact
    
    * Fix
    
    * Improve test
    
    * Naming changes
    
    * Add test for non-exact statistics
    
    * Generalize solution
    
    * Added tests
    
    * Fix name
---
 datafusion/src/datasource/datasource.rs          |   5 +
 datafusion/src/datasource/memory.rs              |   4 +
 datafusion/src/datasource/parquet.rs             |   4 +
 datafusion/src/execution/context.rs              |   4 +-
 datafusion/src/optimizer/aggregate_statistics.rs | 335 +++++++++++++++++++++++
 datafusion/src/optimizer/mod.rs                  |   1 +
 6 files changed, 352 insertions(+), 1 deletion(-)

diff --git a/datafusion/src/datasource/datasource.rs b/datafusion/src/datasource/datasource.rs
index 0349a49..b83aa4b 100644
--- a/datafusion/src/datasource/datasource.rs
+++ b/datafusion/src/datasource/datasource.rs
@@ -108,6 +108,11 @@ pub trait TableProvider: Sync + Send {
     /// Statistics should be optional because not all data sources can provide statistics.
     fn statistics(&self) -> Statistics;
 
+    /// Returns whether statistics provided are exact values or estimates
+    fn has_exact_statistics(&self) -> bool {
+        false
+    }
+
     /// Tests whether the table provider can make use of a filter expression
     /// to optimise data retrieval.
     fn supports_filter_pushdown(
diff --git a/datafusion/src/datasource/memory.rs b/datafusion/src/datasource/memory.rs
index af40480..a4dbfd6 100644
--- a/datafusion/src/datasource/memory.rs
+++ b/datafusion/src/datasource/memory.rs
@@ -216,6 +216,10 @@ impl TableProvider for MemTable {
     fn statistics(&self) -> Statistics {
         self.statistics.clone()
     }
+
+    fn has_exact_statistics(&self) -> bool {
+        true
+    }
 }
 
 #[cfg(test)]
diff --git a/datafusion/src/datasource/parquet.rs b/datafusion/src/datasource/parquet.rs
index fd14741..e53fbbd 100644
--- a/datafusion/src/datasource/parquet.rs
+++ b/datafusion/src/datasource/parquet.rs
@@ -102,6 +102,10 @@ impl TableProvider for ParquetTable {
     fn statistics(&self) -> Statistics {
         self.statistics.clone()
     }
+
+    fn has_exact_statistics(&self) -> bool {
+        true
+    }
 }
 
 #[cfg(test)]
diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs
index 318ea59..5c41ed2 100644
--- a/datafusion/src/execution/context.rs
+++ b/datafusion/src/execution/context.rs
@@ -22,7 +22,8 @@ use crate::{
         information_schema::CatalogWithInformationSchema,
     },
     optimizer::{
-        eliminate_limit::EliminateLimit, hash_build_probe_order::HashBuildProbeOrder,
+        aggregate_statistics::AggregateStatistics, eliminate_limit::EliminateLimit,
+        hash_build_probe_order::HashBuildProbeOrder,
     },
     physical_optimizer::optimizer::PhysicalOptimizerRule,
 };
@@ -639,6 +640,7 @@ impl Default for ExecutionConfig {
             optimizers: vec![
                 Arc::new(ConstantFolding::new()),
                 Arc::new(EliminateLimit::new()),
+                Arc::new(AggregateStatistics::new()),
                 Arc::new(ProjectionPushDown::new()),
                 Arc::new(FilterPushDown::new()),
                 Arc::new(SimplifyExpressions::new()),
diff --git a/datafusion/src/optimizer/aggregate_statistics.rs b/datafusion/src/optimizer/aggregate_statistics.rs
new file mode 100644
index 0000000..a20eafc
--- /dev/null
+++ b/datafusion/src/optimizer/aggregate_statistics.rs
@@ -0,0 +1,335 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Utilizing exact statistics from sources to avoid scanning data
+use std::{sync::Arc, vec};
+
+use crate::{
+    execution::context::ExecutionProps,
+    logical_plan::{col, DFField, DFSchema, Expr, LogicalPlan},
+    physical_plan::aggregates::AggregateFunction,
+    scalar::ScalarValue,
+};
+
+use super::{optimizer::OptimizerRule, utils};
+use crate::error::Result;
+
+/// Optimizer that uses available statistics for aggregate functions
+pub struct AggregateStatistics {}
+
+impl AggregateStatistics {
+    #[allow(missing_docs)]
+    pub fn new() -> Self {
+        Self {}
+    }
+}
+
+impl OptimizerRule for AggregateStatistics {
+    fn optimize(
+        &self,
+        plan: &LogicalPlan,
+        execution_props: &ExecutionProps,
+    ) -> crate::error::Result<LogicalPlan> {
+        match plan {
+            // match only select count(*) from table_scan
+            LogicalPlan::Aggregate {
+                input,
+                group_expr,
+                aggr_expr,
+                schema,
+            } if group_expr.is_empty() => {
+                // aggregations that can not be replaced
+                // using statistics
+                let mut agg = vec![];
+                // expressions that can be replaced by constants
+                let mut projections = vec![];
+                if let Some(num_rows) = match input.as_ref() {
+                    LogicalPlan::TableScan { source, .. }
+                        if source.has_exact_statistics() =>
+                    {
+                        source.statistics().num_rows
+                    }
+                    _ => None,
+                } {
+                    for expr in aggr_expr {
+                        match expr {
+                            Expr::AggregateFunction {
+                                fun: AggregateFunction::Count,
+                                args,
+                                distinct: false,
+                            } if args
+                                == &[Expr::Literal(ScalarValue::UInt8(Some(1)))] =>
+                            {
+                                projections.push(Expr::Alias(
+                                    Box::new(Expr::Literal(ScalarValue::UInt64(Some(
+                                        num_rows as u64,
+                                    )))),
+                                    "COUNT(Uint8(1))".to_string(),
+                                ));
+                            }
+                            _ => {
+                                agg.push(expr.clone());
+                            }
+                        }
+                    }
+
+                    return Ok(if agg.is_empty() {
+                        // table scan can be entirely removed
+
+                        LogicalPlan::Projection {
+                            expr: projections,
+                            input: Arc::new(LogicalPlan::EmptyRelation {
+                                produce_one_row: true,
+                                schema: Arc::new(DFSchema::empty()),
+                            }),
+                            schema: schema.clone(),
+                        }
+                    } else if projections.is_empty() {
+                        // no replacements -> return original plan
+                        plan.clone()
+                    } else {
+                        // Split into parts that can be supported and part that should stay in aggregate
+                        let agg_fields = agg
+                            .iter()
+                            .map(|x| x.to_field(input.schema()))
+                            .collect::<Result<Vec<DFField>>>()?;
+                        let agg_schema = DFSchema::new(agg_fields)?;
+                        let cols = agg
+                            .iter()
+                            .map(|e| e.name(&agg_schema))
+                            .collect::<Result<Vec<String>>>()?;
+                        projections.extend(cols.iter().map(|x| col(x)));
+                        LogicalPlan::Projection {
+                            expr: projections,
+                            schema: schema.clone(),
+                            input: Arc::new(LogicalPlan::Aggregate {
+                                input: input.clone(),
+                                group_expr: vec![],
+                                aggr_expr: agg,
+                                schema: Arc::new(agg_schema),
+                            }),
+                        }
+                    });
+                }
+                Ok(plan.clone())
+            }
+            // Rest: recurse and find possible statistics
+            _ => {
+                let expr = plan.expressions();
+
+                // apply the optimization to all inputs of the plan
+                let inputs = plan.inputs();
+                let new_inputs = inputs
+                    .iter()
+                    .map(|plan| self.optimize(plan, execution_props))
+                    .collect::<Result<Vec<_>>>()?;
+
+                utils::from_plan(plan, &expr, &new_inputs)
+            }
+        }
+    }
+
+    fn name(&self) -> &str {
+        "aggregate_statistics"
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use std::sync::Arc;
+
+    use arrow::datatypes::{DataType, Field, Schema};
+
+    use crate::error::Result;
+    use crate::execution::context::ExecutionProps;
+    use crate::logical_plan::LogicalPlan;
+    use crate::optimizer::aggregate_statistics::AggregateStatistics;
+    use crate::optimizer::optimizer::OptimizerRule;
+    use crate::{
+        datasource::{datasource::Statistics, TableProvider},
+        logical_plan::Expr,
+    };
+
+    struct TestTableProvider {
+        num_rows: usize,
+        is_exact: bool,
+    }
+
+    impl TableProvider for TestTableProvider {
+        fn as_any(&self) -> &dyn std::any::Any {
+            unimplemented!()
+        }
+        fn schema(&self) -> arrow::datatypes::SchemaRef {
+            Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)]))
+        }
+
+        fn scan(
+            &self,
+            _projection: &Option<Vec<usize>>,
+            _batch_size: usize,
+            _filters: &[Expr],
+            _limit: Option<usize>,
+        ) -> Result<std::sync::Arc<dyn crate::physical_plan::ExecutionPlan>> {
+            unimplemented!()
+        }
+        fn statistics(&self) -> crate::datasource::datasource::Statistics {
+            Statistics {
+                num_rows: Some(self.num_rows),
+                total_byte_size: None,
+                column_statistics: None,
+            }
+        }
+        fn has_exact_statistics(&self) -> bool {
+            self.is_exact
+        }
+    }
+
+    #[test]
+    fn optimize_count_using_statistics() -> Result<()> {
+        use crate::execution::context::ExecutionContext;
+        let mut ctx = ExecutionContext::new();
+        ctx.register_table(
+            "test",
+            Arc::new(TestTableProvider {
+                num_rows: 100,
+                is_exact: true,
+            }),
+        )
+        .unwrap();
+
+        let plan = ctx
+            .create_logical_plan("select count(*) from test")
+            .unwrap();
+        let expected = "\
+            Projection: #COUNT(UInt8(1))\
+            \n  Projection: UInt64(100) AS COUNT(Uint8(1))\
+            \n    EmptyRelation";
+
+        assert_optimized_plan_eq(&plan, expected);
+        Ok(())
+    }
+
+    #[test]
+    fn optimize_count_not_exact() -> Result<()> {
+        use crate::execution::context::ExecutionContext;
+        let mut ctx = ExecutionContext::new();
+        ctx.register_table(
+            "test",
+            Arc::new(TestTableProvider {
+                num_rows: 100,
+                is_exact: false,
+            }),
+        )
+        .unwrap();
+
+        let plan = ctx
+            .create_logical_plan("select count(*) from test")
+            .unwrap();
+        let expected = "\
+            Projection: #COUNT(UInt8(1))\
+            \n  Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\
+            \n    TableScan: test projection=None";
+
+        assert_optimized_plan_eq(&plan, expected);
+        Ok(())
+    }
+
+    #[test]
+    fn optimize_count_sum() -> Result<()> {
+        use crate::execution::context::ExecutionContext;
+        let mut ctx = ExecutionContext::new();
+        ctx.register_table(
+            "test",
+            Arc::new(TestTableProvider {
+                num_rows: 100,
+                is_exact: true,
+            }),
+        )
+        .unwrap();
+
+        let plan = ctx
+            .create_logical_plan("select sum(a)/count(*) from test")
+            .unwrap();
+        let expected = "\
+            Projection: #SUM(test.a) Divide #COUNT(UInt8(1))\
+            \n  Projection: UInt64(100) AS COUNT(Uint8(1)), #SUM(test.a)\
+            \n    Aggregate: groupBy=[[]], aggr=[[SUM(#test.a)]]\
+            \n      TableScan: test projection=None";
+
+        assert_optimized_plan_eq(&plan, expected);
+        Ok(())
+    }
+
+    #[test]
+    fn optimize_count_group_by() -> Result<()> {
+        use crate::execution::context::ExecutionContext;
+        let mut ctx = ExecutionContext::new();
+        ctx.register_table(
+            "test",
+            Arc::new(TestTableProvider {
+                num_rows: 100,
+                is_exact: true,
+            }),
+        )
+        .unwrap();
+
+        let plan = ctx
+            .create_logical_plan("SELECT count(*), a FROM test GROUP BY a")
+            .unwrap();
+        let expected = "\
+            Projection: #COUNT(UInt8(1)), #test.a\
+            \n  Aggregate: groupBy=[[#test.a]], aggr=[[COUNT(UInt8(1))]]\
+            \n    TableScan: test projection=None";
+
+        assert_optimized_plan_eq(&plan, expected);
+        Ok(())
+    }
+
+    #[test]
+    fn optimize_count_filter() -> Result<()> {
+        use crate::execution::context::ExecutionContext;
+        let mut ctx = ExecutionContext::new();
+        ctx.register_table(
+            "test",
+            Arc::new(TestTableProvider {
+                num_rows: 100,
+                is_exact: true,
+            }),
+        )
+        .unwrap();
+
+        let plan = ctx
+            .create_logical_plan("SELECT count(*) FROM test WHERE a < 5")
+            .unwrap();
+        let expected = "\
+            Projection: #COUNT(UInt8(1))\
+            \n  Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\
+            \n    Filter: #test.a Lt Int64(5)\
+            \n      TableScan: test projection=None";
+
+        assert_optimized_plan_eq(&plan, expected);
+        Ok(())
+    }
+
+    fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) {
+        let opt = AggregateStatistics::new();
+        let optimized_plan = opt.optimize(plan, &ExecutionProps::new()).unwrap();
+        let formatted_plan = format!("{:?}", optimized_plan);
+        assert_eq!(formatted_plan, expected);
+        assert_eq!(plan.schema(), plan.schema());
+    }
+}
diff --git a/datafusion/src/optimizer/mod.rs b/datafusion/src/optimizer/mod.rs
index e360a54..6875847 100644
--- a/datafusion/src/optimizer/mod.rs
+++ b/datafusion/src/optimizer/mod.rs
@@ -18,6 +18,7 @@
 //! This module contains a query optimizer that operates against a logical plan and applies
 //! some simple rules to a logical plan, such as "Projection Push Down" and "Type Coercion".
 
+pub mod aggregate_statistics;
 pub mod constant_folding;
 pub mod eliminate_limit;
 pub mod filter_push_down;