You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/11/10 19:37:03 UTC
[arrow-datafusion] branch master updated: Add another method to collect referenced columns from an expression (#4153)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 509c82c6d Add another method to collect referenced columns from an expression (#4153)
509c82c6d is described below
commit 509c82c6d624bb63531f67531195b562a241c854
Author: ygf11 <ya...@gmail.com>
AuthorDate: Fri Nov 11 03:36:58 2022 +0800
Add another method to collect referenced columns from an expression (#4153)
---
datafusion/core/src/physical_optimizer/pruning.rs | 7 ++---
.../file_format/parquet/page_filter.rs | 6 ++--
datafusion/core/src/physical_plan/planner.rs | 7 ++---
datafusion/expr/src/expr.rs | 33 ++++++++++++++++++++++
datafusion/expr/src/logical_plan/builder.rs | 10 ++-----
datafusion/optimizer/src/filter_push_down.rs | 15 +++-------
datafusion/sql/src/planner.rs | 6 ++--
7 files changed, 49 insertions(+), 35 deletions(-)
diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs
index a47815377..a53e30227 100644
--- a/datafusion/core/src/physical_optimizer/pruning.rs
+++ b/datafusion/core/src/physical_optimizer/pruning.rs
@@ -48,7 +48,6 @@ use arrow::{
use datafusion_common::{downcast_value, ScalarValue};
use datafusion_expr::expr::{BinaryExpr, Cast};
use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter};
-use datafusion_expr::utils::expr_to_columns;
use datafusion_expr::{binary_expr, cast, try_cast, ExprSchemable};
use datafusion_physical_expr::create_physical_expr;
use log::trace;
@@ -445,10 +444,8 @@ impl<'a> PruningExpressionBuilder<'a> {
required_columns: &'a mut RequiredStatColumns,
) -> Result<Self> {
// find column name; input could be a more complicated expression
- let mut left_columns = HashSet::<Column>::new();
- expr_to_columns(left, &mut left_columns)?;
- let mut right_columns = HashSet::<Column>::new();
- expr_to_columns(right, &mut right_columns)?;
+ let left_columns = left.to_columns()?;
+ let right_columns = right.to_columns()?;
let (column_expr, scalar_expr, columns, correct_operator) =
match (left_columns.len(), right_columns.len()) {
(1, 0) => (left, right, left_columns, op),
diff --git a/datafusion/core/src/physical_plan/file_format/parquet/page_filter.rs b/datafusion/core/src/physical_plan/file_format/parquet/page_filter.rs
index 95c93151a..5f31a6a49 100644
--- a/datafusion/core/src/physical_plan/file_format/parquet/page_filter.rs
+++ b/datafusion/core/src/physical_plan/file_format/parquet/page_filter.rs
@@ -20,7 +20,6 @@
use arrow::array::{BooleanArray, Float32Array, Float64Array, Int32Array, Int64Array};
use arrow::{array::ArrayRef, datatypes::SchemaRef, error::ArrowError};
use datafusion_common::{Column, DataFusionError, Result};
-use datafusion_expr::utils::expr_to_columns;
use datafusion_optimizer::utils::split_conjunction;
use log::{debug, error, trace};
use parquet::{
@@ -32,7 +31,7 @@ use parquet::{
},
format::PageLocation,
};
-use std::collections::{HashSet, VecDeque};
+use std::collections::VecDeque;
use std::sync::Arc;
use crate::physical_optimizer::pruning::{PruningPredicate, PruningStatistics};
@@ -286,8 +285,7 @@ fn extract_page_index_push_down_predicates(
predicates
.into_iter()
.try_for_each::<_, Result<()>>(|predicate| {
- let mut columns: HashSet<Column> = HashSet::new();
- expr_to_columns(predicate, &mut columns)?;
+ let columns = predicate.to_columns()?;
if columns.len() == 1 {
one_col_expr.push(predicate);
}
diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs
index 729649e7a..d2c148c3f 100644
--- a/datafusion/core/src/physical_plan/planner.rs
+++ b/datafusion/core/src/physical_plan/planner.rs
@@ -63,7 +63,7 @@ use datafusion_expr::expr::{
Between, BinaryExpr, Cast, GetIndexedField, GroupingSet, Like,
};
use datafusion_expr::expr_rewriter::unnormalize_cols;
-use datafusion_expr::utils::{expand_wildcard, expr_to_columns};
+use datafusion_expr::utils::expand_wildcard;
use datafusion_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits};
use datafusion_optimizer::utils::unalias;
use datafusion_physical_expr::expressions::Literal;
@@ -72,7 +72,7 @@ use futures::future::BoxFuture;
use futures::{FutureExt, StreamExt, TryStreamExt};
use itertools::Itertools;
use log::{debug, trace};
-use std::collections::{HashMap, HashSet};
+use std::collections::HashMap;
use std::fmt::Write;
use std::sync::Arc;
@@ -875,8 +875,7 @@ impl DefaultPhysicalPlanner {
let join_filter = match filter {
Some(expr) => {
// Extract columns from filter expression
- let mut cols = HashSet::new();
- expr_to_columns(expr, &mut cols)?;
+ let cols = expr.to_columns()?;
// Collect left & right field indices
let left_field_indices = cols.iter()
diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs
index 2017d1e8d..ecab1afd2 100644
--- a/datafusion/expr/src/expr.rs
+++ b/datafusion/expr/src/expr.rs
@@ -21,6 +21,7 @@ use crate::aggregate_function;
use crate::built_in_function;
use crate::expr_fn::binary_expr;
use crate::logical_plan::Subquery;
+use crate::utils::expr_to_columns;
use crate::window_frame;
use crate::window_function;
use crate::AggregateUDF;
@@ -30,6 +31,7 @@ use arrow::datatypes::DataType;
use datafusion_common::Result;
use datafusion_common::{plan_err, Column};
use datafusion_common::{DataFusionError, ScalarValue};
+use std::collections::HashSet;
use std::fmt;
use std::fmt::{Display, Formatter, Write};
use std::hash::{BuildHasher, Hash, Hasher};
@@ -685,6 +687,14 @@ impl Expr {
_ => plan_err!(format!("Could not coerce '{}' into Column!", self)),
}
}
+
+ /// Return all referenced columns of this expression.
+ pub fn to_columns(&self) -> Result<HashSet<Column>> {
+ let mut using_columns = HashSet::new();
+ expr_to_columns(self, &mut using_columns)?;
+
+ Ok(using_columns)
+ }
}
impl Not for Expr {
@@ -1277,6 +1287,7 @@ mod test {
use crate::expr_fn::col;
use crate::{case, lit, Expr};
use arrow::datatypes::DataType;
+ use datafusion_common::Column;
use datafusion_common::{Result, ScalarValue};
#[test]
@@ -1327,4 +1338,26 @@ mod test {
assert!(exp2 > exp3);
assert!(exp3 < exp2);
}
+
+ #[test]
+ fn test_collect_expr() -> Result<()> {
+ // single column
+ {
+ let expr = &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64));
+ let columns = expr.to_columns()?;
+ assert_eq!(1, columns.len());
+ assert!(columns.contains(&Column::from_name("a")));
+ }
+
+ // multiple columns
+ {
+ let expr = col("a") + col("b") + lit(1);
+ let columns = expr.to_columns()?;
+ assert_eq!(2, columns.len());
+ assert!(columns.contains(&Column::from_name("a")));
+ assert!(columns.contains(&Column::from_name("b")));
+ }
+
+ Ok(())
+ }
}
diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs
index 6b83e449b..0782302d7 100644
--- a/datafusion/expr/src/logical_plan/builder.rs
+++ b/datafusion/expr/src/logical_plan/builder.rs
@@ -31,7 +31,7 @@ use crate::{
Window,
},
utils::{
- can_hash, expand_qualified_wildcard, expand_wildcard, expr_to_columns,
+ can_hash, expand_qualified_wildcard, expand_wildcard,
group_window_expr_by_sort_keys,
},
Expr, ExprSchemable, TableSource,
@@ -43,10 +43,7 @@ use datafusion_common::{
};
use std::any::Any;
use std::convert::TryFrom;
-use std::{
- collections::{HashMap, HashSet},
- sync::Arc,
-};
+use std::{collections::HashMap, sync::Arc};
/// Default table name for unnamed table
pub const UNNAMED_TABLE: &str = "?table?";
@@ -378,8 +375,7 @@ impl LogicalPlanBuilder {
.clone()
.into_iter()
.try_for_each::<_, Result<()>>(|expr| {
- let mut columns: HashSet<Column> = HashSet::new();
- expr_to_columns(&expr, &mut columns)?;
+ let columns = expr.to_columns()?;
columns.into_iter().for_each(|c| {
if schema.field_from_column(&c).is_err() {
diff --git a/datafusion/optimizer/src/filter_push_down.rs b/datafusion/optimizer/src/filter_push_down.rs
index 0539a8962..674910cd1 100644
--- a/datafusion/optimizer/src/filter_push_down.rs
+++ b/datafusion/optimizer/src/filter_push_down.rs
@@ -324,8 +324,7 @@ fn extract_or_clauses_for_join(
// If nothing can be extracted from any sub clauses, do nothing for this OR clause.
if let (Some(left_expr), Some(right_expr)) = (left_expr, right_expr) {
let predicate = or(left_expr, right_expr);
- let mut columns: HashSet<Column> = HashSet::new();
- expr_to_columns(&predicate, &mut columns).ok().unwrap();
+ let columns = predicate.to_columns().ok().unwrap();
exprs.push(predicate);
expr_columns.push(columns);
@@ -388,8 +387,7 @@ fn extract_or_clause(expr: &Expr, schema_columns: &HashSet<Column>) -> Option<Ex
}
}
_ => {
- let mut columns: HashSet<Column> = HashSet::new();
- expr_to_columns(expr, &mut columns).ok().unwrap();
+ let columns = expr.to_columns().ok().unwrap();
if schema_columns
.intersection(&columns)
@@ -541,8 +539,7 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result<LogicalPlan> {
utils::split_conjunction_owned(predicate)
.into_iter()
.try_for_each::<_, Result<()>>(|predicate| {
- let mut columns: HashSet<Column> = HashSet::new();
- expr_to_columns(&predicate, &mut columns)?;
+ let columns = predicate.to_columns()?;
state.filters.push((predicate, columns));
Ok(())
})?;
@@ -664,11 +661,7 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result<LogicalPlan> {
predicates
.into_iter()
- .map(|e| {
- let mut accum = HashSet::new();
- expr_to_columns(e, &mut accum)?;
- Ok((e.clone(), accum))
- })
+ .map(|e| Ok((e.clone(), e.to_columns()?)))
.collect::<Result<Vec<_>>>()
})
.unwrap_or_else(|| Ok(vec![]))?;
diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs
index 83f0c6ab9..dacd4af87 100644
--- a/datafusion/sql/src/planner.rs
+++ b/datafusion/sql/src/planner.rs
@@ -31,8 +31,7 @@ use datafusion_expr::logical_plan::{
};
use datafusion_expr::utils::{
can_hash, expand_qualified_wildcard, expand_wildcard, expr_as_column_expr,
- expr_to_columns, find_aggregate_exprs, find_column_exprs, find_window_exprs,
- COUNT_STAR_EXPANSION,
+ find_aggregate_exprs, find_column_exprs, find_window_exprs, COUNT_STAR_EXPANSION,
};
use datafusion_expr::{
and, col, lit, AggregateFunction, AggregateUDF, Expr, ExprSchemable, GetIndexedField,
@@ -690,8 +689,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
let join_filter = filter
.into_iter()
.map(|expr| {
- let mut using_columns = HashSet::new();
- expr_to_columns(&expr, &mut using_columns)?;
+ let using_columns = expr.to_columns()?;
normalize_col_with_schemas(
expr,