You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/11/10 19:37:03 UTC

[arrow-datafusion] branch master updated: Add another method to collect referenced columns from an expression (#4153)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 509c82c6d Add another method to collect referenced columns from an expression (#4153)
509c82c6d is described below

commit 509c82c6d624bb63531f67531195b562a241c854
Author: ygf11 <ya...@gmail.com>
AuthorDate: Fri Nov 11 03:36:58 2022 +0800

    Add another method to collect referenced columns from an expression (#4153)
---
 datafusion/core/src/physical_optimizer/pruning.rs  |  7 ++---
 .../file_format/parquet/page_filter.rs             |  6 ++--
 datafusion/core/src/physical_plan/planner.rs       |  7 ++---
 datafusion/expr/src/expr.rs                        | 33 ++++++++++++++++++++++
 datafusion/expr/src/logical_plan/builder.rs        | 10 ++-----
 datafusion/optimizer/src/filter_push_down.rs       | 15 +++-------
 datafusion/sql/src/planner.rs                      |  6 ++--
 7 files changed, 49 insertions(+), 35 deletions(-)

diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs
index a47815377..a53e30227 100644
--- a/datafusion/core/src/physical_optimizer/pruning.rs
+++ b/datafusion/core/src/physical_optimizer/pruning.rs
@@ -48,7 +48,6 @@ use arrow::{
 use datafusion_common::{downcast_value, ScalarValue};
 use datafusion_expr::expr::{BinaryExpr, Cast};
 use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter};
-use datafusion_expr::utils::expr_to_columns;
 use datafusion_expr::{binary_expr, cast, try_cast, ExprSchemable};
 use datafusion_physical_expr::create_physical_expr;
 use log::trace;
@@ -445,10 +444,8 @@ impl<'a> PruningExpressionBuilder<'a> {
         required_columns: &'a mut RequiredStatColumns,
     ) -> Result<Self> {
         // find column name; input could be a more complicated expression
-        let mut left_columns = HashSet::<Column>::new();
-        expr_to_columns(left, &mut left_columns)?;
-        let mut right_columns = HashSet::<Column>::new();
-        expr_to_columns(right, &mut right_columns)?;
+        let left_columns = left.to_columns()?;
+        let right_columns = right.to_columns()?;
         let (column_expr, scalar_expr, columns, correct_operator) =
             match (left_columns.len(), right_columns.len()) {
                 (1, 0) => (left, right, left_columns, op),
diff --git a/datafusion/core/src/physical_plan/file_format/parquet/page_filter.rs b/datafusion/core/src/physical_plan/file_format/parquet/page_filter.rs
index 95c93151a..5f31a6a49 100644
--- a/datafusion/core/src/physical_plan/file_format/parquet/page_filter.rs
+++ b/datafusion/core/src/physical_plan/file_format/parquet/page_filter.rs
@@ -20,7 +20,6 @@
 use arrow::array::{BooleanArray, Float32Array, Float64Array, Int32Array, Int64Array};
 use arrow::{array::ArrayRef, datatypes::SchemaRef, error::ArrowError};
 use datafusion_common::{Column, DataFusionError, Result};
-use datafusion_expr::utils::expr_to_columns;
 use datafusion_optimizer::utils::split_conjunction;
 use log::{debug, error, trace};
 use parquet::{
@@ -32,7 +31,7 @@ use parquet::{
     },
     format::PageLocation,
 };
-use std::collections::{HashSet, VecDeque};
+use std::collections::VecDeque;
 use std::sync::Arc;
 
 use crate::physical_optimizer::pruning::{PruningPredicate, PruningStatistics};
@@ -286,8 +285,7 @@ fn extract_page_index_push_down_predicates(
         predicates
             .into_iter()
             .try_for_each::<_, Result<()>>(|predicate| {
-                let mut columns: HashSet<Column> = HashSet::new();
-                expr_to_columns(predicate, &mut columns)?;
+                let columns = predicate.to_columns()?;
                 if columns.len() == 1 {
                     one_col_expr.push(predicate);
                 }
diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs
index 729649e7a..d2c148c3f 100644
--- a/datafusion/core/src/physical_plan/planner.rs
+++ b/datafusion/core/src/physical_plan/planner.rs
@@ -63,7 +63,7 @@ use datafusion_expr::expr::{
     Between, BinaryExpr, Cast, GetIndexedField, GroupingSet, Like,
 };
 use datafusion_expr::expr_rewriter::unnormalize_cols;
-use datafusion_expr::utils::{expand_wildcard, expr_to_columns};
+use datafusion_expr::utils::expand_wildcard;
 use datafusion_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits};
 use datafusion_optimizer::utils::unalias;
 use datafusion_physical_expr::expressions::Literal;
@@ -72,7 +72,7 @@ use futures::future::BoxFuture;
 use futures::{FutureExt, StreamExt, TryStreamExt};
 use itertools::Itertools;
 use log::{debug, trace};
-use std::collections::{HashMap, HashSet};
+use std::collections::HashMap;
 use std::fmt::Write;
 use std::sync::Arc;
 
@@ -875,8 +875,7 @@ impl DefaultPhysicalPlanner {
                     let join_filter = match filter {
                         Some(expr) => {
                             // Extract columns from filter expression
-                            let mut cols = HashSet::new();
-                            expr_to_columns(expr, &mut cols)?;
+                            let cols = expr.to_columns()?;
 
                             // Collect left & right field indices
                             let left_field_indices = cols.iter()
diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs
index 2017d1e8d..ecab1afd2 100644
--- a/datafusion/expr/src/expr.rs
+++ b/datafusion/expr/src/expr.rs
@@ -21,6 +21,7 @@ use crate::aggregate_function;
 use crate::built_in_function;
 use crate::expr_fn::binary_expr;
 use crate::logical_plan::Subquery;
+use crate::utils::expr_to_columns;
 use crate::window_frame;
 use crate::window_function;
 use crate::AggregateUDF;
@@ -30,6 +31,7 @@ use arrow::datatypes::DataType;
 use datafusion_common::Result;
 use datafusion_common::{plan_err, Column};
 use datafusion_common::{DataFusionError, ScalarValue};
+use std::collections::HashSet;
 use std::fmt;
 use std::fmt::{Display, Formatter, Write};
 use std::hash::{BuildHasher, Hash, Hasher};
@@ -685,6 +687,14 @@ impl Expr {
             _ => plan_err!(format!("Could not coerce '{}' into Column!", self)),
         }
     }
+
+    /// Return all referenced columns of this expression.
+    pub fn to_columns(&self) -> Result<HashSet<Column>> {
+        let mut using_columns = HashSet::new();
+        expr_to_columns(self, &mut using_columns)?;
+
+        Ok(using_columns)
+    }
 }
 
 impl Not for Expr {
@@ -1277,6 +1287,7 @@ mod test {
     use crate::expr_fn::col;
     use crate::{case, lit, Expr};
     use arrow::datatypes::DataType;
+    use datafusion_common::Column;
     use datafusion_common::{Result, ScalarValue};
 
     #[test]
@@ -1327,4 +1338,26 @@ mod test {
         assert!(exp2 > exp3);
         assert!(exp3 < exp2);
     }
+
+    #[test]
+    fn test_collect_expr() -> Result<()> {
+        // single column
+        {
+            let expr = &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64));
+            let columns = expr.to_columns()?;
+            assert_eq!(1, columns.len());
+            assert!(columns.contains(&Column::from_name("a")));
+        }
+
+        // multiple columns
+        {
+            let expr = col("a") + col("b") + lit(1);
+            let columns = expr.to_columns()?;
+            assert_eq!(2, columns.len());
+            assert!(columns.contains(&Column::from_name("a")));
+            assert!(columns.contains(&Column::from_name("b")));
+        }
+
+        Ok(())
+    }
 }
diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs
index 6b83e449b..0782302d7 100644
--- a/datafusion/expr/src/logical_plan/builder.rs
+++ b/datafusion/expr/src/logical_plan/builder.rs
@@ -31,7 +31,7 @@ use crate::{
         Window,
     },
     utils::{
-        can_hash, expand_qualified_wildcard, expand_wildcard, expr_to_columns,
+        can_hash, expand_qualified_wildcard, expand_wildcard,
         group_window_expr_by_sort_keys,
     },
     Expr, ExprSchemable, TableSource,
@@ -43,10 +43,7 @@ use datafusion_common::{
 };
 use std::any::Any;
 use std::convert::TryFrom;
-use std::{
-    collections::{HashMap, HashSet},
-    sync::Arc,
-};
+use std::{collections::HashMap, sync::Arc};
 
 /// Default table name for unnamed table
 pub const UNNAMED_TABLE: &str = "?table?";
@@ -378,8 +375,7 @@ impl LogicalPlanBuilder {
             .clone()
             .into_iter()
             .try_for_each::<_, Result<()>>(|expr| {
-                let mut columns: HashSet<Column> = HashSet::new();
-                expr_to_columns(&expr, &mut columns)?;
+                let columns = expr.to_columns()?;
 
                 columns.into_iter().for_each(|c| {
                     if schema.field_from_column(&c).is_err() {
diff --git a/datafusion/optimizer/src/filter_push_down.rs b/datafusion/optimizer/src/filter_push_down.rs
index 0539a8962..674910cd1 100644
--- a/datafusion/optimizer/src/filter_push_down.rs
+++ b/datafusion/optimizer/src/filter_push_down.rs
@@ -324,8 +324,7 @@ fn extract_or_clauses_for_join(
             // If nothing can be extracted from any sub clauses, do nothing for this OR clause.
             if let (Some(left_expr), Some(right_expr)) = (left_expr, right_expr) {
                 let predicate = or(left_expr, right_expr);
-                let mut columns: HashSet<Column> = HashSet::new();
-                expr_to_columns(&predicate, &mut columns).ok().unwrap();
+                let columns = predicate.to_columns().ok().unwrap();
 
                 exprs.push(predicate);
                 expr_columns.push(columns);
@@ -388,8 +387,7 @@ fn extract_or_clause(expr: &Expr, schema_columns: &HashSet<Column>) -> Option<Ex
             }
         }
         _ => {
-            let mut columns: HashSet<Column> = HashSet::new();
-            expr_to_columns(expr, &mut columns).ok().unwrap();
+            let columns = expr.to_columns().ok().unwrap();
 
             if schema_columns
                 .intersection(&columns)
@@ -541,8 +539,7 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result<LogicalPlan> {
             utils::split_conjunction_owned(predicate)
                 .into_iter()
                 .try_for_each::<_, Result<()>>(|predicate| {
-                    let mut columns: HashSet<Column> = HashSet::new();
-                    expr_to_columns(&predicate, &mut columns)?;
+                    let columns = predicate.to_columns()?;
                     state.filters.push((predicate, columns));
                     Ok(())
                 })?;
@@ -664,11 +661,7 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result<LogicalPlan> {
 
                     predicates
                         .into_iter()
-                        .map(|e| {
-                            let mut accum = HashSet::new();
-                            expr_to_columns(e, &mut accum)?;
-                            Ok((e.clone(), accum))
-                        })
+                        .map(|e| Ok((e.clone(), e.to_columns()?)))
                         .collect::<Result<Vec<_>>>()
                 })
                 .unwrap_or_else(|| Ok(vec![]))?;
diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs
index 83f0c6ab9..dacd4af87 100644
--- a/datafusion/sql/src/planner.rs
+++ b/datafusion/sql/src/planner.rs
@@ -31,8 +31,7 @@ use datafusion_expr::logical_plan::{
 };
 use datafusion_expr::utils::{
     can_hash, expand_qualified_wildcard, expand_wildcard, expr_as_column_expr,
-    expr_to_columns, find_aggregate_exprs, find_column_exprs, find_window_exprs,
-    COUNT_STAR_EXPANSION,
+    find_aggregate_exprs, find_column_exprs, find_window_exprs, COUNT_STAR_EXPANSION,
 };
 use datafusion_expr::{
     and, col, lit, AggregateFunction, AggregateUDF, Expr, ExprSchemable, GetIndexedField,
@@ -690,8 +689,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
                 let join_filter = filter
                     .into_iter()
                     .map(|expr| {
-                        let mut using_columns = HashSet::new();
-                        expr_to_columns(&expr, &mut using_columns)?;
+                        let using_columns = expr.to_columns()?;
 
                         normalize_col_with_schemas(
                             expr,