You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2021/06/22 18:06:47 UTC
[arrow-datafusion] branch master updated: Support qualified columns
in queries (#55)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new f2c01de Support qualified columns in queries (#55)
f2c01de is described below
commit f2c01de7d620081eb370966d928673c8d38ac798
Author: QP Hou <qp...@scribd.com>
AuthorDate: Tue Jun 22 11:06:39 2021 -0700
Support qualified columns in queries (#55)
* support qualified columns in queries
* handle coalesced hash join partition in HashJoinStream
* implement Into<Column> for &str
* add todo for ARROW-10971
* fix cross join handling in production push down optimizer
When a projection is pushed down to cross join inputs, fields from
resulting plan's schema need to be trimmed to only contain projected
fields.
* maintain field order during plan optimization using projections
* change TableScane name from Option<String> to String
* WIP: fix ballista
* separate logical and physical expressions in proto, fix ballista build
* fix join schema handling in production push down optimizer
schema needs to be recalculated based on newly optimized inputs
* tpch 7 & 8 are now passing!
* fix roundtrip_join test
* fix clippy warnings
* fix sql planner test error checking with matches
`format("{:?}", err)` yields different results between stable and
nightly rust.
* address FIXMEs
* honor datafusion field name semantic
strip qualifer name in physical field names
* add more comment
* enable more queries in benchmark/run.sh
* use unzip to avoid unnecessary iterators
* reduce diff by discarding style related changes
* simplify hash_join tests
* reduce diff for easier revuew
* fix unnecessary reference clippy error
* incorporate code review feedback
* fix window schema handling in projection pushdown optimizer
---
ballista/rust/core/proto/ballista.proto | 175 ++++-
.../rust/core/src/serde/logical_plan/from_proto.rs | 275 ++-----
ballista/rust/core/src/serde/logical_plan/mod.rs | 31 +-
.../rust/core/src/serde/logical_plan/to_proto.rs | 54 +-
ballista/rust/core/src/serde/mod.rs | 224 ++++++
.../core/src/serde/physical_plan/from_proto.rs | 364 +++++++--
ballista/rust/core/src/serde/physical_plan/mod.rs | 32 +-
.../rust/core/src/serde/physical_plan/to_proto.rs | 160 ++--
benchmarks/run.sh | 2 +-
benchmarks/src/bin/tpch.rs | 10 +
datafusion/src/dataframe.rs | 2 +
datafusion/src/execution/context.rs | 52 +-
datafusion/src/execution/dataframe_impl.rs | 7 +-
datafusion/src/logical_plan/builder.rs | 332 +++++---
datafusion/src/logical_plan/dfschema.rs | 139 +++-
datafusion/src/logical_plan/expr.rs | 208 +++++-
datafusion/src/logical_plan/mod.rs | 22 +-
datafusion/src/logical_plan/plan.rs | 86 ++-
datafusion/src/optimizer/constant_folding.rs | 32 +-
datafusion/src/optimizer/eliminate_limit.rs | 2 +-
datafusion/src/optimizer/filter_push_down.rs | 234 +++---
datafusion/src/optimizer/hash_build_probe_order.rs | 29 +-
datafusion/src/optimizer/limit_push_down.rs | 6 +-
datafusion/src/optimizer/projection_push_down.rs | 272 ++++---
datafusion/src/optimizer/simplify_expressions.rs | 8 +-
datafusion/src/optimizer/utils.rs | 71 +-
datafusion/src/physical_optimizer/pruning.rs | 155 ++--
datafusion/src/physical_plan/expressions/binary.rs | 41 +-
datafusion/src/physical_plan/expressions/case.rs | 20 +-
datafusion/src/physical_plan/expressions/cast.rs | 12 +-
datafusion/src/physical_plan/expressions/column.rs | 33 +-
.../src/physical_plan/expressions/in_list.rs | 112 ++-
.../src/physical_plan/expressions/is_not_null.rs | 2 +-
.../src/physical_plan/expressions/is_null.rs | 5 +-
.../src/physical_plan/expressions/min_max.rs | 2 +-
datafusion/src/physical_plan/expressions/mod.rs | 8 +-
datafusion/src/physical_plan/expressions/not.rs | 4 +-
.../src/physical_plan/expressions/nth_value.rs | 32 +-
.../src/physical_plan/expressions/try_cast.rs | 9 +-
datafusion/src/physical_plan/filter.rs | 4 +-
datafusion/src/physical_plan/functions.rs | 4 +-
datafusion/src/physical_plan/hash_aggregate.rs | 55 +-
datafusion/src/physical_plan/hash_join.rs | 419 ++++++++---
datafusion/src/physical_plan/hash_utils.rs | 96 ++-
datafusion/src/physical_plan/mod.rs | 4 +-
datafusion/src/physical_plan/parquet.rs | 9 +-
datafusion/src/physical_plan/planner.rs | 511 ++++++++++---
datafusion/src/physical_plan/projection.rs | 6 +-
datafusion/src/physical_plan/repartition.rs | 10 +-
datafusion/src/physical_plan/sort.rs | 10 +-
.../src/physical_plan/sort_preserving_merge.rs | 84 +--
datafusion/src/physical_plan/type_coercion.rs | 4 +-
datafusion/src/physical_plan/windows.rs | 10 +-
datafusion/src/prelude.rs | 2 +-
datafusion/src/sql/planner.rs | 831 +++++++++++----------
datafusion/src/sql/utils.rs | 12 +-
datafusion/src/test/mod.rs | 2 +-
datafusion/tests/custom_sources.rs | 11 +-
datafusion/tests/sql.rs | 118 +--
datafusion/tests/user_defined_plan.rs | 15 +-
integration-tests/test_psql_parity.py | 2 +-
61 files changed, 3603 insertions(+), 1880 deletions(-)
diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto
index 5aafd00..d75cbaa 100644
--- a/ballista/rust/core/proto/ballista.proto
+++ b/ballista/rust/core/proto/ballista.proto
@@ -28,11 +28,29 @@ option java_outer_classname = "BallistaProto";
// Ballista Logical Plan
///////////////////////////////////////////////////////////////////////////////////////////////////
+message ColumnRelation {
+ string relation = 1;
+}
+
+message Column {
+ string name = 1;
+ ColumnRelation relation = 2;
+}
+
+message DfField{
+ Field field = 1;
+ ColumnRelation qualifier = 2;
+}
+
+message DfSchema {
+ repeated DfField columns = 1;
+}
+
// logical expressions
message LogicalExprNode {
oneof ExprType {
// column references
- string column_name = 1;
+ Column column = 1;
// alias
AliasNode alias = 2;
@@ -295,7 +313,7 @@ message CreateExternalTableNode{
string location = 2;
FileType file_type = 3;
bool has_header = 4;
- Schema schema = 5;
+ DfSchema schema = 5;
}
enum FileType{
@@ -309,11 +327,6 @@ message ExplainNode{
bool verbose = 2;
}
-message DfField{
- string qualifier = 2;
- Field field = 1;
-}
-
message AggregateNode {
LogicalPlanNode input = 1;
repeated LogicalExprNode group_expr = 2;
@@ -369,8 +382,8 @@ message JoinNode {
LogicalPlanNode left = 1;
LogicalPlanNode right = 2;
JoinType join_type = 3;
- repeated string left_join_column = 4;
- repeated string right_join_column = 5;
+ repeated Column left_join_column = 4;
+ repeated Column right_join_column = 5;
}
message LimitNode {
@@ -408,6 +421,119 @@ message PhysicalPlanNode {
}
}
+// physical expressions
+message PhysicalExprNode {
+ oneof ExprType {
+ // column references
+ PhysicalColumn column = 1;
+
+ ScalarValue literal = 2;
+
+ // binary expressions
+ PhysicalBinaryExprNode binary_expr = 3;
+
+ // aggregate expressions
+ PhysicalAggregateExprNode aggregate_expr = 4;
+
+ // null checks
+ PhysicalIsNull is_null_expr = 5;
+ PhysicalIsNotNull is_not_null_expr = 6;
+ PhysicalNot not_expr = 7;
+
+ PhysicalCaseNode case_ = 8;
+ PhysicalCastNode cast = 9;
+ PhysicalSortExprNode sort = 10;
+ PhysicalNegativeNode negative = 11;
+ PhysicalInListNode in_list = 12;
+ PhysicalScalarFunctionNode scalar_function = 13;
+ PhysicalTryCastNode try_cast = 14;
+
+ // window expressions
+ PhysicalWindowExprNode window_expr = 15;
+ }
+}
+
+message PhysicalAggregateExprNode {
+ AggregateFunction aggr_function = 1;
+ PhysicalExprNode expr = 2;
+}
+
+message PhysicalWindowExprNode {
+ oneof window_function {
+ AggregateFunction aggr_function = 1;
+ BuiltInWindowFunction built_in_function = 2;
+ // udaf = 3
+ }
+ PhysicalExprNode expr = 4;
+}
+
+message PhysicalIsNull {
+ PhysicalExprNode expr = 1;
+}
+
+message PhysicalIsNotNull {
+ PhysicalExprNode expr = 1;
+}
+
+message PhysicalNot {
+ PhysicalExprNode expr = 1;
+}
+
+message PhysicalAliasNode {
+ PhysicalExprNode expr = 1;
+ string alias = 2;
+}
+
+message PhysicalBinaryExprNode {
+ PhysicalExprNode l = 1;
+ PhysicalExprNode r = 2;
+ string op = 3;
+}
+
+message PhysicalSortExprNode {
+ PhysicalExprNode expr = 1;
+ bool asc = 2;
+ bool nulls_first = 3;
+}
+
+message PhysicalWhenThen {
+ PhysicalExprNode when_expr = 1;
+ PhysicalExprNode then_expr = 2;
+}
+
+message PhysicalInListNode {
+ PhysicalExprNode expr = 1;
+ repeated PhysicalExprNode list = 2;
+ bool negated = 3;
+}
+
+message PhysicalCaseNode {
+ PhysicalExprNode expr = 1;
+ repeated PhysicalWhenThen when_then_expr = 2;
+ PhysicalExprNode else_expr = 3;
+}
+
+message PhysicalScalarFunctionNode {
+ string name = 1;
+ ScalarFunction fun = 2;
+ repeated PhysicalExprNode args = 3;
+ ArrowType return_type = 4;
+}
+
+message PhysicalTryCastNode {
+ PhysicalExprNode expr = 1;
+ ArrowType arrow_type = 2;
+}
+
+message PhysicalCastNode {
+ PhysicalExprNode expr = 1;
+ ArrowType arrow_type = 2;
+}
+
+message PhysicalNegativeNode {
+ PhysicalExprNode expr = 1;
+}
+
message UnresolvedShuffleExecNode {
repeated uint32 query_stage_ids = 1;
Schema schema = 2;
@@ -416,7 +542,7 @@ message UnresolvedShuffleExecNode {
message FilterExecNode {
PhysicalPlanNode input = 1;
- LogicalExprNode expr = 2;
+ PhysicalExprNode expr = 2;
}
message ParquetScanExecNode {
@@ -447,11 +573,15 @@ message HashJoinExecNode {
}
-message JoinOn {
- string left = 1;
- string right = 2;
+message PhysicalColumn {
+ string name = 1;
+ uint32 index = 2;
}
+message JoinOn {
+ PhysicalColumn left = 1;
+ PhysicalColumn right = 2;
+}
message EmptyExecNode {
bool produce_one_row = 1;
@@ -460,7 +590,7 @@ message EmptyExecNode {
message ProjectionExecNode {
PhysicalPlanNode input = 1;
- repeated LogicalExprNode expr = 2;
+ repeated PhysicalExprNode expr = 2;
repeated string expr_name = 3;
}
@@ -472,14 +602,14 @@ enum AggregateMode {
message WindowAggExecNode {
PhysicalPlanNode input = 1;
- repeated LogicalExprNode window_expr = 2;
+ repeated PhysicalExprNode window_expr = 2;
repeated string window_expr_name = 3;
Schema input_schema = 4;
}
message HashAggregateExecNode {
- repeated LogicalExprNode group_expr = 1;
- repeated LogicalExprNode aggr_expr = 2;
+ repeated PhysicalExprNode group_expr = 1;
+ repeated PhysicalExprNode aggr_expr = 2;
AggregateMode mode = 3;
PhysicalPlanNode input = 4;
repeated string group_expr_name = 5;
@@ -510,7 +640,7 @@ message LocalLimitExecNode {
message SortExecNode {
PhysicalPlanNode input = 1;
- repeated LogicalExprNode expr = 2;
+ repeated PhysicalExprNode expr = 2;
}
message CoalesceBatchesExecNode {
@@ -522,11 +652,16 @@ message MergeExecNode {
PhysicalPlanNode input = 1;
}
+message PhysicalHashRepartition {
+ repeated PhysicalExprNode hash_expr = 1;
+ uint64 partition_count = 2;
+}
+
message RepartitionExecNode{
PhysicalPlanNode input = 1;
oneof partition_method {
uint64 round_robin = 2;
- HashRepartition hash = 3;
+ PhysicalHashRepartition hash = 3;
uint64 unknown = 4;
}
}
@@ -803,7 +938,7 @@ message ScalarListValue{
message ScalarValue{
- oneof value{
+ oneof value {
bool bool_value = 1;
string utf8_value = 2;
string large_utf8_value = 3;
diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs
index c2c1001..1b7deb7 100644
--- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs
+++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs
@@ -18,7 +18,7 @@
//! Serde code to convert from protocol buffers to Rust data structures.
use crate::error::BallistaError;
-use crate::serde::{proto_error, protobuf};
+use crate::serde::{from_proto_binary_op, proto_error, protobuf};
use crate::{convert_box_required, convert_required};
use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit};
use datafusion::logical_plan::window_frames::{
@@ -26,7 +26,8 @@ use datafusion::logical_plan::window_frames::{
};
use datafusion::logical_plan::{
abs, acos, asin, atan, ceil, cos, exp, floor, ln, log10, log2, round, signum, sin,
- sqrt, tan, trunc, Expr, JoinType, LogicalPlan, LogicalPlanBuilder, Operator,
+ sqrt, tan, trunc, Column, DFField, DFSchema, Expr, JoinType, LogicalPlan,
+ LogicalPlanBuilder, Operator,
};
use datafusion::physical_plan::aggregates::AggregateFunction;
use datafusion::physical_plan::csv::CsvReadOptions;
@@ -36,6 +37,7 @@ use protobuf::logical_plan_node::LogicalPlanType;
use protobuf::{logical_expr_node::ExprType, scalar_type};
use std::{
convert::{From, TryInto},
+ sync::Arc,
unimplemented,
};
@@ -115,8 +117,8 @@ impl TryInto<LogicalPlan> for &protobuf::LogicalPlanNode {
.has_header(scan.has_header);
let mut projection = None;
- if let Some(column_names) = &scan.projection {
- let column_indices = column_names
+ if let Some(columns) = &scan.projection {
+ let column_indices = columns
.columns
.iter()
.map(|name| schema.index_of(name))
@@ -234,10 +236,10 @@ impl TryInto<LogicalPlan> for &protobuf::LogicalPlanNode {
.map_err(|e| e.into())
}
LogicalPlanType::Join(join) => {
- let left_keys: Vec<&str> =
- join.left_join_column.iter().map(|i| i.as_str()).collect();
- let right_keys: Vec<&str> =
- join.right_join_column.iter().map(|i| i.as_str()).collect();
+ let left_keys: Vec<Column> =
+ join.left_join_column.iter().map(|i| i.into()).collect();
+ let right_keys: Vec<Column> =
+ join.right_join_column.iter().map(|i| i.into()).collect();
let join_type =
protobuf::JoinType::from_i32(join.join_type).ok_or_else(|| {
proto_error(format!(
@@ -257,8 +259,8 @@ impl TryInto<LogicalPlan> for &protobuf::LogicalPlanNode {
.join(
&convert_box_required!(join.right)?,
join_type,
- &left_keys,
- &right_keys,
+ left_keys,
+ right_keys,
)?
.build()
.map_err(|e| e.into())
@@ -267,22 +269,48 @@ impl TryInto<LogicalPlan> for &protobuf::LogicalPlanNode {
}
}
-impl TryInto<datafusion::logical_plan::DFSchema> for protobuf::Schema {
+impl From<&protobuf::Column> for Column {
+ fn from(c: &protobuf::Column) -> Column {
+ let c = c.clone();
+ Column {
+ relation: c.relation.map(|r| r.relation),
+ name: c.name,
+ }
+ }
+}
+
+impl TryInto<DFSchema> for &protobuf::DfSchema {
type Error = BallistaError;
- fn try_into(self) -> Result<datafusion::logical_plan::DFSchema, Self::Error> {
- let schema: Schema = (&self).try_into()?;
- schema.try_into().map_err(BallistaError::DataFusionError)
+
+ fn try_into(self) -> Result<DFSchema, BallistaError> {
+ let fields = self
+ .columns
+ .iter()
+ .map(|c| c.try_into())
+ .collect::<Result<Vec<DFField>, _>>()?;
+ Ok(DFSchema::new(fields)?)
}
}
-impl TryInto<datafusion::logical_plan::DFSchemaRef> for protobuf::Schema {
+impl TryInto<datafusion::logical_plan::DFSchemaRef> for protobuf::DfSchema {
type Error = BallistaError;
+
fn try_into(self) -> Result<datafusion::logical_plan::DFSchemaRef, Self::Error> {
- use datafusion::logical_plan::ToDFSchema;
- let schema: Schema = (&self).try_into()?;
- schema
- .to_dfschema_ref()
- .map_err(BallistaError::DataFusionError)
+ let dfschema: DFSchema = (&self).try_into()?;
+ Ok(Arc::new(dfschema))
+ }
+}
+
+impl TryInto<DFField> for &protobuf::DfField {
+ type Error = BallistaError;
+
+ fn try_into(self) -> Result<DFField, Self::Error> {
+ let field: Field = convert_required!(self.field)?;
+
+ Ok(match &self.qualifier {
+ Some(q) => DFField::from_qualified(&q.relation, field),
+ None => DFField::from(field),
+ })
}
}
@@ -339,149 +367,6 @@ impl TryInto<DataType> for &protobuf::scalar_type::Datatype {
}
}
-impl TryInto<DataType> for &protobuf::arrow_type::ArrowTypeEnum {
- type Error = BallistaError;
- fn try_into(self) -> Result<DataType, Self::Error> {
- use protobuf::arrow_type;
- Ok(match self {
- arrow_type::ArrowTypeEnum::None(_) => DataType::Null,
- arrow_type::ArrowTypeEnum::Bool(_) => DataType::Boolean,
- arrow_type::ArrowTypeEnum::Uint8(_) => DataType::UInt8,
- arrow_type::ArrowTypeEnum::Int8(_) => DataType::Int8,
- arrow_type::ArrowTypeEnum::Uint16(_) => DataType::UInt16,
- arrow_type::ArrowTypeEnum::Int16(_) => DataType::Int16,
- arrow_type::ArrowTypeEnum::Uint32(_) => DataType::UInt32,
- arrow_type::ArrowTypeEnum::Int32(_) => DataType::Int32,
- arrow_type::ArrowTypeEnum::Uint64(_) => DataType::UInt64,
- arrow_type::ArrowTypeEnum::Int64(_) => DataType::Int64,
- arrow_type::ArrowTypeEnum::Float16(_) => DataType::Float16,
- arrow_type::ArrowTypeEnum::Float32(_) => DataType::Float32,
- arrow_type::ArrowTypeEnum::Float64(_) => DataType::Float64,
- arrow_type::ArrowTypeEnum::Utf8(_) => DataType::Utf8,
- arrow_type::ArrowTypeEnum::LargeUtf8(_) => DataType::LargeUtf8,
- arrow_type::ArrowTypeEnum::Binary(_) => DataType::Binary,
- arrow_type::ArrowTypeEnum::FixedSizeBinary(size) => {
- DataType::FixedSizeBinary(*size)
- }
- arrow_type::ArrowTypeEnum::LargeBinary(_) => DataType::LargeBinary,
- arrow_type::ArrowTypeEnum::Date32(_) => DataType::Date32,
- arrow_type::ArrowTypeEnum::Date64(_) => DataType::Date64,
- arrow_type::ArrowTypeEnum::Duration(time_unit) => {
- DataType::Duration(protobuf::TimeUnit::from_i32_to_arrow(*time_unit)?)
- }
- arrow_type::ArrowTypeEnum::Timestamp(protobuf::Timestamp {
- time_unit,
- timezone,
- }) => DataType::Timestamp(
- protobuf::TimeUnit::from_i32_to_arrow(*time_unit)?,
- match timezone.len() {
- 0 => None,
- _ => Some(timezone.to_owned()),
- },
- ),
- arrow_type::ArrowTypeEnum::Time32(time_unit) => {
- DataType::Time32(protobuf::TimeUnit::from_i32_to_arrow(*time_unit)?)
- }
- arrow_type::ArrowTypeEnum::Time64(time_unit) => {
- DataType::Time64(protobuf::TimeUnit::from_i32_to_arrow(*time_unit)?)
- }
- arrow_type::ArrowTypeEnum::Interval(interval_unit) => DataType::Interval(
- protobuf::IntervalUnit::from_i32_to_arrow(*interval_unit)?,
- ),
- arrow_type::ArrowTypeEnum::Decimal(protobuf::Decimal {
- whole,
- fractional,
- }) => DataType::Decimal(*whole as usize, *fractional as usize),
- arrow_type::ArrowTypeEnum::List(list) => {
- let list_type: &protobuf::Field = list
- .as_ref()
- .field_type
- .as_ref()
- .ok_or_else(|| proto_error("Protobuf deserialization error: List message missing required field 'field_type'"))?
- .as_ref();
- DataType::List(Box::new(list_type.try_into()?))
- }
- arrow_type::ArrowTypeEnum::LargeList(list) => {
- let list_type: &protobuf::Field = list
- .as_ref()
- .field_type
- .as_ref()
- .ok_or_else(|| proto_error("Protobuf deserialization error: List message missing required field 'field_type'"))?
- .as_ref();
- DataType::LargeList(Box::new(list_type.try_into()?))
- }
- arrow_type::ArrowTypeEnum::FixedSizeList(list) => {
- let list_type: &protobuf::Field = list
- .as_ref()
- .field_type
- .as_ref()
- .ok_or_else(|| proto_error("Protobuf deserialization error: List message missing required field 'field_type'"))?
- .as_ref();
- let list_size = list.list_size;
- DataType::FixedSizeList(Box::new(list_type.try_into()?), list_size)
- }
- arrow_type::ArrowTypeEnum::Struct(strct) => DataType::Struct(
- strct
- .sub_field_types
- .iter()
- .map(|field| field.try_into())
- .collect::<Result<Vec<_>, _>>()?,
- ),
- arrow_type::ArrowTypeEnum::Union(union) => DataType::Union(
- union
- .union_types
- .iter()
- .map(|field| field.try_into())
- .collect::<Result<Vec<_>, _>>()?,
- ),
- arrow_type::ArrowTypeEnum::Dictionary(dict) => {
- let pb_key_datatype = dict
- .as_ref()
- .key
- .as_ref()
- .ok_or_else(|| proto_error("Protobuf deserialization error: Dictionary message missing required field 'key'"))?;
- let pb_value_datatype = dict
- .as_ref()
- .value
- .as_ref()
- .ok_or_else(|| proto_error("Protobuf deserialization error: Dictionary message missing required field 'key'"))?;
- let key_datatype: DataType = pb_key_datatype.as_ref().try_into()?;
- let value_datatype: DataType = pb_value_datatype.as_ref().try_into()?;
- DataType::Dictionary(Box::new(key_datatype), Box::new(value_datatype))
- }
- })
- }
-}
-
-#[allow(clippy::from_over_into)]
-impl Into<DataType> for protobuf::PrimitiveScalarType {
- fn into(self) -> DataType {
- match self {
- protobuf::PrimitiveScalarType::Bool => DataType::Boolean,
- protobuf::PrimitiveScalarType::Uint8 => DataType::UInt8,
- protobuf::PrimitiveScalarType::Int8 => DataType::Int8,
- protobuf::PrimitiveScalarType::Uint16 => DataType::UInt16,
- protobuf::PrimitiveScalarType::Int16 => DataType::Int16,
- protobuf::PrimitiveScalarType::Uint32 => DataType::UInt32,
- protobuf::PrimitiveScalarType::Int32 => DataType::Int32,
- protobuf::PrimitiveScalarType::Uint64 => DataType::UInt64,
- protobuf::PrimitiveScalarType::Int64 => DataType::Int64,
- protobuf::PrimitiveScalarType::Float32 => DataType::Float32,
- protobuf::PrimitiveScalarType::Float64 => DataType::Float64,
- protobuf::PrimitiveScalarType::Utf8 => DataType::Utf8,
- protobuf::PrimitiveScalarType::LargeUtf8 => DataType::LargeUtf8,
- protobuf::PrimitiveScalarType::Date32 => DataType::Date32,
- protobuf::PrimitiveScalarType::TimeMicrosecond => {
- DataType::Time64(TimeUnit::Microsecond)
- }
- protobuf::PrimitiveScalarType::TimeNanosecond => {
- DataType::Time64(TimeUnit::Nanosecond)
- }
- protobuf::PrimitiveScalarType::Null => DataType::Null,
- }
- }
-}
-
//Does not typecheck lists
fn typechecked_scalar_value_conversion(
tested_type: &protobuf::scalar_value::Value,
@@ -899,7 +784,7 @@ impl TryInto<Expr> for &protobuf::LogicalExprNode {
op: from_proto_binary_op(&binary_expr.op)?,
right: Box::new(parse_required_expr(&binary_expr.r)?),
}),
- ExprType::ColumnName(column_name) => Ok(Expr::Column(column_name.to_owned())),
+ ExprType::Column(column) => Ok(Expr::Column(column.into())),
ExprType::Literal(literal) => {
use datafusion::scalar::ScalarValue;
let scalar_value: datafusion::scalar::ScalarValue = literal.try_into()?;
@@ -1164,28 +1049,6 @@ impl TryInto<Expr> for &protobuf::LogicalExprNode {
}
}
-fn from_proto_binary_op(op: &str) -> Result<Operator, BallistaError> {
- match op {
- "And" => Ok(Operator::And),
- "Or" => Ok(Operator::Or),
- "Eq" => Ok(Operator::Eq),
- "NotEq" => Ok(Operator::NotEq),
- "LtEq" => Ok(Operator::LtEq),
- "Lt" => Ok(Operator::Lt),
- "Gt" => Ok(Operator::Gt),
- "GtEq" => Ok(Operator::GtEq),
- "Plus" => Ok(Operator::Plus),
- "Minus" => Ok(Operator::Minus),
- "Multiply" => Ok(Operator::Multiply),
- "Divide" => Ok(Operator::Divide),
- "Like" => Ok(Operator::Like),
- other => Err(proto_error(format!(
- "Unsupported binary operator '{:?}'",
- other
- ))),
- }
-}
-
impl TryInto<DataType> for &protobuf::ScalarType {
type Error = BallistaError;
fn try_into(self) -> Result<DataType, Self::Error> {
@@ -1361,43 +1224,3 @@ impl TryFrom<protobuf::WindowFrame> for WindowFrame {
})
}
}
-
-impl From<protobuf::AggregateFunction> for AggregateFunction {
- fn from(aggr_function: protobuf::AggregateFunction) -> Self {
- match aggr_function {
- protobuf::AggregateFunction::Min => AggregateFunction::Min,
- protobuf::AggregateFunction::Max => AggregateFunction::Max,
- protobuf::AggregateFunction::Sum => AggregateFunction::Sum,
- protobuf::AggregateFunction::Avg => AggregateFunction::Avg,
- protobuf::AggregateFunction::Count => AggregateFunction::Count,
- }
- }
-}
-
-impl From<protobuf::BuiltInWindowFunction> for BuiltInWindowFunction {
- fn from(built_in_function: protobuf::BuiltInWindowFunction) -> Self {
- match built_in_function {
- protobuf::BuiltInWindowFunction::RowNumber => {
- BuiltInWindowFunction::RowNumber
- }
- protobuf::BuiltInWindowFunction::Rank => BuiltInWindowFunction::Rank,
- protobuf::BuiltInWindowFunction::PercentRank => {
- BuiltInWindowFunction::PercentRank
- }
- protobuf::BuiltInWindowFunction::DenseRank => {
- BuiltInWindowFunction::DenseRank
- }
- protobuf::BuiltInWindowFunction::Lag => BuiltInWindowFunction::Lag,
- protobuf::BuiltInWindowFunction::Lead => BuiltInWindowFunction::Lead,
- protobuf::BuiltInWindowFunction::FirstValue => {
- BuiltInWindowFunction::FirstValue
- }
- protobuf::BuiltInWindowFunction::CumeDist => BuiltInWindowFunction::CumeDist,
- protobuf::BuiltInWindowFunction::Ntile => BuiltInWindowFunction::Ntile,
- protobuf::BuiltInWindowFunction::NthValue => BuiltInWindowFunction::NthValue,
- protobuf::BuiltInWindowFunction::LastValue => {
- BuiltInWindowFunction::LastValue
- }
- }
- }
-}
diff --git a/ballista/rust/core/src/serde/logical_plan/mod.rs b/ballista/rust/core/src/serde/logical_plan/mod.rs
index d2792b0..0d27c58 100644
--- a/ballista/rust/core/src/serde/logical_plan/mod.rs
+++ b/ballista/rust/core/src/serde/logical_plan/mod.rs
@@ -26,7 +26,9 @@ mod roundtrip_tests {
use core::panic;
use datafusion::{
arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit},
- logical_plan::{Expr, LogicalPlan, LogicalPlanBuilder, Partitioning, ToDFSchema},
+ logical_plan::{
+ col, Expr, LogicalPlan, LogicalPlanBuilder, Partitioning, ToDFSchema,
+ },
physical_plan::{csv::CsvReadOptions, functions::BuiltinScalarFunction::Sqrt},
prelude::*,
scalar::ScalarValue,
@@ -61,10 +63,8 @@ mod roundtrip_tests {
let test_batch_sizes = [usize::MIN, usize::MAX, 43256];
- let test_expr: Vec<Expr> = vec![
- Expr::Column("c1".to_string()) + Expr::Column("c2".to_string()),
- Expr::Literal((4.0).into()),
- ];
+ let test_expr: Vec<Expr> =
+ vec![col("c1") + col("c2"), Expr::Literal((4.0).into())];
let schema = Schema::new(vec![
Field::new("id", DataType::Int32, false),
@@ -688,15 +688,20 @@ mod roundtrip_tests {
Field::new("salary", DataType::Int32, false),
]);
- let scan_plan = LogicalPlanBuilder::empty(false)
- .build()
- .map_err(BallistaError::DataFusionError)?;
+ let scan_plan = LogicalPlanBuilder::scan_csv(
+ "employee1",
+ CsvReadOptions::new().schema(&schema).has_header(true),
+ Some(vec![0, 3, 4]),
+ )?
+ .build()
+ .map_err(BallistaError::DataFusionError)?;
+
let plan = LogicalPlanBuilder::scan_csv(
- "employee.csv",
+ "employee2",
CsvReadOptions::new().schema(&schema).has_header(true),
- Some(vec![3, 4]),
+ Some(vec![0, 3, 4]),
)
- .and_then(|plan| plan.join(&scan_plan, JoinType::Inner, &["id"], &["id"]))
+ .and_then(|plan| plan.join(&scan_plan, JoinType::Inner, vec!["id"], vec!["id"]))
.and_then(|plan| plan.build())
.map_err(BallistaError::DataFusionError)?;
@@ -779,7 +784,7 @@ mod roundtrip_tests {
#[test]
fn roundtrip_is_null() -> Result<()> {
- let test_expr = Expr::IsNull(Box::new(Expr::Column("id".into())));
+ let test_expr = Expr::IsNull(Box::new(col("id")));
roundtrip_test!(test_expr, protobuf::LogicalExprNode, Expr);
@@ -788,7 +793,7 @@ mod roundtrip_tests {
#[test]
fn roundtrip_is_not_null() -> Result<()> {
- let test_expr = Expr::IsNotNull(Box::new(Expr::Column("id".into())));
+ let test_expr = Expr::IsNotNull(Box::new(col("id")));
roundtrip_test!(test_expr, protobuf::LogicalExprNode, Expr);
diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs
index c454d03..24e2b56 100644
--- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs
+++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs
@@ -26,7 +26,7 @@ use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUn
use datafusion::datasource::CsvFile;
use datafusion::logical_plan::{
window_frames::{WindowFrame, WindowFrameBound, WindowFrameUnits},
- Expr, JoinType, LogicalPlan,
+ Column, Expr, JoinType, LogicalPlan,
};
use datafusion::physical_plan::aggregates::AggregateFunction;
use datafusion::physical_plan::functions::BuiltinScalarFunction;
@@ -816,8 +816,8 @@ impl TryInto<protobuf::LogicalPlanNode> for &LogicalPlan {
JoinType::Semi => protobuf::JoinType::Semi,
JoinType::Anti => protobuf::JoinType::Anti,
};
- let left_join_column = on.iter().map(|on| on.0.to_owned()).collect();
- let right_join_column = on.iter().map(|on| on.1.to_owned()).collect();
+ let (left_join_column, right_join_column) =
+ on.iter().map(|(l, r)| (l.into(), r.into())).unzip();
Ok(protobuf::LogicalPlanNode {
logical_plan_type: Some(LogicalPlanType::Join(Box::new(
protobuf::JoinNode {
@@ -908,13 +908,6 @@ impl TryInto<protobuf::LogicalPlanNode> for &LogicalPlan {
schema: df_schema,
} => {
use datafusion::sql::parser::FileType;
- let schema: Schema = df_schema.as_ref().clone().into();
- let pb_schema: protobuf::Schema = (&schema).try_into().map_err(|e| {
- BallistaError::General(format!(
- "Could not convert schema into protobuf: {:?}",
- e
- ))
- })?;
let pb_file_type: protobuf::FileType = match file_type {
FileType::NdJson => protobuf::FileType::NdJson,
@@ -929,7 +922,7 @@ impl TryInto<protobuf::LogicalPlanNode> for &LogicalPlan {
location: location.clone(),
file_type: pb_file_type as i32,
has_header: *has_header,
- schema: Some(pb_schema),
+ schema: Some(df_schema.into()),
},
)),
})
@@ -971,9 +964,9 @@ impl TryInto<protobuf::LogicalExprNode> for &Expr {
use datafusion::scalar::ScalarValue;
use protobuf::scalar_value::Value;
match self {
- Expr::Column(name) => {
+ Expr::Column(c) => {
let expr = protobuf::LogicalExprNode {
- expr_type: Some(ExprType::ColumnName(name.clone())),
+ expr_type: Some(ExprType::Column(c.into())),
};
Ok(expr)
}
@@ -1214,6 +1207,23 @@ impl TryInto<protobuf::LogicalExprNode> for &Expr {
}
}
+impl From<Column> for protobuf::Column {
+ fn from(c: Column) -> protobuf::Column {
+ protobuf::Column {
+ relation: c
+ .relation
+ .map(|relation| protobuf::ColumnRelation { relation }),
+ name: c.name,
+ }
+ }
+}
+
+impl From<&Column> for protobuf::Column {
+ fn from(c: &Column) -> protobuf::Column {
+ c.clone().into()
+ }
+}
+
#[allow(clippy::from_over_into)]
impl Into<protobuf::Schema> for &Schema {
fn into(self) -> protobuf::Schema {
@@ -1227,6 +1237,24 @@ impl Into<protobuf::Schema> for &Schema {
}
}
+impl From<&datafusion::logical_plan::DFField> for protobuf::DfField {
+ fn from(f: &datafusion::logical_plan::DFField) -> protobuf::DfField {
+ protobuf::DfField {
+ field: Some(f.field().into()),
+ qualifier: f.qualifier().map(|r| protobuf::ColumnRelation {
+ relation: r.to_string(),
+ }),
+ }
+ }
+}
+
+impl From<&datafusion::logical_plan::DFSchemaRef> for protobuf::DfSchema {
+ fn from(s: &datafusion::logical_plan::DFSchemaRef) -> protobuf::DfSchema {
+ let columns = s.fields().iter().map(|f| f.into()).collect::<Vec<_>>();
+ protobuf::DfSchema { columns }
+ }
+}
+
impl From<&AggregateFunction> for protobuf::AggregateFunction {
fn from(value: &AggregateFunction) -> Self {
match value {
diff --git a/ballista/rust/core/src/serde/mod.rs b/ballista/rust/core/src/serde/mod.rs
index b961639..af83660 100644
--- a/ballista/rust/core/src/serde/mod.rs
+++ b/ballista/rust/core/src/serde/mod.rs
@@ -20,6 +20,10 @@
use std::{convert::TryInto, io::Cursor};
+use datafusion::logical_plan::Operator;
+use datafusion::physical_plan::aggregates::AggregateFunction;
+use datafusion::physical_plan::window_functions::BuiltInWindowFunction;
+
use crate::{error::BallistaError, serde::scheduler::Action as BallistaAction};
use prost::Message;
@@ -58,6 +62,17 @@ macro_rules! convert_required {
}
#[macro_export]
+macro_rules! into_required {
+ ($PB:expr) => {{
+ if let Some(field) = $PB.as_ref() {
+ Ok(field.into())
+ } else {
+ Err(proto_error("Missing required field in protobuf"))
+ }
+ }};
+}
+
+#[macro_export]
macro_rules! convert_box_required {
($PB:expr) => {{
if let Some(field) = $PB.as_ref() {
@@ -67,3 +82,212 @@ macro_rules! convert_box_required {
}
}};
}
+
+pub(crate) fn from_proto_binary_op(op: &str) -> Result<Operator, BallistaError> {
+ match op {
+ "And" => Ok(Operator::And),
+ "Or" => Ok(Operator::Or),
+ "Eq" => Ok(Operator::Eq),
+ "NotEq" => Ok(Operator::NotEq),
+ "LtEq" => Ok(Operator::LtEq),
+ "Lt" => Ok(Operator::Lt),
+ "Gt" => Ok(Operator::Gt),
+ "GtEq" => Ok(Operator::GtEq),
+ "Plus" => Ok(Operator::Plus),
+ "Minus" => Ok(Operator::Minus),
+ "Multiply" => Ok(Operator::Multiply),
+ "Divide" => Ok(Operator::Divide),
+ "Like" => Ok(Operator::Like),
+ other => Err(proto_error(format!(
+ "Unsupported binary operator '{:?}'",
+ other
+ ))),
+ }
+}
+
+impl From<protobuf::AggregateFunction> for AggregateFunction {
+ fn from(agg_fun: protobuf::AggregateFunction) -> AggregateFunction {
+ match agg_fun {
+ protobuf::AggregateFunction::Min => AggregateFunction::Min,
+ protobuf::AggregateFunction::Max => AggregateFunction::Max,
+ protobuf::AggregateFunction::Sum => AggregateFunction::Sum,
+ protobuf::AggregateFunction::Avg => AggregateFunction::Avg,
+ protobuf::AggregateFunction::Count => AggregateFunction::Count,
+ }
+ }
+}
+
+impl From<protobuf::BuiltInWindowFunction> for BuiltInWindowFunction {
+ fn from(built_in_function: protobuf::BuiltInWindowFunction) -> Self {
+ match built_in_function {
+ protobuf::BuiltInWindowFunction::RowNumber => {
+ BuiltInWindowFunction::RowNumber
+ }
+ protobuf::BuiltInWindowFunction::Rank => BuiltInWindowFunction::Rank,
+ protobuf::BuiltInWindowFunction::PercentRank => {
+ BuiltInWindowFunction::PercentRank
+ }
+ protobuf::BuiltInWindowFunction::DenseRank => {
+ BuiltInWindowFunction::DenseRank
+ }
+ protobuf::BuiltInWindowFunction::Lag => BuiltInWindowFunction::Lag,
+ protobuf::BuiltInWindowFunction::Lead => BuiltInWindowFunction::Lead,
+ protobuf::BuiltInWindowFunction::FirstValue => {
+ BuiltInWindowFunction::FirstValue
+ }
+ protobuf::BuiltInWindowFunction::CumeDist => BuiltInWindowFunction::CumeDist,
+ protobuf::BuiltInWindowFunction::Ntile => BuiltInWindowFunction::Ntile,
+ protobuf::BuiltInWindowFunction::NthValue => BuiltInWindowFunction::NthValue,
+ protobuf::BuiltInWindowFunction::LastValue => {
+ BuiltInWindowFunction::LastValue
+ }
+ }
+ }
+}
+
+impl TryInto<datafusion::arrow::datatypes::DataType>
+ for &protobuf::arrow_type::ArrowTypeEnum
+{
+ type Error = BallistaError;
+ fn try_into(self) -> Result<datafusion::arrow::datatypes::DataType, Self::Error> {
+ use datafusion::arrow::datatypes::DataType;
+ use protobuf::arrow_type;
+ Ok(match self {
+ arrow_type::ArrowTypeEnum::None(_) => DataType::Null,
+ arrow_type::ArrowTypeEnum::Bool(_) => DataType::Boolean,
+ arrow_type::ArrowTypeEnum::Uint8(_) => DataType::UInt8,
+ arrow_type::ArrowTypeEnum::Int8(_) => DataType::Int8,
+ arrow_type::ArrowTypeEnum::Uint16(_) => DataType::UInt16,
+ arrow_type::ArrowTypeEnum::Int16(_) => DataType::Int16,
+ arrow_type::ArrowTypeEnum::Uint32(_) => DataType::UInt32,
+ arrow_type::ArrowTypeEnum::Int32(_) => DataType::Int32,
+ arrow_type::ArrowTypeEnum::Uint64(_) => DataType::UInt64,
+ arrow_type::ArrowTypeEnum::Int64(_) => DataType::Int64,
+ arrow_type::ArrowTypeEnum::Float16(_) => DataType::Float16,
+ arrow_type::ArrowTypeEnum::Float32(_) => DataType::Float32,
+ arrow_type::ArrowTypeEnum::Float64(_) => DataType::Float64,
+ arrow_type::ArrowTypeEnum::Utf8(_) => DataType::Utf8,
+ arrow_type::ArrowTypeEnum::LargeUtf8(_) => DataType::LargeUtf8,
+ arrow_type::ArrowTypeEnum::Binary(_) => DataType::Binary,
+ arrow_type::ArrowTypeEnum::FixedSizeBinary(size) => {
+ DataType::FixedSizeBinary(*size)
+ }
+ arrow_type::ArrowTypeEnum::LargeBinary(_) => DataType::LargeBinary,
+ arrow_type::ArrowTypeEnum::Date32(_) => DataType::Date32,
+ arrow_type::ArrowTypeEnum::Date64(_) => DataType::Date64,
+ arrow_type::ArrowTypeEnum::Duration(time_unit) => {
+ DataType::Duration(protobuf::TimeUnit::from_i32_to_arrow(*time_unit)?)
+ }
+ arrow_type::ArrowTypeEnum::Timestamp(protobuf::Timestamp {
+ time_unit,
+ timezone,
+ }) => DataType::Timestamp(
+ protobuf::TimeUnit::from_i32_to_arrow(*time_unit)?,
+ match timezone.len() {
+ 0 => None,
+ _ => Some(timezone.to_owned()),
+ },
+ ),
+ arrow_type::ArrowTypeEnum::Time32(time_unit) => {
+ DataType::Time32(protobuf::TimeUnit::from_i32_to_arrow(*time_unit)?)
+ }
+ arrow_type::ArrowTypeEnum::Time64(time_unit) => {
+ DataType::Time64(protobuf::TimeUnit::from_i32_to_arrow(*time_unit)?)
+ }
+ arrow_type::ArrowTypeEnum::Interval(interval_unit) => DataType::Interval(
+ protobuf::IntervalUnit::from_i32_to_arrow(*interval_unit)?,
+ ),
+ arrow_type::ArrowTypeEnum::Decimal(protobuf::Decimal {
+ whole,
+ fractional,
+ }) => DataType::Decimal(*whole as usize, *fractional as usize),
+ arrow_type::ArrowTypeEnum::List(list) => {
+ let list_type: &protobuf::Field = list
+ .as_ref()
+ .field_type
+ .as_ref()
+ .ok_or_else(|| proto_error("Protobuf deserialization error: List message missing required field 'field_type'"))?
+ .as_ref();
+ DataType::List(Box::new(list_type.try_into()?))
+ }
+ arrow_type::ArrowTypeEnum::LargeList(list) => {
+ let list_type: &protobuf::Field = list
+ .as_ref()
+ .field_type
+ .as_ref()
+ .ok_or_else(|| proto_error("Protobuf deserialization error: List message missing required field 'field_type'"))?
+ .as_ref();
+ DataType::LargeList(Box::new(list_type.try_into()?))
+ }
+ arrow_type::ArrowTypeEnum::FixedSizeList(list) => {
+ let list_type: &protobuf::Field = list
+ .as_ref()
+ .field_type
+ .as_ref()
+ .ok_or_else(|| proto_error("Protobuf deserialization error: List message missing required field 'field_type'"))?
+ .as_ref();
+ let list_size = list.list_size;
+ DataType::FixedSizeList(Box::new(list_type.try_into()?), list_size)
+ }
+ arrow_type::ArrowTypeEnum::Struct(strct) => DataType::Struct(
+ strct
+ .sub_field_types
+ .iter()
+ .map(|field| field.try_into())
+ .collect::<Result<Vec<_>, _>>()?,
+ ),
+ arrow_type::ArrowTypeEnum::Union(union) => DataType::Union(
+ union
+ .union_types
+ .iter()
+ .map(|field| field.try_into())
+ .collect::<Result<Vec<_>, _>>()?,
+ ),
+ arrow_type::ArrowTypeEnum::Dictionary(dict) => {
+ let pb_key_datatype = dict
+ .as_ref()
+ .key
+ .as_ref()
+ .ok_or_else(|| proto_error("Protobuf deserialization error: Dictionary message missing required field 'key'"))?;
+ let pb_value_datatype = dict
+ .as_ref()
+ .value
+ .as_ref()
+ .ok_or_else(|| proto_error("Protobuf deserialization error: Dictionary message missing required field 'key'"))?;
+ let key_datatype: DataType = pb_key_datatype.as_ref().try_into()?;
+ let value_datatype: DataType = pb_value_datatype.as_ref().try_into()?;
+ DataType::Dictionary(Box::new(key_datatype), Box::new(value_datatype))
+ }
+ })
+ }
+}
+
+#[allow(clippy::from_over_into)]
+impl Into<datafusion::arrow::datatypes::DataType> for protobuf::PrimitiveScalarType {
+ fn into(self) -> datafusion::arrow::datatypes::DataType {
+ use datafusion::arrow::datatypes::{DataType, TimeUnit};
+ match self {
+ protobuf::PrimitiveScalarType::Bool => DataType::Boolean,
+ protobuf::PrimitiveScalarType::Uint8 => DataType::UInt8,
+ protobuf::PrimitiveScalarType::Int8 => DataType::Int8,
+ protobuf::PrimitiveScalarType::Uint16 => DataType::UInt16,
+ protobuf::PrimitiveScalarType::Int16 => DataType::Int16,
+ protobuf::PrimitiveScalarType::Uint32 => DataType::UInt32,
+ protobuf::PrimitiveScalarType::Int32 => DataType::Int32,
+ protobuf::PrimitiveScalarType::Uint64 => DataType::UInt64,
+ protobuf::PrimitiveScalarType::Int64 => DataType::Int64,
+ protobuf::PrimitiveScalarType::Float32 => DataType::Float32,
+ protobuf::PrimitiveScalarType::Float64 => DataType::Float64,
+ protobuf::PrimitiveScalarType::Utf8 => DataType::Utf8,
+ protobuf::PrimitiveScalarType::LargeUtf8 => DataType::LargeUtf8,
+ protobuf::PrimitiveScalarType::Date32 => DataType::Date32,
+ protobuf::PrimitiveScalarType::TimeMicrosecond => {
+ DataType::Time64(TimeUnit::Microsecond)
+ }
+ protobuf::PrimitiveScalarType::TimeNanosecond => {
+ DataType::Time64(TimeUnit::Nanosecond)
+ }
+ protobuf::PrimitiveScalarType::Null => DataType::Null,
+ }
+ }
+}
diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs
index a2c9db9..4b87be4 100644
--- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs
+++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs
@@ -18,17 +18,16 @@
//! Serde code to convert from protocol buffers to Rust data structures.
use std::collections::HashMap;
-use std::convert::TryInto;
+use std::convert::{TryFrom, TryInto};
use std::sync::Arc;
use crate::error::BallistaError;
use crate::execution_plans::{ShuffleReaderExec, UnresolvedShuffleExec};
use crate::serde::protobuf::repartition_exec_node::PartitionMethod;
-use crate::serde::protobuf::LogicalExprNode;
use crate::serde::protobuf::ShuffleReaderPartition;
use crate::serde::scheduler::PartitionLocation;
-use crate::serde::{proto_error, protobuf};
-use crate::{convert_box_required, convert_required};
+use crate::serde::{from_proto_binary_op, proto_error, protobuf};
+use crate::{convert_box_required, convert_required, into_required};
use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef};
use datafusion::catalog::catalog::{
CatalogList, CatalogProvider, MemoryCatalogList, MemoryCatalogProvider,
@@ -36,9 +35,8 @@ use datafusion::catalog::catalog::{
use datafusion::execution::context::{
ExecutionConfig, ExecutionContextState, ExecutionProps,
};
-use datafusion::logical_plan::{DFSchema, Expr};
-use datafusion::physical_plan::aggregates::AggregateFunction;
-use datafusion::physical_plan::expressions::col;
+use datafusion::logical_plan::{window_frames::WindowFrame, DFSchema, Expr};
+use datafusion::physical_plan::aggregates::{create_aggregate_expr, AggregateFunction};
use datafusion::physical_plan::hash_aggregate::{AggregateMode, HashAggregateExec};
use datafusion::physical_plan::hash_join::PartitionMode;
use datafusion::physical_plan::merge::MergeExec;
@@ -46,13 +44,18 @@ use datafusion::physical_plan::planner::DefaultPhysicalPlanner;
use datafusion::physical_plan::window_functions::{
BuiltInWindowFunction, WindowFunction,
};
-use datafusion::physical_plan::windows::WindowAggExec;
+use datafusion::physical_plan::windows::{create_window_expr, WindowAggExec};
use datafusion::physical_plan::{
coalesce_batches::CoalesceBatchesExec,
csv::CsvExec,
empty::EmptyExec,
- expressions::{Avg, Column, PhysicalSortExpr},
+ expressions::{
+ col, Avg, BinaryExpr, CaseExpr, CastExpr, Column, InListExpr, IsNotNullExpr,
+ IsNullExpr, Literal, NegativeExpr, NotExpr, PhysicalSortExpr, TryCastExpr,
+ DEFAULT_DATAFUSION_CAST_OPTIONS,
+ },
filter::FilterExec,
+ functions::{self, BuiltinScalarFunction, ScalarFunctionExpr},
hash_join::HashJoinExec,
hash_utils::JoinType,
limit::{GlobalLimitExec, LocalLimitExec},
@@ -65,7 +68,7 @@ use datafusion::physical_plan::{
use datafusion::physical_plan::{AggregateExpr, ExecutionPlan, PhysicalExpr, WindowExpr};
use datafusion::prelude::CsvReadOptions;
use log::debug;
-use protobuf::logical_expr_node::ExprType;
+use protobuf::physical_expr_node::ExprType;
use protobuf::physical_plan_node::PhysicalPlanType;
impl TryInto<Arc<dyn ExecutionPlan>> for &protobuf::PhysicalPlanNode {
@@ -86,23 +89,23 @@ impl TryInto<Arc<dyn ExecutionPlan>> for &protobuf::PhysicalPlanNode {
.expr
.iter()
.zip(projection.expr_name.iter())
- .map(|(expr, name)| {
- compile_expr(expr, &input.schema()).map(|e| (e, name.to_string()))
- })
- .collect::<Result<Vec<_>, _>>()?;
+ .map(|(expr, name)| Ok((expr.try_into()?, name.to_string())))
+ .collect::<Result<Vec<(Arc<dyn PhysicalExpr>, String)>, Self::Error>>(
+ )?;
Ok(Arc::new(ProjectionExec::try_new(exprs, input)?))
}
PhysicalPlanType::Filter(filter) => {
let input: Arc<dyn ExecutionPlan> = convert_box_required!(filter.input)?;
- let predicate = compile_expr(
- filter.expr.as_ref().ok_or_else(|| {
+ let predicate = filter
+ .expr
+ .as_ref()
+ .ok_or_else(|| {
BallistaError::General(
"filter (FilterExecNode) in PhysicalPlanNode is missing."
.to_owned(),
)
- })?,
- &input.schema(),
- )?;
+ })?
+ .try_into()?;
Ok(Arc::new(FilterExec::try_new(predicate, input)?))
}
PhysicalPlanType::CsvScan(scan) => {
@@ -153,7 +156,7 @@ impl TryInto<Arc<dyn ExecutionPlan>> for &protobuf::PhysicalPlanNode {
let expr = hash_part
.hash_expr
.iter()
- .map(|e| compile_expr(e, &input.schema()))
+ .map(|e| e.try_into())
.collect::<Result<Vec<Arc<dyn PhysicalExpr>>, _>>()?;
Ok(Arc::new(RepartitionExec::try_new(
@@ -207,25 +210,33 @@ impl TryInto<Arc<dyn ExecutionPlan>> for &protobuf::PhysicalPlanNode {
.clone();
let physical_schema: SchemaRef =
SchemaRef::new((&input_schema).try_into()?);
- let ctx_state = ExecutionContextState::new();
- let window_agg_expr: Vec<(Expr, String)> = window_agg
+
+ let physical_window_expr: Vec<Arc<dyn WindowExpr>> = window_agg
.window_expr
.iter()
.zip(window_agg.window_expr_name.iter())
- .map(|(expr, name)| expr.try_into().map(|expr| (expr, name.clone())))
- .collect::<Result<Vec<_>, _>>()?;
- let df_planner = DefaultPhysicalPlanner::default();
- let physical_window_expr = window_agg_expr
- .iter()
.map(|(expr, name)| {
- df_planner.create_window_expr_with_name(
- expr,
- name.to_string(),
- &physical_schema,
- &ctx_state,
- )
+ let expr_type = expr.expr_type.as_ref().ok_or_else(|| {
+ proto_error("Unexpected empty window physical expression")
+ })?;
+
+ match expr_type {
+ ExprType::WindowExpr(window_node) => Ok(create_window_expr(
+ &convert_required!(window_node.window_function)?,
+ name.to_owned(),
+ &[convert_box_required!(window_node.expr)?],
+ &[],
+ &[],
+ Some(WindowFrame::default()),
+ &physical_schema,
+ )?),
+ _ => Err(BallistaError::General(
+ "Invalid expression for WindowAggrExec".to_string(),
+ )),
+ }
})
.collect::<Result<Vec<_>, _>>()?;
+
Ok(Arc::new(WindowAggExec::try_new(
physical_window_expr,
input,
@@ -253,16 +264,10 @@ impl TryInto<Arc<dyn ExecutionPlan>> for &protobuf::PhysicalPlanNode {
.iter()
.zip(hash_agg.group_expr_name.iter())
.map(|(expr, name)| {
- compile_expr(expr, &input.schema()).map(|e| (e, name.to_string()))
+ expr.try_into().map(|expr| (expr, name.to_string()))
})
.collect::<Result<Vec<_>, _>>()?;
- let logical_agg_expr: Vec<(Expr, String)> = hash_agg
- .aggr_expr
- .iter()
- .zip(hash_agg.aggr_expr_name.iter())
- .map(|(expr, name)| expr.try_into().map(|expr| (expr, name.clone())))
- .collect::<Result<Vec<_>, _>>()?;
- let ctx_state = ExecutionContextState::new();
+
let input_schema = hash_agg
.input_schema
.as_ref()
@@ -274,18 +279,47 @@ impl TryInto<Arc<dyn ExecutionPlan>> for &protobuf::PhysicalPlanNode {
.clone();
let physical_schema: SchemaRef =
SchemaRef::new((&input_schema).try_into()?);
- let df_planner = DefaultPhysicalPlanner::default();
- let physical_aggr_expr = logical_agg_expr
+
+ let physical_aggr_expr: Vec<Arc<dyn AggregateExpr>> = hash_agg
+ .aggr_expr
.iter()
+ .zip(hash_agg.aggr_expr_name.iter())
.map(|(expr, name)| {
- df_planner.create_aggregate_expr_with_name(
- expr,
- name.to_string(),
- &physical_schema,
- &ctx_state,
- )
+ let expr_type = expr.expr_type.as_ref().ok_or_else(|| {
+ proto_error("Unexpected empty aggregate physical expression")
+ })?;
+
+ match expr_type {
+ ExprType::AggregateExpr(agg_node) => {
+ let aggr_function =
+ protobuf::AggregateFunction::from_i32(
+ agg_node.aggr_function,
+ )
+ .ok_or_else(
+ || {
+ proto_error(format!(
+ "Received an unknown aggregate function: {}",
+ agg_node.aggr_function
+ ))
+ },
+ )?;
+
+ Ok(create_aggregate_expr(
+ &aggr_function.into(),
+ false,
+ &[convert_box_required!(agg_node.expr)?],
+ &physical_schema,
+ name.to_string(),
+ )?)
+ }
+ _ => Err(BallistaError::General(
+ "Invalid aggregate expression for HashAggregateExec"
+ .to_string(),
+ )),
+ }
})
.collect::<Result<Vec<_>, _>>()?;
+
Ok(Arc::new(HashAggregateExec::try_new(
agg_mode,
group,
@@ -298,11 +332,15 @@ impl TryInto<Arc<dyn ExecutionPlan>> for &protobuf::PhysicalPlanNode {
let left: Arc<dyn ExecutionPlan> = convert_box_required!(hashjoin.left)?;
let right: Arc<dyn ExecutionPlan> =
convert_box_required!(hashjoin.right)?;
- let on: Vec<(String, String)> = hashjoin
+ let on: Vec<(Column, Column)> = hashjoin
.on
.iter()
- .map(|col| (col.left.clone(), col.right.clone()))
- .collect();
+ .map(|col| {
+ let left = into_required!(col.left)?;
+ let right = into_required!(col.right)?;
+ Ok((left, right))
+ })
+ .collect::<Result<_, Self::Error>>()?;
let join_type = protobuf::JoinType::from_i32(hashjoin.join_type)
.ok_or_else(|| {
proto_error(format!(
@@ -321,7 +359,7 @@ impl TryInto<Arc<dyn ExecutionPlan>> for &protobuf::PhysicalPlanNode {
Ok(Arc::new(HashJoinExec::try_new(
left,
right,
- &on,
+ on,
&join_type,
PartitionMode::CollectLeft,
)?))
@@ -358,7 +396,7 @@ impl TryInto<Arc<dyn ExecutionPlan>> for &protobuf::PhysicalPlanNode {
self
))
})?;
- if let protobuf::logical_expr_node::ExprType::Sort(sort_expr) = expr {
+ if let protobuf::physical_expr_node::ExprType::Sort(sort_expr) = expr {
let expr = sort_expr
.expr
.as_ref()
@@ -370,7 +408,7 @@ impl TryInto<Arc<dyn ExecutionPlan>> for &protobuf::PhysicalPlanNode {
})?
.as_ref();
Ok(PhysicalSortExpr {
- expr: compile_expr(expr, &input.schema())?,
+ expr: expr.try_into()?,
options: SortOptions {
descending: !sort_expr.asc,
nulls_first: sort_expr.nulls_first,
@@ -403,14 +441,210 @@ impl TryInto<Arc<dyn ExecutionPlan>> for &protobuf::PhysicalPlanNode {
}
}
-fn compile_expr(
- expr: &protobuf::LogicalExprNode,
- schema: &Schema,
-) -> Result<Arc<dyn PhysicalExpr>, BallistaError> {
- let df_planner = DefaultPhysicalPlanner::default();
- let state = ExecutionContextState::new();
- let expr: Expr = expr.try_into()?;
- df_planner
- .create_physical_expr(&expr, schema, &state)
- .map_err(|e| BallistaError::General(format!("{:?}", e)))
+impl From<&protobuf::PhysicalColumn> for Column {
+ fn from(c: &protobuf::PhysicalColumn) -> Column {
+ Column::new(&c.name, c.index as usize)
+ }
+}
+
+impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction {
+ fn from(f: &protobuf::ScalarFunction) -> BuiltinScalarFunction {
+ use protobuf::ScalarFunction;
+ match f {
+ ScalarFunction::Sqrt => BuiltinScalarFunction::Sqrt,
+ ScalarFunction::Sin => BuiltinScalarFunction::Sin,
+ ScalarFunction::Cos => BuiltinScalarFunction::Cos,
+ ScalarFunction::Tan => BuiltinScalarFunction::Tan,
+ ScalarFunction::Asin => BuiltinScalarFunction::Asin,
+ ScalarFunction::Acos => BuiltinScalarFunction::Acos,
+ ScalarFunction::Atan => BuiltinScalarFunction::Atan,
+ ScalarFunction::Exp => BuiltinScalarFunction::Exp,
+ ScalarFunction::Log => BuiltinScalarFunction::Log,
+ ScalarFunction::Log2 => BuiltinScalarFunction::Log2,
+ ScalarFunction::Log10 => BuiltinScalarFunction::Log10,
+ ScalarFunction::Floor => BuiltinScalarFunction::Floor,
+ ScalarFunction::Ceil => BuiltinScalarFunction::Ceil,
+ ScalarFunction::Round => BuiltinScalarFunction::Round,
+ ScalarFunction::Trunc => BuiltinScalarFunction::Trunc,
+ ScalarFunction::Abs => BuiltinScalarFunction::Abs,
+ ScalarFunction::Signum => BuiltinScalarFunction::Signum,
+ ScalarFunction::Octetlength => BuiltinScalarFunction::OctetLength,
+ ScalarFunction::Concat => BuiltinScalarFunction::Concat,
+ ScalarFunction::Lower => BuiltinScalarFunction::Lower,
+ ScalarFunction::Upper => BuiltinScalarFunction::Upper,
+ ScalarFunction::Trim => BuiltinScalarFunction::Trim,
+ ScalarFunction::Ltrim => BuiltinScalarFunction::Ltrim,
+ ScalarFunction::Rtrim => BuiltinScalarFunction::Rtrim,
+ ScalarFunction::Totimestamp => BuiltinScalarFunction::ToTimestamp,
+ ScalarFunction::Array => BuiltinScalarFunction::Array,
+ ScalarFunction::Nullif => BuiltinScalarFunction::NullIf,
+ ScalarFunction::Datetrunc => BuiltinScalarFunction::DateTrunc,
+ ScalarFunction::Md5 => BuiltinScalarFunction::MD5,
+ ScalarFunction::Sha224 => BuiltinScalarFunction::SHA224,
+ ScalarFunction::Sha256 => BuiltinScalarFunction::SHA256,
+ ScalarFunction::Sha384 => BuiltinScalarFunction::SHA384,
+ ScalarFunction::Sha512 => BuiltinScalarFunction::SHA512,
+ ScalarFunction::Ln => BuiltinScalarFunction::Ln,
+ }
+ }
+}
+
+impl TryFrom<&protobuf::PhysicalExprNode> for Arc<dyn PhysicalExpr> {
+ type Error = BallistaError;
+
+ fn try_from(expr: &protobuf::PhysicalExprNode) -> Result<Self, Self::Error> {
+ let expr_type = expr
+ .expr_type
+ .as_ref()
+ .ok_or_else(|| proto_error("Unexpected empty physical expression"))?;
+
+ let pexpr: Arc<dyn PhysicalExpr> = match expr_type {
+ ExprType::Column(c) => {
+ let pcol: Column = c.into();
+ Arc::new(pcol)
+ }
+ ExprType::Literal(scalar) => {
+ Arc::new(Literal::new(convert_required!(scalar.value)?))
+ }
+ ExprType::BinaryExpr(binary_expr) => Arc::new(BinaryExpr::new(
+ convert_box_required!(&binary_expr.l)?,
+ from_proto_binary_op(&binary_expr.op)?,
+ convert_box_required!(&binary_expr.r)?,
+ )),
+ ExprType::AggregateExpr(_) => {
+ return Err(BallistaError::General(
+ "Cannot convert aggregate expr node to physical expression"
+ .to_owned(),
+ ));
+ }
+ ExprType::WindowExpr(_) => {
+ return Err(BallistaError::General(
+ "Cannot convert window expr node to physical expression".to_owned(),
+ ));
+ }
+ ExprType::Sort(_) => {
+ return Err(BallistaError::General(
+ "Cannot convert sort expr node to physical expression".to_owned(),
+ ));
+ }
+ ExprType::IsNullExpr(e) => {
+ Arc::new(IsNullExpr::new(convert_box_required!(e.expr)?))
+ }
+ ExprType::IsNotNullExpr(e) => {
+ Arc::new(IsNotNullExpr::new(convert_box_required!(e.expr)?))
+ }
+ ExprType::NotExpr(e) => {
+ Arc::new(NotExpr::new(convert_box_required!(e.expr)?))
+ }
+ ExprType::Negative(e) => {
+ Arc::new(NegativeExpr::new(convert_box_required!(e.expr)?))
+ }
+ ExprType::InList(e) => Arc::new(InListExpr::new(
+ convert_box_required!(e.expr)?,
+ e.list
+ .iter()
+ .map(|x| x.try_into())
+ .collect::<Result<Vec<_>, _>>()?,
+ e.negated,
+ )),
+ ExprType::Case(e) => Arc::new(CaseExpr::try_new(
+ e.expr.as_ref().map(|e| e.as_ref().try_into()).transpose()?,
+ e.when_then_expr
+ .iter()
+ .map(|e| {
+ Ok((
+ convert_required!(e.when_expr)?,
+ convert_required!(e.then_expr)?,
+ ))
+ })
+ .collect::<Result<Vec<_>, BallistaError>>()?
+ .as_slice(),
+ e.else_expr
+ .as_ref()
+ .map(|e| e.as_ref().try_into())
+ .transpose()?,
+ )?),
+ ExprType::Cast(e) => Arc::new(CastExpr::new(
+ convert_box_required!(e.expr)?,
+ convert_required!(e.arrow_type)?,
+ DEFAULT_DATAFUSION_CAST_OPTIONS,
+ )),
+ ExprType::TryCast(e) => Arc::new(TryCastExpr::new(
+ convert_box_required!(e.expr)?,
+ convert_required!(e.arrow_type)?,
+ )),
+ ExprType::ScalarFunction(e) => {
+ let scalar_function = protobuf::ScalarFunction::from_i32(e.fun)
+ .ok_or_else(|| {
+ proto_error(format!(
+ "Received an unknown scalar function: {}",
+ e.fun,
+ ))
+ })?;
+
+ let args = e
+ .args
+ .iter()
+ .map(|x| x.try_into())
+ .collect::<Result<Vec<_>, _>>()?;
+
+ let catalog_list =
+ Arc::new(MemoryCatalogList::new()) as Arc<dyn CatalogList>;
+ let ctx_state = ExecutionContextState {
+ catalog_list,
+ scalar_functions: Default::default(),
+ var_provider: Default::default(),
+ aggregate_functions: Default::default(),
+ config: ExecutionConfig::new(),
+ execution_props: ExecutionProps::new(),
+ };
+
+ let fun_expr = functions::create_physical_fun(
+ &(&scalar_function).into(),
+ &ctx_state,
+ )?;
+
+ Arc::new(ScalarFunctionExpr::new(
+ &e.name,
+ fun_expr,
+ args,
+ &convert_required!(e.return_type)?,
+ ))
+ }
+ };
+
+ Ok(pexpr)
+ }
+}
+
+impl TryFrom<&protobuf::physical_window_expr_node::WindowFunction> for WindowFunction {
+ type Error = BallistaError;
+
+ fn try_from(
+ expr: &protobuf::physical_window_expr_node::WindowFunction,
+ ) -> Result<Self, Self::Error> {
+ match expr {
+ protobuf::physical_window_expr_node::WindowFunction::AggrFunction(n) => {
+ let f = protobuf::AggregateFunction::from_i32(*n).ok_or_else(|| {
+ proto_error(format!(
+ "Received an unknown window aggregate function: {}",
+ n
+ ))
+ })?;
+
+ Ok(WindowFunction::AggregateFunction(f.into()))
+ }
+ protobuf::physical_window_expr_node::WindowFunction::BuiltInFunction(n) => {
+ let f =
+ protobuf::BuiltInWindowFunction::from_i32(*n).ok_or_else(|| {
+ proto_error(format!(
+ "Received an unknown window builtin function: {}",
+ n
+ ))
+ })?;
+
+ Ok(WindowFunction::BuiltInWindowFunction(f.into()))
+ }
+ }
+ }
}
diff --git a/ballista/rust/core/src/serde/physical_plan/mod.rs b/ballista/rust/core/src/serde/physical_plan/mod.rs
index fdba215..c0fe81f 100644
--- a/ballista/rust/core/src/serde/physical_plan/mod.rs
+++ b/ballista/rust/core/src/serde/physical_plan/mod.rs
@@ -30,7 +30,7 @@ mod roundtrip_tests {
logical_plan::Operator,
physical_plan::{
empty::EmptyExec,
- expressions::{binary, lit, InListExpr, NotExpr},
+ expressions::{binary, col, lit, InListExpr, NotExpr},
expressions::{Avg, Column, PhysicalSortExpr},
filter::FilterExec,
hash_aggregate::{AggregateMode, HashAggregateExec},
@@ -83,35 +83,35 @@ mod roundtrip_tests {
let field_a = Field::new("col", DataType::Int64, false);
let schema_left = Schema::new(vec![field_a.clone()]);
let schema_right = Schema::new(vec![field_a]);
+ let on = vec![(
+ Column::new("col", schema_left.index_of("col")?),
+ Column::new("col", schema_right.index_of("col")?),
+ )];
roundtrip_test(Arc::new(HashJoinExec::try_new(
Arc::new(EmptyExec::new(false, Arc::new(schema_left))),
Arc::new(EmptyExec::new(false, Arc::new(schema_right))),
- &[("col".to_string(), "col".to_string())],
+ on,
&JoinType::Inner,
PartitionMode::CollectLeft,
)?))
}
- fn col(name: &str) -> Arc<dyn PhysicalExpr> {
- Arc::new(Column::new(name))
- }
-
#[test]
fn rountrip_hash_aggregate() -> Result<()> {
+ let field_a = Field::new("a", DataType::Int64, false);
+ let field_b = Field::new("b", DataType::Int64, false);
+ let schema = Arc::new(Schema::new(vec![field_a, field_b]));
+
let groups: Vec<(Arc<dyn PhysicalExpr>, String)> =
- vec![(col("a"), "unused".to_string())];
+ vec![(col("a", &schema)?, "unused".to_string())];
let aggregates: Vec<Arc<dyn AggregateExpr>> = vec![Arc::new(Avg::new(
- col("b"),
+ col("b", &schema)?,
"AVG(b)".to_string(),
DataType::Float64,
))];
- let field_a = Field::new("a", DataType::Int64, false);
- let field_b = Field::new("b", DataType::Int64, false);
- let schema = Arc::new(Schema::new(vec![field_a, field_b]));
-
roundtrip_test(Arc::new(HashAggregateExec::try_new(
AggregateMode::Final,
groups.clone(),
@@ -127,9 +127,9 @@ mod roundtrip_tests {
let field_b = Field::new("b", DataType::Int64, false);
let field_c = Field::new("c", DataType::Int64, false);
let schema = Arc::new(Schema::new(vec![field_a, field_b, field_c]));
- let not = Arc::new(NotExpr::new(col("a")));
+ let not = Arc::new(NotExpr::new(col("a", &schema)?));
let in_list = Arc::new(InListExpr::new(
- col("b"),
+ col("b", &schema)?,
vec![
lit(ScalarValue::Int64(Some(1))),
lit(ScalarValue::Int64(Some(2))),
@@ -150,14 +150,14 @@ mod roundtrip_tests {
let schema = Arc::new(Schema::new(vec![field_a, field_b]));
let sort_exprs = vec![
PhysicalSortExpr {
- expr: col("a"),
+ expr: col("a", &schema)?,
options: SortOptions {
descending: true,
nulls_first: false,
},
},
PhysicalSortExpr {
- expr: col("b"),
+ expr: col("b", &schema)?,
options: SortOptions {
descending: false,
nulls_first: true,
diff --git a/ballista/rust/core/src/serde/physical_plan/to_proto.rs b/ballista/rust/core/src/serde/physical_plan/to_proto.rs
index 15d5d4b..cf5401b 100644
--- a/ballista/rust/core/src/serde/physical_plan/to_proto.rs
+++ b/ballista/rust/core/src/serde/physical_plan/to_proto.rs
@@ -125,8 +125,14 @@ impl TryInto<protobuf::PhysicalPlanNode> for Arc<dyn ExecutionPlan> {
.on()
.iter()
.map(|tuple| protobuf::JoinOn {
- left: tuple.0.to_owned(),
- right: tuple.1.to_owned(),
+ left: Some(protobuf::PhysicalColumn {
+ name: tuple.0.name().to_string(),
+ index: tuple.0.index() as u32,
+ }),
+ right: Some(protobuf::PhysicalColumn {
+ name: tuple.1.name().to_string(),
+ index: tuple.1.index() as u32,
+ }),
})
.collect();
let join_type = match exec.join_type() {
@@ -300,7 +306,7 @@ impl TryInto<protobuf::PhysicalPlanNode> for Arc<dyn ExecutionPlan> {
let pb_partition_method = match exec.partitioning() {
Partitioning::Hash(exprs, partition_count) => {
- PartitionMethod::Hash(protobuf::HashRepartition {
+ PartitionMethod::Hash(protobuf::PhysicalHashRepartition {
hash_expr: exprs
.iter()
.map(|expr| expr.clone().try_into())
@@ -330,13 +336,13 @@ impl TryInto<protobuf::PhysicalPlanNode> for Arc<dyn ExecutionPlan> {
.expr()
.iter()
.map(|expr| {
- let sort_expr = Box::new(protobuf::SortExprNode {
+ let sort_expr = Box::new(protobuf::PhysicalSortExprNode {
expr: Some(Box::new(expr.expr.to_owned().try_into()?)),
asc: !expr.options.descending,
nulls_first: expr.options.nulls_first,
});
- Ok(protobuf::LogicalExprNode {
- expr_type: Some(protobuf::logical_expr_node::ExprType::Sort(
+ Ok(protobuf::PhysicalExprNode {
+ expr_type: Some(protobuf::physical_expr_node::ExprType::Sort(
sort_expr,
)),
})
@@ -373,10 +379,10 @@ impl TryInto<protobuf::PhysicalPlanNode> for Arc<dyn ExecutionPlan> {
}
}
-impl TryInto<protobuf::LogicalExprNode> for Arc<dyn AggregateExpr> {
+impl TryInto<protobuf::PhysicalExprNode> for Arc<dyn AggregateExpr> {
type Error = BallistaError;
- fn try_into(self) -> Result<protobuf::LogicalExprNode, Self::Error> {
+ fn try_into(self) -> Result<protobuf::PhysicalExprNode, Self::Error> {
let aggr_function = if self.as_any().downcast_ref::<Avg>().is_some() {
Ok(protobuf::AggregateFunction::Avg.into())
} else if self.as_any().downcast_ref::<Sum>().is_some() {
@@ -389,14 +395,14 @@ impl TryInto<protobuf::LogicalExprNode> for Arc<dyn AggregateExpr> {
self
)))
}?;
- let expressions: Vec<protobuf::LogicalExprNode> = self
+ let expressions: Vec<protobuf::PhysicalExprNode> = self
.expressions()
.iter()
.map(|e| e.clone().try_into())
.collect::<Result<Vec<_>, BallistaError>>()?;
- Ok(protobuf::LogicalExprNode {
- expr_type: Some(protobuf::logical_expr_node::ExprType::AggregateExpr(
- Box::new(protobuf::AggregateExprNode {
+ Ok(protobuf::PhysicalExprNode {
+ expr_type: Some(protobuf::physical_expr_node::ExprType::AggregateExpr(
+ Box::new(protobuf::PhysicalAggregateExprNode {
aggr_function,
expr: Some(Box::new(expressions[0].clone())),
}),
@@ -405,90 +411,100 @@ impl TryInto<protobuf::LogicalExprNode> for Arc<dyn AggregateExpr> {
}
}
-impl TryFrom<Arc<dyn PhysicalExpr>> for protobuf::LogicalExprNode {
+impl TryFrom<Arc<dyn PhysicalExpr>> for protobuf::PhysicalExprNode {
type Error = BallistaError;
fn try_from(value: Arc<dyn PhysicalExpr>) -> Result<Self, Self::Error> {
let expr = value.as_any();
if let Some(expr) = expr.downcast_ref::<Column>() {
- Ok(protobuf::LogicalExprNode {
- expr_type: Some(protobuf::logical_expr_node::ExprType::ColumnName(
- expr.name().to_owned(),
+ Ok(protobuf::PhysicalExprNode {
+ expr_type: Some(protobuf::physical_expr_node::ExprType::Column(
+ protobuf::PhysicalColumn {
+ name: expr.name().to_string(),
+ index: expr.index() as u32,
+ },
)),
})
} else if let Some(expr) = expr.downcast_ref::<BinaryExpr>() {
- let binary_expr = Box::new(protobuf::BinaryExprNode {
+ let binary_expr = Box::new(protobuf::PhysicalBinaryExprNode {
l: Some(Box::new(expr.left().to_owned().try_into()?)),
r: Some(Box::new(expr.right().to_owned().try_into()?)),
op: format!("{:?}", expr.op()),
});
- Ok(protobuf::LogicalExprNode {
- expr_type: Some(protobuf::logical_expr_node::ExprType::BinaryExpr(
+ Ok(protobuf::PhysicalExprNode {
+ expr_type: Some(protobuf::physical_expr_node::ExprType::BinaryExpr(
binary_expr,
)),
})
} else if let Some(expr) = expr.downcast_ref::<CaseExpr>() {
- Ok(protobuf::LogicalExprNode {
- expr_type: Some(protobuf::logical_expr_node::ExprType::Case(Box::new(
- protobuf::CaseNode {
- expr: expr
- .expr()
- .as_ref()
- .map(|exp| exp.clone().try_into().map(Box::new))
- .transpose()?,
- when_then_expr: expr
- .when_then_expr()
- .iter()
- .map(|(when_expr, then_expr)| {
- try_parse_when_then_expr(when_expr, then_expr)
- })
- .collect::<Result<Vec<protobuf::WhenThen>, Self::Error>>()?,
- else_expr: expr
- .else_expr()
- .map(|a| a.clone().try_into().map(Box::new))
- .transpose()?,
- },
- ))),
+ Ok(protobuf::PhysicalExprNode {
+ expr_type: Some(
+ protobuf::physical_expr_node::ExprType::Case(
+ Box::new(
+ protobuf::PhysicalCaseNode {
+ expr: expr
+ .expr()
+ .as_ref()
+ .map(|exp| exp.clone().try_into().map(Box::new))
+ .transpose()?,
+ when_then_expr: expr
+ .when_then_expr()
+ .iter()
+ .map(|(when_expr, then_expr)| {
+ try_parse_when_then_expr(when_expr, then_expr)
+ })
+ .collect::<Result<
+ Vec<protobuf::PhysicalWhenThen>,
+ Self::Error,
+ >>()?,
+ else_expr: expr
+ .else_expr()
+ .map(|a| a.clone().try_into().map(Box::new))
+ .transpose()?,
+ },
+ ),
+ ),
+ ),
})
} else if let Some(expr) = expr.downcast_ref::<NotExpr>() {
- Ok(protobuf::LogicalExprNode {
- expr_type: Some(protobuf::logical_expr_node::ExprType::NotExpr(
- Box::new(protobuf::Not {
+ Ok(protobuf::PhysicalExprNode {
+ expr_type: Some(protobuf::physical_expr_node::ExprType::NotExpr(
+ Box::new(protobuf::PhysicalNot {
expr: Some(Box::new(expr.arg().to_owned().try_into()?)),
}),
)),
})
} else if let Some(expr) = expr.downcast_ref::<IsNullExpr>() {
- Ok(protobuf::LogicalExprNode {
- expr_type: Some(protobuf::logical_expr_node::ExprType::IsNullExpr(
- Box::new(protobuf::IsNull {
+ Ok(protobuf::PhysicalExprNode {
+ expr_type: Some(protobuf::physical_expr_node::ExprType::IsNullExpr(
+ Box::new(protobuf::PhysicalIsNull {
expr: Some(Box::new(expr.arg().to_owned().try_into()?)),
}),
)),
})
} else if let Some(expr) = expr.downcast_ref::<IsNotNullExpr>() {
- Ok(protobuf::LogicalExprNode {
- expr_type: Some(protobuf::logical_expr_node::ExprType::IsNotNullExpr(
- Box::new(protobuf::IsNotNull {
+ Ok(protobuf::PhysicalExprNode {
+ expr_type: Some(protobuf::physical_expr_node::ExprType::IsNotNullExpr(
+ Box::new(protobuf::PhysicalIsNotNull {
expr: Some(Box::new(expr.arg().to_owned().try_into()?)),
}),
)),
})
} else if let Some(expr) = expr.downcast_ref::<InListExpr>() {
- Ok(protobuf::LogicalExprNode {
+ Ok(protobuf::PhysicalExprNode {
expr_type: Some(
- protobuf::logical_expr_node::ExprType::InList(
+ protobuf::physical_expr_node::ExprType::InList(
Box::new(
- protobuf::InListNode {
+ protobuf::PhysicalInListNode {
expr: Some(Box::new(expr.expr().to_owned().try_into()?)),
list: expr
.list()
.iter()
.map(|a| a.clone().try_into())
.collect::<Result<
- Vec<protobuf::LogicalExprNode>,
+ Vec<protobuf::PhysicalExprNode>,
Self::Error,
>>()?,
negated: expr.negated(),
@@ -498,32 +514,32 @@ impl TryFrom<Arc<dyn PhysicalExpr>> for protobuf::LogicalExprNode {
),
})
} else if let Some(expr) = expr.downcast_ref::<NegativeExpr>() {
- Ok(protobuf::LogicalExprNode {
- expr_type: Some(protobuf::logical_expr_node::ExprType::Negative(
- Box::new(protobuf::NegativeNode {
+ Ok(protobuf::PhysicalExprNode {
+ expr_type: Some(protobuf::physical_expr_node::ExprType::Negative(
+ Box::new(protobuf::PhysicalNegativeNode {
expr: Some(Box::new(expr.arg().to_owned().try_into()?)),
}),
)),
})
} else if let Some(lit) = expr.downcast_ref::<Literal>() {
- Ok(protobuf::LogicalExprNode {
- expr_type: Some(protobuf::logical_expr_node::ExprType::Literal(
+ Ok(protobuf::PhysicalExprNode {
+ expr_type: Some(protobuf::physical_expr_node::ExprType::Literal(
lit.value().try_into()?,
)),
})
} else if let Some(cast) = expr.downcast_ref::<CastExpr>() {
- Ok(protobuf::LogicalExprNode {
- expr_type: Some(protobuf::logical_expr_node::ExprType::Cast(Box::new(
- protobuf::CastNode {
+ Ok(protobuf::PhysicalExprNode {
+ expr_type: Some(protobuf::physical_expr_node::ExprType::Cast(Box::new(
+ protobuf::PhysicalCastNode {
expr: Some(Box::new(cast.expr().clone().try_into()?)),
arrow_type: Some(cast.cast_type().into()),
},
))),
})
} else if let Some(cast) = expr.downcast_ref::<TryCastExpr>() {
- Ok(protobuf::LogicalExprNode {
- expr_type: Some(protobuf::logical_expr_node::ExprType::TryCast(
- Box::new(protobuf::TryCastNode {
+ Ok(protobuf::PhysicalExprNode {
+ expr_type: Some(protobuf::physical_expr_node::ExprType::TryCast(
+ Box::new(protobuf::PhysicalTryCastNode {
expr: Some(Box::new(cast.expr().clone().try_into()?)),
arrow_type: Some(cast.cast_type().into()),
}),
@@ -533,16 +549,18 @@ impl TryFrom<Arc<dyn PhysicalExpr>> for protobuf::LogicalExprNode {
let fun: BuiltinScalarFunction =
BuiltinScalarFunction::from_str(expr.name())?;
let fun: protobuf::ScalarFunction = (&fun).try_into()?;
- let expr: Vec<protobuf::LogicalExprNode> = expr
+ let args: Vec<protobuf::PhysicalExprNode> = expr
.args()
.iter()
.map(|e| e.to_owned().try_into())
.collect::<Result<Vec<_>, _>>()?;
- Ok(protobuf::LogicalExprNode {
- expr_type: Some(protobuf::logical_expr_node::ExprType::ScalarFunction(
- protobuf::ScalarFunctionNode {
+ Ok(protobuf::PhysicalExprNode {
+ expr_type: Some(protobuf::physical_expr_node::ExprType::ScalarFunction(
+ protobuf::PhysicalScalarFunctionNode {
+ name: expr.name().to_string(),
fun: fun.into(),
- expr,
+ args,
+ return_type: Some(expr.return_type().into()),
},
)),
})
@@ -558,8 +576,8 @@ impl TryFrom<Arc<dyn PhysicalExpr>> for protobuf::LogicalExprNode {
fn try_parse_when_then_expr(
when_expr: &Arc<dyn PhysicalExpr>,
then_expr: &Arc<dyn PhysicalExpr>,
-) -> Result<protobuf::WhenThen, BallistaError> {
- Ok(protobuf::WhenThen {
+) -> Result<protobuf::PhysicalWhenThen, BallistaError> {
+ Ok(protobuf::PhysicalWhenThen {
when_expr: Some(when_expr.clone().try_into()?),
then_expr: Some(then_expr.clone().try_into()?),
})
diff --git a/benchmarks/run.sh b/benchmarks/run.sh
index 8e36424..21633d3 100755
--- a/benchmarks/run.sh
+++ b/benchmarks/run.sh
@@ -20,7 +20,7 @@ set -e
# This bash script is meant to be run inside the docker-compose environment. Check the README for instructions
cd /
-for query in 1 3 5 6 10 12
+for query in 1 3 5 6 7 8 9 10 12
do
/tpch benchmark ballista --host ballista-scheduler --port 50050 --query $query --path /data --format tbl --iterations 1 --debug
done
diff --git a/benchmarks/src/bin/tpch.rs b/benchmarks/src/bin/tpch.rs
index 08c4763..286fe45 100644
--- a/benchmarks/src/bin/tpch.rs
+++ b/benchmarks/src/bin/tpch.rs
@@ -708,6 +708,16 @@ mod tests {
}
#[tokio::test]
+ async fn run_q7() -> Result<()> {
+ run_query(7).await
+ }
+
+ #[tokio::test]
+ async fn run_q8() -> Result<()> {
+ run_query(8).await
+ }
+
+ #[tokio::test]
async fn run_q9() -> Result<()> {
run_query(9).await
}
diff --git a/datafusion/src/dataframe.rs b/datafusion/src/dataframe.rs
index 9c7c2ef..507a798 100644
--- a/datafusion/src/dataframe.rs
+++ b/datafusion/src/dataframe.rs
@@ -188,6 +188,8 @@ pub trait DataFrame: Send + Sync {
right_cols: &[&str],
) -> Result<Arc<dyn DataFrame>>;
+ // TODO: add join_using
+
/// Repartition a DataFrame based on a logical partitioning scheme.
///
/// ```
diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs
index b42695b..926e2db 100644
--- a/datafusion/src/execution/context.rs
+++ b/datafusion/src/execution/context.rs
@@ -52,7 +52,7 @@ use crate::datasource::TableProvider;
use crate::error::{DataFusionError, Result};
use crate::execution::dataframe_impl::DataFrameImpl;
use crate::logical_plan::{
- FunctionRegistry, LogicalPlan, LogicalPlanBuilder, ToDFSchema,
+ FunctionRegistry, LogicalPlan, LogicalPlanBuilder, UNNAMED_TABLE,
};
use crate::optimizer::constant_folding::ConstantFolding;
use crate::optimizer::filter_push_down::FilterPushDown;
@@ -297,18 +297,9 @@ impl ExecutionContext {
&mut self,
provider: Arc<dyn TableProvider>,
) -> Result<Arc<dyn DataFrame>> {
- let schema = provider.schema();
- let table_scan = LogicalPlan::TableScan {
- table_name: "".to_string(),
- source: provider,
- projected_schema: schema.to_dfschema_ref()?,
- projection: None,
- filters: vec![],
- limit: None,
- };
Ok(Arc::new(DataFrameImpl::new(
self.state.clone(),
- &LogicalPlanBuilder::from(&table_scan).build()?,
+ &LogicalPlanBuilder::scan(UNNAMED_TABLE, provider, None)?.build()?,
)))
}
@@ -410,22 +401,15 @@ impl ExecutionContext {
) -> Result<Arc<dyn DataFrame>> {
let table_ref = table_ref.into();
let schema = self.state.lock().unwrap().schema_for_ref(table_ref)?;
-
match schema.table(table_ref.table()) {
Some(ref provider) => {
- let schema = provider.schema();
- let table_scan = LogicalPlan::TableScan {
- table_name: table_ref.table().to_owned(),
- source: Arc::clone(provider),
- projected_schema: schema.to_dfschema_ref()?,
- projection: None,
- filters: vec![],
- limit: None,
- };
- Ok(Arc::new(DataFrameImpl::new(
- self.state.clone(),
- &LogicalPlanBuilder::from(&table_scan).build()?,
- )))
+ let plan = LogicalPlanBuilder::scan(
+ table_ref.table(),
+ Arc::clone(provider),
+ None,
+ )?
+ .build()?;
+ Ok(Arc::new(DataFrameImpl::new(self.state.clone(), &plan)))
}
_ => Err(DataFusionError::Plan(format!(
"No table named '{}'",
@@ -1038,7 +1022,6 @@ mod tests {
let logical_plan = ctx.optimize(&logical_plan)?;
let physical_plan = ctx.create_physical_plan(&logical_plan)?;
- println!("{:?}", physical_plan);
let results = collect_partitioned(physical_plan).await?;
@@ -1110,7 +1093,7 @@ mod tests {
_ => panic!("expect optimized_plan to be projection"),
}
- let expected = "Projection: #c2\
+ let expected = "Projection: #test.c2\
\n TableScan: test projection=Some([1])";
assert_eq!(format!("{:?}", optimized_plan), expected);
@@ -1133,7 +1116,7 @@ mod tests {
let schema: Schema = ctx.table("test").unwrap().schema().clone().into();
assert!(!schema.field_with_name("c1")?.is_nullable());
- let plan = LogicalPlanBuilder::scan_empty("", &schema, None)?
+ let plan = LogicalPlanBuilder::scan_empty(None, &schema, None)?
.project(vec![col("c1")])?
.build()?;
@@ -1183,8 +1166,11 @@ mod tests {
_ => panic!("expect optimized_plan to be projection"),
}
- let expected = "Projection: #b\
- \n TableScan: projection=Some([1])";
+ let expected = format!(
+ "Projection: #{}.b\
+ \n TableScan: {} projection=Some([1])",
+ UNNAMED_TABLE, UNNAMED_TABLE
+ );
assert_eq!(format!("{:?}", optimized_plan), expected);
let physical_plan = ctx.create_physical_plan(&optimized_plan)?;
@@ -2138,9 +2124,9 @@ mod tests {
Field::new("c2", DataType::UInt32, false),
]));
- let plan = LogicalPlanBuilder::scan_empty("", schema.as_ref(), None)?
+ let plan = LogicalPlanBuilder::scan_empty(None, schema.as_ref(), None)?
.aggregate(vec![col("c1")], vec![sum(col("c2"))])?
- .project(vec![col("c1"), col("SUM(c2)").alias("total_salary")])?
+ .project(vec![col("c1"), sum(col("c2")).alias("total_salary")])?
.build()?;
let plan = ctx.optimize(&plan)?;
@@ -2590,7 +2576,7 @@ mod tests {
assert_eq!(
format!("{:?}", plan),
- "Projection: #a, #b, my_add(#a, #b)\n TableScan: t projection=None"
+ "Projection: #t.a, #t.b, my_add(#t.a, #t.b)\n TableScan: t projection=None"
);
let plan = ctx.optimize(&plan)?;
diff --git a/datafusion/src/execution/dataframe_impl.rs b/datafusion/src/execution/dataframe_impl.rs
index a674e3c..99eb7f0 100644
--- a/datafusion/src/execution/dataframe_impl.rs
+++ b/datafusion/src/execution/dataframe_impl.rs
@@ -110,7 +110,12 @@ impl DataFrame for DataFrameImpl {
right_cols: &[&str],
) -> Result<Arc<dyn DataFrame>> {
let plan = LogicalPlanBuilder::from(&self.plan)
- .join(&right.to_logical_plan(), join_type, left_cols, right_cols)?
+ .join(
+ &right.to_logical_plan(),
+ join_type,
+ left_cols.to_vec(),
+ right_cols.to_vec(),
+ )?
.build()?;
Ok(Arc::new(DataFrameImpl::new(self.ctx_state.clone(), &plan)))
}
diff --git a/datafusion/src/logical_plan/builder.rs b/datafusion/src/logical_plan/builder.rs
index 6bd5181..4b4ed0f 100644
--- a/datafusion/src/logical_plan/builder.rs
+++ b/datafusion/src/logical_plan/builder.rs
@@ -24,19 +24,27 @@ use arrow::{
record_batch::RecordBatch,
};
-use super::dfschema::ToDFSchema;
-use super::{
- col, exprlist_to_fields, Expr, JoinType, LogicalPlan, PlanType, StringifiedPlan,
-};
use crate::datasource::TableProvider;
use crate::error::{DataFusionError, Result};
-use crate::logical_plan::{DFField, DFSchema, DFSchemaRef, Partitioning};
use crate::{
datasource::{empty::EmptyTable, parquet::ParquetTable, CsvFile, MemTable},
prelude::CsvReadOptions,
};
+
+use super::dfschema::ToDFSchema;
+use super::{
+ exprlist_to_fields, Expr, JoinConstraint, JoinType, LogicalPlan, PlanType,
+ StringifiedPlan,
+};
+use crate::logical_plan::{
+ columnize_expr, normalize_col, normalize_cols, Column, DFField, DFSchema,
+ DFSchemaRef, Partitioning,
+};
use std::collections::HashSet;
+/// Default table name for unnamed table
+pub const UNNAMED_TABLE: &str = "?table?";
+
/// Builder for logical plans
///
/// ```
@@ -62,7 +70,7 @@ use std::collections::HashSet;
/// // FROM employees
/// // WHERE salary < 1000
/// let plan = LogicalPlanBuilder::scan_empty(
-/// "employee.csv",
+/// Some("employee"),
/// &employee_schema(),
/// None,
/// )?
@@ -102,7 +110,7 @@ impl LogicalPlanBuilder {
projection: Option<Vec<usize>>,
) -> Result<Self> {
let provider = Arc::new(MemTable::try_new(schema, partitions)?);
- Self::scan("", provider, projection)
+ Self::scan(UNNAMED_TABLE, provider, projection)
}
/// Scan a CSV data source
@@ -112,7 +120,7 @@ impl LogicalPlanBuilder {
projection: Option<Vec<usize>>,
) -> Result<Self> {
let provider = Arc::new(CsvFile::try_new(path, options)?);
- Self::scan("", provider, projection)
+ Self::scan(path, provider, projection)
}
/// Scan a Parquet data source
@@ -122,38 +130,53 @@ impl LogicalPlanBuilder {
max_concurrency: usize,
) -> Result<Self> {
let provider = Arc::new(ParquetTable::try_new(path, max_concurrency)?);
- Self::scan("", provider, projection)
+ Self::scan(path, provider, projection)
}
/// Scan an empty data source, mainly used in tests
pub fn scan_empty(
- name: &str,
+ name: Option<&str>,
table_schema: &Schema,
projection: Option<Vec<usize>>,
) -> Result<Self> {
let table_schema = Arc::new(table_schema.clone());
let provider = Arc::new(EmptyTable::new(table_schema));
- Self::scan(name, provider, projection)
+ Self::scan(name.unwrap_or(UNNAMED_TABLE), provider, projection)
}
/// Convert a table provider into a builder with a TableScan
pub fn scan(
- name: &str,
+ table_name: &str,
provider: Arc<dyn TableProvider>,
projection: Option<Vec<usize>>,
) -> Result<Self> {
+ if table_name.is_empty() {
+ return Err(DataFusionError::Plan(
+ "table_name cannot be empty".to_string(),
+ ));
+ }
+
let schema = provider.schema();
let projected_schema = projection
.as_ref()
- .map(|p| Schema::new(p.iter().map(|i| schema.field(*i).clone()).collect()))
- .map_or(schema, SchemaRef::new)
- .to_dfschema_ref()?;
+ .map(|p| {
+ DFSchema::new(
+ p.iter()
+ .map(|i| {
+ DFField::from_qualified(table_name, schema.field(*i).clone())
+ })
+ .collect(),
+ )
+ })
+ .unwrap_or_else(|| {
+ DFSchema::try_from_qualified_schema(table_name, &schema)
+ })?;
let table_scan = LogicalPlan::TableScan {
- table_name: name.to_string(),
+ table_name: table_name.to_string(),
source: provider,
- projected_schema,
+ projected_schema: Arc::new(projected_schema),
projection,
filters: vec![],
limit: None,
@@ -170,16 +193,21 @@ impl LogicalPlanBuilder {
/// * An invalid expression is used (e.g. a `sort` expression)
pub fn project(&self, expr: impl IntoIterator<Item = Expr>) -> Result<Self> {
let input_schema = self.plan.schema();
+ let all_schemas = self.plan.all_schemas();
let mut projected_expr = vec![];
for e in expr {
match e {
Expr::Wildcard => {
(0..input_schema.fields().len()).for_each(|i| {
- projected_expr.push(col(input_schema.field(i).name()))
+ projected_expr
+ .push(Expr::Column(input_schema.field(i).qualified_column()))
});
}
- _ => projected_expr.push(e),
- };
+ _ => projected_expr.push(columnize_expr(
+ normalize_col(e, &all_schemas)?,
+ input_schema,
+ )),
+ }
}
validate_unique_names("Projections", projected_expr.iter(), input_schema)?;
@@ -195,6 +223,7 @@ impl LogicalPlanBuilder {
/// Apply a filter
pub fn filter(&self, expr: Expr) -> Result<Self> {
+ let expr = normalize_col(expr, &self.plan.all_schemas())?;
Ok(Self::from(&LogicalPlan::Filter {
predicate: expr,
input: Arc::new(self.plan.clone()),
@@ -210,69 +239,103 @@ impl LogicalPlanBuilder {
}
/// Apply a sort
- pub fn sort(&self, expr: impl IntoIterator<Item = Expr>) -> Result<Self> {
+ pub fn sort(&self, exprs: impl IntoIterator<Item = Expr>) -> Result<Self> {
+ let schemas = self.plan.all_schemas();
Ok(Self::from(&LogicalPlan::Sort {
- expr: expr.into_iter().collect(),
+ expr: normalize_cols(exprs, &schemas)?,
input: Arc::new(self.plan.clone()),
}))
}
/// Apply a union
pub fn union(&self, plan: LogicalPlan) -> Result<Self> {
- let schema = self.plan.schema();
+ Ok(Self::from(&union_with_alias(
+ self.plan.clone(),
+ plan,
+ None,
+ )?))
+ }
- if plan.schema() != schema {
+ /// Apply a join with on constraint
+ pub fn join(
+ &self,
+ right: &LogicalPlan,
+ join_type: JoinType,
+ left_keys: Vec<impl Into<Column>>,
+ right_keys: Vec<impl Into<Column>>,
+ ) -> Result<Self> {
+ if left_keys.len() != right_keys.len() {
return Err(DataFusionError::Plan(
- "Schema's for union should be the same ".to_string(),
+ "left_keys and right_keys were not the same length".to_string(),
));
}
- // Add plan to existing union if possible
- let mut inputs = match &self.plan {
- LogicalPlan::Union { inputs, .. } => inputs.clone(),
- _ => vec![self.plan.clone()],
- };
- inputs.push(plan);
- Ok(Self::from(&LogicalPlan::Union {
- inputs,
- schema: schema.clone(),
- alias: None,
+ let left_keys: Vec<Column> = left_keys
+ .into_iter()
+ .map(|c| c.into().normalize(&self.plan.all_schemas()))
+ .collect::<Result<_>>()?;
+ let right_keys: Vec<Column> = right_keys
+ .into_iter()
+ .map(|c| c.into().normalize(&right.all_schemas()))
+ .collect::<Result<_>>()?;
+ let on: Vec<(_, _)> = left_keys.into_iter().zip(right_keys.into_iter()).collect();
+ let join_schema = build_join_schema(
+ self.plan.schema(),
+ right.schema(),
+ &on,
+ &join_type,
+ &JoinConstraint::On,
+ )?;
+
+ Ok(Self::from(&LogicalPlan::Join {
+ left: Arc::new(self.plan.clone()),
+ right: Arc::new(right.clone()),
+ on,
+ join_type,
+ join_constraint: JoinConstraint::On,
+ schema: DFSchemaRef::new(join_schema),
}))
}
- /// Apply a join
- pub fn join(
+ /// Apply a join with using constraint, which duplicates all join columns in output schema.
+ pub fn join_using(
&self,
right: &LogicalPlan,
join_type: JoinType,
- left_keys: &[&str],
- right_keys: &[&str],
+ using_keys: Vec<impl Into<Column> + Clone>,
) -> Result<Self> {
- if left_keys.len() != right_keys.len() {
- Err(DataFusionError::Plan(
- "left_keys and right_keys were not the same length".to_string(),
- ))
- } else {
- let on: Vec<_> = left_keys
- .iter()
- .zip(right_keys.iter())
- .map(|(x, y)| (x.to_string(), y.to_string()))
- .collect::<Vec<_>>();
- let join_schema =
- build_join_schema(self.plan.schema(), right.schema(), &on, &join_type)?;
- Ok(Self::from(&LogicalPlan::Join {
- left: Arc::new(self.plan.clone()),
- right: Arc::new(right.clone()),
- on,
- join_type,
- schema: DFSchemaRef::new(join_schema),
- }))
- }
+ let left_keys: Vec<Column> = using_keys
+ .clone()
+ .into_iter()
+ .map(|c| c.into().normalize(&self.plan.all_schemas()))
+ .collect::<Result<_>>()?;
+ let right_keys: Vec<Column> = using_keys
+ .into_iter()
+ .map(|c| c.into().normalize(&right.all_schemas()))
+ .collect::<Result<_>>()?;
+
+ let on: Vec<(_, _)> = left_keys.into_iter().zip(right_keys.into_iter()).collect();
+ let join_schema = build_join_schema(
+ self.plan.schema(),
+ right.schema(),
+ &on,
+ &join_type,
+ &JoinConstraint::Using,
+ )?;
+
+ Ok(Self::from(&LogicalPlan::Join {
+ left: Arc::new(self.plan.clone()),
+ right: Arc::new(right.clone()),
+ on,
+ join_type,
+ join_constraint: JoinConstraint::Using,
+ schema: DFSchemaRef::new(join_schema),
+ }))
}
+
/// Apply a cross join
pub fn cross_join(&self, right: &LogicalPlan) -> Result<Self> {
let schema = self.plan.schema().join(right.schema())?;
-
Ok(Self::from(&LogicalPlan::CrossJoin {
left: Arc::new(self.plan.clone()),
right: Arc::new(right.clone()),
@@ -320,9 +383,9 @@ impl LogicalPlanBuilder {
group_expr: impl IntoIterator<Item = Expr>,
aggr_expr: impl IntoIterator<Item = Expr>,
) -> Result<Self> {
- let group_expr = group_expr.into_iter().collect::<Vec<Expr>>();
- let aggr_expr = aggr_expr.into_iter().collect::<Vec<Expr>>();
-
+ let schemas = self.plan.all_schemas();
+ let group_expr = normalize_cols(group_expr, &schemas)?;
+ let aggr_expr = normalize_cols(aggr_expr, &schemas)?;
let all_expr = group_expr.iter().chain(aggr_expr.iter());
validate_unique_names("Aggregations", all_expr.clone(), self.plan.schema())?;
@@ -363,27 +426,35 @@ impl LogicalPlanBuilder {
/// Creates a schema for a join operation.
/// The fields from the left side are first
-fn build_join_schema(
+pub fn build_join_schema(
left: &DFSchema,
right: &DFSchema,
- on: &[(String, String)],
+ on: &[(Column, Column)],
join_type: &JoinType,
+ join_constraint: &JoinConstraint,
) -> Result<DFSchema> {
let fields: Vec<DFField> = match join_type {
JoinType::Inner | JoinType::Left | JoinType::Full => {
- // remove right-side join keys if they have the same names as the left-side
- let duplicate_keys = &on
- .iter()
- .filter(|(l, r)| l == r)
- .map(|on| on.1.to_string())
- .collect::<HashSet<_>>();
+ let duplicate_keys = match join_constraint {
+ JoinConstraint::On => on
+ .iter()
+ .filter(|(l, r)| l == r)
+ .map(|on| on.1.clone())
+ .collect::<HashSet<_>>(),
+ // using join requires unique join columns in the output schema, so we mark all
+ // right join keys as duplicate
+ JoinConstraint::Using => {
+ on.iter().map(|on| on.1.clone()).collect::<HashSet<_>>()
+ }
+ };
let left_fields = left.fields().iter();
+ // remove right-side join keys if they have the same names as the left-side
let right_fields = right
.fields()
.iter()
- .filter(|f| !duplicate_keys.contains(f.name()));
+ .filter(|f| !duplicate_keys.contains(&f.qualified_column()));
// left then right
left_fields.chain(right_fields).cloned().collect()
@@ -393,17 +464,24 @@ fn build_join_schema(
left.fields().clone()
}
JoinType::Right => {
- // remove left-side join keys if they have the same names as the right-side
- let duplicate_keys = &on
- .iter()
- .filter(|(l, r)| l == r)
- .map(|on| on.1.to_string())
- .collect::<HashSet<_>>();
+ let duplicate_keys = match join_constraint {
+ JoinConstraint::On => on
+ .iter()
+ .filter(|(l, r)| l == r)
+ .map(|on| on.1.clone())
+ .collect::<HashSet<_>>(),
+ // using join requires unique join columns in the output schema, so we mark all
+ // left join keys as duplicate
+ JoinConstraint::Using => {
+ on.iter().map(|on| on.0.clone()).collect::<HashSet<_>>()
+ }
+ };
+ // remove left-side join keys if they have the same names as the right-side
let left_fields = left
.fields()
.iter()
- .filter(|f| !duplicate_keys.contains(f.name()));
+ .filter(|f| !duplicate_keys.contains(&f.qualified_column()));
let right_fields = right.fields().iter();
@@ -411,6 +489,7 @@ fn build_join_schema(
left_fields.chain(right_fields).cloned().collect()
}
};
+
DFSchema::new(fields)
}
@@ -441,17 +520,56 @@ fn validate_unique_names<'a>(
})
}
+/// Union two logical plans with an optional alias.
+pub fn union_with_alias(
+ left_plan: LogicalPlan,
+ right_plan: LogicalPlan,
+ alias: Option<String>,
+) -> Result<LogicalPlan> {
+ let inputs = vec![left_plan, right_plan]
+ .into_iter()
+ .flat_map(|p| match p {
+ LogicalPlan::Union { inputs, .. } => inputs,
+ x => vec![x],
+ })
+ .collect::<Vec<_>>();
+ if inputs.is_empty() {
+ return Err(DataFusionError::Plan("Empty UNION".to_string()));
+ }
+
+ let union_schema = (**inputs[0].schema()).clone();
+ let union_schema = Arc::new(match alias {
+ Some(ref alias) => union_schema.replace_qualifier(alias.as_str()),
+ None => union_schema.strip_qualifiers(),
+ });
+ if !inputs.iter().skip(1).all(|input_plan| {
+ // union changes all qualifers in resulting schema, so we only need to
+ // match against arrow schema here, which doesn't include qualifiers
+ union_schema.matches_arrow_schema(&((**input_plan.schema()).clone().into()))
+ }) {
+ return Err(DataFusionError::Plan(
+ "UNION ALL schemas are expected to be the same".to_string(),
+ ));
+ }
+
+ Ok(LogicalPlan::Union {
+ schema: union_schema,
+ inputs,
+ alias,
+ })
+}
+
#[cfg(test)]
mod tests {
use arrow::datatypes::{DataType, Field};
- use super::super::{lit, sum};
+ use super::super::{col, lit, sum};
use super::*;
#[test]
fn plan_builder_simple() -> Result<()> {
let plan = LogicalPlanBuilder::scan_empty(
- "employee.csv",
+ Some("employee_csv"),
&employee_schema(),
Some(vec![0, 3]),
)?
@@ -459,9 +577,9 @@ mod tests {
.project(vec![col("id")])?
.build()?;
- let expected = "Projection: #id\
- \n Filter: #state Eq Utf8(\"CO\")\
- \n TableScan: employee.csv projection=Some([0, 3])";
+ let expected = "Projection: #employee_csv.id\
+ \n Filter: #employee_csv.state Eq Utf8(\"CO\")\
+ \n TableScan: employee_csv projection=Some([0, 3])";
assert_eq!(expected, format!("{:?}", plan));
@@ -471,7 +589,7 @@ mod tests {
#[test]
fn plan_builder_aggregate() -> Result<()> {
let plan = LogicalPlanBuilder::scan_empty(
- "employee.csv",
+ Some("employee_csv"),
&employee_schema(),
Some(vec![3, 4]),
)?
@@ -482,9 +600,9 @@ mod tests {
.project(vec![col("state"), col("total_salary")])?
.build()?;
- let expected = "Projection: #state, #total_salary\
- \n Aggregate: groupBy=[[#state]], aggr=[[SUM(#salary) AS total_salary]]\
- \n TableScan: employee.csv projection=Some([3, 4])";
+ let expected = "Projection: #employee_csv.state, #total_salary\
+ \n Aggregate: groupBy=[[#employee_csv.state]], aggr=[[SUM(#employee_csv.salary) AS total_salary]]\
+ \n TableScan: employee_csv projection=Some([3, 4])";
assert_eq!(expected, format!("{:?}", plan));
@@ -494,7 +612,7 @@ mod tests {
#[test]
fn plan_builder_sort() -> Result<()> {
let plan = LogicalPlanBuilder::scan_empty(
- "employee.csv",
+ Some("employee_csv"),
&employee_schema(),
Some(vec![3, 4]),
)?
@@ -505,15 +623,15 @@ mod tests {
nulls_first: true,
},
Expr::Sort {
- expr: Box::new(col("total_salary")),
+ expr: Box::new(col("salary")),
asc: false,
nulls_first: false,
},
])?
.build()?;
- let expected = "Sort: #state ASC NULLS FIRST, #total_salary DESC NULLS LAST\
- \n TableScan: employee.csv projection=Some([3, 4])";
+ let expected = "Sort: #employee_csv.state ASC NULLS FIRST, #employee_csv.salary DESC NULLS LAST\
+ \n TableScan: employee_csv projection=Some([3, 4])";
assert_eq!(expected, format!("{:?}", plan));
@@ -523,7 +641,7 @@ mod tests {
#[test]
fn plan_builder_union_combined_single_union() -> Result<()> {
let plan = LogicalPlanBuilder::scan_empty(
- "employee.csv",
+ Some("employee_csv"),
&employee_schema(),
Some(vec![3, 4]),
)?;
@@ -536,10 +654,10 @@ mod tests {
// output has only one union
let expected = "Union\
- \n TableScan: employee.csv projection=Some([3, 4])\
- \n TableScan: employee.csv projection=Some([3, 4])\
- \n TableScan: employee.csv projection=Some([3, 4])\
- \n TableScan: employee.csv projection=Some([3, 4])";
+ \n TableScan: employee_csv projection=Some([3, 4])\
+ \n TableScan: employee_csv projection=Some([3, 4])\
+ \n TableScan: employee_csv projection=Some([3, 4])\
+ \n TableScan: employee_csv projection=Some([3, 4])";
assert_eq!(expected, format!("{:?}", plan));
@@ -549,9 +667,10 @@ mod tests {
#[test]
fn projection_non_unique_names() -> Result<()> {
let plan = LogicalPlanBuilder::scan_empty(
- "employee.csv",
+ Some("employee_csv"),
&employee_schema(),
- Some(vec![0, 3]),
+ // project id and first_name by column index
+ Some(vec![0, 1]),
)?
// two columns with the same name => error
.project(vec![col("id"), col("first_name").alias("id")]);
@@ -560,9 +679,8 @@ mod tests {
Err(DataFusionError::Plan(e)) => {
assert_eq!(
e,
- "Projections require unique expression names \
- but the expression \"#id\" at position 0 and \"#first_name AS id\" at \
- position 1 have the same name. Consider aliasing (\"AS\") one of them."
+ "Schema contains qualified field name 'employee_csv.id' \
+ and unqualified field name 'id' which would be ambiguous"
);
Ok(())
}
@@ -575,9 +693,10 @@ mod tests {
#[test]
fn aggregate_non_unique_names() -> Result<()> {
let plan = LogicalPlanBuilder::scan_empty(
- "employee.csv",
+ Some("employee_csv"),
&employee_schema(),
- Some(vec![0, 3]),
+ // project state and salary by column index
+ Some(vec![3, 4]),
)?
// two columns with the same name => error
.aggregate(vec![col("state")], vec![sum(col("salary")).alias("state")]);
@@ -586,9 +705,8 @@ mod tests {
Err(DataFusionError::Plan(e)) => {
assert_eq!(
e,
- "Aggregations require unique expression names \
- but the expression \"#state\" at position 0 and \"SUM(#salary) AS state\" at \
- position 1 have the same name. Consider aliasing (\"AS\") one of them."
+ "Schema contains qualified field name 'employee_csv.state' and \
+ unqualified field name 'state' which would be ambiguous"
);
Ok(())
}
diff --git a/datafusion/src/logical_plan/dfschema.rs b/datafusion/src/logical_plan/dfschema.rs
index c5437b3..e754add 100644
--- a/datafusion/src/logical_plan/dfschema.rs
+++ b/datafusion/src/logical_plan/dfschema.rs
@@ -23,6 +23,7 @@ use std::convert::TryFrom;
use std::sync::Arc;
use crate::error::{DataFusionError, Result};
+use crate::logical_plan::Column;
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use std::fmt::{Display, Formatter};
@@ -88,7 +89,7 @@ impl DFSchema {
}
/// Create a `DFSchema` from an Arrow schema
- pub fn try_from_qualified(qualifier: &str, schema: &Schema) -> Result<Self> {
+ pub fn try_from_qualified_schema(qualifier: &str, schema: &Schema) -> Result<Self> {
Self::new(
schema
.fields()
@@ -108,6 +109,21 @@ impl DFSchema {
Self::new(fields)
}
+ /// Merge a schema into self
+ pub fn merge(&mut self, other_schema: &DFSchema) {
+ for field in other_schema.fields() {
+ // skip duplicate columns
+ let duplicated_field = match field.qualifier() {
+ Some(q) => self.field_with_name(Some(q.as_str()), field.name()).is_ok(),
+ // for unqualifed columns, check as unqualified name
+ None => self.field_with_unqualified_name(field.name()).is_ok(),
+ };
+ if !duplicated_field {
+ self.fields.push(field.clone());
+ }
+ }
+ }
+
/// Get a list of fields
pub fn fields(&self) -> &Vec<DFField> {
&self.fields
@@ -119,7 +135,7 @@ impl DFSchema {
&self.fields[i]
}
- /// Find the index of the column with the given name
+ /// Find the index of the column with the given unqualifed name
pub fn index_of(&self, name: &str) -> Result<usize> {
for i in 0..self.fields.len() {
if self.fields[i].name() == name {
@@ -129,6 +145,20 @@ impl DFSchema {
Err(DataFusionError::Plan(format!("No field named '{}'", name)))
}
+ /// Find the index of the column with the given qualifer and name
+ pub fn index_of_column(&self, col: &Column) -> Result<usize> {
+ for i in 0..self.fields.len() {
+ let field = &self.fields[i];
+ if field.qualifier() == col.relation.as_ref() && field.name() == &col.name {
+ return Ok(i);
+ }
+ }
+ Err(DataFusionError::Plan(format!(
+ "No field matches column '{}'",
+ col,
+ )))
+ }
+
/// Find the field with the given name
pub fn field_with_name(
&self,
@@ -150,7 +180,10 @@ impl DFSchema {
.filter(|field| field.name() == name)
.collect();
match matches.len() {
- 0 => Err(DataFusionError::Plan(format!("No field named '{}'", name))),
+ 0 => Err(DataFusionError::Plan(format!(
+ "No field with unqualified name '{}'",
+ name
+ ))),
1 => Ok(matches[0].to_owned()),
_ => Err(DataFusionError::Plan(format!(
"Ambiguous reference to field named '{}'",
@@ -184,6 +217,62 @@ impl DFSchema {
))),
}
}
+
+ /// Find the field with the given qualified column
+ pub fn field_from_qualified_column(&self, column: &Column) -> Result<DFField> {
+ match &column.relation {
+ Some(r) => self.field_with_qualified_name(r, &column.name),
+ None => self.field_with_unqualified_name(&column.name),
+ }
+ }
+
+ /// Check to see if unqualified field names matches field names in Arrow schema
+ pub fn matches_arrow_schema(&self, arrow_schema: &Schema) -> bool {
+ self.fields
+ .iter()
+ .zip(arrow_schema.fields().iter())
+ .all(|(dffield, arrowfield)| dffield.name() == arrowfield.name())
+ }
+
+ /// Strip all field qualifier in schema
+ pub fn strip_qualifiers(self) -> Self {
+ DFSchema {
+ fields: self
+ .fields
+ .into_iter()
+ .map(|f| {
+ if f.qualifier().is_some() {
+ DFField::new(
+ None,
+ f.name(),
+ f.data_type().to_owned(),
+ f.is_nullable(),
+ )
+ } else {
+ f
+ }
+ })
+ .collect(),
+ }
+ }
+
+ /// Replace all field qualifier with new value in schema
+ pub fn replace_qualifier(self, qualifer: &str) -> Self {
+ DFSchema {
+ fields: self
+ .fields
+ .into_iter()
+ .map(|f| {
+ DFField::new(
+ Some(qualifer),
+ f.name(),
+ f.data_type().to_owned(),
+ f.is_nullable(),
+ )
+ })
+ .collect(),
+ }
+ }
}
impl Into<Schema> for DFSchema {
@@ -195,7 +284,7 @@ impl Into<Schema> for DFSchema {
.map(|f| {
if f.qualifier().is_some() {
Field::new(
- f.qualified_name().as_str(),
+ f.name().as_str(),
f.data_type().to_owned(),
f.is_nullable(),
)
@@ -208,6 +297,13 @@ impl Into<Schema> for DFSchema {
}
}
+impl Into<Schema> for &DFSchema {
+ /// Convert a schema into a DFSchema
+ fn into(self) -> Schema {
+ Schema::new(self.fields.iter().map(|f| f.field.clone()).collect())
+ }
+}
+
/// Create a `DFSchema` from an Arrow schema
impl TryFrom<Schema> for DFSchema {
type Error = DataFusionError;
@@ -340,7 +436,7 @@ impl DFField {
self.field.is_nullable()
}
- /// Returns a reference to the `DFField`'s qualified name
+ /// Returns a string to the `DFField`'s qualified name
pub fn qualified_name(&self) -> String {
if let Some(relation_name) = &self.qualifier {
format!("{}.{}", relation_name, self.field.name())
@@ -349,10 +445,23 @@ impl DFField {
}
}
+ /// Builds a qualified column based on self
+ pub fn qualified_column(&self) -> Column {
+ Column {
+ relation: self.qualifier.clone(),
+ name: self.field.name().to_string(),
+ }
+ }
+
/// Get the optional qualifier
pub fn qualifier(&self) -> Option<&String> {
self.qualifier.as_ref()
}
+
+ /// Get the arrow field
+ pub fn field(&self) -> &Field {
+ &self.field
+ }
}
#[cfg(test)]
@@ -385,25 +494,25 @@ mod tests {
#[test]
fn from_qualified_schema() -> Result<()> {
- let schema = DFSchema::try_from_qualified("t1", &test_schema_1())?;
+ let schema = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?;
assert_eq!("t1.c0, t1.c1", schema.to_string());
Ok(())
}
#[test]
fn from_qualified_schema_into_arrow_schema() -> Result<()> {
- let schema = DFSchema::try_from_qualified("t1", &test_schema_1())?;
+ let schema = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?;
let arrow_schema: Schema = schema.into();
- let expected = "Field { name: \"t1.c0\", data_type: Boolean, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: None }, \
- Field { name: \"t1.c1\", data_type: Boolean, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: None }";
+ let expected = "Field { name: \"c0\", data_type: Boolean, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: None }, \
+ Field { name: \"c1\", data_type: Boolean, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: None }";
assert_eq!(expected, arrow_schema.to_string());
Ok(())
}
#[test]
fn join_qualified() -> Result<()> {
- let left = DFSchema::try_from_qualified("t1", &test_schema_1())?;
- let right = DFSchema::try_from_qualified("t2", &test_schema_1())?;
+ let left = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?;
+ let right = DFSchema::try_from_qualified_schema("t2", &test_schema_1())?;
let join = left.join(&right)?;
assert_eq!("t1.c0, t1.c1, t2.c0, t2.c1", join.to_string());
// test valid access
@@ -418,8 +527,8 @@ mod tests {
#[test]
fn join_qualified_duplicate() -> Result<()> {
- let left = DFSchema::try_from_qualified("t1", &test_schema_1())?;
- let right = DFSchema::try_from_qualified("t1", &test_schema_1())?;
+ let left = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?;
+ let right = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?;
let join = left.join(&right);
assert!(join.is_err());
assert_eq!(
@@ -446,7 +555,7 @@ mod tests {
#[test]
fn join_mixed() -> Result<()> {
- let left = DFSchema::try_from_qualified("t1", &test_schema_1())?;
+ let left = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?;
let right = DFSchema::try_from(test_schema_2())?;
let join = left.join(&right)?;
assert_eq!("t1.c0, t1.c1, c100, c101", join.to_string());
@@ -464,7 +573,7 @@ mod tests {
#[test]
fn join_mixed_duplicate() -> Result<()> {
- let left = DFSchema::try_from_qualified("t1", &test_schema_1())?;
+ let left = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?;
let right = DFSchema::try_from(test_schema_1())?;
let join = left.join(&right);
assert!(join.is_err());
diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs
index 58dba16..1c5cc77 100644
--- a/datafusion/src/logical_plan/expr.rs
+++ b/datafusion/src/logical_plan/expr.rs
@@ -20,7 +20,7 @@
pub use super::Operator;
use crate::error::{DataFusionError, Result};
-use crate::logical_plan::{window_frames, DFField, DFSchema};
+use crate::logical_plan::{window_frames, DFField, DFSchema, DFSchemaRef};
use crate::physical_plan::{
aggregates, expressions::binary_operator_data_type, functions, udf::ScalarUDF,
window_functions,
@@ -33,6 +33,90 @@ use std::collections::HashSet;
use std::fmt;
use std::sync::Arc;
+/// A named reference to a qualified field in a schema.
+#[derive(Debug, Clone, PartialEq, Eq, Hash)]
+pub struct Column {
+ /// relation/table name.
+ pub relation: Option<String>,
+ /// field/column name.
+ pub name: String,
+}
+
+impl Column {
+ /// Create Column from unqualified name.
+ pub fn from_name(name: String) -> Self {
+ Self {
+ relation: None,
+ name,
+ }
+ }
+
+ /// Deserialize a fully qualified name string into a column
+ pub fn from_qualified_name(flat_name: &str) -> Self {
+ use sqlparser::tokenizer::Token;
+
+ let dialect = sqlparser::dialect::GenericDialect {};
+ let mut tokenizer = sqlparser::tokenizer::Tokenizer::new(&dialect, flat_name);
+ if let Ok(tokens) = tokenizer.tokenize() {
+ if let [Token::Word(relation), Token::Period, Token::Word(name)] =
+ tokens.as_slice()
+ {
+ return Column {
+ relation: Some(relation.value.clone()),
+ name: name.value.clone(),
+ };
+ }
+ }
+ // any expression that's not in the form of `foo.bar` will be treated as unqualified column
+ // name
+ Column {
+ relation: None,
+ name: String::from(flat_name),
+ }
+ }
+
+ /// Serialize column into a flat name string
+ pub fn flat_name(&self) -> String {
+ match &self.relation {
+ Some(r) => format!("{}.{}", r, self.name),
+ None => self.name.clone(),
+ }
+ }
+
+ /// Normalize Column with qualifier based on provided dataframe schemas.
+ pub fn normalize(self, schemas: &[&DFSchemaRef]) -> Result<Self> {
+ if self.relation.is_some() {
+ return Ok(self);
+ }
+
+ for schema in schemas {
+ if let Ok(field) = schema.field_with_unqualified_name(&self.name) {
+ return Ok(field.qualified_column());
+ }
+ }
+
+ Err(DataFusionError::Plan(format!(
+ "Column {} not found in provided schemas",
+ self
+ )))
+ }
+}
+
+impl From<&str> for Column {
+ fn from(c: &str) -> Self {
+ Self::from_qualified_name(c)
+ }
+}
+
+impl fmt::Display for Column {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ match &self.relation {
+ Some(r) => write!(f, "#{}.{}", r, self.name),
+ None => write!(f, "#{}", self.name),
+ }
+ }
+}
+
/// `Expr` is a central struct of DataFusion's query API, and
/// represent logical expressions such as `A + 1`, or `CAST(c1 AS
/// int)`.
@@ -47,7 +131,7 @@ use std::sync::Arc;
/// ```
/// # use datafusion::logical_plan::*;
/// let expr = col("c1");
-/// assert_eq!(expr, Expr::Column("c1".to_string()));
+/// assert_eq!(expr, Expr::Column(Column::from_name("c1".to_string())));
/// ```
///
/// ## Create the expression `c1 + c2` to add columns "c1" and "c2" together
@@ -81,8 +165,8 @@ use std::sync::Arc;
pub enum Expr {
/// An expression with a specific name.
Alias(Box<Expr>, String),
- /// A named reference to a field in a schema.
- Column(String),
+ /// A named reference to a qualified filed in a schema.
+ Column(Column),
/// A named reference to a variable in a registry.
ScalarVariable(Vec<String>),
/// A constant value.
@@ -232,10 +316,9 @@ impl Expr {
pub fn get_type(&self, schema: &DFSchema) -> Result<DataType> {
match self {
Expr::Alias(expr, _) => expr.get_type(schema),
- Expr::Column(name) => Ok(schema
- .field_with_unqualified_name(name)?
- .data_type()
- .clone()),
+ Expr::Column(c) => {
+ Ok(schema.field_from_qualified_column(c)?.data_type().clone())
+ }
Expr::ScalarVariable(_) => Ok(DataType::Utf8),
Expr::Literal(l) => Ok(l.get_datatype()),
Expr::Case { when_then_expr, .. } => when_then_expr[0].1.get_type(schema),
@@ -307,9 +390,9 @@ impl Expr {
pub fn nullable(&self, input_schema: &DFSchema) -> Result<bool> {
match self {
Expr::Alias(expr, _) => expr.nullable(input_schema),
- Expr::Column(name) => Ok(input_schema
- .field_with_unqualified_name(name)?
- .is_nullable()),
+ Expr::Column(c) => {
+ Ok(input_schema.field_from_qualified_column(c)?.is_nullable())
+ }
Expr::Literal(value) => Ok(value.is_null()),
Expr::ScalarVariable(_) => Ok(true),
Expr::Case {
@@ -355,7 +438,7 @@ impl Expr {
}
}
- /// Returns the name of this expression based on [arrow::datatypes::Schema].
+ /// Returns the name of this expression based on [crate::logical_plan::DFSchema].
///
/// This represents how a column with this expression is named when no alias is chosen
pub fn name(&self, input_schema: &DFSchema) -> Result<String> {
@@ -364,12 +447,20 @@ impl Expr {
/// Returns a [arrow::datatypes::Field] compatible with this expression.
pub fn to_field(&self, input_schema: &DFSchema) -> Result<DFField> {
- Ok(DFField::new(
- None, //TODO qualifier
- &self.name(input_schema)?,
- self.get_type(input_schema)?,
- self.nullable(input_schema)?,
- ))
+ match self {
+ Expr::Column(c) => Ok(DFField::new(
+ c.relation.as_deref(),
+ &c.name,
+ self.get_type(input_schema)?,
+ self.nullable(input_schema)?,
+ )),
+ _ => Ok(DFField::new(
+ None,
+ &self.name(input_schema)?,
+ self.get_type(input_schema)?,
+ self.nullable(input_schema)?,
+ )),
+ }
}
/// Wraps this expression in a cast to a target [arrow::datatypes::DataType].
@@ -540,7 +631,7 @@ impl Expr {
// recurse (and cover all expression types)
let visitor = match self {
Expr::Alias(expr, _) => expr.accept(visitor),
- Expr::Column(..) => Ok(visitor),
+ Expr::Column(_) => Ok(visitor),
Expr::ScalarVariable(..) => Ok(visitor),
Expr::Literal(..) => Ok(visitor),
Expr::BinaryExpr { left, right, .. } => {
@@ -668,7 +759,7 @@ impl Expr {
// recurse into all sub expressions(and cover all expression types)
let expr = match self {
Expr::Alias(expr, name) => Expr::Alias(rewrite_boxed(expr, rewriter)?, name),
- Expr::Column(name) => Expr::Column(name),
+ Expr::Column(_) => self.clone(),
Expr::ScalarVariable(names) => Expr::ScalarVariable(names),
Expr::Literal(value) => Expr::Literal(value),
Expr::BinaryExpr { left, op, right } => Expr::BinaryExpr {
@@ -985,9 +1076,72 @@ pub fn or(left: Expr, right: Expr) -> Expr {
}
}
-/// Create a column expression based on a column name
-pub fn col(name: &str) -> Expr {
- Expr::Column(name.to_owned())
+/// Create a column expression based on a qualified or unqualified column name
+pub fn col(ident: &str) -> Expr {
+ Expr::Column(ident.into())
+}
+
+/// Convert an expression into Column expression if it's already provided as input plan.
+///
+/// For example, it rewrites:
+///
+/// ```ignore
+/// .aggregate(vec![col("c1")], vec![sum(col("c2"))])?
+/// .project(vec![col("c1"), sum(col("c2"))?
+/// ```
+///
+/// Into:
+///
+/// ```ignore
+/// .aggregate(vec![col("c1")], vec![sum(col("c2"))])?
+/// .project(vec![col("c1"), col("SUM(#c2)")?
+/// ```
+pub fn columnize_expr(e: Expr, input_schema: &DFSchema) -> Expr {
+ match e {
+ Expr::Column(_) => e,
+ Expr::Alias(inner_expr, name) => {
+ Expr::Alias(Box::new(columnize_expr(*inner_expr, input_schema)), name)
+ }
+ _ => match e.name(input_schema) {
+ Ok(name) => match input_schema.field_with_unqualified_name(&name) {
+ Ok(field) => Expr::Column(field.qualified_column()),
+ // expression not provided as input, do not convert to a column reference
+ Err(_) => e,
+ },
+ Err(_) => e,
+ },
+ }
+}
+
+/// Recursively normalize all Column expressions in a given expression tree
+pub fn normalize_col(e: Expr, schemas: &[&DFSchemaRef]) -> Result<Expr> {
+ struct ColumnNormalizer<'a, 'b> {
+ schemas: &'a [&'b DFSchemaRef],
+ }
+
+ impl<'a, 'b> ExprRewriter for ColumnNormalizer<'a, 'b> {
+ fn mutate(&mut self, expr: Expr) -> Result<Expr> {
+ if let Expr::Column(c) = expr {
+ Ok(Expr::Column(c.normalize(self.schemas)?))
+ } else {
+ Ok(expr)
+ }
+ }
+ }
+
+ e.rewrite(&mut ColumnNormalizer { schemas })
+}
+
+/// Recursively normalize all Column expressions in a list of expression trees
+#[inline]
+pub fn normalize_cols(
+ exprs: impl IntoIterator<Item = Expr>,
+ schemas: &[&DFSchemaRef],
+) -> Result<Vec<Expr>> {
+ exprs
+ .into_iter()
+ .map(|e| normalize_col(e, schemas))
+ .collect()
}
/// Create an expression to represent the min() aggregate function
@@ -1240,7 +1394,7 @@ impl fmt::Debug for Expr {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Expr::Alias(expr, alias) => write!(f, "{:?} AS {}", expr, alias),
- Expr::Column(name) => write!(f, "#{}", name),
+ Expr::Column(c) => write!(f, "{}", c),
Expr::ScalarVariable(var_names) => write!(f, "{}", var_names.join(".")),
Expr::Literal(v) => write!(f, "{:?}", v),
Expr::Case {
@@ -1373,7 +1527,7 @@ fn create_function_name(
fn create_name(e: &Expr, input_schema: &DFSchema) -> Result<String> {
match e {
Expr::Alias(_, name) => Ok(name.clone()),
- Expr::Column(name) => Ok(name.clone()),
+ Expr::Column(c) => Ok(c.flat_name()),
Expr::ScalarVariable(variable_names) => Ok(variable_names.join(".")),
Expr::Literal(value) => Ok(format!("{:?}", value)),
Expr::BinaryExpr { left, op, right } => {
@@ -1524,8 +1678,8 @@ mod tests {
#[test]
fn filter_is_null_and_is_not_null() {
- let col_null = Expr::Column("col1".to_string());
- let col_not_null = Expr::Column("col2".to_string());
+ let col_null = col("col1");
+ let col_not_null = col("col2");
assert_eq!(format!("{:?}", col_null.is_null()), "#col1 IS NULL");
assert_eq!(
format!("{:?}", col_not_null.is_not_null()),
diff --git a/datafusion/src/logical_plan/mod.rs b/datafusion/src/logical_plan/mod.rs
index 4a39e11..69d03d2 100644
--- a/datafusion/src/logical_plan/mod.rs
+++ b/datafusion/src/logical_plan/mod.rs
@@ -30,22 +30,26 @@ mod operators;
mod plan;
mod registry;
pub mod window_frames;
-pub use builder::LogicalPlanBuilder;
+pub use builder::{
+ build_join_schema, union_with_alias, LogicalPlanBuilder, UNNAMED_TABLE,
+};
pub use dfschema::{DFField, DFSchema, DFSchemaRef, ToDFSchema};
pub use display::display_schema;
pub use expr::{
abs, acos, and, array, ascii, asin, atan, avg, binary_expr, bit_length, btrim, case,
- ceil, character_length, chr, col, combine_filters, concat, concat_ws, cos, count,
- count_distinct, create_udaf, create_udf, exp, exprlist_to_fields, floor, in_list,
- initcap, left, length, lit, ln, log10, log2, lower, lpad, ltrim, max, md5, min, now,
- octet_length, or, random, regexp_match, regexp_replace, repeat, replace, reverse,
- right, round, rpad, rtrim, sha224, sha256, sha384, sha512, signum, sin, split_part,
- sqrt, starts_with, strpos, substr, sum, tan, to_hex, translate, trim, trunc, upper,
- when, Expr, ExprRewriter, ExpressionVisitor, Literal, Recursion,
+ ceil, character_length, chr, col, columnize_expr, combine_filters, concat, concat_ws,
+ cos, count, count_distinct, create_udaf, create_udf, exp, exprlist_to_fields, floor,
+ in_list, initcap, left, length, lit, ln, log10, log2, lower, lpad, ltrim, max, md5,
+ min, normalize_col, normalize_cols, now, octet_length, or, random, regexp_match,
+ regexp_replace, repeat, replace, reverse, right, round, rpad, rtrim, sha224, sha256,
+ sha384, sha512, signum, sin, split_part, sqrt, starts_with, strpos, substr, sum, tan,
+ to_hex, translate, trim, trunc, upper, when, Column, Expr, ExprRewriter,
+ ExpressionVisitor, Literal, Recursion,
};
pub use extension::UserDefinedLogicalNode;
pub use operators::Operator;
pub use plan::{
- JoinType, LogicalPlan, Partitioning, PlanType, PlanVisitor, StringifiedPlan,
+ JoinConstraint, JoinType, LogicalPlan, Partitioning, PlanType, PlanVisitor,
+ StringifiedPlan,
};
pub use registry::FunctionRegistry;
diff --git a/datafusion/src/logical_plan/plan.rs b/datafusion/src/logical_plan/plan.rs
index a80bc54..2562472 100644
--- a/datafusion/src/logical_plan/plan.rs
+++ b/datafusion/src/logical_plan/plan.rs
@@ -17,18 +17,14 @@
//! This module contains the `LogicalPlan` enum that describes queries
//! via a logical query plan.
-use super::expr::Expr;
+use super::display::{GraphvizVisitor, IndentVisitor};
+use super::expr::{Column, Expr};
use super::extension::UserDefinedLogicalNode;
-use super::{
- col,
- display::{GraphvizVisitor, IndentVisitor},
-};
use crate::datasource::TableProvider;
use crate::logical_plan::dfschema::DFSchemaRef;
use crate::sql::parser::FileType;
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use std::{
- cmp::min,
fmt::{self, Display},
sync::Arc,
};
@@ -50,6 +46,15 @@ pub enum JoinType {
Anti,
}
+/// Join constraint
+#[derive(Debug, Clone, Copy)]
+pub enum JoinConstraint {
+ /// Join ON
+ On,
+ /// Join USING
+ Using,
+}
+
/// A LogicalPlan represents the different types of relational
/// operators (such as Projection, Filter, etc) and can be created by
/// the SQL query planner and the DataFrame API.
@@ -125,9 +130,11 @@ pub enum LogicalPlan {
/// Right input
right: Arc<LogicalPlan>,
/// Equijoin clause expressed as pairs of (left, right) join columns
- on: Vec<(String, String)>,
+ on: Vec<(Column, Column)>,
/// Join type
join_type: JoinType,
+ /// Join constraint
+ join_constraint: JoinConstraint,
/// The output schema, containing fields from the left and right inputs
schema: DFSchemaRef,
},
@@ -312,9 +319,10 @@ impl LogicalPlan {
aggr_expr,
..
} => group_expr.iter().chain(aggr_expr.iter()).cloned().collect(),
- LogicalPlan::Join { on, .. } => {
- on.iter().flat_map(|(l, r)| vec![col(l), col(r)]).collect()
- }
+ LogicalPlan::Join { on, .. } => on
+ .iter()
+ .flat_map(|(l, r)| vec![Expr::Column(l.clone()), Expr::Column(r.clone())])
+ .collect(),
LogicalPlan::Sort { expr, .. } => expr.clone(),
LogicalPlan::Extension { node } => node.expressions(),
// plans without expressions
@@ -479,9 +487,9 @@ impl LogicalPlan {
/// per node. For example:
///
/// ```text
- /// Projection: #id
- /// Filter: #state Eq Utf8(\"CO\")\
- /// CsvScan: employee.csv projection=Some([0, 3])
+ /// Projection: #employee.id
+ /// Filter: #employee.state Eq Utf8(\"CO\")\
+ /// CsvScan: employee projection=Some([0, 3])
/// ```
///
/// ```
@@ -490,15 +498,15 @@ impl LogicalPlan {
/// let schema = Schema::new(vec![
/// Field::new("id", DataType::Int32, false),
/// ]);
- /// let plan = LogicalPlanBuilder::scan_empty("foo.csv", &schema, None).unwrap()
+ /// let plan = LogicalPlanBuilder::scan_empty(Some("foo_csv"), &schema, None).unwrap()
/// .filter(col("id").eq(lit(5))).unwrap()
/// .build().unwrap();
///
/// // Format using display_indent
/// let display_string = format!("{}", plan.display_indent());
///
- /// assert_eq!("Filter: #id Eq Int32(5)\
- /// \n TableScan: foo.csv projection=None",
+ /// assert_eq!("Filter: #foo_csv.id Eq Int32(5)\
+ /// \n TableScan: foo_csv projection=None",
/// display_string);
/// ```
pub fn display_indent(&self) -> impl fmt::Display + '_ {
@@ -520,9 +528,9 @@ impl LogicalPlan {
/// per node that includes the output schema. For example:
///
/// ```text
- /// Projection: #id [id:Int32]\
- /// Filter: #state Eq Utf8(\"CO\") [id:Int32, state:Utf8]\
- /// TableScan: employee.csv projection=Some([0, 3]) [id:Int32, state:Utf8]";
+ /// Projection: #employee.id [id:Int32]\
+ /// Filter: #employee.state Eq Utf8(\"CO\") [id:Int32, state:Utf8]\
+ /// TableScan: employee projection=Some([0, 3]) [id:Int32, state:Utf8]";
/// ```
///
/// ```
@@ -531,15 +539,15 @@ impl LogicalPlan {
/// let schema = Schema::new(vec![
/// Field::new("id", DataType::Int32, false),
/// ]);
- /// let plan = LogicalPlanBuilder::scan_empty("foo.csv", &schema, None).unwrap()
+ /// let plan = LogicalPlanBuilder::scan_empty(Some("foo_csv"), &schema, None).unwrap()
/// .filter(col("id").eq(lit(5))).unwrap()
/// .build().unwrap();
///
/// // Format using display_indent_schema
/// let display_string = format!("{}", plan.display_indent_schema());
///
- /// assert_eq!("Filter: #id Eq Int32(5) [id:Int32]\
- /// \n TableScan: foo.csv projection=None [id:Int32]",
+ /// assert_eq!("Filter: #foo_csv.id Eq Int32(5) [id:Int32]\
+ /// \n TableScan: foo_csv projection=None [id:Int32]",
/// display_string);
/// ```
pub fn display_indent_schema(&self) -> impl fmt::Display + '_ {
@@ -571,7 +579,7 @@ impl LogicalPlan {
/// let schema = Schema::new(vec![
/// Field::new("id", DataType::Int32, false),
/// ]);
- /// let plan = LogicalPlanBuilder::scan_empty("foo.csv", &schema, None).unwrap()
+ /// let plan = LogicalPlanBuilder::scan_empty(Some("foo.csv"), &schema, None).unwrap()
/// .filter(col("id").eq(lit(5))).unwrap()
/// .build().unwrap();
///
@@ -630,7 +638,7 @@ impl LogicalPlan {
/// let schema = Schema::new(vec![
/// Field::new("id", DataType::Int32, false),
/// ]);
- /// let plan = LogicalPlanBuilder::scan_empty("foo.csv", &schema, None).unwrap()
+ /// let plan = LogicalPlanBuilder::scan_empty(Some("foo.csv"), &schema, None).unwrap()
/// .build().unwrap();
///
/// // Format using display
@@ -653,11 +661,10 @@ impl LogicalPlan {
ref limit,
..
} => {
- let sep = " ".repeat(min(1, table_name.len()));
write!(
f,
- "TableScan: {}{}projection={:?}",
- table_name, sep, projection
+ "TableScan: {} projection={:?}",
+ table_name, projection
)?;
if !filters.is_empty() {
@@ -826,7 +833,7 @@ mod tests {
fn display_plan() -> LogicalPlan {
LogicalPlanBuilder::scan_empty(
- "employee.csv",
+ Some("employee_csv"),
&employee_schema(),
Some(vec![0, 3]),
)
@@ -843,9 +850,9 @@ mod tests {
fn test_display_indent() {
let plan = display_plan();
- let expected = "Projection: #id\
- \n Filter: #state Eq Utf8(\"CO\")\
- \n TableScan: employee.csv projection=Some([0, 3])";
+ let expected = "Projection: #employee_csv.id\
+ \n Filter: #employee_csv.state Eq Utf8(\"CO\")\
+ \n TableScan: employee_csv projection=Some([0, 3])";
assert_eq!(expected, format!("{}", plan.display_indent()));
}
@@ -854,9 +861,9 @@ mod tests {
fn test_display_indent_schema() {
let plan = display_plan();
- let expected = "Projection: #id [id:Int32]\
- \n Filter: #state Eq Utf8(\"CO\") [id:Int32, state:Utf8]\
- \n TableScan: employee.csv projection=Some([0, 3]) [id:Int32, state:Utf8]";
+ let expected = "Projection: #employee_csv.id [id:Int32]\
+ \n Filter: #employee_csv.state Eq Utf8(\"CO\") [id:Int32, state:Utf8]\
+ \n TableScan: employee_csv projection=Some([0, 3]) [id:Int32, state:Utf8]";
assert_eq!(expected, format!("{}", plan.display_indent_schema()));
}
@@ -878,12 +885,12 @@ mod tests {
);
assert!(
graphviz.contains(
- r#"[shape=box label="TableScan: employee.csv projection=Some([0, 3])"]"#
+ r#"[shape=box label="TableScan: employee_csv projection=Some([0, 3])"]"#
),
"\n{}",
plan.display_graphviz()
);
- assert!(graphviz.contains(r#"[shape=box label="TableScan: employee.csv projection=Some([0, 3])\nSchema: [id:Int32, state:Utf8]"]"#),
+ assert!(graphviz.contains(r#"[shape=box label="TableScan: employee_csv projection=Some([0, 3])\nSchema: [id:Int32, state:Utf8]"]"#),
"\n{}", plan.display_graphviz());
assert!(
graphviz.contains(r#"// End DataFusion GraphViz Plan"#),
@@ -1128,9 +1135,12 @@ mod tests {
}
fn test_plan() -> LogicalPlan {
- let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]);
+ let schema = Schema::new(vec![
+ Field::new("id", DataType::Int32, false),
+ Field::new("state", DataType::Utf8, false),
+ ]);
- LogicalPlanBuilder::scan_empty("", &schema, Some(vec![0]))
+ LogicalPlanBuilder::scan_empty(None, &schema, Some(vec![0, 1]))
.unwrap()
.filter(col("state").eq(lit("CO")))
.unwrap()
diff --git a/datafusion/src/optimizer/constant_folding.rs b/datafusion/src/optimizer/constant_folding.rs
index d2ac5ce..956f74a 100644
--- a/datafusion/src/optimizer/constant_folding.rs
+++ b/datafusion/src/optimizer/constant_folding.rs
@@ -293,7 +293,7 @@ mod tests {
Field::new("c", DataType::Boolean, false),
Field::new("d", DataType::UInt32, false),
]);
- LogicalPlanBuilder::scan_empty("test", &schema, None)?.build()
+ LogicalPlanBuilder::scan_empty(Some("test"), &schema, None)?.build()
}
fn expr_test_schema() -> DFSchemaRef {
@@ -551,9 +551,9 @@ mod tests {
.build()?;
let expected = "\
- Projection: #a\
- \n Filter: NOT #c\
- \n Filter: #b\
+ Projection: #test.a\
+ \n Filter: NOT #test.c\
+ \n Filter: #test.b\
\n TableScan: test projection=None";
assert_optimized_plan_eq(&plan, expected);
@@ -571,10 +571,10 @@ mod tests {
.build()?;
let expected = "\
- Projection: #a\
+ Projection: #test.a\
\n Limit: 1\
- \n Filter: #c\
- \n Filter: NOT #b\
+ \n Filter: #test.c\
+ \n Filter: NOT #test.b\
\n TableScan: test projection=None";
assert_optimized_plan_eq(&plan, expected);
@@ -590,8 +590,8 @@ mod tests {
.build()?;
let expected = "\
- Projection: #a\
- \n Filter: NOT #b And #c\
+ Projection: #test.a\
+ \n Filter: NOT #test.b And #test.c\
\n TableScan: test projection=None";
assert_optimized_plan_eq(&plan, expected);
@@ -607,8 +607,8 @@ mod tests {
.build()?;
let expected = "\
- Projection: #a\
- \n Filter: NOT #b Or NOT #c\
+ Projection: #test.a\
+ \n Filter: NOT #test.b Or NOT #test.c\
\n TableScan: test projection=None";
assert_optimized_plan_eq(&plan, expected);
@@ -624,8 +624,8 @@ mod tests {
.build()?;
let expected = "\
- Projection: #a\
- \n Filter: #b\
+ Projection: #test.a\
+ \n Filter: #test.b\
\n TableScan: test projection=None";
assert_optimized_plan_eq(&plan, expected);
@@ -640,7 +640,7 @@ mod tests {
.build()?;
let expected = "\
- Projection: #a, #d, NOT #b\
+ Projection: #test.a, #test.d, NOT #test.b\
\n TableScan: test projection=None";
assert_optimized_plan_eq(&plan, expected);
@@ -659,8 +659,8 @@ mod tests {
.build()?;
let expected = "\
- Aggregate: groupBy=[[#a, #c]], aggr=[[MAX(#b), MIN(#b)]]\
- \n Projection: #a, #c, #b\
+ Aggregate: groupBy=[[#test.a, #test.c]], aggr=[[MAX(#test.b), MIN(#test.b)]]\
+ \n Projection: #test.a, #test.c, #test.b\
\n TableScan: test projection=None";
assert_optimized_plan_eq(&plan, expected);
diff --git a/datafusion/src/optimizer/eliminate_limit.rs b/datafusion/src/optimizer/eliminate_limit.rs
index 1b965f1..4b5a634 100644
--- a/datafusion/src/optimizer/eliminate_limit.rs
+++ b/datafusion/src/optimizer/eliminate_limit.rs
@@ -122,7 +122,7 @@ mod tests {
// Left side is removed
let expected = "Union\
\n EmptyRelation\
- \n Aggregate: groupBy=[[#a]], aggr=[[SUM(#b)]]\
+ \n Aggregate: groupBy=[[#test.a]], aggr=[[SUM(#test.b)]]\
\n TableScan: test projection=None";
assert_optimized_plan_eq(&plan, expected);
}
diff --git a/datafusion/src/optimizer/filter_push_down.rs b/datafusion/src/optimizer/filter_push_down.rs
index dc4d5e9..e5f8dcf 100644
--- a/datafusion/src/optimizer/filter_push_down.rs
+++ b/datafusion/src/optimizer/filter_push_down.rs
@@ -16,7 +16,7 @@
use crate::datasource::datasource::TableProviderFilterPushDown;
use crate::execution::context::ExecutionProps;
-use crate::logical_plan::{and, LogicalPlan};
+use crate::logical_plan::{and, Column, LogicalPlan};
use crate::logical_plan::{DFSchema, Expr};
use crate::optimizer::optimizer::OptimizerRule;
use crate::optimizer::utils;
@@ -56,15 +56,15 @@ pub struct FilterPushDown {}
#[derive(Debug, Clone, Default)]
struct State {
// (predicate, columns on the predicate)
- filters: Vec<(Expr, HashSet<String>)>,
+ filters: Vec<(Expr, HashSet<Column>)>,
}
-type Predicates<'a> = (Vec<&'a Expr>, Vec<&'a HashSet<String>>);
+type Predicates<'a> = (Vec<&'a Expr>, Vec<&'a HashSet<Column>>);
/// returns all predicates in `state` that depend on any of `used_columns`
fn get_predicates<'a>(
state: &'a State,
- used_columns: &HashSet<String>,
+ used_columns: &HashSet<Column>,
) -> Predicates<'a> {
state
.filters
@@ -89,19 +89,19 @@ fn get_join_predicates<'a>(
left: &DFSchema,
right: &DFSchema,
) -> (
- Vec<&'a HashSet<String>>,
- Vec<&'a HashSet<String>>,
+ Vec<&'a HashSet<Column>>,
+ Vec<&'a HashSet<Column>>,
Predicates<'a>,
) {
let left_columns = &left
.fields()
.iter()
- .map(|f| f.name().clone())
+ .map(|f| f.qualified_column())
.collect::<HashSet<_>>();
let right_columns = &right
.fields()
.iter()
- .map(|f| f.name().clone())
+ .map(|f| f.qualified_column())
.collect::<HashSet<_>>();
let filters = state
@@ -173,9 +173,9 @@ fn add_filter(plan: LogicalPlan, predicates: &[&Expr]) -> LogicalPlan {
// remove all filters from `filters` that are in `predicate_columns`
fn remove_filters(
- filters: &[(Expr, HashSet<String>)],
- predicate_columns: &[&HashSet<String>],
-) -> Vec<(Expr, HashSet<String>)> {
+ filters: &[(Expr, HashSet<Column>)],
+ predicate_columns: &[&HashSet<Column>],
+) -> Vec<(Expr, HashSet<Column>)> {
filters
.iter()
.filter(|(_, columns)| !predicate_columns.contains(&columns))
@@ -185,9 +185,9 @@ fn remove_filters(
// keeps all filters from `filters` that are in `predicate_columns`
fn keep_filters(
- filters: &[(Expr, HashSet<String>)],
- predicate_columns: &[&HashSet<String>],
-) -> Vec<(Expr, HashSet<String>)> {
+ filters: &[(Expr, HashSet<Column>)],
+ predicate_columns: &[&HashSet<Column>],
+) -> Vec<(Expr, HashSet<Column>)> {
filters
.iter()
.filter(|(_, columns)| predicate_columns.contains(&columns))
@@ -199,7 +199,7 @@ fn keep_filters(
/// in `state` depend on the columns `used_columns`.
fn issue_filters(
mut state: State,
- used_columns: HashSet<String>,
+ used_columns: HashSet<Column>,
plan: &LogicalPlan,
) -> Result<LogicalPlan> {
let (predicates, predicate_columns) = get_predicates(&state, &used_columns);
@@ -248,8 +248,8 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result<LogicalPlan> {
predicates
.into_iter()
.try_for_each::<_, Result<()>>(|predicate| {
- let mut columns: HashSet<String> = HashSet::new();
- utils::expr_to_column_names(predicate, &mut columns)?;
+ let mut columns: HashSet<Column> = HashSet::new();
+ utils::expr_to_columns(predicate, &mut columns)?;
if columns.is_empty() {
no_col_predicates.push(predicate)
} else {
@@ -282,7 +282,7 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result<LogicalPlan> {
expr => expr.clone(),
};
- projection.insert(field.name().clone(), expr);
+ projection.insert(field.qualified_name(), expr);
});
// re-write all filters based on this projection
@@ -291,7 +291,7 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result<LogicalPlan> {
*predicate = rewrite(predicate, &projection)?;
columns.clear();
- utils::expr_to_column_names(predicate, columns)?;
+ utils::expr_to_columns(predicate, columns)?;
}
// optimize inner
@@ -308,11 +308,11 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result<LogicalPlan> {
// construct set of columns that `aggr_expr` depends on
let mut used_columns = HashSet::new();
- utils::exprlist_to_column_names(aggr_expr, &mut used_columns)?;
+ utils::exprlist_to_columns(aggr_expr, &mut used_columns)?;
let agg_columns = aggr_expr
.iter()
- .map(|x| x.name(input.schema()))
+ .map(|x| Ok(Column::from_name(x.name(input.schema())?)))
.collect::<Result<HashSet<_>>>()?;
used_columns.extend(agg_columns);
@@ -332,7 +332,7 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result<LogicalPlan> {
.schema()
.fields()
.iter()
- .map(|f| f.name().clone())
+ .map(|f| f.qualified_column())
.collect::<HashSet<_>>();
issue_filters(state, used_columns, plan)
}
@@ -415,7 +415,7 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result<LogicalPlan> {
.schema()
.fields()
.iter()
- .map(|f| f.name().clone())
+ .map(|f| f.qualified_column())
.collect::<HashSet<_>>();
issue_filters(state, used_columns, plan)
}
@@ -448,8 +448,8 @@ fn rewrite(expr: &Expr, projection: &HashMap<String, Expr>) -> Result<Expr> {
.map(|e| rewrite(e, projection))
.collect::<Result<Vec<_>>>()?;
- if let Expr::Column(name) = expr {
- if let Some(expr) = projection.get(name) {
+ if let Expr::Column(c) = expr {
+ if let Some(expr) = projection.get(&c.flat_name()) {
return Ok(expr.clone());
}
}
@@ -489,8 +489,8 @@ mod tests {
.build()?;
// filter is before projection
let expected = "\
- Projection: #a, #b\
- \n Filter: #a Eq Int64(1)\
+ Projection: #test.a, #test.b\
+ \n Filter: #test.a Eq Int64(1)\
\n TableScan: test projection=None";
assert_optimized_plan_eq(&plan, expected);
Ok(())
@@ -506,9 +506,9 @@ mod tests {
.build()?;
// filter is before single projection
let expected = "\
- Filter: #a Eq Int64(1)\
+ Filter: #test.a Eq Int64(1)\
\n Limit: 10\
- \n Projection: #a, #b\
+ \n Projection: #test.a, #test.b\
\n TableScan: test projection=None";
assert_optimized_plan_eq(&plan, expected);
Ok(())
@@ -537,9 +537,9 @@ mod tests {
.build()?;
// filter is before double projection
let expected = "\
- Projection: #c, #b\
- \n Projection: #a, #b, #c\
- \n Filter: #a Eq Int64(1)\
+ Projection: #test.c, #test.b\
+ \n Projection: #test.a, #test.b, #test.c\
+ \n Filter: #test.a Eq Int64(1)\
\n TableScan: test projection=None";
assert_optimized_plan_eq(&plan, expected);
Ok(())
@@ -554,8 +554,8 @@ mod tests {
.build()?;
// filter of key aggregation is commutative
let expected = "\
- Aggregate: groupBy=[[#a]], aggr=[[SUM(#b) AS total_salary]]\
- \n Filter: #a Gt Int64(10)\
+ Aggregate: groupBy=[[#test.a]], aggr=[[SUM(#test.b) AS total_salary]]\
+ \n Filter: #test.a Gt Int64(10)\
\n TableScan: test projection=None";
assert_optimized_plan_eq(&plan, expected);
Ok(())
@@ -571,7 +571,7 @@ mod tests {
// filter of aggregate is after aggregation since they are non-commutative
let expected = "\
Filter: #b Gt Int64(10)\
- \n Aggregate: groupBy=[[#a]], aggr=[[SUM(#b) AS b]]\
+ \n Aggregate: groupBy=[[#test.a]], aggr=[[SUM(#test.b) AS b]]\
\n TableScan: test projection=None";
assert_optimized_plan_eq(&plan, expected);
Ok(())
@@ -587,8 +587,8 @@ mod tests {
.build()?;
// filter is before projection
let expected = "\
- Projection: #a AS b, #c\
- \n Filter: #a Eq Int64(1)\
+ Projection: #test.a AS b, #test.c\
+ \n Filter: #test.a Eq Int64(1)\
\n TableScan: test projection=None";
assert_optimized_plan_eq(&plan, expected);
Ok(())
@@ -627,14 +627,14 @@ mod tests {
format!("{:?}", plan),
"\
Filter: #b Eq Int64(1)\
- \n Projection: #a Multiply Int32(2) Plus #c AS b, #c\
+ \n Projection: #test.a Multiply Int32(2) Plus #test.c AS b, #test.c\
\n TableScan: test projection=None"
);
// filter is before projection
let expected = "\
- Projection: #a Multiply Int32(2) Plus #c AS b, #c\
- \n Filter: #a Multiply Int32(2) Plus #c Eq Int64(1)\
+ Projection: #test.a Multiply Int32(2) Plus #test.c AS b, #test.c\
+ \n Filter: #test.a Multiply Int32(2) Plus #test.c Eq Int64(1)\
\n TableScan: test projection=None";
assert_optimized_plan_eq(&plan, expected);
Ok(())
@@ -659,16 +659,16 @@ mod tests {
format!("{:?}", plan),
"\
Filter: #a Eq Int64(1)\
- \n Projection: #b Multiply Int32(3) AS a, #c\
- \n Projection: #a Multiply Int32(2) Plus #c AS b, #c\
+ \n Projection: #b Multiply Int32(3) AS a, #test.c\
+ \n Projection: #test.a Multiply Int32(2) Plus #test.c AS b, #test.c\
\n TableScan: test projection=None"
);
// filter is before the projections
let expected = "\
- Projection: #b Multiply Int32(3) AS a, #c\
- \n Projection: #a Multiply Int32(2) Plus #c AS b, #c\
- \n Filter: #a Multiply Int32(2) Plus #c Multiply Int32(3) Eq Int64(1)\
+ Projection: #b Multiply Int32(3) AS a, #test.c\
+ \n Projection: #test.a Multiply Int32(2) Plus #test.c AS b, #test.c\
+ \n Filter: #test.a Multiply Int32(2) Plus #test.c Multiply Int32(3) Eq Int64(1)\
\n TableScan: test projection=None";
assert_optimized_plan_eq(&plan, expected);
Ok(())
@@ -684,26 +684,26 @@ mod tests {
.project(vec![col("a").alias("b"), col("c")])?
.aggregate(vec![col("b")], vec![sum(col("c"))])?
.filter(col("b").gt(lit(10i64)))?
- .filter(col("SUM(c)").gt(lit(10i64)))?
+ .filter(col("SUM(test.c)").gt(lit(10i64)))?
.build()?;
// not part of the test, just good to know:
assert_eq!(
format!("{:?}", plan),
"\
- Filter: #SUM(c) Gt Int64(10)\
+ Filter: #SUM(test.c) Gt Int64(10)\
\n Filter: #b Gt Int64(10)\
- \n Aggregate: groupBy=[[#b]], aggr=[[SUM(#c)]]\
- \n Projection: #a AS b, #c\
+ \n Aggregate: groupBy=[[#b]], aggr=[[SUM(#test.c)]]\
+ \n Projection: #test.a AS b, #test.c\
\n TableScan: test projection=None"
);
// filter is before the projections
let expected = "\
- Filter: #SUM(c) Gt Int64(10)\
- \n Aggregate: groupBy=[[#b]], aggr=[[SUM(#c)]]\
- \n Projection: #a AS b, #c\
- \n Filter: #a Gt Int64(10)\
+ Filter: #SUM(test.c) Gt Int64(10)\
+ \n Aggregate: groupBy=[[#b]], aggr=[[SUM(#test.c)]]\
+ \n Projection: #test.a AS b, #test.c\
+ \n Filter: #test.a Gt Int64(10)\
\n TableScan: test projection=None";
assert_optimized_plan_eq(&plan, expected);
@@ -720,8 +720,8 @@ mod tests {
.project(vec![col("a").alias("b"), col("c")])?
.aggregate(vec![col("b")], vec![sum(col("c"))])?
.filter(and(
- col("SUM(c)").gt(lit(10i64)),
- and(col("b").gt(lit(10i64)), col("SUM(c)").lt(lit(20i64))),
+ col("SUM(test.c)").gt(lit(10i64)),
+ and(col("b").gt(lit(10i64)), col("SUM(test.c)").lt(lit(20i64))),
))?
.build()?;
@@ -729,18 +729,18 @@ mod tests {
assert_eq!(
format!("{:?}", plan),
"\
- Filter: #SUM(c) Gt Int64(10) And #b Gt Int64(10) And #SUM(c) Lt Int64(20)\
- \n Aggregate: groupBy=[[#b]], aggr=[[SUM(#c)]]\
- \n Projection: #a AS b, #c\
+ Filter: #SUM(test.c) Gt Int64(10) And #b Gt Int64(10) And #SUM(test.c) Lt Int64(20)\
+ \n Aggregate: groupBy=[[#b]], aggr=[[SUM(#test.c)]]\
+ \n Projection: #test.a AS b, #test.c\
\n TableScan: test projection=None"
);
// filter is before the projections
let expected = "\
- Filter: #SUM(c) Gt Int64(10) And #SUM(c) Lt Int64(20)\
- \n Aggregate: groupBy=[[#b]], aggr=[[SUM(#c)]]\
- \n Projection: #a AS b, #c\
- \n Filter: #a Gt Int64(10)\
+ Filter: #SUM(test.c) Gt Int64(10) And #SUM(test.c) Lt Int64(20)\
+ \n Aggregate: groupBy=[[#b]], aggr=[[SUM(#test.c)]]\
+ \n Projection: #test.a AS b, #test.c\
+ \n Filter: #test.a Gt Int64(10)\
\n TableScan: test projection=None";
assert_optimized_plan_eq(&plan, expected);
@@ -760,11 +760,11 @@ mod tests {
.build()?;
// filter does not just any of the limits
let expected = "\
- Projection: #a, #b\
- \n Filter: #a Eq Int64(1)\
+ Projection: #test.a, #test.b\
+ \n Filter: #test.a Eq Int64(1)\
\n Limit: 10\
\n Limit: 20\
- \n Projection: #a, #b\
+ \n Projection: #test.a, #test.b\
\n TableScan: test projection=None";
assert_optimized_plan_eq(&plan, expected);
Ok(())
@@ -804,20 +804,20 @@ mod tests {
// not part of the test
assert_eq!(
format!("{:?}", plan),
- "Filter: #a GtEq Int64(1)\
- \n Projection: #a\
+ "Filter: #test.a GtEq Int64(1)\
+ \n Projection: #test.a\
\n Limit: 1\
- \n Filter: #a LtEq Int64(1)\
- \n Projection: #a\
+ \n Filter: #test.a LtEq Int64(1)\
+ \n Projection: #test.a\
\n TableScan: test projection=None"
);
let expected = "\
- Projection: #a\
- \n Filter: #a GtEq Int64(1)\
+ Projection: #test.a\
+ \n Filter: #test.a GtEq Int64(1)\
\n Limit: 1\
- \n Projection: #a\
- \n Filter: #a LtEq Int64(1)\
+ \n Projection: #test.a\
+ \n Filter: #test.a LtEq Int64(1)\
\n TableScan: test projection=None";
assert_optimized_plan_eq(&plan, expected);
@@ -838,16 +838,16 @@ mod tests {
// not part of the test
assert_eq!(
format!("{:?}", plan),
- "Projection: #a\
- \n Filter: #a GtEq Int64(1)\
- \n Filter: #a LtEq Int64(1)\
+ "Projection: #test.a\
+ \n Filter: #test.a GtEq Int64(1)\
+ \n Filter: #test.a LtEq Int64(1)\
\n Limit: 1\
\n TableScan: test projection=None"
);
let expected = "\
- Projection: #a\
- \n Filter: #a GtEq Int64(1) And #a LtEq Int64(1)\
+ Projection: #test.a\
+ \n Filter: #test.a GtEq Int64(1) And #test.a LtEq Int64(1)\
\n Limit: 1\
\n TableScan: test projection=None";
@@ -868,7 +868,7 @@ mod tests {
let expected = "\
TestUserDefined\
- \n Filter: #a LtEq Int64(1)\
+ \n Filter: #test.a LtEq Int64(1)\
\n TableScan: test projection=None";
// not part of the test
@@ -887,7 +887,12 @@ mod tests {
.project(vec![col("a")])?
.build()?;
let plan = LogicalPlanBuilder::from(&left)
- .join(&right, JoinType::Inner, &["a"], &["a"])?
+ .join(
+ &right,
+ JoinType::Inner,
+ vec![Column::from_name("a".to_string())],
+ vec![Column::from_name("a".to_string())],
+ )?
.filter(col("a").lt_eq(lit(1i64)))?
.build()?;
@@ -895,20 +900,20 @@ mod tests {
assert_eq!(
format!("{:?}", plan),
"\
- Filter: #a LtEq Int64(1)\
- \n Join: a = a\
+ Filter: #test.a LtEq Int64(1)\
+ \n Join: #test.a = #test.a\
\n TableScan: test projection=None\
- \n Projection: #a\
+ \n Projection: #test.a\
\n TableScan: test projection=None"
);
// filter sent to side before the join
let expected = "\
- Join: a = a\
- \n Filter: #a LtEq Int64(1)\
+ Join: #test.a = #test.a\
+ \n Filter: #test.a LtEq Int64(1)\
\n TableScan: test projection=None\
- \n Projection: #a\
- \n Filter: #a LtEq Int64(1)\
+ \n Projection: #test.a\
+ \n Filter: #test.a LtEq Int64(1)\
\n TableScan: test projection=None";
assert_optimized_plan_eq(&plan, expected);
Ok(())
@@ -925,7 +930,12 @@ mod tests {
.project(vec![col("a"), col("b")])?
.build()?;
let plan = LogicalPlanBuilder::from(&left)
- .join(&right, JoinType::Inner, &["a"], &["a"])?
+ .join(
+ &right,
+ JoinType::Inner,
+ vec![Column::from_name("a".to_string())],
+ vec![Column::from_name("a".to_string())],
+ )?
// "b" and "c" are not shared by either side: they are only available together after the join
.filter(col("c").lt_eq(col("b")))?
.build()?;
@@ -934,11 +944,11 @@ mod tests {
assert_eq!(
format!("{:?}", plan),
"\
- Filter: #c LtEq #b\
- \n Join: a = a\
- \n Projection: #a, #c\
+ Filter: #test.c LtEq #test.b\
+ \n Join: #test.a = #test.a\
+ \n Projection: #test.a, #test.c\
\n TableScan: test projection=None\
- \n Projection: #a, #b\
+ \n Projection: #test.a, #test.b\
\n TableScan: test projection=None"
);
@@ -959,7 +969,12 @@ mod tests {
.project(vec![col("a"), col("c")])?
.build()?;
let plan = LogicalPlanBuilder::from(&left)
- .join(&right, JoinType::Inner, &["a"], &["a"])?
+ .join(
+ &right,
+ JoinType::Inner,
+ vec![Column::from_name("a".to_string())],
+ vec![Column::from_name("a".to_string())],
+ )?
.filter(col("b").lt_eq(lit(1i64)))?
.build()?;
@@ -967,20 +982,20 @@ mod tests {
assert_eq!(
format!("{:?}", plan),
"\
- Filter: #b LtEq Int64(1)\
- \n Join: a = a\
- \n Projection: #a, #b\
+ Filter: #test.b LtEq Int64(1)\
+ \n Join: #test.a = #test.a\
+ \n Projection: #test.a, #test.b\
\n TableScan: test projection=None\
- \n Projection: #a, #c\
+ \n Projection: #test.a, #test.c\
\n TableScan: test projection=None"
);
let expected = "\
- Join: a = a\
- \n Projection: #a, #b\
- \n Filter: #b LtEq Int64(1)\
+ Join: #test.a = #test.a\
+ \n Projection: #test.a, #test.b\
+ \n Filter: #test.b LtEq Int64(1)\
\n TableScan: test projection=None\
- \n Projection: #a, #c\
+ \n Projection: #test.a, #test.c\
\n TableScan: test projection=None";
assert_optimized_plan_eq(&plan, expected);
Ok(())
@@ -1030,14 +1045,15 @@ mod tests {
fn table_scan_with_pushdown_provider(
filter_support: TableProviderFilterPushDown,
) -> Result<LogicalPlan> {
+ use std::convert::TryFrom;
+
let test_provider = PushDownProvider { filter_support };
let table_scan = LogicalPlan::TableScan {
- table_name: "".into(),
+ table_name: "test".to_string(),
filters: vec![],
- projected_schema: Arc::new(DFSchema::try_from_qualified(
- "",
- &*test_provider.schema(),
+ projected_schema: Arc::new(DFSchema::try_from(
+ (*test_provider.schema()).clone(),
)?),
projection: None,
source: Arc::new(test_provider),
@@ -1054,7 +1070,7 @@ mod tests {
let plan = table_scan_with_pushdown_provider(TableProviderFilterPushDown::Exact)?;
let expected = "\
- TableScan: projection=None, filters=[#a Eq Int64(1)]";
+ TableScan: test projection=None, filters=[#a Eq Int64(1)]";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
@@ -1066,7 +1082,7 @@ mod tests {
let expected = "\
Filter: #a Eq Int64(1)\
- \n TableScan: projection=None, filters=[#a Eq Int64(1)]";
+ \n TableScan: test projection=None, filters=[#a Eq Int64(1)]";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
@@ -1080,7 +1096,7 @@ mod tests {
let expected = "\
Filter: #a Eq Int64(1)\
- \n TableScan: projection=None, filters=[#a Eq Int64(1)]";
+ \n TableScan: test projection=None, filters=[#a Eq Int64(1)]";
// Optimizing the same plan multiple times should produce the same plan
// each time.
@@ -1095,7 +1111,7 @@ mod tests {
let expected = "\
Filter: #a Eq Int64(1)\
- \n TableScan: projection=None";
+ \n TableScan: test projection=None";
assert_optimized_plan_eq(&plan, expected);
Ok(())
}
diff --git a/datafusion/src/optimizer/hash_build_probe_order.rs b/datafusion/src/optimizer/hash_build_probe_order.rs
index 74d2b00..a2a99ae 100644
--- a/datafusion/src/optimizer/hash_build_probe_order.rs
+++ b/datafusion/src/optimizer/hash_build_probe_order.rs
@@ -22,7 +22,7 @@
use std::sync::Arc;
-use crate::logical_plan::LogicalPlan;
+use crate::logical_plan::{Expr, LogicalPlan, LogicalPlanBuilder};
use crate::optimizer::optimizer::OptimizerRule;
use crate::{error::Result, prelude::JoinType};
@@ -131,6 +131,7 @@ impl OptimizerRule for HashBuildProbeOrder {
right,
on,
join_type,
+ join_constraint,
schema,
} => {
let left = self.optimize(left, execution_props)?;
@@ -140,11 +141,9 @@ impl OptimizerRule for HashBuildProbeOrder {
Ok(LogicalPlan::Join {
left: Arc::new(right),
right: Arc::new(left),
- on: on
- .iter()
- .map(|(l, r)| (r.to_string(), l.to_string()))
- .collect(),
+ on: on.iter().map(|(l, r)| (r.clone(), l.clone())).collect(),
join_type: swap_join_type(*join_type),
+ join_constraint: *join_constraint,
schema: schema.clone(),
})
} else {
@@ -154,6 +153,7 @@ impl OptimizerRule for HashBuildProbeOrder {
right: Arc::new(right),
on: on.clone(),
join_type: *join_type,
+ join_constraint: *join_constraint,
schema: schema.clone(),
})
}
@@ -166,12 +166,19 @@ impl OptimizerRule for HashBuildProbeOrder {
let left = self.optimize(left, execution_props)?;
let right = self.optimize(right, execution_props)?;
if should_swap_join_order(&left, &right) {
- // Swap left and right
- Ok(LogicalPlan::CrossJoin {
- left: Arc::new(right),
- right: Arc::new(left),
- schema: schema.clone(),
- })
+ let swapped = LogicalPlanBuilder::from(&right).cross_join(&left)?;
+ // wrap plan with projection to maintain column order
+ let left_cols = left
+ .schema()
+ .fields()
+ .iter()
+ .map(|f| Expr::Column(f.qualified_column()));
+ let right_cols = right
+ .schema()
+ .fields()
+ .iter()
+ .map(|f| Expr::Column(f.qualified_column()));
+ swapped.project(left_cols.chain(right_cols))?.build()
} else {
// Keep join as is
Ok(LogicalPlan::CrossJoin {
diff --git a/datafusion/src/optimizer/limit_push_down.rs b/datafusion/src/optimizer/limit_push_down.rs
index e616869..afd9937 100644
--- a/datafusion/src/optimizer/limit_push_down.rs
+++ b/datafusion/src/optimizer/limit_push_down.rs
@@ -163,7 +163,7 @@ mod test {
// Should push the limit down to table provider
// When it has a select
let expected = "Limit: 1000\
- \n Projection: #a\
+ \n Projection: #test.a\
\n TableScan: test projection=None, limit=1000";
assert_optimized_plan_eq(&plan, expected);
@@ -202,7 +202,7 @@ mod test {
// Limit should *not* push down aggregate node
let expected = "Limit: 1000\
- \n Aggregate: groupBy=[[#a]], aggr=[[MAX(#b)]]\
+ \n Aggregate: groupBy=[[#test.a]], aggr=[[MAX(#test.b)]]\
\n TableScan: test projection=None";
assert_optimized_plan_eq(&plan, expected);
@@ -244,7 +244,7 @@ mod test {
// Limit should use deeper LIMIT 1000, but Limit 10 shouldn't push down aggregation
let expected = "Limit: 10\
- \n Aggregate: groupBy=[[#a]], aggr=[[MAX(#b)]]\
+ \n Aggregate: groupBy=[[#test.a]], aggr=[[MAX(#test.b)]]\
\n Limit: 1000\
\n TableScan: test projection=None, limit=1000";
diff --git a/datafusion/src/optimizer/projection_push_down.rs b/datafusion/src/optimizer/projection_push_down.rs
index ad795f5..2544d89 100644
--- a/datafusion/src/optimizer/projection_push_down.rs
+++ b/datafusion/src/optimizer/projection_push_down.rs
@@ -20,11 +20,14 @@
use crate::error::Result;
use crate::execution::context::ExecutionProps;
-use crate::logical_plan::{DFField, DFSchema, DFSchemaRef, LogicalPlan, ToDFSchema};
+use crate::logical_plan::{
+ build_join_schema, Column, DFField, DFSchema, DFSchemaRef, LogicalPlan,
+ LogicalPlanBuilder, ToDFSchema,
+};
use crate::optimizer::optimizer::OptimizerRule;
use crate::optimizer::utils;
use crate::sql::utils::find_sort_exprs;
-use arrow::datatypes::Schema;
+use arrow::datatypes::{Field, Schema};
use arrow::error::Result as ArrowResult;
use std::{collections::HashSet, sync::Arc};
use utils::optimize_explain;
@@ -44,8 +47,8 @@ impl OptimizerRule for ProjectionPushDown {
.schema()
.fields()
.iter()
- .map(|f| f.name().clone())
- .collect::<HashSet<String>>();
+ .map(|f| f.qualified_column())
+ .collect::<HashSet<Column>>();
optimize_plan(self, plan, &required_columns, false, execution_props)
}
@@ -62,8 +65,9 @@ impl ProjectionPushDown {
}
fn get_projected_schema(
+ table_name: Option<&String>,
schema: &Schema,
- required_columns: &HashSet<String>,
+ required_columns: &HashSet<Column>,
has_projection: bool,
) -> Result<(Vec<usize>, DFSchemaRef)> {
// once we reach the table scan, we can use the accumulated set of column
@@ -73,7 +77,8 @@ fn get_projected_schema(
// e.g. when the column derives from an aggregation
let mut projection: Vec<usize> = required_columns
.iter()
- .map(|name| schema.index_of(name))
+ .filter(|c| c.relation.as_ref() == table_name)
+ .map(|c| schema.index_of(&c.name))
.filter_map(ArrowResult::ok)
.collect();
@@ -98,8 +103,20 @@ fn get_projected_schema(
// create the projected schema
let mut projected_fields: Vec<DFField> = Vec::with_capacity(projection.len());
- for i in &projection {
- projected_fields.push(DFField::from(schema.fields()[*i].clone()));
+ match table_name {
+ Some(qualifer) => {
+ for i in &projection {
+ projected_fields.push(DFField::from_qualified(
+ qualifer,
+ schema.fields()[*i].clone(),
+ ));
+ }
+ }
+ None => {
+ for i in &projection {
+ projected_fields.push(DFField::from(schema.fields()[*i].clone()));
+ }
+ }
}
Ok((projection, projected_fields.to_dfschema_ref()?))
@@ -109,7 +126,7 @@ fn get_projected_schema(
fn optimize_plan(
optimizer: &ProjectionPushDown,
plan: &LogicalPlan,
- required_columns: &HashSet<String>, // set of columns required up to this step
+ required_columns: &HashSet<Column>, // set of columns required up to this step
has_projection: bool,
execution_props: &ExecutionProps,
) -> Result<LogicalPlan> {
@@ -133,12 +150,12 @@ fn optimize_plan(
.iter()
.enumerate()
.try_for_each(|(i, field)| {
- if required_columns.contains(field.name()) {
+ if required_columns.contains(&field.qualified_column()) {
new_expr.push(expr[i].clone());
new_fields.push(field.clone());
// gather the new set of required columns
- utils::expr_to_column_names(&expr[i], &mut new_required_columns)
+ utils::expr_to_columns(&expr[i], &mut new_required_columns)
} else {
Ok(())
}
@@ -167,31 +184,45 @@ fn optimize_plan(
right,
on,
join_type,
- schema,
+ join_constraint,
+ ..
} => {
for (l, r) in on {
- new_required_columns.insert(l.to_owned());
- new_required_columns.insert(r.to_owned());
+ new_required_columns.insert(l.clone());
+ new_required_columns.insert(r.clone());
}
- Ok(LogicalPlan::Join {
- left: Arc::new(optimize_plan(
- optimizer,
- left,
- &new_required_columns,
- true,
- execution_props,
- )?),
- right: Arc::new(optimize_plan(
- optimizer,
- right,
- &new_required_columns,
- true,
- execution_props,
- )?),
+ let optimized_left = Arc::new(optimize_plan(
+ optimizer,
+ left,
+ &new_required_columns,
+ true,
+ execution_props,
+ )?);
+
+ let optimized_right = Arc::new(optimize_plan(
+ optimizer,
+ right,
+ &new_required_columns,
+ true,
+ execution_props,
+ )?);
+
+ let schema = build_join_schema(
+ optimized_left.schema(),
+ optimized_right.schema(),
+ on,
+ join_type,
+ join_constraint,
+ )?;
+
+ Ok(LogicalPlan::Join {
+ left: optimized_left,
+ right: optimized_right,
join_type: *join_type,
+ join_constraint: *join_constraint,
on: on.clone(),
- schema: schema.clone(),
+ schema: DFSchemaRef::new(schema),
})
}
LogicalPlan::Window {
@@ -205,11 +236,12 @@ fn optimize_plan(
{
window_expr.iter().try_for_each(|expr| {
let name = &expr.name(schema)?;
- if required_columns.contains(name) {
+ let column = Column::from_name(name.to_string());
+ if required_columns.contains(&column) {
new_window_expr.push(expr.clone());
- new_required_columns.insert(name.clone());
+ new_required_columns.insert(column);
// add to the new set of required columns
- utils::expr_to_column_names(expr, &mut new_required_columns)
+ utils::expr_to_columns(expr, &mut new_required_columns)
} else {
Ok(())
}
@@ -217,31 +249,20 @@ fn optimize_plan(
}
// for all the retained window expr, find their sort expressions if any, and retain these
- utils::exprlist_to_column_names(
+ utils::exprlist_to_columns(
&find_sort_exprs(&new_window_expr),
&mut new_required_columns,
)?;
- let new_schema = DFSchema::new(
- schema
- .fields()
- .iter()
- .filter(|x| new_required_columns.contains(x.name()))
- .cloned()
- .collect(),
- )?;
-
- Ok(LogicalPlan::Window {
- window_expr: new_window_expr,
- input: Arc::new(optimize_plan(
- optimizer,
- input,
- &new_required_columns,
- true,
- execution_props,
- )?),
- schema: DFSchemaRef::new(new_schema),
- })
+ LogicalPlanBuilder::from(&optimize_plan(
+ optimizer,
+ input,
+ &new_required_columns,
+ true,
+ execution_props,
+ )?)
+ .window(new_window_expr)?
+ .build()
}
LogicalPlan::Aggregate {
schema,
@@ -254,19 +275,20 @@ fn optimize_plan(
// * remove any aggregate expression that is not required
// * construct the new set of required columns
- utils::exprlist_to_column_names(group_expr, &mut new_required_columns)?;
+ utils::exprlist_to_columns(group_expr, &mut new_required_columns)?;
// Gather all columns needed for expressions in this Aggregate
let mut new_aggr_expr = Vec::new();
aggr_expr.iter().try_for_each(|expr| {
let name = &expr.name(schema)?;
+ let column = Column::from_name(name.to_string());
- if required_columns.contains(name) {
+ if required_columns.contains(&column) {
new_aggr_expr.push(expr.clone());
- new_required_columns.insert(name.clone());
+ new_required_columns.insert(column);
// add to the new set of required columns
- utils::expr_to_column_names(expr, &mut new_required_columns)
+ utils::expr_to_columns(expr, &mut new_required_columns)
} else {
Ok(())
}
@@ -276,7 +298,7 @@ fn optimize_plan(
schema
.fields()
.iter()
- .filter(|x| new_required_columns.contains(x.name()))
+ .filter(|x| new_required_columns.contains(&x.qualified_column()))
.cloned()
.collect(),
)?;
@@ -303,12 +325,15 @@ fn optimize_plan(
limit,
..
} => {
- let (projection, projected_schema) =
- get_projected_schema(&source.schema(), required_columns, has_projection)?;
-
+ let (projection, projected_schema) = get_projected_schema(
+ Some(table_name),
+ &source.schema(),
+ required_columns,
+ has_projection,
+ )?;
// return the table scan with projection
Ok(LogicalPlan::TableScan {
- table_name: table_name.to_string(),
+ table_name: table_name.clone(),
source: source.clone(),
projection: Some(projection),
projected_schema,
@@ -332,6 +357,48 @@ fn optimize_plan(
execution_props,
)
}
+ LogicalPlan::Union {
+ inputs,
+ schema,
+ alias,
+ } => {
+ // UNION inputs will reference the same column with different identifiers, so we need
+ // to populate new_required_columns by unqualified column name based on required fields
+ // from the resulting UNION output
+ let union_required_fields = schema
+ .fields()
+ .iter()
+ .filter(|f| new_required_columns.contains(&f.qualified_column()))
+ .map(|f| f.field())
+ .collect::<HashSet<&Field>>();
+
+ let new_inputs = inputs
+ .iter()
+ .map(|input_plan| {
+ input_plan
+ .schema()
+ .fields()
+ .iter()
+ .filter(|f| union_required_fields.contains(f.field()))
+ .for_each(|f| {
+ new_required_columns.insert(f.qualified_column());
+ });
+ optimize_plan(
+ optimizer,
+ input_plan,
+ &new_required_columns,
+ has_projection,
+ execution_props,
+ )
+ })
+ .collect::<Result<Vec<_>>>()?;
+
+ Ok(LogicalPlan::Union {
+ inputs: new_inputs,
+ schema: schema.clone(),
+ alias: alias.clone(),
+ })
+ }
// all other nodes: Add any additional columns used by
// expressions in this node to the list of required columns
LogicalPlan::Limit { .. }
@@ -340,21 +407,20 @@ fn optimize_plan(
| LogicalPlan::EmptyRelation { .. }
| LogicalPlan::Sort { .. }
| LogicalPlan::CreateExternalTable { .. }
- | LogicalPlan::Union { .. }
| LogicalPlan::CrossJoin { .. }
| LogicalPlan::Extension { .. } => {
let expr = plan.expressions();
// collect all required columns by this plan
- utils::exprlist_to_column_names(&expr, &mut new_required_columns)?;
+ utils::exprlist_to_columns(&expr, &mut new_required_columns)?;
// apply the optimization to all inputs of the plan
let inputs = plan.inputs();
let new_inputs = inputs
.iter()
- .map(|plan| {
+ .map(|input_plan| {
optimize_plan(
optimizer,
- plan,
+ input_plan,
&new_required_columns,
has_projection,
execution_props,
@@ -371,8 +437,7 @@ fn optimize_plan(
mod tests {
use super::*;
- use crate::logical_plan::{col, lit};
- use crate::logical_plan::{max, min, Expr, LogicalPlanBuilder};
+ use crate::logical_plan::{col, lit, max, min, Expr, JoinType, LogicalPlanBuilder};
use crate::test::*;
use arrow::datatypes::DataType;
@@ -384,7 +449,7 @@ mod tests {
.aggregate(vec![], vec![max(col("b"))])?
.build()?;
- let expected = "Aggregate: groupBy=[[]], aggr=[[MAX(#b)]]\
+ let expected = "Aggregate: groupBy=[[]], aggr=[[MAX(#test.b)]]\
\n TableScan: test projection=Some([1])";
assert_optimized_plan_eq(&plan, expected);
@@ -400,7 +465,7 @@ mod tests {
.aggregate(vec![col("c")], vec![max(col("b"))])?
.build()?;
- let expected = "Aggregate: groupBy=[[#c]], aggr=[[MAX(#b)]]\
+ let expected = "Aggregate: groupBy=[[#test.c]], aggr=[[MAX(#test.b)]]\
\n TableScan: test projection=Some([1, 2])";
assert_optimized_plan_eq(&plan, expected);
@@ -417,8 +482,8 @@ mod tests {
.aggregate(vec![], vec![max(col("b"))])?
.build()?;
- let expected = "Aggregate: groupBy=[[]], aggr=[[MAX(#b)]]\
- \n Filter: #c\
+ let expected = "Aggregate: groupBy=[[]], aggr=[[MAX(#test.b)]]\
+ \n Filter: #test.c\
\n TableScan: test projection=Some([1, 2])";
assert_optimized_plan_eq(&plan, expected);
@@ -427,6 +492,43 @@ mod tests {
}
#[test]
+ fn join_schema_trim() -> Result<()> {
+ let table_scan = test_table_scan()?;
+
+ let schema = Schema::new(vec![Field::new("c1", DataType::UInt32, false)]);
+ let table2_scan =
+ LogicalPlanBuilder::scan_empty(Some("test2"), &schema, None)?.build()?;
+
+ let plan = LogicalPlanBuilder::from(&table_scan)
+ .join(&table2_scan, JoinType::Left, vec!["a"], vec!["c1"])?
+ .project(vec![col("a"), col("b"), col("c1")])?
+ .build()?;
+
+ // make sure projections are pushed down to table scan
+ let expected = "Projection: #test.a, #test.b, #test2.c1\
+ \n Join: #test.a = #test2.c1\
+ \n TableScan: test projection=Some([0, 1])\
+ \n TableScan: test2 projection=Some([0])";
+
+ let optimized_plan = optimize(&plan)?;
+ let formatted_plan = format!("{:?}", optimized_plan);
+ assert_eq!(formatted_plan, expected);
+
+ // make sure schema for join node doesn't include c1 column
+ let optimized_join = optimized_plan.inputs()[0];
+ assert_eq!(
+ **optimized_join.schema(),
+ DFSchema::new(vec![
+ DFField::new(Some("test"), "a", DataType::UInt32, false),
+ DFField::new(Some("test"), "b", DataType::UInt32, false),
+ DFField::new(Some("test2"), "c1", DataType::UInt32, false),
+ ])?,
+ );
+
+ Ok(())
+ }
+
+ #[test]
fn cast() -> Result<()> {
let table_scan = test_table_scan()?;
@@ -437,7 +539,7 @@ mod tests {
}])?
.build()?;
- let expected = "Projection: CAST(#c AS Float64)\
+ let expected = "Projection: CAST(#test.c AS Float64)\
\n TableScan: test projection=Some([2])";
assert_optimized_plan_eq(&projection, expected);
@@ -457,7 +559,7 @@ mod tests {
assert_fields_eq(&plan, vec!["a", "b"]);
- let expected = "Projection: #a, #b\
+ let expected = "Projection: #test.a, #test.b\
\n TableScan: test projection=Some([0, 1])";
assert_optimized_plan_eq(&plan, expected);
@@ -479,7 +581,7 @@ mod tests {
assert_fields_eq(&plan, vec!["c", "a"]);
let expected = "Limit: 5\
- \n Projection: #c, #a\
+ \n Projection: #test.c, #test.a\
\n TableScan: test projection=Some([0, 2])";
assert_optimized_plan_eq(&plan, expected);
@@ -523,12 +625,12 @@ mod tests {
.aggregate(vec![col("c")], vec![max(col("a"))])?
.build()?;
- assert_fields_eq(&plan, vec!["c", "MAX(a)"]);
+ assert_fields_eq(&plan, vec!["c", "MAX(test.a)"]);
let expected = "\
- Aggregate: groupBy=[[#c]], aggr=[[MAX(#a)]]\
- \n Filter: #c Gt Int32(1)\
- \n Projection: #c, #a\
+ Aggregate: groupBy=[[#test.c]], aggr=[[MAX(#test.a)]]\
+ \n Filter: #test.c Gt Int32(1)\
+ \n Projection: #test.c, #test.a\
\n TableScan: test projection=Some([0, 2])";
assert_optimized_plan_eq(&plan, expected);
@@ -591,15 +693,15 @@ mod tests {
let plan = LogicalPlanBuilder::from(&table_scan)
.aggregate(vec![col("a"), col("c")], vec![max(col("b")), min(col("b"))])?
.filter(col("c").gt(lit(1)))?
- .project(vec![col("c"), col("a"), col("MAX(b)")])?
+ .project(vec![col("c"), col("a"), col("MAX(test.b)")])?
.build()?;
- assert_fields_eq(&plan, vec!["c", "a", "MAX(b)"]);
+ assert_fields_eq(&plan, vec!["c", "a", "MAX(test.b)"]);
let expected = "\
- Projection: #c, #a, #MAX(b)\
- \n Filter: #c Gt Int32(1)\
- \n Aggregate: groupBy=[[#a, #c]], aggr=[[MAX(#b)]]\
+ Projection: #test.c, #test.a, #MAX(test.b)\
+ \n Filter: #test.c Gt Int32(1)\
+ \n Aggregate: groupBy=[[#test.a, #test.c]], aggr=[[MAX(#test.b)]]\
\n TableScan: test projection=Some([0, 1, 2])";
assert_optimized_plan_eq(&plan, expected);
diff --git a/datafusion/src/optimizer/simplify_expressions.rs b/datafusion/src/optimizer/simplify_expressions.rs
index 9ad7a94..4253d2f 100644
--- a/datafusion/src/optimizer/simplify_expressions.rs
+++ b/datafusion/src/optimizer/simplify_expressions.rs
@@ -510,8 +510,8 @@ mod tests {
assert_optimized_plan_eq(
&plan,
"\
- Filter: #b Gt Int32(1)\
- \n Projection: #a\
+ Filter: #test.b Gt Int32(1)\
+ \n Projection: #test.a\
\n TableScan: test projection=None",
);
Ok(())
@@ -532,8 +532,8 @@ mod tests {
assert_optimized_plan_eq(
&plan,
"\
- Filter: #a Gt Int32(5) And #b Lt Int32(6)\
- \n Projection: #a\
+ Filter: #test.a Gt Int32(5) And #test.b Lt Int32(6)\
+ \n Projection: #test.a\
\n TableScan: test projection=None",
);
Ok(())
diff --git a/datafusion/src/optimizer/utils.rs b/datafusion/src/optimizer/utils.rs
index 014ec74..76f44b8 100644
--- a/datafusion/src/optimizer/utils.rs
+++ b/datafusion/src/optimizer/utils.rs
@@ -24,8 +24,8 @@ use arrow::datatypes::Schema;
use super::optimizer::OptimizerRule;
use crate::execution::context::ExecutionProps;
use crate::logical_plan::{
- Expr, LogicalPlan, Operator, Partitioning, PlanType, Recursion, StringifiedPlan,
- ToDFSchema,
+ build_join_schema, Column, DFSchemaRef, Expr, LogicalPlan, LogicalPlanBuilder,
+ Operator, Partitioning, PlanType, Recursion, StringifiedPlan, ToDFSchema,
};
use crate::prelude::lit;
use crate::scalar::ScalarValue;
@@ -39,14 +39,11 @@ const CASE_ELSE_MARKER: &str = "__DATAFUSION_CASE_ELSE__";
const WINDOW_PARTITION_MARKER: &str = "__DATAFUSION_WINDOW_PARTITION__";
const WINDOW_SORT_MARKER: &str = "__DATAFUSION_WINDOW_SORT__";
-/// Recursively walk a list of expression trees, collecting the unique set of column
-/// names referenced in the expression
-pub fn exprlist_to_column_names(
- expr: &[Expr],
- accum: &mut HashSet<String>,
-) -> Result<()> {
+/// Recursively walk a list of expression trees, collecting the unique set of columns
+/// referenced in the expression
+pub fn exprlist_to_columns(expr: &[Expr], accum: &mut HashSet<Column>) -> Result<()> {
for e in expr {
- expr_to_column_names(e, accum)?;
+ expr_to_columns(e, accum)?;
}
Ok(())
}
@@ -54,17 +51,17 @@ pub fn exprlist_to_column_names(
/// Recursively walk an expression tree, collecting the unique set of column names
/// referenced in the expression
struct ColumnNameVisitor<'a> {
- accum: &'a mut HashSet<String>,
+ accum: &'a mut HashSet<Column>,
}
impl ExpressionVisitor for ColumnNameVisitor<'_> {
fn pre_visit(self, expr: &Expr) -> Result<Recursion<Self>> {
match expr {
- Expr::Column(name) => {
- self.accum.insert(name.clone());
+ Expr::Column(qc) => {
+ self.accum.insert(qc.clone());
}
Expr::ScalarVariable(var_names) => {
- self.accum.insert(var_names.join("."));
+ self.accum.insert(Column::from_name(var_names.join(".")));
}
Expr::Alias(_, _) => {}
Expr::Literal(_) => {}
@@ -90,9 +87,9 @@ impl ExpressionVisitor for ColumnNameVisitor<'_> {
}
}
-/// Recursively walk an expression tree, collecting the unique set of column names
+/// Recursively walk an expression tree, collecting the unique set of columns
/// referenced in the expression
-pub fn expr_to_column_names(expr: &Expr, accum: &mut HashSet<String>) -> Result<()> {
+pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet<Column>) -> Result<()> {
expr.accept(ColumnNameVisitor { accum })?;
Ok(())
}
@@ -214,21 +211,31 @@ pub fn from_plan(
}),
LogicalPlan::Join {
join_type,
+ join_constraint,
on,
- schema,
..
- } => Ok(LogicalPlan::Join {
- left: Arc::new(inputs[0].clone()),
- right: Arc::new(inputs[1].clone()),
- join_type: *join_type,
- on: on.clone(),
- schema: schema.clone(),
- }),
- LogicalPlan::CrossJoin { schema, .. } => Ok(LogicalPlan::CrossJoin {
- left: Arc::new(inputs[0].clone()),
- right: Arc::new(inputs[1].clone()),
- schema: schema.clone(),
- }),
+ } => {
+ let schema = build_join_schema(
+ inputs[0].schema(),
+ inputs[1].schema(),
+ on,
+ join_type,
+ join_constraint,
+ )?;
+ Ok(LogicalPlan::Join {
+ left: Arc::new(inputs[0].clone()),
+ right: Arc::new(inputs[1].clone()),
+ join_type: *join_type,
+ join_constraint: *join_constraint,
+ on: on.clone(),
+ schema: DFSchemaRef::new(schema),
+ })
+ }
+ LogicalPlan::CrossJoin { .. } => {
+ let left = &inputs[0];
+ let right = &inputs[1];
+ LogicalPlanBuilder::from(left).cross_join(right)?.build()
+ }
LogicalPlan::Limit { n, .. } => Ok(LogicalPlan::Limit {
n: *n,
input: Arc::new(inputs[0].clone()),
@@ -493,15 +500,15 @@ mod tests {
#[test]
fn test_collect_expr() -> Result<()> {
- let mut accum: HashSet<String> = HashSet::new();
- expr_to_column_names(
+ let mut accum: HashSet<Column> = HashSet::new();
+ expr_to_columns(
&Expr::Cast {
expr: Box::new(col("a")),
data_type: DataType::Float64,
},
&mut accum,
)?;
- expr_to_column_names(
+ expr_to_columns(
&Expr::Cast {
expr: Box::new(col("a")),
data_type: DataType::Float64,
@@ -509,7 +516,7 @@ mod tests {
&mut accum,
)?;
assert_eq!(1, accum.len());
- assert!(accum.contains("a"));
+ assert!(accum.contains(&Column::from_name("a".to_string())));
Ok(())
}
diff --git a/datafusion/src/physical_optimizer/pruning.rs b/datafusion/src/physical_optimizer/pruning.rs
index 9e8d9fa..5585c4d 100644
--- a/datafusion/src/physical_optimizer/pruning.rs
+++ b/datafusion/src/physical_optimizer/pruning.rs
@@ -28,6 +28,7 @@
//! https://github.com/apache/arrow-datafusion/issues/363 it will
//! be genericized.
+use std::convert::TryFrom;
use std::{collections::HashSet, sync::Arc};
use arrow::{
@@ -39,7 +40,7 @@ use arrow::{
use crate::{
error::{DataFusionError, Result},
execution::context::ExecutionContextState,
- logical_plan::{Expr, Operator},
+ logical_plan::{Column, DFSchema, Expr, Operator},
optimizer::utils,
physical_plan::{planner::DefaultPhysicalPlanner, ColumnarValue, PhysicalExpr},
};
@@ -65,11 +66,11 @@ use crate::{
pub trait PruningStatistics {
/// return the minimum values for the named column, if known.
/// Note: the returned array must contain `num_containers()` rows
- fn min_values(&self, column: &str) -> Option<ArrayRef>;
+ fn min_values(&self, column: &Column) -> Option<ArrayRef>;
/// return the maximum values for the named column, if known.
/// Note: the returned array must contain `num_containers()` rows.
- fn max_values(&self, column: &str) -> Option<ArrayRef>;
+ fn max_values(&self, column: &Column) -> Option<ArrayRef>;
/// return the number of containers (e.g. row groups) being
/// pruned with these statistics
@@ -120,9 +121,11 @@ impl PruningPredicate {
.map(|(_, _, f)| f.clone())
.collect::<Vec<_>>();
let stat_schema = Schema::new(stat_fields);
+ let stat_dfschema = DFSchema::try_from(stat_schema.clone())?;
let execution_context_state = ExecutionContextState::new();
let predicate_expr = DefaultPhysicalPlanner::default().create_physical_expr(
&logical_predicate_expr,
+ &stat_dfschema,
&stat_schema,
&execution_context_state,
)?;
@@ -196,11 +199,11 @@ impl PruningPredicate {
#[derive(Debug, Default, Clone)]
struct RequiredStatColumns {
/// The statistics required to evaluate this predicate:
- /// * The column name in the input schema
+ /// * The unqualified column in the input schema
/// * Statistics type (e.g. Min or Max)
/// * The field the statistics value should be placed in for
/// pruning predicate evaluation
- columns: Vec<(String, StatisticsType, Field)>,
+ columns: Vec<(Column, StatisticsType, Field)>,
}
impl RequiredStatColumns {
@@ -210,22 +213,22 @@ impl RequiredStatColumns {
/// Retur an iterator over items in columns (see doc on
/// `self.columns` for details)
- fn iter(&self) -> impl Iterator<Item = &(String, StatisticsType, Field)> {
+ fn iter(&self) -> impl Iterator<Item = &(Column, StatisticsType, Field)> {
self.columns.iter()
}
fn is_stat_column_missing(
&self,
- column_name: &str,
+ column: &Column,
statistics_type: StatisticsType,
) -> bool {
!self
.columns
.iter()
- .any(|(c, t, _f)| c == column_name && t == &statistics_type)
+ .any(|(c, t, _f)| c == column && t == &statistics_type)
}
- /// Rewrites column_expr so that all appearances of column_name
+ /// Rewrites column_expr so that all appearances of column
/// are replaced with a reference to either the min or max
/// statistics column, while keeping track that a reference to the statistics
/// column is required
@@ -235,49 +238,53 @@ impl RequiredStatColumns {
/// 5` with the approprate entry noted in self.columns
fn stat_column_expr(
&mut self,
- column_name: &str,
+ column: &Column,
column_expr: &Expr,
field: &Field,
stat_type: StatisticsType,
suffix: &str,
) -> Result<Expr> {
- let stat_column_name = format!("{}_{}", column_name, suffix);
+ let stat_column = Column {
+ relation: column.relation.clone(),
+ name: format!("{}_{}", column.flat_name(), suffix),
+ };
+
let stat_field = Field::new(
- stat_column_name.as_str(),
+ stat_column.flat_name().as_str(),
field.data_type().clone(),
field.is_nullable(),
);
- if self.is_stat_column_missing(column_name, stat_type) {
+
+ if self.is_stat_column_missing(column, stat_type) {
// only add statistics column if not previously added
- self.columns
- .push((column_name.to_string(), stat_type, stat_field));
+ self.columns.push((column.clone(), stat_type, stat_field));
}
- rewrite_column_expr(column_expr, column_name, stat_column_name.as_str())
+ rewrite_column_expr(column_expr, column, &stat_column)
}
/// rewrite col --> col_min
fn min_column_expr(
&mut self,
- column_name: &str,
+ column: &Column,
column_expr: &Expr,
field: &Field,
) -> Result<Expr> {
- self.stat_column_expr(column_name, column_expr, field, StatisticsType::Min, "min")
+ self.stat_column_expr(column, column_expr, field, StatisticsType::Min, "min")
}
/// rewrite col --> col_max
fn max_column_expr(
&mut self,
- column_name: &str,
+ column: &Column,
column_expr: &Expr,
field: &Field,
) -> Result<Expr> {
- self.stat_column_expr(column_name, column_expr, field, StatisticsType::Max, "max")
+ self.stat_column_expr(column, column_expr, field, StatisticsType::Max, "max")
}
}
-impl From<Vec<(String, StatisticsType, Field)>> for RequiredStatColumns {
- fn from(columns: Vec<(String, StatisticsType, Field)>) -> Self {
+impl From<Vec<(Column, StatisticsType, Field)>> for RequiredStatColumns {
+ fn from(columns: Vec<(Column, StatisticsType, Field)>) -> Self {
Self { columns }
}
}
@@ -314,14 +321,14 @@ fn build_statistics_record_batch<S: PruningStatistics>(
let mut fields = Vec::<Field>::new();
let mut arrays = Vec::<ArrayRef>::new();
// For each needed statistics column:
- for (column_name, statistics_type, stat_field) in required_columns.iter() {
+ for (column, statistics_type, stat_field) in required_columns.iter() {
let data_type = stat_field.data_type();
let num_containers = statistics.num_containers();
let array = match statistics_type {
- StatisticsType::Min => statistics.min_values(column_name),
- StatisticsType::Max => statistics.max_values(column_name),
+ StatisticsType::Min => statistics.min_values(column),
+ StatisticsType::Max => statistics.max_values(column),
};
let array = array.unwrap_or_else(|| new_null_array(data_type, num_containers));
@@ -347,7 +354,7 @@ fn build_statistics_record_batch<S: PruningStatistics>(
}
struct PruningExpressionBuilder<'a> {
- column_name: String,
+ column: Column,
column_expr: &'a Expr,
scalar_expr: &'a Expr,
field: &'a Field,
@@ -363,11 +370,11 @@ impl<'a> PruningExpressionBuilder<'a> {
required_columns: &'a mut RequiredStatColumns,
) -> Result<Self> {
// find column name; input could be a more complicated expression
- let mut left_columns = HashSet::<String>::new();
- utils::expr_to_column_names(left, &mut left_columns)?;
- let mut right_columns = HashSet::<String>::new();
- utils::expr_to_column_names(right, &mut right_columns)?;
- let (column_expr, scalar_expr, column_names, reverse_operator) =
+ let mut left_columns = HashSet::<Column>::new();
+ utils::expr_to_columns(left, &mut left_columns)?;
+ let mut right_columns = HashSet::<Column>::new();
+ utils::expr_to_columns(right, &mut right_columns)?;
+ let (column_expr, scalar_expr, columns, reverse_operator) =
match (left_columns.len(), right_columns.len()) {
(1, 0) => (left, right, left_columns, false),
(0, 1) => (right, left, right_columns, true),
@@ -379,8 +386,8 @@ impl<'a> PruningExpressionBuilder<'a> {
));
}
};
- let column_name = column_names.iter().next().unwrap().clone();
- let field = match schema.column_with_name(&column_name) {
+ let column = columns.iter().next().unwrap().clone();
+ let field = match schema.column_with_name(&column.flat_name()) {
Some((_, f)) => f,
_ => {
return Err(DataFusionError::Plan(
@@ -390,7 +397,7 @@ impl<'a> PruningExpressionBuilder<'a> {
};
Ok(Self {
- column_name,
+ column,
column_expr,
scalar_expr,
field,
@@ -418,63 +425,56 @@ impl<'a> PruningExpressionBuilder<'a> {
}
fn min_column_expr(&mut self) -> Result<Expr> {
- self.required_columns.min_column_expr(
- &self.column_name,
- self.column_expr,
- self.field,
- )
+ self.required_columns
+ .min_column_expr(&self.column, self.column_expr, self.field)
}
fn max_column_expr(&mut self) -> Result<Expr> {
- self.required_columns.max_column_expr(
- &self.column_name,
- self.column_expr,
- self.field,
- )
+ self.required_columns
+ .max_column_expr(&self.column, self.column_expr, self.field)
}
}
/// replaces a column with an old name with a new name in an expression
fn rewrite_column_expr(
expr: &Expr,
- column_old_name: &str,
- column_new_name: &str,
+ column_old: &Column,
+ column_new: &Column,
) -> Result<Expr> {
let expressions = utils::expr_sub_expressions(expr)?;
let expressions = expressions
.iter()
- .map(|e| rewrite_column_expr(e, column_old_name, column_new_name))
+ .map(|e| rewrite_column_expr(e, column_old, column_new))
.collect::<Result<Vec<_>>>()?;
- if let Expr::Column(name) = expr {
- if name == column_old_name {
- return Ok(Expr::Column(column_new_name.to_string()));
+ if let Expr::Column(c) = expr {
+ if c == column_old {
+ return Ok(Expr::Column(column_new.clone()));
}
}
utils::rewrite_expression(expr, &expressions)
}
-/// Given a column reference to `column_name`, returns a pruning
+/// Given a column reference to `column`, returns a pruning
/// expression in terms of the min and max that will evaluate to true
/// if the column may contain values, and false if definitely does not
/// contain values
fn build_single_column_expr(
- column_name: &str,
+ column: &Column,
schema: &Schema,
required_columns: &mut RequiredStatColumns,
is_not: bool, // if true, treat as !col
) -> Option<Expr> {
- use crate::logical_plan;
- let field = schema.field_with_name(column_name).ok()?;
+ let field = schema.field_with_name(&column.name).ok()?;
if matches!(field.data_type(), &DataType::Boolean) {
- let col_ref = logical_plan::col(column_name);
+ let col_ref = Expr::Column(column.clone());
let min = required_columns
- .min_column_expr(column_name, &col_ref, field)
+ .min_column_expr(column, &col_ref, field)
.ok()?;
let max = required_columns
- .max_column_expr(column_name, &col_ref, field)
+ .max_column_expr(column, &col_ref, field)
.ok()?;
// remember -- we want an expression that is:
@@ -514,15 +514,15 @@ fn build_predicate_expression(
// predicate expression can only be a binary expression
let (left, op, right) = match expr {
Expr::BinaryExpr { left, op, right } => (left, *op, right),
- Expr::Column(name) => {
- let expr = build_single_column_expr(name, schema, required_columns, false)
+ Expr::Column(col) => {
+ let expr = build_single_column_expr(col, schema, required_columns, false)
.unwrap_or(unhandled);
return Ok(expr);
}
// match !col (don't do so recursively)
Expr::Not(input) => {
- if let Expr::Column(name) = input.as_ref() {
- let expr = build_single_column_expr(name, schema, required_columns, true)
+ if let Expr::Column(col) = input.as_ref() {
+ let expr = build_single_column_expr(col, schema, required_columns, true)
.unwrap_or(unhandled);
return Ok(expr);
} else {
@@ -674,7 +674,7 @@ mod tests {
#[derive(Debug, Default)]
struct TestStatistics {
// key: column name
- stats: HashMap<String, ContainerStats>,
+ stats: HashMap<Column, ContainerStats>,
}
impl TestStatistics {
@@ -687,20 +687,21 @@ mod tests {
name: impl Into<String>,
container_stats: ContainerStats,
) -> Self {
- self.stats.insert(name.into(), container_stats);
+ self.stats
+ .insert(Column::from_name(name.into()), container_stats);
self
}
}
impl PruningStatistics for TestStatistics {
- fn min_values(&self, column: &str) -> Option<ArrayRef> {
+ fn min_values(&self, column: &Column) -> Option<ArrayRef> {
self.stats
.get(column)
.map(|container_stats| container_stats.min())
.unwrap_or(None)
}
- fn max_values(&self, column: &str) -> Option<ArrayRef> {
+ fn max_values(&self, column: &Column) -> Option<ArrayRef> {
self.stats
.get(column)
.map(|container_stats| container_stats.max())
@@ -724,11 +725,11 @@ mod tests {
}
impl PruningStatistics for OneContainerStats {
- fn min_values(&self, _column: &str) -> Option<ArrayRef> {
+ fn min_values(&self, _column: &Column) -> Option<ArrayRef> {
self.min_values.clone()
}
- fn max_values(&self, _column: &str) -> Option<ArrayRef> {
+ fn max_values(&self, _column: &Column) -> Option<ArrayRef> {
self.max_values.clone()
}
@@ -743,25 +744,25 @@ mod tests {
let required_columns = RequiredStatColumns::from(vec![
// min of original column s1, named s1_min
(
- "s1".to_string(),
+ "s1".into(),
StatisticsType::Min,
Field::new("s1_min", DataType::Int32, true),
),
// max of original column s2, named s2_max
(
- "s2".to_string(),
+ "s2".into(),
StatisticsType::Max,
Field::new("s2_max", DataType::Int32, true),
),
// max of original column s3, named s3_max
(
- "s3".to_string(),
+ "s3".into(),
StatisticsType::Max,
Field::new("s3_max", DataType::Utf8, true),
),
// min of original column s3, named s3_min
(
- "s3".to_string(),
+ "s3".into(),
StatisticsType::Min,
Field::new("s3_min", DataType::Utf8, true),
),
@@ -813,7 +814,7 @@ mod tests {
// Request a record batch with of s1_min as a timestamp
let required_columns = RequiredStatColumns::from(vec![(
- "s1".to_string(),
+ "s3".into(),
StatisticsType::Min,
Field::new(
"s1_min",
@@ -867,7 +868,7 @@ mod tests {
// Request a record batch with of s1_min as a timestamp
let required_columns = RequiredStatColumns::from(vec![(
- "s1".to_string(),
+ "s3".into(),
StatisticsType::Min,
Field::new("s1_min", DataType::Utf8, true),
)]);
@@ -896,7 +897,7 @@ mod tests {
fn test_build_statistics_inconsistent_length() {
// return an inconsistent length to the actual statistics arrays
let required_columns = RequiredStatColumns::from(vec![(
- "s1".to_string(),
+ "s1".into(),
StatisticsType::Min,
Field::new("s1_min", DataType::Int64, true),
)]);
@@ -1143,18 +1144,18 @@ mod tests {
let c1_min_field = Field::new("c1_min", DataType::Int32, false);
assert_eq!(
required_columns.columns[0],
- ("c1".to_owned(), StatisticsType::Min, c1_min_field)
+ ("c1".into(), StatisticsType::Min, c1_min_field)
);
// c2 = 2 should add c2_min and c2_max
let c2_min_field = Field::new("c2_min", DataType::Int32, false);
assert_eq!(
required_columns.columns[1],
- ("c2".to_owned(), StatisticsType::Min, c2_min_field)
+ ("c2".into(), StatisticsType::Min, c2_min_field)
);
let c2_max_field = Field::new("c2_max", DataType::Int32, false);
assert_eq!(
required_columns.columns[2],
- ("c2".to_owned(), StatisticsType::Max, c2_max_field)
+ ("c2".into(), StatisticsType::Max, c2_max_field)
);
// c2 = 3 shouldn't add any new statistics fields
assert_eq!(required_columns.columns.len(), 3);
diff --git a/datafusion/src/physical_plan/expressions/binary.rs b/datafusion/src/physical_plan/expressions/binary.rs
index 5ed0c74..a69b776 100644
--- a/datafusion/src/physical_plan/expressions/binary.rs
+++ b/datafusion/src/physical_plan/expressions/binary.rs
@@ -611,11 +611,12 @@ mod tests {
]);
let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
let b = Int32Array::from(vec![1, 2, 4, 8, 16]);
+
+ // expression: "a < b"
+ let lt = binary_simple(col("a", &schema)?, Operator::Lt, col("b", &schema)?);
let batch =
RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)])?;
- // expression: "a < b"
- let lt = binary_simple(col("a"), Operator::Lt, col("b"));
let result = lt.evaluate(&batch)?.into_array(batch.num_rows());
assert_eq!(result.len(), 5);
@@ -639,16 +640,17 @@ mod tests {
]);
let a = Int32Array::from(vec![2, 4, 6, 8, 10]);
let b = Int32Array::from(vec![2, 5, 4, 8, 8]);
- let batch =
- RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)])?;
// expression: "a < b OR a == b"
let expr = binary_simple(
- binary_simple(col("a"), Operator::Lt, col("b")),
+ binary_simple(col("a", &schema)?, Operator::Lt, col("b", &schema)?),
Operator::Or,
- binary_simple(col("a"), Operator::Eq, col("b")),
+ binary_simple(col("a", &schema)?, Operator::Eq, col("b", &schema)?),
);
- assert_eq!("a < b OR a = b", format!("{}", expr));
+ let batch =
+ RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)])?;
+
+ assert_eq!("a@0 < b@1 OR a@0 = b@1", format!("{}", expr));
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
assert_eq!(result.len(), 5);
@@ -680,14 +682,15 @@ mod tests {
]);
let a = $A_ARRAY::from($A_VEC);
let b = $B_ARRAY::from($B_VEC);
+
+ // verify that we can construct the expression
+ let expression =
+ binary(col("a", &schema)?, $OP, col("b", &schema)?, &schema)?;
let batch = RecordBatch::try_new(
Arc::new(schema.clone()),
vec![Arc::new(a), Arc::new(b)],
)?;
- // verify that we can construct the expression
- let expression = binary(col("a"), $OP, col("b"), &schema)?;
-
// verify that the expression's type is correct
assert_eq!(expression.data_type(&schema)?, $C_TYPE);
@@ -863,7 +866,12 @@ mod tests {
// Test 1: dict = str
// verify that we can construct the expression
- let expression = binary(col("dict"), Operator::Eq, col("str"), &schema)?;
+ let expression = binary(
+ col("dict", &schema)?,
+ Operator::Eq,
+ col("str", &schema)?,
+ &schema,
+ )?;
assert_eq!(expression.data_type(&schema)?, DataType::Boolean);
// evaluate and verify the result type matched
@@ -877,7 +885,12 @@ mod tests {
// str = dict
// verify that we can construct the expression
- let expression = binary(col("str"), Operator::Eq, col("dict"), &schema)?;
+ let expression = binary(
+ col("str", &schema)?,
+ Operator::Eq,
+ col("dict", &schema)?,
+ &schema,
+ )?;
assert_eq!(expression.data_type(&schema)?, DataType::Boolean);
// evaluate and verify the result type matched
@@ -989,7 +1002,7 @@ mod tests {
op: Operator,
expected: PrimitiveArray<T>,
) -> Result<()> {
- let arithmetic_op = binary_simple(col("a"), op, col("b"));
+ let arithmetic_op = binary_simple(col("a", &schema)?, op, col("b", &schema)?);
let batch = RecordBatch::try_new(schema, data)?;
let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows());
@@ -1004,7 +1017,7 @@ mod tests {
op: Operator,
expected: BooleanArray,
) -> Result<()> {
- let arithmetic_op = binary_simple(col("a"), op, col("b"));
+ let arithmetic_op = binary_simple(col("a", &schema)?, op, col("b", &schema)?);
let data: Vec<ArrayRef> = vec![Arc::new(left), Arc::new(right)];
let batch = RecordBatch::try_new(schema, data)?;
let result = arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows());
diff --git a/datafusion/src/physical_plan/expressions/case.rs b/datafusion/src/physical_plan/expressions/case.rs
index f89ea8d..a46522d 100644
--- a/datafusion/src/physical_plan/expressions/case.rs
+++ b/datafusion/src/physical_plan/expressions/case.rs
@@ -451,6 +451,7 @@ mod tests {
#[test]
fn case_with_expr() -> Result<()> {
let batch = case_test_batch()?;
+ let schema = batch.schema();
// CASE a WHEN 'foo' THEN 123 WHEN 'bar' THEN 456 END
let when1 = lit(ScalarValue::Utf8(Some("foo".to_string())));
@@ -458,7 +459,11 @@ mod tests {
let when2 = lit(ScalarValue::Utf8(Some("bar".to_string())));
let then2 = lit(ScalarValue::Int32(Some(456)));
- let expr = case(Some(col("a")), &[(when1, then1), (when2, then2)], None)?;
+ let expr = case(
+ Some(col("a", &schema)?),
+ &[(when1, then1), (when2, then2)],
+ None,
+ )?;
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
let result = result
.as_any()
@@ -475,6 +480,7 @@ mod tests {
#[test]
fn case_with_expr_else() -> Result<()> {
let batch = case_test_batch()?;
+ let schema = batch.schema();
// CASE a WHEN 'foo' THEN 123 WHEN 'bar' THEN 456 ELSE 999 END
let when1 = lit(ScalarValue::Utf8(Some("foo".to_string())));
@@ -484,7 +490,7 @@ mod tests {
let else_value = lit(ScalarValue::Int32(Some(999)));
let expr = case(
- Some(col("a")),
+ Some(col("a", &schema)?),
&[(when1, then1), (when2, then2)],
Some(else_value),
)?;
@@ -505,17 +511,18 @@ mod tests {
#[test]
fn case_without_expr() -> Result<()> {
let batch = case_test_batch()?;
+ let schema = batch.schema();
// CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 END
let when1 = binary(
- col("a"),
+ col("a", &schema)?,
Operator::Eq,
lit(ScalarValue::Utf8(Some("foo".to_string()))),
&batch.schema(),
)?;
let then1 = lit(ScalarValue::Int32(Some(123)));
let when2 = binary(
- col("a"),
+ col("a", &schema)?,
Operator::Eq,
lit(ScalarValue::Utf8(Some("bar".to_string()))),
&batch.schema(),
@@ -539,17 +546,18 @@ mod tests {
#[test]
fn case_without_expr_else() -> Result<()> {
let batch = case_test_batch()?;
+ let schema = batch.schema();
// CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 ELSE 999 END
let when1 = binary(
- col("a"),
+ col("a", &schema)?,
Operator::Eq,
lit(ScalarValue::Utf8(Some("foo".to_string()))),
&batch.schema(),
)?;
let then1 = lit(ScalarValue::Int32(Some(123)));
let when2 = binary(
- col("a"),
+ col("a", &schema)?,
Operator::Eq,
lit(ScalarValue::Utf8(Some("bar".to_string()))),
&batch.schema(),
diff --git a/datafusion/src/physical_plan/expressions/cast.rs b/datafusion/src/physical_plan/expressions/cast.rs
index 558b1e5..bba125e 100644
--- a/datafusion/src/physical_plan/expressions/cast.rs
+++ b/datafusion/src/physical_plan/expressions/cast.rs
@@ -180,10 +180,14 @@ mod tests {
RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
// verify that we can construct the expression
- let expression = cast_with_options(col("a"), &schema, $TYPE, $CAST_OPTIONS)?;
+ let expression =
+ cast_with_options(col("a", &schema)?, &schema, $TYPE, $CAST_OPTIONS)?;
// verify that its display is correct
- assert_eq!(format!("CAST(a AS {:?})", $TYPE), format!("{}", expression));
+ assert_eq!(
+ format!("CAST(a@0 AS {:?})", $TYPE),
+ format!("{}", expression)
+ );
// verify that the expression's type is correct
assert_eq!(expression.data_type(&schema)?, $TYPE);
@@ -272,7 +276,7 @@ mod tests {
// Ensure a useful error happens at plan time if invalid casts are used
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
- let result = cast(col("a"), &schema, DataType::LargeBinary);
+ let result = cast(col("a", &schema).unwrap(), &schema, DataType::LargeBinary);
result.expect_err("expected Invalid CAST");
}
@@ -283,7 +287,7 @@ mod tests {
let a = StringArray::from(vec!["9.1"]);
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
let expression = cast_with_options(
- col("a"),
+ col("a", &schema)?,
&schema,
DataType::Int32,
DEFAULT_DATAFUSION_CAST_OPTIONS,
diff --git a/datafusion/src/physical_plan/expressions/column.rs b/datafusion/src/physical_plan/expressions/column.rs
index 7e0304e..d6eafbb 100644
--- a/datafusion/src/physical_plan/expressions/column.rs
+++ b/datafusion/src/physical_plan/expressions/column.rs
@@ -28,28 +28,40 @@ use crate::error::Result;
use crate::physical_plan::{ColumnarValue, PhysicalExpr};
/// Represents the column at a given index in a RecordBatch
-#[derive(Debug)]
+#[derive(Debug, Hash, PartialEq, Eq, Clone)]
pub struct Column {
name: String,
+ index: usize,
}
impl Column {
/// Create a new column expression
- pub fn new(name: &str) -> Self {
+ pub fn new(name: &str, index: usize) -> Self {
Self {
name: name.to_owned(),
+ index,
}
}
+ /// Create a new column expression based on column name and schema
+ pub fn new_with_schema(name: &str, schema: &Schema) -> Result<Self> {
+ Ok(Column::new(name, schema.index_of(name)?))
+ }
+
/// Get the column name
pub fn name(&self) -> &str {
&self.name
}
+
+ /// Get the column index
+ pub fn index(&self) -> usize {
+ self.index
+ }
}
impl std::fmt::Display for Column {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
- write!(f, "{}", self.name)
+ write!(f, "{}@{}", self.name, self.index)
}
}
@@ -61,26 +73,21 @@ impl PhysicalExpr for Column {
/// Get the data type of this expression, given the schema of the input
fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
- Ok(input_schema
- .field_with_name(&self.name)?
- .data_type()
- .clone())
+ Ok(input_schema.field(self.index).data_type().clone())
}
/// Decide whehter this expression is nullable, given the schema of the input
fn nullable(&self, input_schema: &Schema) -> Result<bool> {
- Ok(input_schema.field_with_name(&self.name)?.is_nullable())
+ Ok(input_schema.field(self.index).is_nullable())
}
/// Evaluate the expression
fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
- Ok(ColumnarValue::Array(
- batch.column(batch.schema().index_of(&self.name)?).clone(),
- ))
+ Ok(ColumnarValue::Array(batch.column(self.index).clone()))
}
}
/// Create a column expression
-pub fn col(name: &str) -> Arc<dyn PhysicalExpr> {
- Arc::new(Column::new(name))
+pub fn col(name: &str, schema: &Schema) -> Result<Arc<dyn PhysicalExpr>> {
+ Ok(Arc::new(Column::new_with_schema(name, schema)?))
}
diff --git a/datafusion/src/physical_plan/expressions/in_list.rs b/datafusion/src/physical_plan/expressions/in_list.rs
index 41f1110..38b2b9d 100644
--- a/datafusion/src/physical_plan/expressions/in_list.rs
+++ b/datafusion/src/physical_plan/expressions/in_list.rs
@@ -296,8 +296,8 @@ mod tests {
// applies the in_list expr to an input batch and list
macro_rules! in_list {
- ($BATCH:expr, $LIST:expr, $NEGATED:expr, $EXPECTED:expr) => {{
- let expr = in_list(col("a"), $LIST, $NEGATED).unwrap();
+ ($BATCH:expr, $LIST:expr, $NEGATED:expr, $EXPECTED:expr, $COL:expr) => {{
+ let expr = in_list($COL, $LIST, $NEGATED).unwrap();
let result = expr.evaluate(&$BATCH)?.into_array($BATCH.num_rows());
let result = result
.as_any()
@@ -312,6 +312,7 @@ mod tests {
fn in_list_utf8() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
let a = StringArray::from(vec![Some("a"), Some("d"), None]);
+ let col_a = col("a", &schema)?;
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
// expression: "a in ("a", "b")"
@@ -319,14 +320,26 @@ mod tests {
lit(ScalarValue::Utf8(Some("a".to_string()))),
lit(ScalarValue::Utf8(Some("b".to_string()))),
];
- in_list!(batch, list, &false, vec![Some(true), Some(false), None]);
+ in_list!(
+ batch,
+ list,
+ &false,
+ vec![Some(true), Some(false), None],
+ col_a.clone()
+ );
// expression: "a not in ("a", "b")"
let list = vec![
lit(ScalarValue::Utf8(Some("a".to_string()))),
lit(ScalarValue::Utf8(Some("b".to_string()))),
];
- in_list!(batch, list, &true, vec![Some(false), Some(true), None]);
+ in_list!(
+ batch,
+ list,
+ &true,
+ vec![Some(false), Some(true), None],
+ col_a.clone()
+ );
// expression: "a not in ("a", "b")"
let list = vec![
@@ -334,7 +347,13 @@ mod tests {
lit(ScalarValue::Utf8(Some("b".to_string()))),
lit(ScalarValue::Utf8(None)),
];
- in_list!(batch, list, &false, vec![Some(true), None, None]);
+ in_list!(
+ batch,
+ list,
+ &false,
+ vec![Some(true), None, None],
+ col_a.clone()
+ );
// expression: "a not in ("a", "b")"
let list = vec![
@@ -342,7 +361,13 @@ mod tests {
lit(ScalarValue::Utf8(Some("b".to_string()))),
lit(ScalarValue::Utf8(None)),
];
- in_list!(batch, list, &true, vec![Some(false), None, None]);
+ in_list!(
+ batch,
+ list,
+ &true,
+ vec![Some(false), None, None],
+ col_a.clone()
+ );
Ok(())
}
@@ -351,6 +376,7 @@ mod tests {
fn in_list_int64() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]);
let a = Int64Array::from(vec![Some(0), Some(2), None]);
+ let col_a = col("a", &schema)?;
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
// expression: "a in (0, 1)"
@@ -358,14 +384,26 @@ mod tests {
lit(ScalarValue::Int64(Some(0))),
lit(ScalarValue::Int64(Some(1))),
];
- in_list!(batch, list, &false, vec![Some(true), Some(false), None]);
+ in_list!(
+ batch,
+ list,
+ &false,
+ vec![Some(true), Some(false), None],
+ col_a.clone()
+ );
// expression: "a not in (0, 1)"
let list = vec![
lit(ScalarValue::Int64(Some(0))),
lit(ScalarValue::Int64(Some(1))),
];
- in_list!(batch, list, &true, vec![Some(false), Some(true), None]);
+ in_list!(
+ batch,
+ list,
+ &true,
+ vec![Some(false), Some(true), None],
+ col_a.clone()
+ );
// expression: "a in (0, 1, NULL)"
let list = vec![
@@ -373,7 +411,13 @@ mod tests {
lit(ScalarValue::Int64(Some(1))),
lit(ScalarValue::Utf8(None)),
];
- in_list!(batch, list, &false, vec![Some(true), None, None]);
+ in_list!(
+ batch,
+ list,
+ &false,
+ vec![Some(true), None, None],
+ col_a.clone()
+ );
// expression: "a not in (0, 1, NULL)"
let list = vec![
@@ -381,7 +425,13 @@ mod tests {
lit(ScalarValue::Int64(Some(1))),
lit(ScalarValue::Utf8(None)),
];
- in_list!(batch, list, &true, vec![Some(false), None, None]);
+ in_list!(
+ batch,
+ list,
+ &true,
+ vec![Some(false), None, None],
+ col_a.clone()
+ );
Ok(())
}
@@ -390,6 +440,7 @@ mod tests {
fn in_list_float64() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]);
let a = Float64Array::from(vec![Some(0.0), Some(0.2), None]);
+ let col_a = col("a", &schema)?;
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
// expression: "a in (0.0, 0.2)"
@@ -397,14 +448,26 @@ mod tests {
lit(ScalarValue::Float64(Some(0.0))),
lit(ScalarValue::Float64(Some(0.1))),
];
- in_list!(batch, list, &false, vec![Some(true), Some(false), None]);
+ in_list!(
+ batch,
+ list,
+ &false,
+ vec![Some(true), Some(false), None],
+ col_a.clone()
+ );
// expression: "a not in (0.0, 0.2)"
let list = vec![
lit(ScalarValue::Float64(Some(0.0))),
lit(ScalarValue::Float64(Some(0.1))),
];
- in_list!(batch, list, &true, vec![Some(false), Some(true), None]);
+ in_list!(
+ batch,
+ list,
+ &true,
+ vec![Some(false), Some(true), None],
+ col_a.clone()
+ );
// expression: "a in (0.0, 0.2, NULL)"
let list = vec![
@@ -412,7 +475,13 @@ mod tests {
lit(ScalarValue::Float64(Some(0.1))),
lit(ScalarValue::Utf8(None)),
];
- in_list!(batch, list, &false, vec![Some(true), None, None]);
+ in_list!(
+ batch,
+ list,
+ &false,
+ vec![Some(true), None, None],
+ col_a.clone()
+ );
// expression: "a not in (0.0, 0.2, NULL)"
let list = vec![
@@ -420,7 +489,13 @@ mod tests {
lit(ScalarValue::Float64(Some(0.1))),
lit(ScalarValue::Utf8(None)),
];
- in_list!(batch, list, &true, vec![Some(false), None, None]);
+ in_list!(
+ batch,
+ list,
+ &true,
+ vec![Some(false), None, None],
+ col_a.clone()
+ );
Ok(())
}
@@ -429,29 +504,30 @@ mod tests {
fn in_list_bool() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Boolean, true)]);
let a = BooleanArray::from(vec![Some(true), None]);
+ let col_a = col("a", &schema)?;
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
// expression: "a in (true)"
let list = vec![lit(ScalarValue::Boolean(Some(true)))];
- in_list!(batch, list, &false, vec![Some(true), None]);
+ in_list!(batch, list, &false, vec![Some(true), None], col_a.clone());
// expression: "a not in (true)"
let list = vec![lit(ScalarValue::Boolean(Some(true)))];
- in_list!(batch, list, &true, vec![Some(false), None]);
+ in_list!(batch, list, &true, vec![Some(false), None], col_a.clone());
// expression: "a in (true, NULL)"
let list = vec![
lit(ScalarValue::Boolean(Some(true))),
lit(ScalarValue::Utf8(None)),
];
- in_list!(batch, list, &false, vec![Some(true), None]);
+ in_list!(batch, list, &false, vec![Some(true), None], col_a.clone());
// expression: "a not in (true, NULL)"
let list = vec![
lit(ScalarValue::Boolean(Some(true))),
lit(ScalarValue::Utf8(None)),
];
- in_list!(batch, list, &true, vec![Some(false), None]);
+ in_list!(batch, list, &true, vec![Some(false), None], col_a.clone());
Ok(())
}
diff --git a/datafusion/src/physical_plan/expressions/is_not_null.rs b/datafusion/src/physical_plan/expressions/is_not_null.rs
index 7ac2110..cce27e3 100644
--- a/datafusion/src/physical_plan/expressions/is_not_null.rs
+++ b/datafusion/src/physical_plan/expressions/is_not_null.rs
@@ -100,10 +100,10 @@ mod tests {
fn is_not_null_op() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
let a = StringArray::from(vec![Some("foo"), None]);
+ let expr = is_not_null(col("a", &schema)?).unwrap();
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
// expression: "a is not null"
- let expr = is_not_null(col("a")).unwrap();
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
let result = result
.as_any()
diff --git a/datafusion/src/physical_plan/expressions/is_null.rs b/datafusion/src/physical_plan/expressions/is_null.rs
index dfa53f3..dbb57df 100644
--- a/datafusion/src/physical_plan/expressions/is_null.rs
+++ b/datafusion/src/physical_plan/expressions/is_null.rs
@@ -100,10 +100,11 @@ mod tests {
fn is_null_op() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
let a = StringArray::from(vec![Some("foo"), None]);
- let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
// expression: "a is null"
- let expr = is_null(col("a")).unwrap();
+ let expr = is_null(col("a", &schema)?).unwrap();
+ let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?;
+
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
let result = result
.as_any()
diff --git a/datafusion/src/physical_plan/expressions/min_max.rs b/datafusion/src/physical_plan/expressions/min_max.rs
index ea917d3..680e739 100644
--- a/datafusion/src/physical_plan/expressions/min_max.rs
+++ b/datafusion/src/physical_plan/expressions/min_max.rs
@@ -278,7 +278,7 @@ macro_rules! min_max {
}
e => {
return Err(DataFusionError::Internal(format!(
- "MIN/MAX is not expected to receive a scalar {:?}",
+ "MIN/MAX is not expected to receive scalars of incompatible types {:?}",
e
)))
}
diff --git a/datafusion/src/physical_plan/expressions/mod.rs b/datafusion/src/physical_plan/expressions/mod.rs
index f8cb40c..0b32dca 100644
--- a/datafusion/src/physical_plan/expressions/mod.rs
+++ b/datafusion/src/physical_plan/expressions/mod.rs
@@ -66,6 +66,7 @@ pub use nullif::{nullif_func, SUPPORTED_NULLIF_TYPES};
pub use row_number::RowNumber;
pub use sum::{sum_return_type, Sum};
pub use try_cast::{try_cast, TryCastExpr};
+
/// returns the name of the state
pub fn format_state_name(name: &str, state_name: &str) -> String {
format!("{}[{}]", name, state_name)
@@ -126,8 +127,11 @@ mod tests {
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![$ARRAY])?;
- let agg =
- Arc::new(<$OP>::new(col("a"), "bla".to_string(), $EXPECTED_DATATYPE));
+ let agg = Arc::new(<$OP>::new(
+ col("a", &schema)?,
+ "bla".to_string(),
+ $EXPECTED_DATATYPE,
+ ));
let actual = aggregate(&batch, agg)?;
let expected = ScalarValue::from($EXPECTED);
diff --git a/datafusion/src/physical_plan/expressions/not.rs b/datafusion/src/physical_plan/expressions/not.rs
index 7a997b6..341d38a 100644
--- a/datafusion/src/physical_plan/expressions/not.rs
+++ b/datafusion/src/physical_plan/expressions/not.rs
@@ -127,7 +127,7 @@ mod tests {
fn neg_op() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Boolean, true)]);
- let expr = not(col("a"), &schema)?;
+ let expr = not(col("a", &schema)?, &schema)?;
assert_eq!(expr.data_type(&schema)?, DataType::Boolean);
assert!(expr.nullable(&schema)?);
@@ -152,7 +152,7 @@ mod tests {
fn neg_op_not_null() {
let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
- let expr = not(col("a"), &schema);
+ let expr = not(col("a", &schema).unwrap(), &schema);
assert!(expr.is_err());
}
}
diff --git a/datafusion/src/physical_plan/expressions/nth_value.rs b/datafusion/src/physical_plan/expressions/nth_value.rs
index 16897d4..577c19b 100644
--- a/datafusion/src/physical_plan/expressions/nth_value.rs
+++ b/datafusion/src/physical_plan/expressions/nth_value.rs
@@ -148,7 +148,7 @@ impl BuiltInWindowFunctionExpr for NthValue {
mod tests {
use super::*;
use crate::error::Result;
- use crate::physical_plan::expressions::col;
+ use crate::physical_plan::expressions::Column;
use arrow::record_batch::RecordBatch;
use arrow::{array::*, datatypes::*};
@@ -166,32 +166,46 @@ mod tests {
#[test]
fn first_value() -> Result<()> {
- let first_value =
- NthValue::first_value("first_value".to_owned(), col("arr"), DataType::Int32);
+ let first_value = NthValue::first_value(
+ "first_value".to_owned(),
+ Arc::new(Column::new("arr", 0)),
+ DataType::Int32,
+ );
test_i32_result(first_value, vec![1; 8])?;
Ok(())
}
#[test]
fn last_value() -> Result<()> {
- let last_value =
- NthValue::last_value("last_value".to_owned(), col("arr"), DataType::Int32);
+ let last_value = NthValue::last_value(
+ "last_value".to_owned(),
+ Arc::new(Column::new("arr", 0)),
+ DataType::Int32,
+ );
test_i32_result(last_value, vec![8; 8])?;
Ok(())
}
#[test]
fn nth_value_1() -> Result<()> {
- let nth_value =
- NthValue::nth_value("nth_value".to_owned(), col("arr"), DataType::Int32, 1)?;
+ let nth_value = NthValue::nth_value(
+ "nth_value".to_owned(),
+ Arc::new(Column::new("arr", 0)),
+ DataType::Int32,
+ 1,
+ )?;
test_i32_result(nth_value, vec![1; 8])?;
Ok(())
}
#[test]
fn nth_value_2() -> Result<()> {
- let nth_value =
- NthValue::nth_value("nth_value".to_owned(), col("arr"), DataType::Int32, 2)?;
+ let nth_value = NthValue::nth_value(
+ "nth_value".to_owned(),
+ Arc::new(Column::new("arr", 0)),
+ DataType::Int32,
+ 2,
+ )?;
test_i32_result(nth_value, vec![-2; 8])?;
Ok(())
}
diff --git a/datafusion/src/physical_plan/expressions/try_cast.rs b/datafusion/src/physical_plan/expressions/try_cast.rs
index 5e402fd..1ba4a50 100644
--- a/datafusion/src/physical_plan/expressions/try_cast.rs
+++ b/datafusion/src/physical_plan/expressions/try_cast.rs
@@ -139,10 +139,13 @@ mod tests {
RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
// verify that we can construct the expression
- let expression = try_cast(col("a"), &schema, $TYPE)?;
+ let expression = try_cast(col("a", &schema)?, &schema, $TYPE)?;
// verify that its display is correct
- assert_eq!(format!("CAST(a AS {:?})", $TYPE), format!("{}", expression));
+ assert_eq!(
+ format!("CAST(a@0 AS {:?})", $TYPE),
+ format!("{}", expression)
+ );
// verify that the expression's type is correct
assert_eq!(expression.data_type(&schema)?, $TYPE);
@@ -241,7 +244,7 @@ mod tests {
// Ensure a useful error happens at plan time if invalid casts are used
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
- let result = try_cast(col("a"), &schema, DataType::LargeBinary);
+ let result = try_cast(col("a", &schema).unwrap(), &schema, DataType::LargeBinary);
result.expect_err("expected Invalid CAST");
}
}
diff --git a/datafusion/src/physical_plan/filter.rs b/datafusion/src/physical_plan/filter.rs
index 0a8c825..9e7fa9d 100644
--- a/datafusion/src/physical_plan/filter.rs
+++ b/datafusion/src/physical_plan/filter.rs
@@ -223,14 +223,14 @@ mod tests {
let predicate: Arc<dyn PhysicalExpr> = binary(
binary(
- col("c2"),
+ col("c2", &schema)?,
Operator::Gt,
lit(ScalarValue::from(1u32)),
&schema,
)?,
Operator::And,
binary(
- col("c2"),
+ col("c2", &schema)?,
Operator::Lt,
lit(ScalarValue::from(4u32)),
&schema,
diff --git a/datafusion/src/physical_plan/functions.rs b/datafusion/src/physical_plan/functions.rs
index 0e2be51..01f7e95 100644
--- a/datafusion/src/physical_plan/functions.rs
+++ b/datafusion/src/physical_plan/functions.rs
@@ -3651,7 +3651,7 @@ mod tests {
let expr = create_physical_expr(
&BuiltinScalarFunction::Array,
- &[col("a"), col("b")],
+ &[col("a", &schema)?, col("b", &schema)?],
&schema,
&ctx_state,
)?;
@@ -3718,7 +3718,7 @@ mod tests {
let columns: Vec<ArrayRef> = vec![col_value];
let expr = create_physical_expr(
&BuiltinScalarFunction::RegexpMatch,
- &[col("a"), pattern],
+ &[col("a", &schema)?, pattern],
&schema,
&ctx_state,
)?;
diff --git a/datafusion/src/physical_plan/hash_aggregate.rs b/datafusion/src/physical_plan/hash_aggregate.rs
index f1611eb..250ba2b 100644
--- a/datafusion/src/physical_plan/hash_aggregate.rs
+++ b/datafusion/src/physical_plan/hash_aggregate.rs
@@ -663,9 +663,12 @@ async fn compute_grouped_hash_aggregate(
aggr_expr: Vec<Arc<dyn AggregateExpr>>,
mut input: SendableRecordBatchStream,
) -> ArrowResult<RecordBatch> {
- // the expressions to evaluate the batch, one vec of expressions per aggregation
- let aggregate_expressions = aggregate_expressions(&aggr_expr, &mode)
- .map_err(DataFusionError::into_arrow_external_error)?;
+ // The expressions to evaluate the batch, one vec of expressions per aggregation.
+ // Assume create_schema() always put group columns in front of aggr columns, we set
+ // col_idx_base to group expression count.
+ let aggregate_expressions =
+ aggregate_expressions(&aggr_expr, &mode, group_expr.len())
+ .map_err(DataFusionError::into_arrow_external_error)?;
// mapping key -> (set of accumulators, indices of the key in the batch)
// * the indexes are updated at each row
@@ -794,14 +797,21 @@ fn evaluate_many(
.collect::<Result<Vec<_>>>()
}
-/// uses `state_fields` to build a vec of expressions required to merge the AggregateExpr' accumulator's state.
+/// uses `state_fields` to build a vec of physical column expressions required to merge the
+/// AggregateExpr' accumulator's state.
+///
+/// `index_base` is the starting physical column index for the next expanded state field.
fn merge_expressions(
+ index_base: usize,
expr: &Arc<dyn AggregateExpr>,
) -> Result<Vec<Arc<dyn PhysicalExpr>>> {
Ok(expr
.state_fields()?
.iter()
- .map(|f| Arc::new(Column::new(f.name())) as Arc<dyn PhysicalExpr>)
+ .enumerate()
+ .map(|(idx, f)| {
+ Arc::new(Column::new(f.name(), index_base + idx)) as Arc<dyn PhysicalExpr>
+ })
.collect::<Vec<_>>())
}
@@ -809,22 +819,27 @@ fn merge_expressions(
/// The expressions are different depending on `mode`:
/// * Partial: AggregateExpr::expressions
/// * Final: columns of `AggregateExpr::state_fields()`
-/// The return value is to be understood as:
-/// * index 0 is the aggregation
-/// * index 1 is the expression i of the aggregation
fn aggregate_expressions(
aggr_expr: &[Arc<dyn AggregateExpr>],
mode: &AggregateMode,
+ col_idx_base: usize,
) -> Result<Vec<Vec<Arc<dyn PhysicalExpr>>>> {
match mode {
AggregateMode::Partial => {
Ok(aggr_expr.iter().map(|agg| agg.expressions()).collect())
}
// in this mode, we build the merge expressions of the aggregation
- AggregateMode::Final | AggregateMode::FinalPartitioned => Ok(aggr_expr
- .iter()
- .map(|agg| merge_expressions(agg))
- .collect::<Result<Vec<_>>>()?),
+ AggregateMode::Final | AggregateMode::FinalPartitioned => {
+ let mut col_idx_base = col_idx_base;
+ Ok(aggr_expr
+ .iter()
+ .map(|agg| {
+ let exprs = merge_expressions(col_idx_base, agg)?;
+ col_idx_base += exprs.len();
+ Ok(exprs)
+ })
+ .collect::<Result<Vec<_>>>()?)
+ }
}
}
@@ -846,10 +861,8 @@ async fn compute_hash_aggregate(
) -> ArrowResult<RecordBatch> {
let mut accumulators = create_accumulators(&aggr_expr)
.map_err(DataFusionError::into_arrow_external_error)?;
-
- let expressions = aggregate_expressions(&aggr_expr, &mode)
+ let expressions = aggregate_expressions(&aggr_expr, &mode, 0)
.map_err(DataFusionError::into_arrow_external_error)?;
-
let expressions = Arc::new(expressions);
// 1 for each batch, update / merge accumulators with the expressions' values
@@ -1253,16 +1266,17 @@ mod tests {
/// build the aggregates on the data from some_data() and check the results
async fn check_aggregates(input: Arc<dyn ExecutionPlan>) -> Result<()> {
+ let input_schema = input.schema();
+
let groups: Vec<(Arc<dyn PhysicalExpr>, String)> =
- vec![(col("a"), "a".to_string())];
+ vec![(col("a", &input_schema)?, "a".to_string())];
let aggregates: Vec<Arc<dyn AggregateExpr>> = vec![Arc::new(Avg::new(
- col("b"),
+ col("b", &input_schema)?,
"AVG(b)".to_string(),
DataType::Float64,
))];
- let input_schema = input.schema();
let partial_aggregate = Arc::new(HashAggregateExec::try_new(
AggregateMode::Partial,
groups.clone(),
@@ -1286,8 +1300,9 @@ mod tests {
let merge = Arc::new(MergeExec::new(partial_aggregate));
- let final_group: Vec<Arc<dyn PhysicalExpr>> =
- (0..groups.len()).map(|i| col(&groups[i].1)).collect();
+ let final_group: Vec<Arc<dyn PhysicalExpr>> = (0..groups.len())
+ .map(|i| col(&groups[i].1, &input_schema))
+ .collect::<Result<_>>()?;
let merged_aggregate = Arc::new(HashAggregateExec::try_new(
AggregateMode::Final,
diff --git a/datafusion/src/physical_plan/hash_join.rs b/datafusion/src/physical_plan/hash_join.rs
index 928392a..ad35607 100644
--- a/datafusion/src/physical_plan/hash_join.rs
+++ b/datafusion/src/physical_plan/hash_join.rs
@@ -52,7 +52,7 @@ use arrow::array::{
UInt64Array, UInt8Array,
};
-use super::expressions::col;
+use super::expressions::Column;
use super::{
hash_utils::{build_join_schema, check_join_is_valid, JoinOn, JoinType},
merge::MergeExec,
@@ -64,6 +64,7 @@ use super::{
SendableRecordBatchStream,
};
use crate::physical_plan::coalesce_batches::concat_batches;
+use crate::physical_plan::PhysicalExpr;
use log::debug;
// Maps a `u64` hash value based on the left ["on" values] to a list of indices with this key's value.
@@ -90,7 +91,7 @@ pub struct HashJoinExec {
/// right (probe) side which are filtered by the hash table
right: Arc<dyn ExecutionPlan>,
/// Set of common columns used to join on
- on: Vec<(String, String)>,
+ on: Vec<(Column, Column)>,
/// How the join is performed
join_type: JoinType,
/// The schema once the join is applied
@@ -127,26 +128,21 @@ impl HashJoinExec {
pub fn try_new(
left: Arc<dyn ExecutionPlan>,
right: Arc<dyn ExecutionPlan>,
- on: &JoinOn,
+ on: JoinOn,
join_type: &JoinType,
partition_mode: PartitionMode,
) -> Result<Self> {
let left_schema = left.schema();
let right_schema = right.schema();
- check_join_is_valid(&left_schema, &right_schema, on)?;
+ check_join_is_valid(&left_schema, &right_schema, &on)?;
let schema = Arc::new(build_join_schema(
&left_schema,
&right_schema,
- on,
+ &on,
join_type,
));
- let on = on
- .iter()
- .map(|(l, r)| (l.to_string(), r.to_string()))
- .collect();
-
let random_state = RandomState::with_seeds(0, 0, 0, 0);
Ok(HashJoinExec {
@@ -172,7 +168,7 @@ impl HashJoinExec {
}
/// Set of common columns used to join on
- pub fn on(&self) -> &[(String, String)] {
+ pub fn on(&self) -> &[(Column, Column)] {
&self.on
}
@@ -236,7 +232,7 @@ impl ExecutionPlan for HashJoinExec {
2 => Ok(Arc::new(HashJoinExec::try_new(
children[0].clone(),
children[1].clone(),
- &self.on,
+ self.on.clone(),
&self.join_type,
self.mode,
)?)),
@@ -307,10 +303,10 @@ impl ExecutionPlan for HashJoinExec {
*build_side = Some(left_side.clone());
debug!(
- "Built build-side of hash join containing {} rows in {} ms",
- num_rows,
- start.elapsed().as_millis()
- );
+ "Built build-side of hash join containing {} rows in {} ms",
+ num_rows,
+ start.elapsed().as_millis()
+ );
left_side
}
@@ -372,7 +368,7 @@ impl ExecutionPlan for HashJoinExec {
// we have the batches and the hash map with their keys. We can how create a stream
// over the right that uses this information to issue new batches.
- let stream = self.right.execute(partition).await?;
+ let right_stream = self.right.execute(partition).await?;
let on_right = self.on.iter().map(|on| on.1.clone()).collect::<Vec<_>>();
let column_indices = self.column_indices_from_schema()?;
@@ -383,23 +379,17 @@ impl ExecutionPlan for HashJoinExec {
}
JoinType::Inner | JoinType::Right => vec![],
};
- Ok(Box::pin(HashJoinStream {
- schema: self.schema.clone(),
+ Ok(Box::pin(HashJoinStream::new(
+ self.schema.clone(),
on_left,
on_right,
- join_type: self.join_type,
+ self.join_type,
left_data,
- right: stream,
+ right_stream,
column_indices,
- num_input_batches: 0,
- num_input_rows: 0,
- num_output_batches: 0,
- num_output_rows: 0,
- join_time: 0,
- random_state: self.random_state.clone(),
+ self.random_state.clone(),
visited_left_side,
- is_exhausted: false,
- }))
+ )))
}
fn fmt_as(
@@ -422,7 +412,7 @@ impl ExecutionPlan for HashJoinExec {
/// Updates `hash` with new entries from [RecordBatch] evaluated against the expressions `on`,
/// assuming that the [RecordBatch] corresponds to the `index`th
fn update_hash(
- on: &[String],
+ on: &[Column],
batch: &RecordBatch,
hash: &mut JoinHashMap,
offset: usize,
@@ -432,7 +422,7 @@ fn update_hash(
// evaluate the keys
let keys_values = on
.iter()
- .map(|name| Ok(col(name).evaluate(batch)?.into_array(batch.num_rows())))
+ .map(|c| Ok(c.evaluate(batch)?.into_array(batch.num_rows())))
.collect::<Result<Vec<_>>>()?;
// calculate the hash values
@@ -461,9 +451,9 @@ struct HashJoinStream {
/// Input schema
schema: Arc<Schema>,
/// columns from the left
- on_left: Vec<String>,
+ on_left: Vec<Column>,
/// columns from the right used to compute the hash
- on_right: Vec<String>,
+ on_right: Vec<Column>,
/// type of the join
join_type: JoinType,
/// information from the left
@@ -490,6 +480,39 @@ struct HashJoinStream {
is_exhausted: bool,
}
+#[allow(clippy::too_many_arguments)]
+impl HashJoinStream {
+ fn new(
+ schema: Arc<Schema>,
+ on_left: Vec<Column>,
+ on_right: Vec<Column>,
+ join_type: JoinType,
+ left_data: JoinLeftData,
+ right: SendableRecordBatchStream,
+ column_indices: Vec<ColumnIndex>,
+ random_state: RandomState,
+ visited_left_side: Vec<bool>,
+ ) -> Self {
+ HashJoinStream {
+ schema,
+ on_left,
+ on_right,
+ join_type,
+ left_data,
+ right,
+ column_indices,
+ num_input_batches: 0,
+ num_input_rows: 0,
+ num_output_batches: 0,
+ num_output_rows: 0,
+ join_time: 0,
+ random_state,
+ visited_left_side,
+ is_exhausted: false,
+ }
+ }
+}
+
impl RecordBatchStream for HashJoinStream {
fn schema(&self) -> SchemaRef {
self.schema.clone()
@@ -531,8 +554,8 @@ fn build_batch_from_indices(
fn build_batch(
batch: &RecordBatch,
left_data: &JoinLeftData,
- on_left: &[String],
- on_right: &[String],
+ on_left: &[Column],
+ on_right: &[Column],
join_type: JoinType,
schema: &Schema,
column_indices: &[ColumnIndex],
@@ -590,21 +613,17 @@ fn build_join_indexes(
left_data: &JoinLeftData,
right: &RecordBatch,
join_type: JoinType,
- left_on: &[String],
- right_on: &[String],
+ left_on: &[Column],
+ right_on: &[Column],
random_state: &RandomState,
) -> Result<(UInt64Array, UInt32Array)> {
let keys_values = right_on
.iter()
- .map(|name| Ok(col(name).evaluate(right)?.into_array(right.num_rows())))
+ .map(|c| Ok(c.evaluate(right)?.into_array(right.num_rows())))
.collect::<Result<Vec<_>>>()?;
let left_join_values = left_on
.iter()
- .map(|name| {
- Ok(col(name)
- .evaluate(&left_data.1)?
- .into_array(left_data.1.num_rows()))
- })
+ .map(|c| Ok(c.evaluate(&left_data.1)?.into_array(left_data.1.num_rows())))
.collect::<Result<Vec<_>>>()?;
let hashes_buffer = &mut vec![0; keys_values[0].len()];
let hash_values = create_hashes(&keys_values, random_state, hashes_buffer)?;
@@ -1250,6 +1269,7 @@ impl Stream for HashJoinStream {
| JoinType::Right => {}
}
+ // End of right batch, print stats in debug mode
debug!(
"Processed {} probe-side input batches containing {} rows and \
produced {} output batches containing {} rows in {} ms",
@@ -1269,7 +1289,9 @@ impl Stream for HashJoinStream {
mod tests {
use crate::{
assert_batches_sorted_eq,
- physical_plan::{common, memory::MemoryExec},
+ physical_plan::{
+ common, expressions::Column, memory::MemoryExec, repartition::RepartitionExec,
+ },
test::{build_table_i32, columns},
};
@@ -1289,14 +1311,74 @@ mod tests {
fn join(
left: Arc<dyn ExecutionPlan>,
right: Arc<dyn ExecutionPlan>,
- on: &[(&str, &str)],
+ on: JoinOn,
join_type: &JoinType,
) -> Result<HashJoinExec> {
- let on: Vec<_> = on
+ HashJoinExec::try_new(left, right, on, join_type, PartitionMode::CollectLeft)
+ }
+
+ async fn join_collect(
+ left: Arc<dyn ExecutionPlan>,
+ right: Arc<dyn ExecutionPlan>,
+ on: JoinOn,
+ join_type: &JoinType,
+ ) -> Result<(Vec<String>, Vec<RecordBatch>)> {
+ let join = join(left, right, on, join_type)?;
+ let columns = columns(&join.schema());
+
+ let stream = join.execute(0).await?;
+ let batches = common::collect(stream).await?;
+
+ Ok((columns, batches))
+ }
+
+ async fn partitioned_join_collect(
+ left: Arc<dyn ExecutionPlan>,
+ right: Arc<dyn ExecutionPlan>,
+ on: JoinOn,
+ join_type: &JoinType,
+ ) -> Result<(Vec<String>, Vec<RecordBatch>)> {
+ let partition_count = 4;
+
+ let (left_expr, right_expr) = on
.iter()
- .map(|(l, r)| (l.to_string(), r.to_string()))
- .collect();
- HashJoinExec::try_new(left, right, &on, join_type, PartitionMode::CollectLeft)
+ .map(|(l, r)| {
+ (
+ Arc::new(l.clone()) as Arc<dyn PhysicalExpr>,
+ Arc::new(r.clone()) as Arc<dyn PhysicalExpr>,
+ )
+ })
+ .unzip();
+
+ let join = HashJoinExec::try_new(
+ Arc::new(RepartitionExec::try_new(
+ left,
+ Partitioning::Hash(left_expr, partition_count),
+ )?),
+ Arc::new(RepartitionExec::try_new(
+ right,
+ Partitioning::Hash(right_expr, partition_count),
+ )?),
+ on,
+ join_type,
+ PartitionMode::Partitioned,
+ )?;
+
+ let columns = columns(&join.schema());
+
+ let mut batches = vec![];
+ for i in 0..partition_count {
+ let stream = join.execute(i).await?;
+ let more_batches = common::collect(stream).await?;
+ batches.extend(
+ more_batches
+ .into_iter()
+ .filter(|b| b.num_rows() > 0)
+ .collect::<Vec<_>>(),
+ );
+ }
+
+ Ok((columns, batches))
}
#[tokio::test]
@@ -1311,15 +1393,58 @@ mod tests {
("b1", &vec![4, 5, 6]),
("c2", &vec![70, 80, 90]),
);
- let on = &[("b1", "b1")];
- let join = join(left, right, on, &JoinType::Inner)?;
+ let on = vec![(
+ Column::new_with_schema("b1", &left.schema())?,
+ Column::new_with_schema("b1", &right.schema())?,
+ )];
+
+ let (columns, batches) =
+ join_collect(left.clone(), right.clone(), on.clone(), &JoinType::Inner)
+ .await?;
- let columns = columns(&join.schema());
assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "c2"]);
- let stream = join.execute(0).await?;
- let batches = common::collect(stream).await?;
+ let expected = vec![
+ "+----+----+----+----+----+",
+ "| a1 | b1 | c1 | a2 | c2 |",
+ "+----+----+----+----+----+",
+ "| 1 | 4 | 7 | 10 | 70 |",
+ "| 2 | 5 | 8 | 20 | 80 |",
+ "| 3 | 5 | 9 | 20 | 80 |",
+ "+----+----+----+----+----+",
+ ];
+ assert_batches_sorted_eq!(expected, &batches);
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn partitioned_join_inner_one() -> Result<()> {
+ let left = build_table(
+ ("a1", &vec![1, 2, 3]),
+ ("b1", &vec![4, 5, 5]), // this has a repetition
+ ("c1", &vec![7, 8, 9]),
+ );
+ let right = build_table(
+ ("a2", &vec![10, 20, 30]),
+ ("b1", &vec![4, 5, 6]),
+ ("c2", &vec![70, 80, 90]),
+ );
+ let on = vec![(
+ Column::new_with_schema("b1", &left.schema())?,
+ Column::new_with_schema("b1", &right.schema())?,
+ )];
+
+ let (columns, batches) = partitioned_join_collect(
+ left.clone(),
+ right.clone(),
+ on.clone(),
+ &JoinType::Inner,
+ )
+ .await?;
+
+ assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "c2"]);
let expected = vec![
"+----+----+----+----+----+",
@@ -1347,16 +1472,15 @@ mod tests {
("b2", &vec![4, 5, 6]),
("c2", &vec![70, 80, 90]),
);
- let on = &[("b1", "b2")];
+ let on = vec![(
+ Column::new_with_schema("b1", &left.schema())?,
+ Column::new_with_schema("b2", &right.schema())?,
+ )];
- let join = join(left, right, on, &JoinType::Inner)?;
+ let (columns, batches) = join_collect(left, right, on, &JoinType::Inner).await?;
- let columns = columns(&join.schema());
assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]);
- let stream = join.execute(0).await?;
- let batches = common::collect(stream).await?;
-
let expected = vec![
"+----+----+----+----+----+----+",
"| a1 | b1 | c1 | a2 | b2 | c2 |",
@@ -1384,15 +1508,21 @@ mod tests {
("b2", &vec![1, 2, 2]),
("c2", &vec![70, 80, 90]),
);
- let on = &[("a1", "a1"), ("b2", "b2")];
+ let on = vec![
+ (
+ Column::new_with_schema("a1", &left.schema())?,
+ Column::new_with_schema("a1", &right.schema())?,
+ ),
+ (
+ Column::new_with_schema("b2", &left.schema())?,
+ Column::new_with_schema("b2", &right.schema())?,
+ ),
+ ];
- let join = join(left, right, on, &JoinType::Inner)?;
+ let (columns, batches) = join_collect(left, right, on, &JoinType::Inner).await?;
- let columns = columns(&join.schema());
assert_eq!(columns, vec!["a1", "b2", "c1", "c2"]);
- let stream = join.execute(0).await?;
- let batches = common::collect(stream).await?;
assert_eq!(batches.len(), 1);
let expected = vec![
@@ -1430,15 +1560,21 @@ mod tests {
("b2", &vec![1, 2, 2]),
("c2", &vec![70, 80, 90]),
);
- let on = &[("a1", "a1"), ("b2", "b2")];
+ let on = vec![
+ (
+ Column::new_with_schema("a1", &left.schema())?,
+ Column::new_with_schema("a1", &right.schema())?,
+ ),
+ (
+ Column::new_with_schema("b2", &left.schema())?,
+ Column::new_with_schema("b2", &right.schema())?,
+ ),
+ ];
- let join = join(left, right, on, &JoinType::Inner)?;
+ let (columns, batches) = join_collect(left, right, on, &JoinType::Inner).await?;
- let columns = columns(&join.schema());
assert_eq!(columns, vec!["a1", "b2", "c1", "c2"]);
- let stream = join.execute(0).await?;
- let batches = common::collect(stream).await?;
assert_eq!(batches.len(), 1);
let expected = vec![
@@ -1477,7 +1613,10 @@ mod tests {
MemoryExec::try_new(&[vec![batch1], vec![batch2]], schema, None).unwrap(),
);
- let on = &[("b1", "b1")];
+ let on = vec![(
+ Column::new_with_schema("b1", &left.schema())?,
+ Column::new_with_schema("b1", &right.schema())?,
+ )];
let join = join(left, right, on, &JoinType::Inner)?;
@@ -1540,7 +1679,10 @@ mod tests {
("b1", &vec![4, 5, 6]),
("c2", &vec![70, 80, 90]),
);
- let on = &[("b1", "b1")];
+ let on = vec![(
+ Column::new_with_schema("b1", &left.schema()).unwrap(),
+ Column::new_with_schema("b1", &right.schema()).unwrap(),
+ )];
let join = join(left, right, on, &JoinType::Left).unwrap();
@@ -1578,7 +1720,10 @@ mod tests {
("b2", &vec![4, 5, 6]),
("c2", &vec![70, 80, 90]),
);
- let on = &[("b1", "b2")];
+ let on = vec![(
+ Column::new_with_schema("b1", &left.schema()).unwrap(),
+ Column::new_with_schema("b2", &right.schema()).unwrap(),
+ )];
let join = join(left, right, on, &JoinType::Full).unwrap();
@@ -1613,7 +1758,10 @@ mod tests {
("c1", &vec![7, 8, 9]),
);
let right = build_table_i32(("a2", &vec![]), ("b1", &vec![]), ("c2", &vec![]));
- let on = &[("b1", "b1")];
+ let on = vec![(
+ Column::new_with_schema("b1", &left.schema()).unwrap(),
+ Column::new_with_schema("b1", &right.schema()).unwrap(),
+ )];
let schema = right.schema();
let right = Arc::new(MemoryExec::try_new(&[vec![right]], schema, None).unwrap());
let join = join(left, right, on, &JoinType::Left).unwrap();
@@ -1645,7 +1793,10 @@ mod tests {
("c1", &vec![7, 8, 9]),
);
let right = build_table_i32(("a2", &vec![]), ("b2", &vec![]), ("c2", &vec![]));
- let on = &[("b1", "b2")];
+ let on = vec![(
+ Column::new_with_schema("b1", &left.schema()).unwrap(),
+ Column::new_with_schema("b2", &right.schema()).unwrap(),
+ )];
let schema = right.schema();
let right = Arc::new(MemoryExec::try_new(&[vec![right]], schema, None).unwrap());
let join = join(left, right, on, &JoinType::Full).unwrap();
@@ -1681,15 +1832,55 @@ mod tests {
("b1", &vec![4, 5, 6]),
("c2", &vec![70, 80, 90]),
);
- let on = &[("b1", "b1")];
+ let on = vec![(
+ Column::new_with_schema("b1", &left.schema())?,
+ Column::new_with_schema("b1", &right.schema())?,
+ )];
+
+ let (columns, batches) =
+ join_collect(left.clone(), right.clone(), on.clone(), &JoinType::Left)
+ .await?;
+ assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "c2"]);
- let join = join(left, right, on, &JoinType::Left)?;
+ let expected = vec![
+ "+----+----+----+----+----+",
+ "| a1 | b1 | c1 | a2 | c2 |",
+ "+----+----+----+----+----+",
+ "| 1 | 4 | 7 | 10 | 70 |",
+ "| 2 | 5 | 8 | 20 | 80 |",
+ "| 3 | 7 | 9 | | |",
+ "+----+----+----+----+----+",
+ ];
+ assert_batches_sorted_eq!(expected, &batches);
- let columns = columns(&join.schema());
- assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "c2"]);
+ Ok(())
+ }
- let stream = join.execute(0).await?;
- let batches = common::collect(stream).await?;
+ #[tokio::test]
+ async fn partitioned_join_left_one() -> Result<()> {
+ let left = build_table(
+ ("a1", &vec![1, 2, 3]),
+ ("b1", &vec![4, 5, 7]), // 7 does not exist on the right
+ ("c1", &vec![7, 8, 9]),
+ );
+ let right = build_table(
+ ("a2", &vec![10, 20, 30]),
+ ("b1", &vec![4, 5, 6]),
+ ("c2", &vec![70, 80, 90]),
+ );
+ let on = vec![(
+ Column::new_with_schema("b1", &left.schema())?,
+ Column::new_with_schema("b1", &right.schema())?,
+ )];
+
+ let (columns, batches) = partitioned_join_collect(
+ left.clone(),
+ right.clone(),
+ on.clone(),
+ &JoinType::Left,
+ )
+ .await?;
+ assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "c2"]);
let expected = vec![
"+----+----+----+----+----+",
@@ -1717,7 +1908,10 @@ mod tests {
("b1", &vec![4, 5, 6, 5]), // 5 is double on the right
("c2", &vec![70, 80, 90, 100]),
);
- let on = &[("b1", "b1")];
+ let on = vec![(
+ Column::new_with_schema("b1", &left.schema())?,
+ Column::new_with_schema("b1", &right.schema())?,
+ )];
let join = join(left, right, on, &JoinType::Semi)?;
@@ -1753,7 +1947,10 @@ mod tests {
("b1", &vec![4, 5, 6, 5]), // 5 is double on the right
("c2", &vec![70, 80, 90, 100]),
);
- let on = &[("b1", "b1")];
+ let on = vec![(
+ Column::new_with_schema("b1", &left.schema())?,
+ Column::new_with_schema("b1", &right.schema())?,
+ )];
let join = join(left, right, on, &JoinType::Anti)?;
@@ -1787,15 +1984,51 @@ mod tests {
("b1", &vec![4, 5, 6]), // 6 does not exist on the left
("c2", &vec![70, 80, 90]),
);
- let on = &[("b1", "b1")];
+ let on = vec![(
+ Column::new_with_schema("b1", &left.schema())?,
+ Column::new_with_schema("b1", &right.schema())?,
+ )];
- let join = join(left, right, on, &JoinType::Right)?;
+ let (columns, batches) = join_collect(left, right, on, &JoinType::Right).await?;
- let columns = columns(&join.schema());
assert_eq!(columns, vec!["a1", "c1", "a2", "b1", "c2"]);
- let stream = join.execute(0).await?;
- let batches = common::collect(stream).await?;
+ let expected = vec![
+ "+----+----+----+----+----+",
+ "| a1 | c1 | a2 | b1 | c2 |",
+ "+----+----+----+----+----+",
+ "| | | 30 | 6 | 90 |",
+ "| 1 | 7 | 10 | 4 | 70 |",
+ "| 2 | 8 | 20 | 5 | 80 |",
+ "+----+----+----+----+----+",
+ ];
+
+ assert_batches_sorted_eq!(expected, &batches);
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn partitioned_join_right_one() -> Result<()> {
+ let left = build_table(
+ ("a1", &vec![1, 2, 3]),
+ ("b1", &vec![4, 5, 7]),
+ ("c1", &vec![7, 8, 9]),
+ );
+ let right = build_table(
+ ("a2", &vec![10, 20, 30]),
+ ("b1", &vec![4, 5, 6]), // 6 does not exist on the left
+ ("c2", &vec![70, 80, 90]),
+ );
+ let on = vec![(
+ Column::new_with_schema("b1", &left.schema())?,
+ Column::new_with_schema("b1", &right.schema())?,
+ )];
+
+ let (columns, batches) =
+ partitioned_join_collect(left, right, on, &JoinType::Right).await?;
+
+ assert_eq!(columns, vec!["a1", "c1", "a2", "b1", "c2"]);
let expected = vec![
"+----+----+----+----+----+",
@@ -1824,7 +2057,10 @@ mod tests {
("b2", &vec![4, 5, 6]),
("c2", &vec![70, 80, 90]),
);
- let on = &[("b1", "b2")];
+ let on = vec![(
+ Column::new_with_schema("b1", &left.schema()).unwrap(),
+ Column::new_with_schema("b2", &right.schema()).unwrap(),
+ )];
let join = join(left, right, on, &JoinType::Full)?;
@@ -1904,8 +2140,8 @@ mod tests {
&left_data,
&right,
JoinType::Inner,
- &["a".to_string()],
- &["a".to_string()],
+ &[Column::new("a", 0)],
+ &[Column::new("a", 0)],
&random_state,
)?;
@@ -1914,7 +2150,6 @@ mod tests {
left_ids.append_value(1)?;
let mut right_ids = UInt32Builder::new(0);
-
right_ids.append_value(0)?;
right_ids.append_value(1)?;
diff --git a/datafusion/src/physical_plan/hash_utils.rs b/datafusion/src/physical_plan/hash_utils.rs
index a48710b..0cf0b92 100644
--- a/datafusion/src/physical_plan/hash_utils.rs
+++ b/datafusion/src/physical_plan/hash_utils.rs
@@ -21,6 +21,8 @@ use crate::error::{DataFusionError, Result};
use arrow::datatypes::{Field, Schema};
use std::collections::HashSet;
+use crate::physical_plan::expressions::Column;
+
/// All valid types of joins.
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum JoinType {
@@ -39,14 +41,25 @@ pub enum JoinType {
}
/// The on clause of the join, as vector of (left, right) columns.
-pub type JoinOn = [(String, String)];
+pub type JoinOn = Vec<(Column, Column)>;
+/// Reference for JoinOn.
+pub type JoinOnRef<'a> = &'a [(Column, Column)];
/// Checks whether the schemas "left" and "right" and columns "on" represent a valid join.
/// They are valid whenever their columns' intersection equals the set `on`
-pub fn check_join_is_valid(left: &Schema, right: &Schema, on: &JoinOn) -> Result<()> {
- let left: HashSet<String> = left.fields().iter().map(|f| f.name().clone()).collect();
- let right: HashSet<String> =
- right.fields().iter().map(|f| f.name().clone()).collect();
+pub fn check_join_is_valid(left: &Schema, right: &Schema, on: JoinOnRef) -> Result<()> {
+ let left: HashSet<Column> = left
+ .fields()
+ .iter()
+ .enumerate()
+ .map(|(idx, f)| Column::new(f.name(), idx))
+ .collect();
+ let right: HashSet<Column> = right
+ .fields()
+ .iter()
+ .enumerate()
+ .map(|(idx, f)| Column::new(f.name(), idx))
+ .collect();
check_join_set_is_valid(&left, &right, on)
}
@@ -54,14 +67,14 @@ pub fn check_join_is_valid(left: &Schema, right: &Schema, on: &JoinOn) -> Result
/// Checks whether the sets left, right and on compose a valid join.
/// They are valid whenever their intersection equals the set `on`
fn check_join_set_is_valid(
- left: &HashSet<String>,
- right: &HashSet<String>,
- on: &JoinOn,
+ left: &HashSet<Column>,
+ right: &HashSet<Column>,
+ on: &[(Column, Column)],
) -> Result<()> {
- let on_left = &on.iter().map(|on| on.0.to_string()).collect::<HashSet<_>>();
+ let on_left = &on.iter().map(|on| on.0.clone()).collect::<HashSet<_>>();
let left_missing = on_left.difference(left).collect::<HashSet<_>>();
- let on_right = &on.iter().map(|on| on.1.to_string()).collect::<HashSet<_>>();
+ let on_right = &on.iter().map(|on| on.1.clone()).collect::<HashSet<_>>();
let right_missing = on_right.difference(right).collect::<HashSet<_>>();
if !left_missing.is_empty() | !right_missing.is_empty() {
@@ -75,7 +88,7 @@ fn check_join_set_is_valid(
let remaining = right
.difference(on_right)
.cloned()
- .collect::<HashSet<String>>();
+ .collect::<HashSet<Column>>();
let collisions = left.intersection(&remaining).collect::<HashSet<_>>();
@@ -94,7 +107,7 @@ fn check_join_set_is_valid(
pub fn build_join_schema(
left: &Schema,
right: &Schema,
- on: &JoinOn,
+ on: JoinOnRef,
join_type: &JoinType,
) -> Schema {
let fields: Vec<Field> = match join_type {
@@ -102,8 +115,8 @@ pub fn build_join_schema(
// remove right-side join keys if they have the same names as the left-side
let duplicate_keys = &on
.iter()
- .filter(|(l, r)| l == r)
- .map(|on| on.1.to_string())
+ .filter(|(l, r)| l.name() == r.name())
+ .map(|on| on.1.name())
.collect::<HashSet<_>>();
let left_fields = left.fields().iter();
@@ -111,7 +124,7 @@ pub fn build_join_schema(
let right_fields = right
.fields()
.iter()
- .filter(|f| !duplicate_keys.contains(f.name()));
+ .filter(|f| !duplicate_keys.contains(f.name().as_str()));
// left then right
left_fields.chain(right_fields).cloned().collect()
@@ -120,14 +133,14 @@ pub fn build_join_schema(
// remove left-side join keys if they have the same names as the right-side
let duplicate_keys = &on
.iter()
- .filter(|(l, r)| l == r)
- .map(|on| on.1.to_string())
+ .filter(|(l, r)| l.name() == r.name())
+ .map(|on| on.1.name())
.collect::<HashSet<_>>();
let left_fields = left
.fields()
.iter()
- .filter(|f| !duplicate_keys.contains(f.name()));
+ .filter(|f| !duplicate_keys.contains(f.name().as_str()));
let right_fields = right.fields().iter();
@@ -141,24 +154,25 @@ pub fn build_join_schema(
#[cfg(test)]
mod tests {
-
use super::*;
- fn check(left: &[&str], right: &[&str], on: &[(&str, &str)]) -> Result<()> {
- let left = left.iter().map(|x| x.to_string()).collect::<HashSet<_>>();
- let right = right.iter().map(|x| x.to_string()).collect::<HashSet<_>>();
- let on: Vec<_> = on
+ fn check(left: &[Column], right: &[Column], on: &[(Column, Column)]) -> Result<()> {
+ let left = left
+ .iter()
+ .map(|x| x.to_owned())
+ .collect::<HashSet<Column>>();
+ let right = right
.iter()
- .map(|(l, r)| (l.to_string(), r.to_string()))
- .collect();
- check_join_set_is_valid(&left, &right, &on)
+ .map(|x| x.to_owned())
+ .collect::<HashSet<Column>>();
+ check_join_set_is_valid(&left, &right, on)
}
#[test]
fn check_valid() -> Result<()> {
- let left = vec!["a", "b1"];
- let right = vec!["a", "b2"];
- let on = &[("a", "a")];
+ let left = vec![Column::new("a", 0), Column::new("b1", 1)];
+ let right = vec![Column::new("a", 0), Column::new("b2", 1)];
+ let on = &[(Column::new("a", 0), Column::new("a", 0))];
check(&left, &right, on)?;
Ok(())
@@ -166,18 +180,18 @@ mod tests {
#[test]
fn check_not_in_right() {
- let left = vec!["a", "b"];
- let right = vec!["b"];
- let on = &[("a", "a")];
+ let left = vec![Column::new("a", 0), Column::new("b", 1)];
+ let right = vec![Column::new("b", 0)];
+ let on = &[(Column::new("a", 0), Column::new("a", 0))];
assert!(check(&left, &right, on).is_err());
}
#[test]
fn check_not_in_left() {
- let left = vec!["b"];
- let right = vec!["a"];
- let on = &[("a", "a")];
+ let left = vec![Column::new("b", 0)];
+ let right = vec![Column::new("a", 0)];
+ let on = &[(Column::new("a", 0), Column::new("a", 0))];
assert!(check(&left, &right, on).is_err());
}
@@ -185,18 +199,18 @@ mod tests {
#[test]
fn check_collision() {
// column "a" would appear both in left and right
- let left = vec!["a", "c"];
- let right = vec!["a", "b"];
- let on = &[("a", "b")];
+ let left = vec![Column::new("a", 0), Column::new("c", 1)];
+ let right = vec![Column::new("a", 0), Column::new("b", 1)];
+ let on = &[(Column::new("a", 0), Column::new("b", 1))];
assert!(check(&left, &right, on).is_err());
}
#[test]
fn check_in_right() {
- let left = vec!["a", "c"];
- let right = vec!["b"];
- let on = &[("a", "b")];
+ let left = vec![Column::new("a", 0), Column::new("c", 1)];
+ let right = vec![Column::new("b", 0)];
+ let on = &[(Column::new("a", 0), Column::new("b", 0))];
assert!(check(&left, &right, on).is_ok());
}
diff --git a/datafusion/src/physical_plan/mod.rs b/datafusion/src/physical_plan/mod.rs
index 50c30a5..7b26d7b 100644
--- a/datafusion/src/physical_plan/mod.rs
+++ b/datafusion/src/physical_plan/mod.rs
@@ -211,9 +211,9 @@ pub trait ExecutionPlan: Debug + Send + Sync {
/// let displayable_plan = displayable(physical_plan.as_ref());
/// let plan_string = format!("{}", displayable_plan.indent());
///
-/// assert_eq!("ProjectionExec: expr=[a]\
+/// assert_eq!("ProjectionExec: expr=[a@0 as a]\
/// \n CoalesceBatchesExec: target_batch_size=4096\
-/// \n FilterExec: a < 5\
+/// \n FilterExec: a@0 < 5\
/// \n RepartitionExec: partitioning=RoundRobinBatch(3)\
/// \n CsvExec: source=Path(tests/example.csv: [tests/example.csv]), has_header=true",
/// plan_string.trim());
diff --git a/datafusion/src/physical_plan/parquet.rs b/datafusion/src/physical_plan/parquet.rs
index 2bea94a..3d20a9b 100644
--- a/datafusion/src/physical_plan/parquet.rs
+++ b/datafusion/src/physical_plan/parquet.rs
@@ -25,7 +25,7 @@ use std::{any::Any, convert::TryInto};
use crate::{
error::{DataFusionError, Result},
- logical_plan::Expr,
+ logical_plan::{Column, Expr},
physical_optimizer::pruning::{PruningPredicate, PruningStatistics},
physical_plan::{
common, DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream,
@@ -497,7 +497,7 @@ macro_rules! get_statistic {
// Extract the min or max value calling `func` or `bytes_func` on the ParquetStatistics as appropriate
macro_rules! get_min_max_values {
($self:expr, $column:expr, $func:ident, $bytes_func:ident) => {{
- let (column_index, field) = if let Some((v, f)) = $self.parquet_schema.column_with_name($column) {
+ let (column_index, field) = if let Some((v, f)) = $self.parquet_schema.column_with_name(&$column.name) {
(v, f)
} else {
// Named column was not present
@@ -532,11 +532,11 @@ macro_rules! get_min_max_values {
}
impl<'a> PruningStatistics for RowGroupPruningStatistics<'a> {
- fn min_values(&self, column: &str) -> Option<ArrayRef> {
+ fn min_values(&self, column: &Column) -> Option<ArrayRef> {
get_min_max_values!(self, column, min, min_bytes)
}
- fn max_values(&self, column: &str) -> Option<ArrayRef> {
+ fn max_values(&self, column: &Column) -> Option<ArrayRef> {
get_min_max_values!(self, column, max, max_bytes)
}
@@ -593,7 +593,6 @@ fn read_files(
loop {
match batch_reader.next() {
Some(Ok(batch)) => {
- //println!("ParquetExec got new batch from {}", filename);
total_rows += batch.num_rows();
send_result(&response_tx, Ok(batch))?;
if limit.map(|l| total_rows >= l).unwrap_or(false) {
diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs
index af0e60f..a4c20a7 100644
--- a/datafusion/src/physical_plan/planner.rs
+++ b/datafusion/src/physical_plan/planner.rs
@@ -56,6 +56,121 @@ use expressions::col;
use log::debug;
use std::sync::Arc;
+fn create_function_physical_name(
+ fun: &str,
+ distinct: bool,
+ args: &[Expr],
+ input_schema: &DFSchema,
+) -> Result<String> {
+ let names: Vec<String> = args
+ .iter()
+ .map(|e| physical_name(e, input_schema))
+ .collect::<Result<_>>()?;
+
+ let distinct_str = match distinct {
+ true => "DISTINCT ",
+ false => "",
+ };
+ Ok(format!("{}({}{})", fun, distinct_str, names.join(",")))
+}
+
+fn physical_name(e: &Expr, input_schema: &DFSchema) -> Result<String> {
+ match e {
+ Expr::Column(c) => Ok(c.name.clone()),
+ Expr::Alias(_, name) => Ok(name.clone()),
+ Expr::ScalarVariable(variable_names) => Ok(variable_names.join(".")),
+ Expr::Literal(value) => Ok(format!("{:?}", value)),
+ Expr::BinaryExpr { left, op, right } => {
+ let left = physical_name(left, input_schema)?;
+ let right = physical_name(right, input_schema)?;
+ Ok(format!("{} {:?} {}", left, op, right))
+ }
+ Expr::Case {
+ expr,
+ when_then_expr,
+ else_expr,
+ } => {
+ let mut name = "CASE ".to_string();
+ if let Some(e) = expr {
+ name += &format!("{:?} ", e);
+ }
+ for (w, t) in when_then_expr {
+ name += &format!("WHEN {:?} THEN {:?} ", w, t);
+ }
+ if let Some(e) = else_expr {
+ name += &format!("ELSE {:?} ", e);
+ }
+ name += "END";
+ Ok(name)
+ }
+ Expr::Cast { expr, data_type } => {
+ let expr = physical_name(expr, input_schema)?;
+ Ok(format!("CAST({} AS {:?})", expr, data_type))
+ }
+ Expr::TryCast { expr, data_type } => {
+ let expr = physical_name(expr, input_schema)?;
+ Ok(format!("TRY_CAST({} AS {:?})", expr, data_type))
+ }
+ Expr::Not(expr) => {
+ let expr = physical_name(expr, input_schema)?;
+ Ok(format!("NOT {}", expr))
+ }
+ Expr::Negative(expr) => {
+ let expr = physical_name(expr, input_schema)?;
+ Ok(format!("(- {})", expr))
+ }
+ Expr::IsNull(expr) => {
+ let expr = physical_name(expr, input_schema)?;
+ Ok(format!("{} IS NULL", expr))
+ }
+ Expr::IsNotNull(expr) => {
+ let expr = physical_name(expr, input_schema)?;
+ Ok(format!("{} IS NOT NULL", expr))
+ }
+ Expr::ScalarFunction { fun, args, .. } => {
+ create_function_physical_name(&fun.to_string(), false, args, input_schema)
+ }
+ Expr::ScalarUDF { fun, args, .. } => {
+ create_function_physical_name(&fun.name, false, args, input_schema)
+ }
+ Expr::WindowFunction { fun, args, .. } => {
+ create_function_physical_name(&fun.to_string(), false, args, input_schema)
+ }
+ Expr::AggregateFunction {
+ fun,
+ distinct,
+ args,
+ ..
+ } => {
+ create_function_physical_name(&fun.to_string(), *distinct, args, input_schema)
+ }
+ Expr::AggregateUDF { fun, args } => {
+ let mut names = Vec::with_capacity(args.len());
+ for e in args {
+ names.push(physical_name(e, input_schema)?);
+ }
+ Ok(format!("{}({})", fun.name, names.join(",")))
+ }
+ Expr::InList {
+ expr,
+ list,
+ negated,
+ } => {
+ let expr = physical_name(expr, input_schema)?;
+ let list = list.iter().map(|expr| physical_name(expr, input_schema));
+ if *negated {
+ Ok(format!("{} NOT IN ({:?})", expr, list))
+ } else {
+ Ok(format!("{} IN ({:?})", expr, list))
+ }
+ }
+ other => Err(DataFusionError::NotImplemented(format!(
+ "Cannot derive physical field name for logical expression {:?}",
+ other
+ ))),
+ }
+}
+
/// This trait exposes the ability to plan an [`ExecutionPlan`] out of a [`LogicalPlan`].
pub trait ExtensionPlanner {
/// Create a physical plan for a [`UserDefinedLogicalNode`].
@@ -150,10 +265,8 @@ impl DefaultPhysicalPlanner {
}
let input_exec = self.create_initial_plan(input, ctx_state)?;
- let input_schema = input_exec.schema();
-
+ let physical_input_schema = input_exec.schema();
let logical_input_schema = input.as_ref().schema();
- let physical_input_schema = input_exec.as_ref().schema();
let window_expr = window_expr
.iter()
@@ -170,7 +283,7 @@ impl DefaultPhysicalPlanner {
Ok(Arc::new(WindowAggExec::try_new(
window_expr,
input_exec.clone(),
- input_schema,
+ physical_input_schema,
)?))
}
LogicalPlan::Aggregate {
@@ -181,8 +294,7 @@ impl DefaultPhysicalPlanner {
} => {
// Initially need to perform the aggregate and then merge the partitions
let input_exec = self.create_initial_plan(input, ctx_state)?;
- let input_schema = input_exec.schema();
- let physical_input_schema = input_exec.as_ref().schema();
+ let physical_input_schema = input_exec.schema();
let logical_input_schema = input.as_ref().schema();
let groups = group_expr
@@ -191,10 +303,11 @@ impl DefaultPhysicalPlanner {
tuple_err((
self.create_physical_expr(
e,
+ logical_input_schema,
&physical_input_schema,
ctx_state,
),
- e.name(logical_input_schema),
+ physical_name(e, logical_input_schema),
))
})
.collect::<Result<Vec<_>>>()?;
@@ -215,11 +328,13 @@ impl DefaultPhysicalPlanner {
groups.clone(),
aggregates.clone(),
input_exec,
- input_schema.clone(),
+ physical_input_schema.clone(),
)?);
- let final_group: Vec<Arc<dyn PhysicalExpr>> =
- (0..groups.len()).map(|i| col(&groups[i].1)).collect();
+ // update group column indices based on partial aggregate plan evaluation
+ let final_group: Vec<Arc<dyn PhysicalExpr>> = (0..groups.len())
+ .map(|i| col(&groups[i].1, &initial_aggr.schema()))
+ .collect::<Result<_>>()?;
// TODO: dictionary type not yet supported in Hash Repartition
let contains_dict = groups
@@ -261,31 +376,74 @@ impl DefaultPhysicalPlanner {
.collect(),
aggregates,
initial_aggr,
- input_schema,
+ physical_input_schema.clone(),
)?))
}
LogicalPlan::Projection { input, expr, .. } => {
let input_exec = self.create_initial_plan(input, ctx_state)?;
let input_schema = input.as_ref().schema();
- let runtime_expr = expr
+
+ let physical_exprs = expr
.iter()
.map(|e| {
+ // For projections, SQL planner and logical plan builder may convert user
+ // provided expressions into logical Column expressions if their results
+ // are already provided from the input plans. Because we work with
+ // qualified columns in logical plane, derived columns involve operators or
+ // functions will contain qualifers as well. This will result in logical
+ // columns with names like `SUM(t1.c1)`, `t1.c1 + t1.c2`, etc.
+ //
+ // If we run these logical columns through physical_name function, we will
+ // get physical names with column qualifiers, which violates Datafusion's
+ // field name semantics. To account for this, we need to derive the
+ // physical name from physical input instead.
+ //
+ // This depends on the invariant that logical schema field index MUST match
+ // with physical schema field index.
+ let physical_name = if let Expr::Column(col) = e {
+ match input_schema.index_of_column(col) {
+ Ok(idx) => {
+ // index physical field using logical field index
+ Ok(input_exec.schema().field(idx).name().to_string())
+ }
+ // logical column is not a derived column, safe to pass along to
+ // physical_name
+ Err(_) => physical_name(e, input_schema),
+ }
+ } else {
+ physical_name(e, input_schema)
+ };
+
tuple_err((
- self.create_physical_expr(e, &input_exec.schema(), ctx_state),
- e.name(input_schema),
+ self.create_physical_expr(
+ e,
+ input_schema,
+ &input_exec.schema(),
+ ctx_state,
+ ),
+ physical_name,
))
})
.collect::<Result<Vec<_>>>()?;
- Ok(Arc::new(ProjectionExec::try_new(runtime_expr, input_exec)?))
+
+ Ok(Arc::new(ProjectionExec::try_new(
+ physical_exprs,
+ input_exec,
+ )?))
}
LogicalPlan::Filter {
input, predicate, ..
} => {
- let input = self.create_initial_plan(input, ctx_state)?;
- let input_schema = input.as_ref().schema();
- let runtime_expr =
- self.create_physical_expr(predicate, &input_schema, ctx_state)?;
- Ok(Arc::new(FilterExec::try_new(runtime_expr, input)?))
+ let physical_input = self.create_initial_plan(input, ctx_state)?;
+ let input_schema = physical_input.as_ref().schema();
+ let input_dfschema = input.as_ref().schema();
+ let runtime_expr = self.create_physical_expr(
+ predicate,
+ input_dfschema,
+ &input_schema,
+ ctx_state,
+ )?;
+ Ok(Arc::new(FilterExec::try_new(runtime_expr, physical_input)?))
}
LogicalPlan::Union { inputs, .. } => {
let physical_plans = inputs
@@ -298,8 +456,9 @@ impl DefaultPhysicalPlanner {
input,
partitioning_scheme,
} => {
- let input = self.create_initial_plan(input, ctx_state)?;
- let input_schema = input.schema();
+ let physical_input = self.create_initial_plan(input, ctx_state)?;
+ let input_schema = physical_input.schema();
+ let input_dfschema = input.as_ref().schema();
let physical_partitioning = match partitioning_scheme {
LogicalPartitioning::RoundRobinBatch(n) => {
Partitioning::RoundRobinBatch(*n)
@@ -308,20 +467,26 @@ impl DefaultPhysicalPlanner {
let runtime_expr = expr
.iter()
.map(|e| {
- self.create_physical_expr(e, &input_schema, ctx_state)
+ self.create_physical_expr(
+ e,
+ input_dfschema,
+ &input_schema,
+ ctx_state,
+ )
})
.collect::<Result<Vec<_>>>()?;
Partitioning::Hash(runtime_expr, *n)
}
};
Ok(Arc::new(RepartitionExec::try_new(
- input,
+ physical_input,
physical_partitioning,
)?))
}
LogicalPlan::Sort { expr, input, .. } => {
- let input = self.create_initial_plan(input, ctx_state)?;
- let input_schema = input.as_ref().schema();
+ let physical_input = self.create_initial_plan(input, ctx_state)?;
+ let input_schema = physical_input.as_ref().schema();
+ let input_dfschema = input.as_ref().schema();
let sort_expr = expr
.iter()
@@ -332,6 +497,7 @@ impl DefaultPhysicalPlanner {
nulls_first,
} => self.create_physical_sort_expr(
expr,
+ input_dfschema,
&input_schema,
SortOptions {
descending: !*asc,
@@ -345,7 +511,7 @@ impl DefaultPhysicalPlanner {
})
.collect::<Result<Vec<_>>>()?;
- Ok(Arc::new(SortExec::try_new(sort_expr, input)?))
+ Ok(Arc::new(SortExec::try_new(sort_expr, physical_input)?))
}
LogicalPlan::Join {
left,
@@ -354,8 +520,10 @@ impl DefaultPhysicalPlanner {
join_type,
..
} => {
- let left = self.create_initial_plan(left, ctx_state)?;
- let right = self.create_initial_plan(right, ctx_state)?;
+ let left_df_schema = left.schema();
+ let physical_left = self.create_initial_plan(left, ctx_state)?;
+ let right_df_schema = right.schema();
+ let physical_right = self.create_initial_plan(right, ctx_state)?;
let physical_join_type = match join_type {
JoinType::Inner => hash_utils::JoinType::Inner,
JoinType::Left => hash_utils::JoinType::Left,
@@ -364,30 +532,47 @@ impl DefaultPhysicalPlanner {
JoinType::Semi => hash_utils::JoinType::Semi,
JoinType::Anti => hash_utils::JoinType::Anti,
};
+ let join_on = keys
+ .iter()
+ .map(|(l, r)| {
+ Ok((
+ Column::new(&l.name, left_df_schema.index_of_column(l)?),
+ Column::new(&r.name, right_df_schema.index_of_column(r)?),
+ ))
+ })
+ .collect::<Result<hash_utils::JoinOn>>()?;
+
if ctx_state.config.concurrency > 1 && ctx_state.config.repartition_joins
{
- let left_expr = keys.iter().map(|x| col(&x.0)).collect();
- let right_expr = keys.iter().map(|x| col(&x.1)).collect();
+ let (left_expr, right_expr) = join_on
+ .iter()
+ .map(|(l, r)| {
+ (
+ Arc::new(l.clone()) as Arc<dyn PhysicalExpr>,
+ Arc::new(r.clone()) as Arc<dyn PhysicalExpr>,
+ )
+ })
+ .unzip();
// Use hash partition by default to parallelize hash joins
Ok(Arc::new(HashJoinExec::try_new(
Arc::new(RepartitionExec::try_new(
- left,
+ physical_left,
Partitioning::Hash(left_expr, ctx_state.config.concurrency),
)?),
Arc::new(RepartitionExec::try_new(
- right,
+ physical_right,
Partitioning::Hash(right_expr, ctx_state.config.concurrency),
)?),
- keys,
+ join_on,
&physical_join_type,
PartitionMode::Partitioned,
)?))
} else {
Ok(Arc::new(HashJoinExec::try_new(
- left,
- right,
- keys,
+ physical_left,
+ physical_right,
+ join_on,
&physical_join_type,
PartitionMode::CollectLeft,
)?))
@@ -476,10 +661,10 @@ impl DefaultPhysicalPlanner {
"No installed planner was able to convert the custom node to an execution plan: {:?}", node
)))?;
- // Ensure the ExecutionPlan's schema matches the
+ // Ensure the ExecutionPlan's schema matches the
// declared logical schema to catch and warn about
// logic errors when creating user defined plans.
- if plan.schema() != node.schema().as_ref().to_owned().into() {
+ if !node.schema().matches_arrow_schema(&plan.schema()) {
Err(DataFusionError::Plan(format!(
"Extension planner for {:?} created an ExecutionPlan with mismatched schema. \
LogicalPlan schema: {:?}, ExecutionPlan schema: {:?}",
@@ -496,17 +681,20 @@ impl DefaultPhysicalPlanner {
pub fn create_physical_expr(
&self,
e: &Expr,
+ input_dfschema: &DFSchema,
input_schema: &Schema,
ctx_state: &ExecutionContextState,
) -> Result<Arc<dyn PhysicalExpr>> {
match e {
- Expr::Alias(expr, ..) => {
- Ok(self.create_physical_expr(expr, input_schema, ctx_state)?)
- }
- Expr::Column(name) => {
- // check that name exists
- input_schema.field_with_name(name)?;
- Ok(Arc::new(Column::new(name)))
+ Expr::Alias(expr, ..) => Ok(self.create_physical_expr(
+ expr,
+ input_dfschema,
+ input_schema,
+ ctx_state,
+ )?),
+ Expr::Column(c) => {
+ let idx = input_dfschema.index_of_column(c)?;
+ Ok(Arc::new(Column::new(&c.name, idx)))
}
Expr::Literal(value) => Ok(Arc::new(Literal::new(value.clone()))),
Expr::ScalarVariable(variable_names) => {
@@ -535,8 +723,18 @@ impl DefaultPhysicalPlanner {
}
}
Expr::BinaryExpr { left, op, right } => {
- let lhs = self.create_physical_expr(left, input_schema, ctx_state)?;
- let rhs = self.create_physical_expr(right, input_schema, ctx_state)?;
+ let lhs = self.create_physical_expr(
+ left,
+ input_dfschema,
+ input_schema,
+ ctx_state,
+ )?;
+ let rhs = self.create_physical_expr(
+ right,
+ input_dfschema,
+ input_schema,
+ ctx_state,
+ )?;
binary(lhs, *op, rhs, input_schema)
}
Expr::Case {
@@ -548,6 +746,7 @@ impl DefaultPhysicalPlanner {
let expr: Option<Arc<dyn PhysicalExpr>> = if let Some(e) = expr {
Some(self.create_physical_expr(
e.as_ref(),
+ input_dfschema,
input_schema,
ctx_state,
)?)
@@ -557,13 +756,23 @@ impl DefaultPhysicalPlanner {
let when_expr = when_then_expr
.iter()
.map(|(w, _)| {
- self.create_physical_expr(w.as_ref(), input_schema, ctx_state)
+ self.create_physical_expr(
+ w.as_ref(),
+ input_dfschema,
+ input_schema,
+ ctx_state,
+ )
})
.collect::<Result<Vec<_>>>()?;
let then_expr = when_then_expr
.iter()
.map(|(_, t)| {
- self.create_physical_expr(t.as_ref(), input_schema, ctx_state)
+ self.create_physical_expr(
+ t.as_ref(),
+ input_dfschema,
+ input_schema,
+ ctx_state,
+ )
})
.collect::<Result<Vec<_>>>()?;
let when_then_expr: Vec<(Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>)> =
@@ -576,6 +785,7 @@ impl DefaultPhysicalPlanner {
{
Some(self.create_physical_expr(
e.as_ref(),
+ input_dfschema,
input_schema,
ctx_state,
)?)
@@ -589,35 +799,43 @@ impl DefaultPhysicalPlanner {
)?))
}
Expr::Cast { expr, data_type } => expressions::cast(
- self.create_physical_expr(expr, input_schema, ctx_state)?,
+ self.create_physical_expr(expr, input_dfschema, input_schema, ctx_state)?,
input_schema,
data_type.clone(),
),
Expr::TryCast { expr, data_type } => expressions::try_cast(
- self.create_physical_expr(expr, input_schema, ctx_state)?,
+ self.create_physical_expr(expr, input_dfschema, input_schema, ctx_state)?,
input_schema,
data_type.clone(),
),
Expr::Not(expr) => expressions::not(
- self.create_physical_expr(expr, input_schema, ctx_state)?,
+ self.create_physical_expr(expr, input_dfschema, input_schema, ctx_state)?,
input_schema,
),
Expr::Negative(expr) => expressions::negative(
- self.create_physical_expr(expr, input_schema, ctx_state)?,
+ self.create_physical_expr(expr, input_dfschema, input_schema, ctx_state)?,
input_schema,
),
Expr::IsNull(expr) => expressions::is_null(self.create_physical_expr(
expr,
+ input_dfschema,
input_schema,
ctx_state,
)?),
Expr::IsNotNull(expr) => expressions::is_not_null(
- self.create_physical_expr(expr, input_schema, ctx_state)?,
+ self.create_physical_expr(expr, input_dfschema, input_schema, ctx_state)?,
),
Expr::ScalarFunction { fun, args } => {
let physical_args = args
.iter()
- .map(|e| self.create_physical_expr(e, input_schema, ctx_state))
+ .map(|e| {
+ self.create_physical_expr(
+ e,
+ input_dfschema,
+ input_schema,
+ ctx_state,
+ )
+ })
.collect::<Result<Vec<_>>>()?;
functions::create_physical_expr(
fun,
@@ -631,6 +849,7 @@ impl DefaultPhysicalPlanner {
for e in args {
physical_args.push(self.create_physical_expr(
e,
+ input_dfschema,
input_schema,
ctx_state,
)?);
@@ -648,11 +867,24 @@ impl DefaultPhysicalPlanner {
low,
high,
} => {
- let value_expr =
- self.create_physical_expr(expr, input_schema, ctx_state)?;
- let low_expr = self.create_physical_expr(low, input_schema, ctx_state)?;
- let high_expr =
- self.create_physical_expr(high, input_schema, ctx_state)?;
+ let value_expr = self.create_physical_expr(
+ expr,
+ input_dfschema,
+ input_schema,
+ ctx_state,
+ )?;
+ let low_expr = self.create_physical_expr(
+ low,
+ input_dfschema,
+ input_schema,
+ ctx_state,
+ )?;
+ let high_expr = self.create_physical_expr(
+ high,
+ input_dfschema,
+ input_schema,
+ ctx_state,
+ )?;
// rewrite the between into the two binary operators
let binary_expr = binary(
@@ -677,44 +909,54 @@ impl DefaultPhysicalPlanner {
Ok(expressions::lit(ScalarValue::Boolean(None)))
}
_ => {
- let value_expr =
- self.create_physical_expr(expr, input_schema, ctx_state)?;
+ let value_expr = self.create_physical_expr(
+ expr,
+ input_dfschema,
+ input_schema,
+ ctx_state,
+ )?;
let value_expr_data_type = value_expr.data_type(input_schema)?;
- let list_exprs =
- list.iter()
- .map(|expr| match expr {
- Expr::Literal(ScalarValue::Utf8(None)) => self
- .create_physical_expr(expr, input_schema, ctx_state),
- _ => {
- let list_expr = self.create_physical_expr(
- expr,
+ let list_exprs = list
+ .iter()
+ .map(|expr| match expr {
+ Expr::Literal(ScalarValue::Utf8(None)) => self
+ .create_physical_expr(
+ expr,
+ input_dfschema,
+ input_schema,
+ ctx_state,
+ ),
+ _ => {
+ let list_expr = self.create_physical_expr(
+ expr,
+ input_dfschema,
+ input_schema,
+ ctx_state,
+ )?;
+ let list_expr_data_type =
+ list_expr.data_type(input_schema)?;
+
+ if list_expr_data_type == value_expr_data_type {
+ Ok(list_expr)
+ } else if can_cast_types(
+ &list_expr_data_type,
+ &value_expr_data_type,
+ ) {
+ expressions::cast(
+ list_expr,
input_schema,
- ctx_state,
- )?;
- let list_expr_data_type =
- list_expr.data_type(input_schema)?;
-
- if list_expr_data_type == value_expr_data_type {
- Ok(list_expr)
- } else if can_cast_types(
- &list_expr_data_type,
- &value_expr_data_type,
- ) {
- expressions::cast(
- list_expr,
- input_schema,
- value_expr.data_type(input_schema)?,
- )
- } else {
- Err(DataFusionError::Plan(format!(
- "Unsupported CAST from {:?} to {:?}",
- list_expr_data_type, value_expr_data_type
- )))
- }
+ value_expr.data_type(input_schema)?,
+ )
+ } else {
+ Err(DataFusionError::Plan(format!(
+ "Unsupported CAST from {:?} to {:?}",
+ list_expr_data_type, value_expr_data_type
+ )))
}
- })
- .collect::<Result<Vec<_>>>()?;
+ }
+ })
+ .collect::<Result<Vec<_>>>()?;
expressions::in_list(value_expr, list_exprs, negated)
}
@@ -731,6 +973,7 @@ impl DefaultPhysicalPlanner {
&self,
e: &Expr,
name: String,
+ logical_input_schema: &DFSchema,
physical_input_schema: &Schema,
ctx_state: &ExecutionContextState,
) -> Result<Arc<dyn WindowExpr>> {
@@ -745,13 +988,23 @@ impl DefaultPhysicalPlanner {
let args = args
.iter()
.map(|e| {
- self.create_physical_expr(e, physical_input_schema, ctx_state)
+ self.create_physical_expr(
+ e,
+ logical_input_schema,
+ physical_input_schema,
+ ctx_state,
+ )
})
.collect::<Result<Vec<_>>>()?;
let partition_by = partition_by
.iter()
.map(|e| {
- self.create_physical_expr(e, physical_input_schema, ctx_state)
+ self.create_physical_expr(
+ e,
+ logical_input_schema,
+ physical_input_schema,
+ ctx_state,
+ )
})
.collect::<Result<Vec<_>>>()?;
let order_by = order_by
@@ -763,6 +1016,7 @@ impl DefaultPhysicalPlanner {
nulls_first,
} => self.create_physical_sort_expr(
expr,
+ logical_input_schema,
physical_input_schema,
SortOptions {
descending: !*asc,
@@ -809,9 +1063,15 @@ impl DefaultPhysicalPlanner {
// unpack aliased logical expressions, e.g. "sum(col) over () as total"
let (name, e) = match e {
Expr::Alias(sub_expr, alias) => (alias.clone(), sub_expr.as_ref()),
- _ => (e.name(logical_input_schema)?, e),
+ _ => (physical_name(e, logical_input_schema)?, e),
};
- self.create_window_expr_with_name(e, name, physical_input_schema, ctx_state)
+ self.create_window_expr_with_name(
+ e,
+ name,
+ logical_input_schema,
+ physical_input_schema,
+ ctx_state,
+ )
}
/// Create an aggregate expression with a name from a logical expression
@@ -819,6 +1079,7 @@ impl DefaultPhysicalPlanner {
&self,
e: &Expr,
name: String,
+ logical_input_schema: &DFSchema,
physical_input_schema: &Schema,
ctx_state: &ExecutionContextState,
) -> Result<Arc<dyn AggregateExpr>> {
@@ -832,7 +1093,12 @@ impl DefaultPhysicalPlanner {
let args = args
.iter()
.map(|e| {
- self.create_physical_expr(e, physical_input_schema, ctx_state)
+ self.create_physical_expr(
+ e,
+ logical_input_schema,
+ physical_input_schema,
+ ctx_state,
+ )
})
.collect::<Result<Vec<_>>>()?;
aggregates::create_aggregate_expr(
@@ -847,7 +1113,12 @@ impl DefaultPhysicalPlanner {
let args = args
.iter()
.map(|e| {
- self.create_physical_expr(e, physical_input_schema, ctx_state)
+ self.create_physical_expr(
+ e,
+ logical_input_schema,
+ physical_input_schema,
+ ctx_state,
+ )
})
.collect::<Result<Vec<_>>>()?;
@@ -871,21 +1142,34 @@ impl DefaultPhysicalPlanner {
// unpack aliased logical expressions, e.g. "sum(col) as total"
let (name, e) = match e {
Expr::Alias(sub_expr, alias) => (alias.clone(), sub_expr.as_ref()),
- _ => (e.name(logical_input_schema)?, e),
+ _ => (physical_name(e, logical_input_schema)?, e),
};
- self.create_aggregate_expr_with_name(e, name, physical_input_schema, ctx_state)
+
+ self.create_aggregate_expr_with_name(
+ e,
+ name,
+ logical_input_schema,
+ physical_input_schema,
+ ctx_state,
+ )
}
/// Create a physical sort expression from a logical expression
pub fn create_physical_sort_expr(
&self,
e: &Expr,
+ input_dfschema: &DFSchema,
input_schema: &Schema,
options: SortOptions,
ctx_state: &ExecutionContextState,
) -> Result<PhysicalSortExpr> {
Ok(PhysicalSortExpr {
- expr: self.create_physical_expr(e, input_schema, ctx_state)?,
+ expr: self.create_physical_expr(
+ e,
+ input_dfschema,
+ input_schema,
+ ctx_state,
+ )?,
options,
})
}
@@ -913,6 +1197,7 @@ mod tests {
use arrow::datatypes::{DataType, Field, SchemaRef};
use async_trait::async_trait;
use fmt::Debug;
+ use std::convert::TryFrom;
use std::{any::Any, fmt};
fn make_ctx_state() -> ExecutionContextState {
@@ -945,7 +1230,7 @@ mod tests {
// verify that the plan correctly casts u8 to i64
// the cast here is implicit so has CastOptions with safe=true
- let expected = "BinaryExpr { left: Column { name: \"c7\" }, op: Lt, right: TryCastExpr { expr: Literal { value: UInt8(5) }, cast_type: Int64 } }";
+ let expected = "BinaryExpr { left: Column { name: \"c7\", index: 6 }, op: Lt, right: TryCastExpr { expr: Literal { value: UInt8(5) }, cast_type: Int64 } }";
assert!(format!("{:?}", plan).contains(expected));
Ok(())
@@ -954,12 +1239,17 @@ mod tests {
#[test]
fn test_create_not() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Boolean, true)]);
+ let dfschema = DFSchema::try_from(schema.clone())?;
let planner = DefaultPhysicalPlanner::default();
- let expr =
- planner.create_physical_expr(&col("a").not(), &schema, &make_ctx_state())?;
- let expected = expressions::not(expressions::col("a"), &schema)?;
+ let expr = planner.create_physical_expr(
+ &col("a").not(),
+ &dfschema,
+ &schema,
+ &make_ctx_state(),
+ )?;
+ let expected = expressions::not(expressions::col("a", &schema)?, &schema)?;
assert_eq!(format!("{:?}", expr), format!("{:?}", expected));
@@ -980,7 +1270,7 @@ mod tests {
// c12 is f64, c7 is u8 -> cast c7 to f64
// the cast here is implicit so has CastOptions with safe=true
- let expected = "predicate: BinaryExpr { left: TryCastExpr { expr: Column { name: \"c7\" }, cast_type: Float64 }, op: Lt, right: Column { name: \"c12\" } }";
+ let expected = "predicate: BinaryExpr { left: TryCastExpr { expr: Column { name: \"c7\", index: 6 }, cast_type: Float64 }, op: Lt, right: Column { name: \"c12\", index: 11 } }";
assert!(format!("{:?}", plan).contains(expected));
Ok(())
}
@@ -1105,8 +1395,7 @@ mod tests {
.build()?;
let execution_plan = plan(&logical_plan)?;
// verify that the plan correctly adds cast from Int64(1) to Utf8
- let expected = "InListExpr { expr: Column { name: \"c1\" }, list: [Literal { value: Utf8(\"a\") }, CastExpr { expr: Literal { value: Int64(1) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }], negated: false }";
- println!("{:?}", execution_plan);
+ let expected = "InListExpr { expr: Column { name: \"c1\", index: 0 }, list: [Literal { value: Utf8(\"a\") }, CastExpr { expr: Literal { value: Int64(1) }, cast_type: Utf8, cast_options: CastOptions { safe: false } }], negated: false }";
assert!(format!("{:?}", execution_plan).contains(expected));
// expression: "a in (true, 'a')"
diff --git a/datafusion/src/physical_plan/projection.rs b/datafusion/src/physical_plan/projection.rs
index d4c0459..5110e5b 100644
--- a/datafusion/src/physical_plan/projection.rs
+++ b/datafusion/src/physical_plan/projection.rs
@@ -233,8 +233,10 @@ mod tests {
)?;
// pick column c1 and name it column c1 in the output schema
- let projection =
- ProjectionExec::try_new(vec![(col("c1"), "c1".to_string())], Arc::new(csv))?;
+ let projection = ProjectionExec::try_new(
+ vec![(col("c1", &schema)?, "c1".to_string())],
+ Arc::new(csv),
+ )?;
let mut partition_count = 0;
let mut row_count = 0;
diff --git a/datafusion/src/physical_plan/repartition.rs b/datafusion/src/physical_plan/repartition.rs
index a7b17c4..e67e4c2 100644
--- a/datafusion/src/physical_plan/repartition.rs
+++ b/datafusion/src/physical_plan/repartition.rs
@@ -435,7 +435,7 @@ mod tests {
use super::*;
use crate::{
assert_batches_sorted_eq,
- physical_plan::memory::MemoryExec,
+ physical_plan::{expressions::col, memory::MemoryExec},
test::exec::{BarrierExec, ErrorExec, MockExec},
};
use arrow::datatypes::{DataType, Field, Schema};
@@ -513,12 +513,7 @@ mod tests {
let output_partitions = repartition(
&schema,
partitions,
- Partitioning::Hash(
- vec![Arc::new(crate::physical_plan::expressions::Column::new(
- "c0",
- ))],
- 8,
- ),
+ Partitioning::Hash(vec![col("c0", &schema)?], 8),
)
.await?;
@@ -761,6 +756,7 @@ mod tests {
partitioning: Partitioning::Hash(
vec![Arc::new(crate::physical_plan::expressions::Column::new(
"my_awesome_field",
+ 0,
))],
2,
),
diff --git a/datafusion/src/physical_plan/sort.rs b/datafusion/src/physical_plan/sort.rs
index 437519a..3650978 100644
--- a/datafusion/src/physical_plan/sort.rs
+++ b/datafusion/src/physical_plan/sort.rs
@@ -343,17 +343,17 @@ mod tests {
vec![
// c1 string column
PhysicalSortExpr {
- expr: col("c1"),
+ expr: col("c1", &schema)?,
options: SortOptions::default(),
},
// c2 uin32 column
PhysicalSortExpr {
- expr: col("c2"),
+ expr: col("c2", &schema)?,
options: SortOptions::default(),
},
// c7 uin8 column
PhysicalSortExpr {
- expr: col("c7"),
+ expr: col("c7", &schema)?,
options: SortOptions::default(),
},
],
@@ -417,14 +417,14 @@ mod tests {
let sort_exec = Arc::new(SortExec::try_new(
vec![
PhysicalSortExpr {
- expr: col("a"),
+ expr: col("a", &schema)?,
options: SortOptions {
descending: true,
nulls_first: true,
},
},
PhysicalSortExpr {
- expr: col("b"),
+ expr: col("b", &schema)?,
options: SortOptions {
descending: false,
nulls_first: false,
diff --git a/datafusion/src/physical_plan/sort_preserving_merge.rs b/datafusion/src/physical_plan/sort_preserving_merge.rs
index c39acc4..b8ca97c 100644
--- a/datafusion/src/physical_plan/sort_preserving_merge.rs
+++ b/datafusion/src/physical_plan/sort_preserving_merge.rs
@@ -579,21 +579,18 @@ mod tests {
let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
let schema = b1.schema();
+ let sort = vec![
+ PhysicalSortExpr {
+ expr: col("b", &schema).unwrap(),
+ options: Default::default(),
+ },
+ PhysicalSortExpr {
+ expr: col("c", &schema).unwrap(),
+ options: Default::default(),
+ },
+ ];
let exec = MemoryExec::try_new(&[vec![b1], vec![b2]], schema, None).unwrap();
- let merge = Arc::new(SortPreservingMergeExec::new(
- vec![
- PhysicalSortExpr {
- expr: col("b"),
- options: Default::default(),
- },
- PhysicalSortExpr {
- expr: col("c"),
- options: Default::default(),
- },
- ],
- Arc::new(exec),
- 1024,
- ));
+ let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec), 1024));
let collected = collect(merge).await.unwrap();
assert_eq!(collected.len(), 1);
@@ -668,18 +665,18 @@ mod tests {
let sort = vec![
PhysicalSortExpr {
- expr: col("c1"),
+ expr: col("c1", &schema).unwrap(),
options: SortOptions {
descending: true,
nulls_first: true,
},
},
PhysicalSortExpr {
- expr: col("c2"),
+ expr: col("c2", &schema).unwrap(),
options: Default::default(),
},
PhysicalSortExpr {
- expr: col("c7"),
+ expr: col("c7", &schema).unwrap(),
options: SortOptions::default(),
},
];
@@ -744,25 +741,26 @@ mod tests {
#[tokio::test]
async fn test_partition_sort_streaming_input() {
+ let schema = test::aggr_test_schema();
let sort = vec![
// uint8
PhysicalSortExpr {
- expr: col("c7"),
+ expr: col("c7", &schema).unwrap(),
options: Default::default(),
},
// int16
PhysicalSortExpr {
- expr: col("c4"),
+ expr: col("c4", &schema).unwrap(),
options: Default::default(),
},
// utf-8
PhysicalSortExpr {
- expr: col("c1"),
+ expr: col("c1", &schema).unwrap(),
options: SortOptions::default(),
},
// utf-8
PhysicalSortExpr {
- expr: col("c13"),
+ expr: col("c13", &schema).unwrap(),
options: SortOptions::default(),
},
];
@@ -782,15 +780,17 @@ mod tests {
#[tokio::test]
async fn test_partition_sort_streaming_input_output() {
+ let schema = test::aggr_test_schema();
+
let sort = vec![
// float64
PhysicalSortExpr {
- expr: col("c12"),
+ expr: col("c12", &schema).unwrap(),
options: Default::default(),
},
// utf-8
PhysicalSortExpr {
- expr: col("c13"),
+ expr: col("c13", &schema).unwrap(),
options: Default::default(),
},
];
@@ -850,27 +850,24 @@ mod tests {
let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b), ("c", c)]).unwrap();
let schema = b1.schema();
- let exec = MemoryExec::try_new(&[vec![b1], vec![b2]], schema, None).unwrap();
- let merge = Arc::new(SortPreservingMergeExec::new(
- vec![
- PhysicalSortExpr {
- expr: col("b"),
- options: SortOptions {
- descending: false,
- nulls_first: true,
- },
+ let sort = vec![
+ PhysicalSortExpr {
+ expr: col("b", &schema).unwrap(),
+ options: SortOptions {
+ descending: false,
+ nulls_first: true,
},
- PhysicalSortExpr {
- expr: col("c"),
- options: SortOptions {
- descending: false,
- nulls_first: false,
- },
+ },
+ PhysicalSortExpr {
+ expr: col("c", &schema).unwrap(),
+ options: SortOptions {
+ descending: false,
+ nulls_first: false,
},
- ],
- Arc::new(exec),
- 1024,
- ));
+ },
+ ];
+ let exec = MemoryExec::try_new(&[vec![b1], vec![b2]], schema, None).unwrap();
+ let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec), 1024));
let collected = collect(merge).await.unwrap();
assert_eq!(collected.len(), 1);
@@ -898,8 +895,9 @@ mod tests {
#[tokio::test]
async fn test_async() {
+ let schema = test::aggr_test_schema();
let sort = vec![PhysicalSortExpr {
- expr: col("c7"),
+ expr: col("c7", &schema).unwrap(),
options: SortOptions::default(),
}];
diff --git a/datafusion/src/physical_plan/type_coercion.rs b/datafusion/src/physical_plan/type_coercion.rs
index fe87ecd..ffd8f20 100644
--- a/datafusion/src/physical_plan/type_coercion.rs
+++ b/datafusion/src/physical_plan/type_coercion.rs
@@ -267,7 +267,9 @@ mod tests {
let expressions = |t: Vec<DataType>, schema| -> Result<Vec<_>> {
t.iter()
.enumerate()
- .map(|(i, t)| try_cast(col(&format!("c{}", i)), &schema, t.clone()))
+ .map(|(i, t)| {
+ try_cast(col(&format!("c{}", i), &schema)?, &schema, t.clone())
+ })
.collect::<Result<Vec<_>>>()
};
diff --git a/datafusion/src/physical_plan/windows.rs b/datafusion/src/physical_plan/windows.rs
index 466cc51..a214ef1 100644
--- a/datafusion/src/physical_plan/windows.rs
+++ b/datafusion/src/physical_plan/windows.rs
@@ -369,7 +369,7 @@ impl WindowAggExec {
input: Arc<dyn ExecutionPlan>,
input_schema: SchemaRef,
) -> Result<Self> {
- let schema = create_schema(&input.schema(), &window_expr)?;
+ let schema = create_schema(&input_schema, &window_expr)?;
let schema = Arc::new(schema);
Ok(WindowAggExec {
input,
@@ -599,7 +599,7 @@ mod tests {
vec![create_window_expr(
&WindowFunction::AggregateFunction(AggregateFunction::Count),
"count".to_owned(),
- &[col("c3")],
+ &[col("c3", &schema)?],
&[],
&[],
Some(WindowFrame::default()),
@@ -632,7 +632,7 @@ mod tests {
create_window_expr(
&WindowFunction::AggregateFunction(AggregateFunction::Count),
"count".to_owned(),
- &[col("c3")],
+ &[col("c3", &schema)?],
&[],
&[],
Some(WindowFrame::default()),
@@ -641,7 +641,7 @@ mod tests {
create_window_expr(
&WindowFunction::AggregateFunction(AggregateFunction::Max),
"max".to_owned(),
- &[col("c3")],
+ &[col("c3", &schema)?],
&[],
&[],
Some(WindowFrame::default()),
@@ -650,7 +650,7 @@ mod tests {
create_window_expr(
&WindowFunction::AggregateFunction(AggregateFunction::Min),
"min".to_owned(),
- &[col("c3")],
+ &[col("c3", &schema)?],
&[],
&[],
Some(WindowFrame::default()),
diff --git a/datafusion/src/prelude.rs b/datafusion/src/prelude.rs
index e1f1d7b..e7ad04e 100644
--- a/datafusion/src/prelude.rs
+++ b/datafusion/src/prelude.rs
@@ -32,6 +32,6 @@ pub use crate::logical_plan::{
count, create_udf, in_list, initcap, left, length, lit, lower, lpad, ltrim, max, md5,
min, now, octet_length, random, regexp_replace, repeat, replace, reverse, right,
rpad, rtrim, sha224, sha256, sha384, sha512, split_part, starts_with, strpos, substr,
- sum, to_hex, translate, trim, upper, JoinType, Partitioning,
+ sum, to_hex, translate, trim, upper, Column, JoinType, Partitioning,
};
pub use crate::physical_plan::csv::CsvReadOptions;
diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs
index 547e9af..7912241 100644
--- a/datafusion/src/sql/planner.rs
+++ b/datafusion/src/sql/planner.rs
@@ -17,13 +17,18 @@
//! SQL Query Planner (produces logical plan from SQL AST)
+use std::collections::HashSet;
+use std::str::FromStr;
+use std::sync::Arc;
+use std::{convert::TryInto, vec};
+
use crate::catalog::TableReference;
use crate::datasource::TableProvider;
use crate::logical_plan::window_frames::{WindowFrame, WindowFrameUnits};
use crate::logical_plan::Expr::Alias;
use crate::logical_plan::{
- and, lit, DFSchema, Expr, LogicalPlan, LogicalPlanBuilder, Operator, PlanType,
- StringifiedPlan, ToDFSchema,
+ and, lit, union_with_alias, Column, DFSchema, Expr, LogicalPlan, LogicalPlanBuilder,
+ Operator, PlanType, StringifiedPlan, ToDFSchema,
};
use crate::prelude::JoinType;
use crate::scalar::ScalarValue;
@@ -47,9 +52,6 @@ use sqlparser::ast::{
use sqlparser::ast::{ColumnDef as SQLColumnDef, ColumnOption};
use sqlparser::ast::{OrderByExpr, Statement};
use sqlparser::parser::ParserError::ParserError;
-use std::str::FromStr;
-use std::sync::Arc;
-use std::{convert::TryInto, vec};
use super::{
parser::DFParser,
@@ -163,29 +165,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
(SetOperator::Union, true) => {
let left_plan = self.set_expr_to_plan(left.as_ref(), None, ctes)?;
let right_plan = self.set_expr_to_plan(right.as_ref(), None, ctes)?;
- let inputs = vec![left_plan, right_plan]
- .into_iter()
- .flat_map(|p| match p {
- LogicalPlan::Union { inputs, .. } => inputs,
- x => vec![x],
- })
- .collect::<Vec<_>>();
- if inputs.is_empty() {
- return Err(DataFusionError::Plan(format!(
- "Empty UNION: {}",
- set_expr
- )));
- }
- if !inputs.iter().all(|s| s.schema() == inputs[0].schema()) {
- return Err(DataFusionError::Plan(
- "UNION ALL schemas are expected to be the same".to_string(),
- ));
- }
- Ok(LogicalPlan::Union {
- schema: inputs[0].schema().clone(),
- inputs,
- alias,
- })
+ union_with_alias(left_plan, right_plan, alias)
}
_ => Err(DataFusionError::NotImplemented(format!(
"Only UNION ALL is supported, found {}",
@@ -382,7 +362,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
) -> Result<LogicalPlan> {
match constraint {
JoinConstraint::On(sql_expr) => {
- let mut keys: Vec<(String, String)> = vec![];
+ let mut keys: Vec<(Column, Column)> = vec![];
let join_schema = left.schema().join(right.schema())?;
// parse ON expression
@@ -390,20 +370,21 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
// extract join keys
extract_join_keys(&expr, &mut keys)?;
- let left_keys: Vec<&str> =
- keys.iter().map(|pair| pair.0.as_str()).collect();
- let right_keys: Vec<&str> =
- keys.iter().map(|pair| pair.1.as_str()).collect();
+ let (left_keys, right_keys): (Vec<Column>, Vec<Column>) =
+ keys.into_iter().unzip();
// return the logical plan representing the join
LogicalPlanBuilder::from(left)
- .join(right, join_type, &left_keys, &right_keys)?
+ .join(right, join_type, left_keys, right_keys)?
.build()
}
JoinConstraint::Using(idents) => {
- let keys: Vec<&str> = idents.iter().map(|x| x.value.as_str()).collect();
+ let keys: Vec<Column> = idents
+ .iter()
+ .map(|x| Column::from_name(x.value.clone()))
+ .collect();
LogicalPlanBuilder::from(left)
- .join(right, join_type, &keys, &keys)?
+ .join_using(right, join_type, keys)?
.build()
}
JoinConstraint::Natural => {
@@ -489,37 +470,38 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
let mut possible_join_keys = vec![];
extract_possible_join_keys(&filter_expr, &mut possible_join_keys)?;
- let mut all_join_keys = vec![];
+ let mut all_join_keys = HashSet::new();
let mut left = plans[0].clone();
for right in plans.iter().skip(1) {
let left_schema = left.schema();
let right_schema = right.schema();
let mut join_keys = vec![];
for (l, r) in &possible_join_keys {
- if left_schema.field_with_unqualified_name(l).is_ok()
- && right_schema.field_with_unqualified_name(r).is_ok()
+ if left_schema.field_from_qualified_column(l).is_ok()
+ && right_schema.field_from_qualified_column(r).is_ok()
{
- join_keys.push((l.as_str(), r.as_str()));
- } else if left_schema.field_with_unqualified_name(r).is_ok()
- && right_schema.field_with_unqualified_name(l).is_ok()
+ join_keys.push((l.clone(), r.clone()));
+ } else if left_schema.field_from_qualified_column(r).is_ok()
+ && right_schema.field_from_qualified_column(l).is_ok()
{
- join_keys.push((r.as_str(), l.as_str()));
+ join_keys.push((r.clone(), l.clone()));
}
}
if join_keys.is_empty() {
left =
LogicalPlanBuilder::from(&left).cross_join(right)?.build()?;
} else {
- let left_keys: Vec<_> =
- join_keys.iter().map(|(l, _)| *l).collect();
- let right_keys: Vec<_> =
- join_keys.iter().map(|(_, r)| *r).collect();
+ let left_keys: Vec<Column> =
+ join_keys.iter().map(|(l, _)| l.clone()).collect();
+ let right_keys: Vec<Column> =
+ join_keys.iter().map(|(_, r)| r.clone()).collect();
let builder = LogicalPlanBuilder::from(&left);
left = builder
- .join(right, JoinType::Inner, &left_keys, &right_keys)?
+ .join(right, JoinType::Inner, left_keys, right_keys)?
.build()?;
}
- all_join_keys.extend_from_slice(&join_keys);
+
+ all_join_keys.extend(join_keys);
}
// remove join expressions from filter
@@ -548,12 +530,18 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
// The SELECT expressions, with wildcards expanded.
let select_exprs = self.prepare_select_exprs(&plan, &select.projection)?;
+ // having and group by clause may reference aliases defined in select projection
+ let projected_plan = self.project(&plan, select_exprs.clone())?;
+ let mut combined_schema = (**projected_plan.schema()).clone();
+ combined_schema.merge(plan.schema());
+
// Optionally the HAVING expression.
let having_expr_opt = select
.having
.as_ref()
.map::<Result<Expr>, _>(|having_expr| {
- let having_expr = self.sql_expr_to_logical_expr(having_expr)?;
+ let having_expr =
+ self.sql_expr_to_logical_expr(having_expr, &combined_schema)?;
// This step "dereferences" any aliases in the HAVING clause.
//
@@ -582,7 +570,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
// The outer expressions we will search through for
// aggregates. Aggregates may be sourced from the SELECT...
let mut aggr_expr_haystack = select_exprs.clone();
-
// ... or from the HAVING.
if let Some(having_expr) = &having_expr_opt {
aggr_expr_haystack.push(having_expr.clone());
@@ -596,7 +583,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
.group_by
.iter()
.map(|e| {
- let group_by_expr = self.sql_expr_to_logical_expr(e)?;
+ let group_by_expr = self.sql_expr_to_logical_expr(e, &combined_schema)?;
let group_by_expr = resolve_aliases_to_exprs(&group_by_expr, &alias_map)?;
let group_by_expr =
resolve_positions_to_exprs(&group_by_expr, &select_exprs)?;
@@ -816,16 +803,16 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
let order_by_rex = order_by
.iter()
- .map(|e| self.order_by_to_sort_expr(e))
+ .map(|e| self.order_by_to_sort_expr(e, plan.schema()))
.collect::<Result<Vec<_>>>()?;
LogicalPlanBuilder::from(plan).sort(order_by_rex)?.build()
}
/// convert sql OrderByExpr to Expr::Sort
- fn order_by_to_sort_expr(&self, e: &OrderByExpr) -> Result<Expr> {
+ fn order_by_to_sort_expr(&self, e: &OrderByExpr, schema: &DFSchema) -> Result<Expr> {
Ok(Expr::Sort {
- expr: Box::new(self.sql_expr_to_logical_expr(&e.expr)?),
+ expr: Box::new(self.sql_expr_to_logical_expr(&e.expr, schema)?),
// by default asc
asc: e.asc.unwrap_or(true),
// by default nulls first to be consistent with spark
@@ -842,11 +829,15 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
find_column_exprs(exprs)
.iter()
.try_for_each(|col| match col {
- Expr::Column(name) => {
- schema.field_with_unqualified_name(name).map_err(|_| {
+ Expr::Column(col) => {
+ match &col.relation {
+ Some(r) => schema.field_with_qualified_name(r, &col.name),
+ None => schema.field_with_unqualified_name(&col.name),
+ }
+ .map_err(|_| {
DataFusionError::Plan(format!(
"Invalid identifier '{}' for schema {}",
- name,
+ col,
schema.to_string()
))
})?;
@@ -873,19 +864,25 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
/// Generate a relational expression from a SQL expression
pub fn sql_to_rex(&self, sql: &SQLExpr, schema: &DFSchema) -> Result<Expr> {
- let expr = self.sql_expr_to_logical_expr(sql)?;
+ let expr = self.sql_expr_to_logical_expr(sql, schema)?;
self.validate_schema_satisfies_exprs(schema, &[expr.clone()])?;
Ok(expr)
}
- fn sql_fn_arg_to_logical_expr(&self, sql: &FunctionArg) -> Result<Expr> {
+ fn sql_fn_arg_to_logical_expr(
+ &self,
+ sql: &FunctionArg,
+ schema: &DFSchema,
+ ) -> Result<Expr> {
match sql {
- FunctionArg::Named { name: _, arg } => self.sql_expr_to_logical_expr(arg),
- FunctionArg::Unnamed(value) => self.sql_expr_to_logical_expr(value),
+ FunctionArg::Named { name: _, arg } => {
+ self.sql_expr_to_logical_expr(arg, schema)
+ }
+ FunctionArg::Unnamed(value) => self.sql_expr_to_logical_expr(value, schema),
}
}
- fn sql_expr_to_logical_expr(&self, sql: &SQLExpr) -> Result<Expr> {
+ fn sql_expr_to_logical_expr(&self, sql: &SQLExpr, schema: &DFSchema) -> Result<Expr> {
match sql {
SQLExpr::Value(Value::Number(n, _)) => match n.parse::<i64>() {
Ok(n) => Ok(lit(n)),
@@ -900,7 +897,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
fun: functions::BuiltinScalarFunction::DatePart,
args: vec![
Expr::Literal(ScalarValue::Utf8(Some(format!("{}", field)))),
- self.sql_expr_to_logical_expr(expr)?,
+ self.sql_expr_to_logical_expr(expr, schema)?,
],
}),
@@ -923,7 +920,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
let var_names = vec![id.value.clone()];
Ok(Expr::ScalarVariable(var_names))
} else {
- Ok(Expr::Column(id.value.to_string()))
+ Ok(Expr::Column(
+ schema
+ .field_with_unqualified_name(&id.value)?
+ .qualified_column(),
+ ))
}
}
@@ -934,6 +935,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
}
if &var_names[0][0..1] == "@" {
Ok(Expr::ScalarVariable(var_names))
+ } else if var_names.len() == 2 {
+ // table.column identifier
+ let name = var_names.pop().unwrap();
+ let relation = Some(var_names.pop().unwrap());
+ Ok(Expr::Column(Column { relation, name }))
} else {
Err(DataFusionError::NotImplemented(format!(
"Unsupported compound identifier '{:?}'",
@@ -951,20 +957,20 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
else_result,
} => {
let expr = if let Some(e) = operand {
- Some(Box::new(self.sql_expr_to_logical_expr(e)?))
+ Some(Box::new(self.sql_expr_to_logical_expr(e, schema)?))
} else {
None
};
let when_expr = conditions
.iter()
- .map(|e| self.sql_expr_to_logical_expr(e))
+ .map(|e| self.sql_expr_to_logical_expr(e, schema))
.collect::<Result<Vec<_>>>()?;
let then_expr = results
.iter()
- .map(|e| self.sql_expr_to_logical_expr(e))
+ .map(|e| self.sql_expr_to_logical_expr(e, schema))
.collect::<Result<Vec<_>>>()?;
let else_expr = if let Some(e) = else_result {
- Some(Box::new(self.sql_expr_to_logical_expr(e)?))
+ Some(Box::new(self.sql_expr_to_logical_expr(e, schema)?))
} else {
None
};
@@ -984,7 +990,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
ref expr,
ref data_type,
} => Ok(Expr::Cast {
- expr: Box::new(self.sql_expr_to_logical_expr(expr)?),
+ expr: Box::new(self.sql_expr_to_logical_expr(expr, schema)?),
data_type: convert_data_type(data_type)?,
}),
@@ -992,7 +998,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
ref expr,
ref data_type,
} => Ok(Expr::TryCast {
- expr: Box::new(self.sql_expr_to_logical_expr(expr)?),
+ expr: Box::new(self.sql_expr_to_logical_expr(expr, schema)?),
data_type: convert_data_type(data_type)?,
}),
@@ -1004,19 +1010,19 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
data_type: convert_data_type(data_type)?,
}),
- SQLExpr::IsNull(ref expr) => {
- Ok(Expr::IsNull(Box::new(self.sql_expr_to_logical_expr(expr)?)))
- }
+ SQLExpr::IsNull(ref expr) => Ok(Expr::IsNull(Box::new(
+ self.sql_expr_to_logical_expr(expr, schema)?,
+ ))),
SQLExpr::IsNotNull(ref expr) => Ok(Expr::IsNotNull(Box::new(
- self.sql_expr_to_logical_expr(expr)?,
+ self.sql_expr_to_logical_expr(expr, schema)?,
))),
SQLExpr::UnaryOp { ref op, ref expr } => match op {
- UnaryOperator::Not => {
- Ok(Expr::Not(Box::new(self.sql_expr_to_logical_expr(expr)?)))
- }
- UnaryOperator::Plus => Ok(self.sql_expr_to_logical_expr(expr)?),
+ UnaryOperator::Not => Ok(Expr::Not(Box::new(
+ self.sql_expr_to_logical_expr(expr, schema)?,
+ ))),
+ UnaryOperator::Plus => Ok(self.sql_expr_to_logical_expr(expr, schema)?),
UnaryOperator::Minus => {
match expr.as_ref() {
// optimization: if it's a number literal, we apply the negative operator
@@ -1032,7 +1038,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
})?)),
},
// not a literal, apply negative operator on expression
- _ => Ok(Expr::Negative(Box::new(self.sql_expr_to_logical_expr(expr)?))),
+ _ => Ok(Expr::Negative(Box::new(self.sql_expr_to_logical_expr(expr, schema)?))),
}
}
_ => Err(DataFusionError::NotImplemented(format!(
@@ -1047,10 +1053,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
ref low,
ref high,
} => Ok(Expr::Between {
- expr: Box::new(self.sql_expr_to_logical_expr(expr)?),
+ expr: Box::new(self.sql_expr_to_logical_expr(expr, schema)?),
negated: *negated,
- low: Box::new(self.sql_expr_to_logical_expr(low)?),
- high: Box::new(self.sql_expr_to_logical_expr(high)?),
+ low: Box::new(self.sql_expr_to_logical_expr(low, schema)?),
+ high: Box::new(self.sql_expr_to_logical_expr(high, schema)?),
}),
SQLExpr::InList {
@@ -1060,11 +1066,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
} => {
let list_expr = list
.iter()
- .map(|e| self.sql_expr_to_logical_expr(e))
+ .map(|e| self.sql_expr_to_logical_expr(e, schema))
.collect::<Result<Vec<_>>>()?;
Ok(Expr::InList {
- expr: Box::new(self.sql_expr_to_logical_expr(expr)?),
+ expr: Box::new(self.sql_expr_to_logical_expr(expr, schema)?),
list: list_expr,
negated: *negated,
})
@@ -1098,9 +1104,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
}?;
Ok(Expr::BinaryExpr {
- left: Box::new(self.sql_expr_to_logical_expr(left)?),
+ left: Box::new(self.sql_expr_to_logical_expr(left, schema)?),
op: operator,
- right: Box::new(self.sql_expr_to_logical_expr(right)?),
+ right: Box::new(self.sql_expr_to_logical_expr(right, schema)?),
})
}
@@ -1121,7 +1127,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
// first, scalar built-in
if let Ok(fun) = functions::BuiltinScalarFunction::from_str(&name) {
- let args = self.function_args_to_expr(function)?;
+ let args = self.function_args_to_expr(function, schema)?;
return Ok(Expr::ScalarFunction { fun, args });
};
@@ -1131,12 +1137,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
let partition_by = window
.partition_by
.iter()
- .map(|e| self.sql_expr_to_logical_expr(e))
+ .map(|e| self.sql_expr_to_logical_expr(e, schema))
.collect::<Result<Vec<_>>>()?;
let order_by = window
.order_by
.iter()
- .map(|e| self.order_by_to_sort_expr(e))
+ .map(|e| self.order_by_to_sort_expr(e, schema))
.collect::<Result<Vec<_>>>()?;
let window_frame = window
.window_frame
@@ -1163,8 +1169,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
fun: window_functions::WindowFunction::AggregateFunction(
aggregate_fun.clone(),
),
- args: self
- .aggregate_fn_to_expr(&aggregate_fun, function)?,
+ args: self.aggregate_fn_to_expr(
+ &aggregate_fun,
+ function,
+ schema,
+ )?,
partition_by,
order_by,
window_frame,
@@ -1177,7 +1186,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
fun: window_functions::WindowFunction::BuiltInWindowFunction(
window_fun,
),
- args: self.function_args_to_expr(function)?,
+ args:self.function_args_to_expr(function, schema)?,
partition_by,
order_by,
window_frame,
@@ -1188,7 +1197,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
// next, aggregate built-ins
if let Ok(fun) = aggregates::AggregateFunction::from_str(&name) {
- let args = self.aggregate_fn_to_expr(&fun, function)?;
+ let args = self.aggregate_fn_to_expr(&fun, function, schema)?;
return Ok(Expr::AggregateFunction {
fun,
distinct: function.distinct,
@@ -1199,13 +1208,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
// finally, user-defined functions (UDF) and UDAF
match self.schema_provider.get_function_meta(&name) {
Some(fm) => {
- let args = self.function_args_to_expr(function)?;
+ let args = self.function_args_to_expr(function, schema)?;
Ok(Expr::ScalarUDF { fun: fm, args })
}
None => match self.schema_provider.get_aggregate_meta(&name) {
Some(fm) => {
- let args = self.function_args_to_expr(function)?;
+ let args = self.function_args_to_expr(function, schema)?;
Ok(Expr::AggregateUDF { fun: fm, args })
}
_ => Err(DataFusionError::Plan(format!(
@@ -1216,7 +1225,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
}
}
- SQLExpr::Nested(e) => self.sql_expr_to_logical_expr(e),
+ SQLExpr::Nested(e) => self.sql_expr_to_logical_expr(e, schema),
_ => Err(DataFusionError::NotImplemented(format!(
"Unsupported ast node {:?} in sqltorel",
@@ -1228,11 +1237,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
fn function_args_to_expr(
&self,
function: &sqlparser::ast::Function,
+ schema: &DFSchema,
) -> Result<Vec<Expr>> {
function
.args
.iter()
- .map(|a| self.sql_fn_arg_to_logical_expr(a))
+ .map(|a| self.sql_fn_arg_to_logical_expr(a, schema))
.collect::<Result<Vec<Expr>>>()
}
@@ -1240,6 +1250,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
&self,
fun: &aggregates::AggregateFunction,
function: &sqlparser::ast::Function,
+ schema: &DFSchema,
) -> Result<Vec<Expr>> {
if *fun == aggregates::AggregateFunction::Count {
function
@@ -1250,11 +1261,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
Ok(lit(1_u8))
}
FunctionArg::Unnamed(SQLExpr::Wildcard) => Ok(lit(1_u8)),
- _ => self.sql_fn_arg_to_logical_expr(a),
+ _ => self.sql_fn_arg_to_logical_expr(a, schema),
})
.collect::<Result<Vec<Expr>>>()
} else {
- self.function_args_to_expr(function)
+ self.function_args_to_expr(function, schema)
}
}
@@ -1519,13 +1530,15 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
/// Remove join expressions from a filter expression
fn remove_join_expressions(
expr: &Expr,
- join_columns: &[(&str, &str)],
+ join_columns: &HashSet<(Column, Column)>,
) -> Result<Option<Expr>> {
match expr {
Expr::BinaryExpr { left, op, right } => match op {
Operator::Eq => match (left.as_ref(), right.as_ref()) {
(Expr::Column(l), Expr::Column(r)) => {
- if join_columns.contains(&(l, r)) || join_columns.contains(&(r, l)) {
+ if join_columns.contains(&(l.clone(), r.clone()))
+ || join_columns.contains(&(r.clone(), l.clone()))
+ {
Ok(None)
} else {
Ok(Some(expr.clone()))
@@ -1556,12 +1569,12 @@ fn remove_join_expressions(
/// foo = bar
/// foo = bar AND bar = baz AND ...
///
-fn extract_join_keys(expr: &Expr, accum: &mut Vec<(String, String)>) -> Result<()> {
+fn extract_join_keys(expr: &Expr, accum: &mut Vec<(Column, Column)>) -> Result<()> {
match expr {
Expr::BinaryExpr { left, op, right } => match op {
Operator::Eq => match (left.as_ref(), right.as_ref()) {
(Expr::Column(l), Expr::Column(r)) => {
- accum.push((l.to_owned(), r.to_owned()));
+ accum.push((l.clone(), r.clone()));
Ok(())
}
other => Err(DataFusionError::SQL(ParserError(format!(
@@ -1588,13 +1601,13 @@ fn extract_join_keys(expr: &Expr, accum: &mut Vec<(String, String)>) -> Result<(
/// Extract join keys from a WHERE clause
fn extract_possible_join_keys(
expr: &Expr,
- accum: &mut Vec<(String, String)>,
+ accum: &mut Vec<(Column, Column)>,
) -> Result<()> {
match expr {
Expr::BinaryExpr { left, op, right } => match op {
Operator::Eq => match (left.as_ref(), right.as_ref()) {
(Expr::Column(l), Expr::Column(r)) => {
- accum.push((l.to_owned(), r.to_owned()));
+ accum.push((l.clone(), r.clone()));
Ok(())
}
_ => Ok(()),
@@ -1635,9 +1648,6 @@ mod tests {
use crate::{logical_plan::create_udf, sql::parser::DFParser};
use functions::ScalarFunctionImplementation;
- const PERSON_COLUMN_NAMES: &str =
- "id, first_name, last_name, age, state, salary, birth_date, 😀";
-
#[test]
fn select_no_relation() {
quick_test(
@@ -1651,13 +1661,10 @@ mod tests {
fn select_column_does_not_exist() {
let sql = "SELECT doesnotexist FROM person";
let err = logical_plan(sql).expect_err("query should have failed");
- assert_eq!(
- format!(
- r#"Plan("Invalid identifier 'doesnotexist' for schema {}")"#,
- PERSON_COLUMN_NAMES
- ),
- format!("{:?}", err)
- );
+ assert!(matches!(
+ err,
+ DataFusionError::Plan(msg) if msg == "No field with unqualified name 'doesnotexist'",
+ ));
}
#[test]
@@ -1665,7 +1672,7 @@ mod tests {
let sql = "SELECT age, age FROM person";
let err = logical_plan(sql).expect_err("query should have failed");
assert_eq!(
- r##"Plan("Projections require unique expression names but the expression \"#age\" at position 0 and \"#age\" at position 1 have the same name. Consider aliasing (\"AS\") one of them.")"##,
+ r##"Plan("Projections require unique expression names but the expression \"#person.age\" at position 0 and \"#person.age\" at position 1 have the same name. Consider aliasing (\"AS\") one of them.")"##,
format!("{:?}", err)
);
}
@@ -1675,7 +1682,7 @@ mod tests {
let sql = "SELECT *, age FROM person";
let err = logical_plan(sql).expect_err("query should have failed");
assert_eq!(
- r##"Plan("Projections require unique expression names but the expression \"#age\" at position 3 and \"#age\" at position 8 have the same name. Consider aliasing (\"AS\") one of them.")"##,
+ r##"Plan("Projections require unique expression names but the expression \"#person.age\" at position 3 and \"#person.age\" at position 8 have the same name. Consider aliasing (\"AS\") one of them.")"##,
format!("{:?}", err)
);
}
@@ -1684,7 +1691,7 @@ mod tests {
fn select_wildcard_with_repeated_column_but_is_aliased() {
quick_test(
"SELECT *, first_name AS fn from person",
- "Projection: #id, #first_name, #last_name, #age, #state, #salary, #birth_date, #😀, #first_name AS fn\
+ "Projection: #person.id, #person.first_name, #person.last_name, #person.age, #person.state, #person.salary, #person.birth_date, #person.😀, #person.first_name AS fn\
\n TableScan: person projection=None",
);
}
@@ -1702,8 +1709,8 @@ mod tests {
fn select_simple_filter() {
let sql = "SELECT id, first_name, last_name \
FROM person WHERE state = 'CO'";
- let expected = "Projection: #id, #first_name, #last_name\
- \n Filter: #state Eq Utf8(\"CO\")\
+ let expected = "Projection: #person.id, #person.first_name, #person.last_name\
+ \n Filter: #person.state Eq Utf8(\"CO\")\
\n TableScan: person projection=None";
quick_test(sql, expected);
}
@@ -1712,34 +1719,28 @@ mod tests {
fn select_filter_column_does_not_exist() {
let sql = "SELECT first_name FROM person WHERE doesnotexist = 'A'";
let err = logical_plan(sql).expect_err("query should have failed");
- assert_eq!(
- format!(
- r#"Plan("Invalid identifier 'doesnotexist' for schema {}")"#,
- PERSON_COLUMN_NAMES
- ),
- format!("{:?}", err)
- );
+ assert!(matches!(
+ err,
+ DataFusionError::Plan(msg) if msg == "No field with unqualified name 'doesnotexist'",
+ ));
}
#[test]
fn select_filter_cannot_use_alias() {
let sql = "SELECT first_name AS x FROM person WHERE x = 'A'";
let err = logical_plan(sql).expect_err("query should have failed");
- assert_eq!(
- format!(
- r#"Plan("Invalid identifier 'x' for schema {}")"#,
- PERSON_COLUMN_NAMES
- ),
- format!("{:?}", err)
- );
+ assert!(matches!(
+ err,
+ DataFusionError::Plan(msg) if msg == "No field with unqualified name 'x'",
+ ));
}
#[test]
fn select_neg_filter() {
let sql = "SELECT id, first_name, last_name \
FROM person WHERE NOT state";
- let expected = "Projection: #id, #first_name, #last_name\
- \n Filter: NOT #state\
+ let expected = "Projection: #person.id, #person.first_name, #person.last_name\
+ \n Filter: NOT #person.state\
\n TableScan: person projection=None";
quick_test(sql, expected);
}
@@ -1748,8 +1749,8 @@ mod tests {
fn select_compound_filter() {
let sql = "SELECT id, first_name, last_name \
FROM person WHERE state = 'CO' AND age >= 21 AND age <= 65";
- let expected = "Projection: #id, #first_name, #last_name\
- \n Filter: #state Eq Utf8(\"CO\") And #age GtEq Int64(21) And #age LtEq Int64(65)\
+ let expected = "Projection: #person.id, #person.first_name, #person.last_name\
+ \n Filter: #person.state Eq Utf8(\"CO\") And #person.age GtEq Int64(21) And #person.age LtEq Int64(65)\
\n TableScan: person projection=None";
quick_test(sql, expected);
}
@@ -1759,8 +1760,8 @@ mod tests {
let sql =
"SELECT state FROM person WHERE birth_date < CAST (158412331400600000 as timestamp)";
- let expected = "Projection: #state\
- \n Filter: #birth_date Lt CAST(Int64(158412331400600000) AS Timestamp(Nanosecond, None))\
+ let expected = "Projection: #person.state\
+ \n Filter: #person.birth_date Lt CAST(Int64(158412331400600000) AS Timestamp(Nanosecond, None))\
\n TableScan: person projection=None";
quick_test(sql, expected);
@@ -1771,8 +1772,8 @@ mod tests {
let sql =
"SELECT state FROM person WHERE birth_date < CAST ('2020-01-01' as date)";
- let expected = "Projection: #state\
- \n Filter: #birth_date Lt CAST(Utf8(\"2020-01-01\") AS Date32)\
+ let expected = "Projection: #person.state\
+ \n Filter: #person.birth_date Lt CAST(Utf8(\"2020-01-01\") AS Date32)\
\n TableScan: person projection=None";
quick_test(sql, expected);
@@ -1788,13 +1789,13 @@ mod tests {
AND age >= 21 \
AND age < 65 \
AND age <= 65";
- let expected = "Projection: #age, #first_name, #last_name\
- \n Filter: #age Eq Int64(21) \
- And #age NotEq Int64(21) \
- And #age Gt Int64(21) \
- And #age GtEq Int64(21) \
- And #age Lt Int64(65) \
- And #age LtEq Int64(65)\
+ let expected = "Projection: #person.age, #person.first_name, #person.last_name\
+ \n Filter: #person.age Eq Int64(21) \
+ And #person.age NotEq Int64(21) \
+ And #person.age Gt Int64(21) \
+ And #person.age GtEq Int64(21) \
+ And #person.age Lt Int64(65) \
+ And #person.age LtEq Int64(65)\
\n TableScan: person projection=None";
quick_test(sql, expected);
}
@@ -1802,8 +1803,8 @@ mod tests {
#[test]
fn select_between() {
let sql = "SELECT state FROM person WHERE age BETWEEN 21 AND 65";
- let expected = "Projection: #state\
- \n Filter: #age BETWEEN Int64(21) AND Int64(65)\
+ let expected = "Projection: #person.state\
+ \n Filter: #person.age BETWEEN Int64(21) AND Int64(65)\
\n TableScan: person projection=None";
quick_test(sql, expected);
@@ -1812,8 +1813,8 @@ mod tests {
#[test]
fn select_between_negated() {
let sql = "SELECT state FROM person WHERE age NOT BETWEEN 21 AND 65";
- let expected = "Projection: #state\
- \n Filter: #age NOT BETWEEN Int64(21) AND Int64(65)\
+ let expected = "Projection: #person.state\
+ \n Filter: #person.age NOT BETWEEN Int64(21) AND Int64(65)\
\n TableScan: person projection=None";
quick_test(sql, expected);
@@ -1829,9 +1830,9 @@ mod tests {
FROM person
)
)";
- let expected = "Projection: #fn2, #last_name\
- \n Projection: #fn1 AS fn2, #last_name, #birth_date\
- \n Projection: #first_name AS fn1, #last_name, #birth_date, #age\
+ let expected = "Projection: #fn2, #person.last_name\
+ \n Projection: #fn1 AS fn2, #person.last_name, #person.birth_date\
+ \n Projection: #person.first_name AS fn1, #person.last_name, #person.birth_date, #person.age\
\n TableScan: person projection=None";
quick_test(sql, expected);
}
@@ -1846,10 +1847,10 @@ mod tests {
)
WHERE fn1 = 'X' AND age < 30";
- let expected = "Projection: #fn1, #age\
- \n Filter: #fn1 Eq Utf8(\"X\") And #age Lt Int64(30)\
- \n Projection: #first_name AS fn1, #age\
- \n Filter: #age Gt Int64(20)\
+ let expected = "Projection: #fn1, #person.age\
+ \n Filter: #fn1 Eq Utf8(\"X\") And #person.age Lt Int64(30)\
+ \n Projection: #person.first_name AS fn1, #person.age\
+ \n Filter: #person.age Gt Int64(20)\
\n TableScan: person projection=None";
quick_test(sql, expected);
@@ -1860,8 +1861,8 @@ mod tests {
let sql = "SELECT id, age
FROM person
HAVING age > 100 AND age < 200";
- let expected = "Projection: #id, #age\
- \n Filter: #age Gt Int64(100) And #age Lt Int64(200)\
+ let expected = "Projection: #person.id, #person.age\
+ \n Filter: #person.age Gt Int64(100) And #person.age Lt Int64(200)\
\n TableScan: person projection=None";
quick_test(sql, expected);
}
@@ -1907,9 +1908,9 @@ mod tests {
let sql = "SELECT MAX(age)
FROM person
HAVING MAX(age) < 30";
- let expected = "Projection: #MAX(age)\
- \n Filter: #MAX(age) Lt Int64(30)\
- \n Aggregate: groupBy=[[]], aggr=[[MAX(#age)]]\
+ let expected = "Projection: #MAX(person.age)\
+ \n Filter: #MAX(person.age) Lt Int64(30)\
+ \n Aggregate: groupBy=[[]], aggr=[[MAX(#person.age)]]\
\n TableScan: person projection=None";
quick_test(sql, expected);
}
@@ -1919,9 +1920,9 @@ mod tests {
let sql = "SELECT MAX(age)
FROM person
HAVING MAX(first_name) > 'M'";
- let expected = "Projection: #MAX(age)\
- \n Filter: #MAX(first_name) Gt Utf8(\"M\")\
- \n Aggregate: groupBy=[[]], aggr=[[MAX(#age), MAX(#first_name)]]\
+ let expected = "Projection: #MAX(person.age)\
+ \n Filter: #MAX(person.first_name) Gt Utf8(\"M\")\
+ \n Aggregate: groupBy=[[]], aggr=[[MAX(#person.age), MAX(#person.first_name)]]\
\n TableScan: person projection=None";
quick_test(sql, expected);
}
@@ -1943,9 +1944,10 @@ mod tests {
let sql = "SELECT MAX(age) as max_age
FROM person
HAVING max_age < 30";
- let expected = "Projection: #MAX(age) AS max_age\
- \n Filter: #MAX(age) Lt Int64(30)\
- \n Aggregate: groupBy=[[]], aggr=[[MAX(#age)]]\
+ // FIXME: add test for having in execution
+ let expected = "Projection: #MAX(person.age) AS max_age\
+ \n Filter: #MAX(person.age) Lt Int64(30)\
+ \n Aggregate: groupBy=[[]], aggr=[[MAX(#person.age)]]\
\n TableScan: person projection=None";
quick_test(sql, expected);
}
@@ -1955,9 +1957,9 @@ mod tests {
let sql = "SELECT MAX(age) as max_age
FROM person
HAVING MAX(age) < 30";
- let expected = "Projection: #MAX(age) AS max_age\
- \n Filter: #MAX(age) Lt Int64(30)\
- \n Aggregate: groupBy=[[]], aggr=[[MAX(#age)]]\
+ let expected = "Projection: #MAX(person.age) AS max_age\
+ \n Filter: #MAX(person.age) Lt Int64(30)\
+ \n Aggregate: groupBy=[[]], aggr=[[MAX(#person.age)]]\
\n TableScan: person projection=None";
quick_test(sql, expected);
}
@@ -1968,9 +1970,9 @@ mod tests {
FROM person
GROUP BY first_name
HAVING first_name = 'M'";
- let expected = "Projection: #first_name, #MAX(age)\
- \n Filter: #first_name Eq Utf8(\"M\")\
- \n Aggregate: groupBy=[[#first_name]], aggr=[[MAX(#age)]]\
+ let expected = "Projection: #person.first_name, #MAX(person.age)\
+ \n Filter: #person.first_name Eq Utf8(\"M\")\
+ \n Aggregate: groupBy=[[#person.first_name]], aggr=[[MAX(#person.age)]]\
\n TableScan: person projection=None";
quick_test(sql, expected);
}
@@ -1982,10 +1984,10 @@ mod tests {
WHERE id > 5
GROUP BY first_name
HAVING MAX(age) < 100";
- let expected = "Projection: #first_name, #MAX(age)\
- \n Filter: #MAX(age) Lt Int64(100)\
- \n Aggregate: groupBy=[[#first_name]], aggr=[[MAX(#age)]]\
- \n Filter: #id Gt Int64(5)\
+ let expected = "Projection: #person.first_name, #MAX(person.age)\
+ \n Filter: #MAX(person.age) Lt Int64(100)\
+ \n Aggregate: groupBy=[[#person.first_name]], aggr=[[MAX(#person.age)]]\
+ \n Filter: #person.id Gt Int64(5)\
\n TableScan: person projection=None";
quick_test(sql, expected);
}
@@ -1998,10 +2000,10 @@ mod tests {
WHERE id > 5 AND age > 18
GROUP BY first_name
HAVING MAX(age) < 100";
- let expected = "Projection: #first_name, #MAX(age)\
- \n Filter: #MAX(age) Lt Int64(100)\
- \n Aggregate: groupBy=[[#first_name]], aggr=[[MAX(#age)]]\
- \n Filter: #id Gt Int64(5) And #age Gt Int64(18)\
+ let expected = "Projection: #person.first_name, #MAX(person.age)\
+ \n Filter: #MAX(person.age) Lt Int64(100)\
+ \n Aggregate: groupBy=[[#person.first_name]], aggr=[[MAX(#person.age)]]\
+ \n Filter: #person.id Gt Int64(5) And #person.age Gt Int64(18)\
\n TableScan: person projection=None";
quick_test(sql, expected);
}
@@ -2012,9 +2014,9 @@ mod tests {
FROM person
GROUP BY first_name
HAVING MAX(age) > 2 AND fn = 'M'";
- let expected = "Projection: #first_name AS fn, #MAX(age)\
- \n Filter: #MAX(age) Gt Int64(2) And #first_name Eq Utf8(\"M\")\
- \n Aggregate: groupBy=[[#first_name]], aggr=[[MAX(#age)]]\
+ let expected = "Projection: #person.first_name AS fn, #MAX(person.age)\
+ \n Filter: #MAX(person.age) Gt Int64(2) And #person.first_name Eq Utf8(\"M\")\
+ \n Aggregate: groupBy=[[#person.first_name]], aggr=[[MAX(#person.age)]]\
\n TableScan: person projection=None";
quick_test(sql, expected);
}
@@ -2026,9 +2028,9 @@ mod tests {
FROM person
GROUP BY first_name
HAVING MAX(age) > 2 AND max_age < 5 AND first_name = 'M' AND fn = 'N'";
- let expected = "Projection: #first_name AS fn, #MAX(age) AS max_age\
- \n Filter: #MAX(age) Gt Int64(2) And #MAX(age) Lt Int64(5) And #first_name Eq Utf8(\"M\") And #first_name Eq Utf8(\"N\")\
- \n Aggregate: groupBy=[[#first_name]], aggr=[[MAX(#age)]]\
+ let expected = "Projection: #person.first_name AS fn, #MAX(person.age) AS max_age\
+ \n Filter: #MAX(person.age) Gt Int64(2) And #MAX(person.age) Lt Int64(5) And #person.first_name Eq Utf8(\"M\") And #person.first_name Eq Utf8(\"N\")\
+ \n Aggregate: groupBy=[[#person.first_name]], aggr=[[MAX(#person.age)]]\
\n TableScan: person projection=None";
quick_test(sql, expected);
}
@@ -2039,9 +2041,9 @@ mod tests {
FROM person
GROUP BY first_name
HAVING MAX(age) > 100";
- let expected = "Projection: #first_name, #MAX(age)\
- \n Filter: #MAX(age) Gt Int64(100)\
- \n Aggregate: groupBy=[[#first_name]], aggr=[[MAX(#age)]]\
+ let expected = "Projection: #person.first_name, #MAX(person.age)\
+ \n Filter: #MAX(person.age) Gt Int64(100)\
+ \n Aggregate: groupBy=[[#person.first_name]], aggr=[[MAX(#person.age)]]\
\n TableScan: person projection=None";
quick_test(sql, expected);
}
@@ -2065,9 +2067,9 @@ mod tests {
FROM person
GROUP BY first_name
HAVING MAX(age) > 100 AND MAX(age) < 200";
- let expected = "Projection: #first_name, #MAX(age)\
- \n Filter: #MAX(age) Gt Int64(100) And #MAX(age) Lt Int64(200)\
- \n Aggregate: groupBy=[[#first_name]], aggr=[[MAX(#age)]]\
+ let expected = "Projection: #person.first_name, #MAX(person.age)\
+ \n Filter: #MAX(person.age) Gt Int64(100) And #MAX(person.age) Lt Int64(200)\
+ \n Aggregate: groupBy=[[#person.first_name]], aggr=[[MAX(#person.age)]]\
\n TableScan: person projection=None";
quick_test(sql, expected);
}
@@ -2078,9 +2080,9 @@ mod tests {
FROM person
GROUP BY first_name
HAVING MAX(age) > 100 AND MIN(id) < 50";
- let expected = "Projection: #first_name, #MAX(age)\
- \n Filter: #MAX(age) Gt Int64(100) And #MIN(id) Lt Int64(50)\
- \n Aggregate: groupBy=[[#first_name]], aggr=[[MAX(#age), MIN(#id)]]\
+ let expected = "Projection: #person.first_name, #MAX(person.age)\
+ \n Filter: #MAX(person.age) Gt Int64(100) And #MIN(person.id) Lt Int64(50)\
+ \n Aggregate: groupBy=[[#person.first_name]], aggr=[[MAX(#person.age), MIN(#person.id)]]\
\n TableScan: person projection=None";
quick_test(sql, expected);
}
@@ -2092,9 +2094,9 @@ mod tests {
FROM person
GROUP BY first_name
HAVING max_age > 100";
- let expected = "Projection: #first_name, #MAX(age) AS max_age\
- \n Filter: #MAX(age) Gt Int64(100)\
- \n Aggregate: groupBy=[[#first_name]], aggr=[[MAX(#age)]]\
+ let expected = "Projection: #person.first_name, #MAX(person.age) AS max_age\
+ \n Filter: #MAX(person.age) Gt Int64(100)\
+ \n Aggregate: groupBy=[[#person.first_name]], aggr=[[MAX(#person.age)]]\
\n TableScan: person projection=None";
quick_test(sql, expected);
}
@@ -2107,9 +2109,9 @@ mod tests {
GROUP BY first_name
HAVING max_age_plus_one > 100";
let expected =
- "Projection: #first_name, #MAX(age) Plus Int64(1) AS max_age_plus_one\
- \n Filter: #MAX(age) Plus Int64(1) Gt Int64(100)\
- \n Aggregate: groupBy=[[#first_name]], aggr=[[MAX(#age)]]\
+ "Projection: #person.first_name, #MAX(person.age) Plus Int64(1) AS max_age_plus_one\
+ \n Filter: #MAX(person.age) Plus Int64(1) Gt Int64(100)\
+ \n Aggregate: groupBy=[[#person.first_name]], aggr=[[MAX(#person.age)]]\
\n TableScan: person projection=None";
quick_test(sql, expected);
}
@@ -2121,9 +2123,9 @@ mod tests {
FROM person
GROUP BY first_name
HAVING MAX(age) > 100 AND MIN(id - 2) < 50";
- let expected = "Projection: #first_name, #MAX(age)\
- \n Filter: #MAX(age) Gt Int64(100) And #MIN(id Minus Int64(2)) Lt Int64(50)\
- \n Aggregate: groupBy=[[#first_name]], aggr=[[MAX(#age), MIN(#id Minus Int64(2))]]\
+ let expected = "Projection: #person.first_name, #MAX(person.age)\
+ \n Filter: #MAX(person.age) Gt Int64(100) And #MIN(person.id Minus Int64(2)) Lt Int64(50)\
+ \n Aggregate: groupBy=[[#person.first_name]], aggr=[[MAX(#person.age), MIN(#person.id Minus Int64(2))]]\
\n TableScan: person projection=None";
quick_test(sql, expected);
}
@@ -2134,9 +2136,9 @@ mod tests {
FROM person
GROUP BY first_name
HAVING MAX(age) > 100 AND COUNT(*) < 50";
- let expected = "Projection: #first_name, #MAX(age)\
- \n Filter: #MAX(age) Gt Int64(100) And #COUNT(UInt8(1)) Lt Int64(50)\
- \n Aggregate: groupBy=[[#first_name]], aggr=[[MAX(#age), COUNT(UInt8(1))]]\
+ let expected = "Projection: #person.first_name, #MAX(person.age)\
+ \n Filter: #MAX(person.age) Gt Int64(100) And #COUNT(UInt8(1)) Lt Int64(50)\
+ \n Aggregate: groupBy=[[#person.first_name]], aggr=[[MAX(#person.age), COUNT(UInt8(1))]]\
\n TableScan: person projection=None";
quick_test(sql, expected);
}
@@ -2144,7 +2146,7 @@ mod tests {
#[test]
fn select_binary_expr() {
let sql = "SELECT age + salary from person";
- let expected = "Projection: #age Plus #salary\
+ let expected = "Projection: #person.age Plus #person.salary\
\n TableScan: person projection=None";
quick_test(sql, expected);
}
@@ -2152,7 +2154,7 @@ mod tests {
#[test]
fn select_binary_expr_nested() {
let sql = "SELECT (age + salary)/2 from person";
- let expected = "Projection: #age Plus #salary Divide Int64(2)\
+ let expected = "Projection: #person.age Plus #person.salary Divide Int64(2)\
\n TableScan: person projection=None";
quick_test(sql, expected);
}
@@ -2161,15 +2163,15 @@ mod tests {
fn select_wildcard_with_groupby() {
quick_test(
r#"SELECT * FROM person GROUP BY id, first_name, last_name, age, state, salary, birth_date, "😀""#,
- "Projection: #id, #first_name, #last_name, #age, #state, #salary, #birth_date, #😀\
- \n Aggregate: groupBy=[[#id, #first_name, #last_name, #age, #state, #salary, #birth_date, #😀]], aggr=[[]]\
+ "Projection: #person.id, #person.first_name, #person.last_name, #person.age, #person.state, #person.salary, #person.birth_date, #person.😀\
+ \n Aggregate: groupBy=[[#person.id, #person.first_name, #person.last_name, #person.age, #person.state, #person.salary, #person.birth_date, #person.😀]], aggr=[[]]\
\n TableScan: person projection=None",
);
quick_test(
"SELECT * FROM (SELECT first_name, last_name FROM person) GROUP BY first_name, last_name",
- "Projection: #first_name, #last_name\
- \n Aggregate: groupBy=[[#first_name, #last_name]], aggr=[[]]\
- \n Projection: #first_name, #last_name\
+ "Projection: #person.first_name, #person.last_name\
+ \n Aggregate: groupBy=[[#person.first_name, #person.last_name]], aggr=[[]]\
+ \n Projection: #person.first_name, #person.last_name\
\n TableScan: person projection=None",
);
}
@@ -2178,8 +2180,8 @@ mod tests {
fn select_simple_aggregate() {
quick_test(
"SELECT MIN(age) FROM person",
- "Projection: #MIN(age)\
- \n Aggregate: groupBy=[[]], aggr=[[MIN(#age)]]\
+ "Projection: #MIN(person.age)\
+ \n Aggregate: groupBy=[[]], aggr=[[MIN(#person.age)]]\
\n TableScan: person projection=None",
);
}
@@ -2188,8 +2190,8 @@ mod tests {
fn test_sum_aggregate() {
quick_test(
"SELECT SUM(age) from person",
- "Projection: #SUM(age)\
- \n Aggregate: groupBy=[[]], aggr=[[SUM(#age)]]\
+ "Projection: #SUM(person.age)\
+ \n Aggregate: groupBy=[[]], aggr=[[SUM(#person.age)]]\
\n TableScan: person projection=None",
);
}
@@ -2198,13 +2200,10 @@ mod tests {
fn select_simple_aggregate_column_does_not_exist() {
let sql = "SELECT MIN(doesnotexist) FROM person";
let err = logical_plan(sql).expect_err("query should have failed");
- assert_eq!(
- format!(
- r#"Plan("Invalid identifier 'doesnotexist' for schema {}")"#,
- PERSON_COLUMN_NAMES
- ),
- format!("{:?}", err)
- );
+ assert!(matches!(
+ err,
+ DataFusionError::Plan(msg) if msg == "No field with unqualified name 'doesnotexist'",
+ ));
}
#[test]
@@ -2212,7 +2211,7 @@ mod tests {
let sql = "SELECT MIN(age), MIN(age) FROM person";
let err = logical_plan(sql).expect_err("query should have failed");
assert_eq!(
- r##"Plan("Projections require unique expression names but the expression \"#MIN(age)\" at position 0 and \"#MIN(age)\" at position 1 have the same name. Consider aliasing (\"AS\") one of them.")"##,
+ r##"Plan("Projections require unique expression names but the expression \"MIN(#person.age)\" at position 0 and \"MIN(#person.age)\" at position 1 have the same name. Consider aliasing (\"AS\") one of them.")"##,
format!("{:?}", err)
);
}
@@ -2221,8 +2220,8 @@ mod tests {
fn select_simple_aggregate_repeated_aggregate_with_single_alias() {
quick_test(
"SELECT MIN(age), MIN(age) AS a FROM person",
- "Projection: #MIN(age), #MIN(age) AS a\
- \n Aggregate: groupBy=[[]], aggr=[[MIN(#age)]]\
+ "Projection: #MIN(person.age), #MIN(person.age) AS a\
+ \n Aggregate: groupBy=[[]], aggr=[[MIN(#person.age)]]\
\n TableScan: person projection=None",
);
}
@@ -2231,8 +2230,8 @@ mod tests {
fn select_simple_aggregate_repeated_aggregate_with_unique_aliases() {
quick_test(
"SELECT MIN(age) AS a, MIN(age) AS b FROM person",
- "Projection: #MIN(age) AS a, #MIN(age) AS b\
- \n Aggregate: groupBy=[[]], aggr=[[MIN(#age)]]\
+ "Projection: #MIN(person.age) AS a, #MIN(person.age) AS b\
+ \n Aggregate: groupBy=[[]], aggr=[[MIN(#person.age)]]\
\n TableScan: person projection=None",
);
}
@@ -2242,7 +2241,7 @@ mod tests {
let sql = "SELECT MIN(age) AS a, MIN(age) AS a FROM person";
let err = logical_plan(sql).expect_err("query should have failed");
assert_eq!(
- r##"Plan("Projections require unique expression names but the expression \"#MIN(age) AS a\" at position 0 and \"#MIN(age) AS a\" at position 1 have the same name. Consider aliasing (\"AS\") one of them.")"##,
+ r##"Plan("Projections require unique expression names but the expression \"MIN(#person.age) AS a\" at position 0 and \"MIN(#person.age) AS a\" at position 1 have the same name. Consider aliasing (\"AS\") one of them.")"##,
format!("{:?}", err)
);
}
@@ -2251,8 +2250,8 @@ mod tests {
fn select_simple_aggregate_with_groupby() {
quick_test(
"SELECT state, MIN(age), MAX(age) FROM person GROUP BY state",
- "Projection: #state, #MIN(age), #MAX(age)\
- \n Aggregate: groupBy=[[#state]], aggr=[[MIN(#age), MAX(#age)]]\
+ "Projection: #person.state, #MIN(person.age), #MAX(person.age)\
+ \n Aggregate: groupBy=[[#person.state]], aggr=[[MIN(#person.age), MAX(#person.age)]]\
\n TableScan: person projection=None",
);
}
@@ -2261,8 +2260,8 @@ mod tests {
fn select_simple_aggregate_with_groupby_with_aliases() {
quick_test(
"SELECT state AS a, MIN(age) AS b FROM person GROUP BY state",
- "Projection: #state AS a, #MIN(age) AS b\
- \n Aggregate: groupBy=[[#state]], aggr=[[MIN(#age)]]\
+ "Projection: #person.state AS a, #MIN(person.age) AS b\
+ \n Aggregate: groupBy=[[#person.state]], aggr=[[MIN(#person.age)]]\
\n TableScan: person projection=None",
);
}
@@ -2272,7 +2271,7 @@ mod tests {
let sql = "SELECT state AS a, MIN(age) AS a FROM person GROUP BY state";
let err = logical_plan(sql).expect_err("query should have failed");
assert_eq!(
- r##"Plan("Projections require unique expression names but the expression \"#state AS a\" at position 0 and \"#MIN(age) AS a\" at position 1 have the same name. Consider aliasing (\"AS\") one of them.")"##,
+ r##"Plan("Projections require unique expression names but the expression \"#person.state AS a\" at position 0 and \"MIN(#person.age) AS a\" at position 1 have the same name. Consider aliasing (\"AS\") one of them.")"##,
format!("{:?}", err)
);
}
@@ -2281,8 +2280,8 @@ mod tests {
fn select_simple_aggregate_with_groupby_column_unselected() {
quick_test(
"SELECT MIN(age), MAX(age) FROM person GROUP BY state",
- "Projection: #MIN(age), #MAX(age)\
- \n Aggregate: groupBy=[[#state]], aggr=[[MIN(#age), MAX(#age)]]\
+ "Projection: #MIN(person.age), #MAX(person.age)\
+ \n Aggregate: groupBy=[[#person.state]], aggr=[[MIN(#person.age), MAX(#person.age)]]\
\n TableScan: person projection=None",
);
}
@@ -2291,26 +2290,20 @@ mod tests {
fn select_simple_aggregate_with_groupby_and_column_in_group_by_does_not_exist() {
let sql = "SELECT SUM(age) FROM person GROUP BY doesnotexist";
let err = logical_plan(sql).expect_err("query should have failed");
- assert_eq!(
- format!(
- r#"Plan("Invalid identifier 'doesnotexist' for schema {}")"#,
- PERSON_COLUMN_NAMES
- ),
- format!("{:?}", err)
- );
+ assert!(matches!(
+ err,
+ DataFusionError::Plan(msg) if msg == "No field with unqualified name 'doesnotexist'",
+ ));
}
#[test]
fn select_simple_aggregate_with_groupby_and_column_in_aggregate_does_not_exist() {
let sql = "SELECT SUM(doesnotexist) FROM person GROUP BY first_name";
let err = logical_plan(sql).expect_err("query should have failed");
- assert_eq!(
- format!(
- r#"Plan("Invalid identifier 'doesnotexist' for schema {}")"#,
- PERSON_COLUMN_NAMES
- ),
- format!("{:?}", err)
- );
+ assert!(matches!(
+ err,
+ DataFusionError::Plan(msg) if msg == "No field with unqualified name 'doesnotexist'",
+ ));
}
#[test]
@@ -2327,18 +2320,18 @@ mod tests {
fn select_unsupported_complex_interval() {
let sql = "SELECT INTERVAL '1 year 1 day'";
let err = logical_plan(sql).expect_err("query should have failed");
- assert_eq!(
- r#"NotImplemented("DF does not support intervals that have both a Year/Month part as well as Days/Hours/Mins/Seconds: \"1 year 1 day\". Hint: try breaking the interval into two parts, one with Year/Month and the other with Days/Hours/Mins/Seconds - e.g. (NOW() + INTERVAL '1 year') + INTERVAL '1 day'")"#,
- format!("{:?}", err)
- );
+ assert!(matches!(
+ err,
+ DataFusionError::NotImplemented(msg) if msg == "DF does not support intervals that have both a Year/Month part as well as Days/Hours/Mins/Seconds: \"1 year 1 day\". Hint: try breaking the interval into two parts, one with Year/Month and the other with Days/Hours/Mins/Seconds - e.g. (NOW() + INTERVAL '1 year') + INTERVAL '1 day'",
+ ));
}
#[test]
fn select_simple_aggregate_with_groupby_and_column_is_in_aggregate_and_groupby() {
quick_test(
"SELECT MAX(first_name) FROM person GROUP BY first_name",
- "Projection: #MAX(first_name)\
- \n Aggregate: groupBy=[[#first_name]], aggr=[[MAX(#first_name)]]\
+ "Projection: #MAX(person.first_name)\
+ \n Aggregate: groupBy=[[#person.first_name]], aggr=[[MAX(#person.first_name)]]\
\n TableScan: person projection=None",
);
}
@@ -2347,14 +2340,14 @@ mod tests {
fn select_simple_aggregate_with_groupby_can_use_positions() {
quick_test(
"SELECT state, age AS b, COUNT(1) FROM person GROUP BY 1, 2",
- "Projection: #state, #age AS b, #COUNT(UInt8(1))\
- \n Aggregate: groupBy=[[#state, #age]], aggr=[[COUNT(UInt8(1))]]\
+ "Projection: #person.state, #person.age AS b, #COUNT(UInt8(1))\
+ \n Aggregate: groupBy=[[#person.state, #person.age]], aggr=[[COUNT(UInt8(1))]]\
\n TableScan: person projection=None",
);
quick_test(
"SELECT state, age AS b, COUNT(1) FROM person GROUP BY 2, 1",
- "Projection: #state, #age AS b, #COUNT(UInt8(1))\
- \n Aggregate: groupBy=[[#age, #state]], aggr=[[COUNT(UInt8(1))]]\
+ "Projection: #person.state, #person.age AS b, #COUNT(UInt8(1))\
+ \n Aggregate: groupBy=[[#person.age, #person.state]], aggr=[[COUNT(UInt8(1))]]\
\n TableScan: person projection=None",
);
}
@@ -2380,8 +2373,8 @@ mod tests {
fn select_simple_aggregate_with_groupby_can_use_alias() {
quick_test(
"SELECT state AS a, MIN(age) AS b FROM person GROUP BY a",
- "Projection: #state AS a, #MIN(age) AS b\
- \n Aggregate: groupBy=[[#state]], aggr=[[MIN(#age)]]\
+ "Projection: #person.state AS a, #MIN(person.age) AS b\
+ \n Aggregate: groupBy=[[#person.state]], aggr=[[MIN(#person.age)]]\
\n TableScan: person projection=None",
... 1095 lines suppressed ...