You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ag...@apache.org on 2023/05/10 12:54:00 UTC
[arrow-datafusion-python] branch main updated: Expand Expr to include RexType basic support (#378)
This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-python.git
The following commit(s) were added to refs/heads/main by this push:
new 433dbca Expand Expr to include RexType basic support (#378)
433dbca is described below
commit 433dbca009a25a78e698d2c703ef9f1e4177dae0
Author: Jeremy Dyer <jd...@gmail.com>
AuthorDate: Wed May 10 08:53:55 2023 -0400
Expand Expr to include RexType basic support (#378)
* Make expr member of PyExpr public
* Add RexType to Expr
* Add utility functions for mapping ScalarValue instances to DataTypeMap instances
* Add function to get python_value from Expr instance
* Fix syntax problems
* Add function to get the operands for a Rex::Call
* Add function to get operator for RexType::Call
* expand types function to include variant support for BinaryExpr
* Add variant coverage for Decimal128 and Decimal256
* add function for getting the column name of an Expr from a LogicalPlan
* Make PyProjection::projection member public
* Add projected_expressions to projection node
* Adjust function signature
* Add Distinct variant to to_variant function in PyLogicalPlan
* Fill in variants for DataType::Timestamp
* Address syntax issues
* Refactor types() function to extend support for CAST
* Update CAST variant handling
* Cargo fmt
* Cargo clippy
* Coverage for INTERVAL in DataType
* More cargo fmt changes
---
src/common/data_type.rs | 119 +++++++++++++++----
src/expr.rs | 300 +++++++++++++++++++++++++++++++++++++++++++++++-
src/expr/projection.rs | 18 ++-
src/sql/logical.rs | 2 +
4 files changed, 414 insertions(+), 25 deletions(-)
diff --git a/src/common/data_type.rs b/src/common/data_type.rs
index d55a0e8..622e1aa 100644
--- a/src/common/data_type.rs
+++ b/src/common/data_type.rs
@@ -15,8 +15,8 @@
// specific language governing permissions and limitations
// under the License.
-use datafusion::arrow::datatypes::DataType;
-use datafusion_common::DataFusionError;
+use datafusion::arrow::datatypes::{DataType, IntervalUnit, TimeUnit};
+use datafusion_common::{DataFusionError, ScalarValue};
use pyo3::prelude::*;
use crate::errors::py_datafusion_err;
@@ -130,9 +130,11 @@ impl DataTypeMap {
PythonType::Float,
SqlType::FLOAT,
)),
- DataType::Timestamp(_, _) => Err(py_datafusion_err(DataFusionError::NotImplemented(
- format!("{:?}", arrow_type),
- ))),
+ DataType::Timestamp(unit, tz) => Ok(DataTypeMap::new(
+ DataType::Timestamp(unit.clone(), tz.clone()),
+ PythonType::Datetime,
+ SqlType::DATE,
+ )),
DataType::Date32 => Ok(DataTypeMap::new(
DataType::Date32,
PythonType::Datetime,
@@ -143,18 +145,28 @@ impl DataTypeMap {
PythonType::Datetime,
SqlType::DATE,
)),
- DataType::Time32(_) => Err(py_datafusion_err(DataFusionError::NotImplemented(
- format!("{:?}", arrow_type),
- ))),
- DataType::Time64(_) => Err(py_datafusion_err(DataFusionError::NotImplemented(
- format!("{:?}", arrow_type),
- ))),
+ DataType::Time32(unit) => Ok(DataTypeMap::new(
+ DataType::Time32(unit.clone()),
+ PythonType::Datetime,
+ SqlType::DATE,
+ )),
+ DataType::Time64(unit) => Ok(DataTypeMap::new(
+ DataType::Time64(unit.clone()),
+ PythonType::Datetime,
+ SqlType::DATE,
+ )),
DataType::Duration(_) => Err(py_datafusion_err(DataFusionError::NotImplemented(
format!("{:?}", arrow_type),
))),
- DataType::Interval(_) => Err(py_datafusion_err(DataFusionError::NotImplemented(
- format!("{:?}", arrow_type),
- ))),
+ DataType::Interval(interval_unit) => Ok(DataTypeMap::new(
+ DataType::Interval(interval_unit.clone()),
+ PythonType::Datetime,
+ match interval_unit {
+ IntervalUnit::DayTime => SqlType::INTERVAL_DAY,
+ IntervalUnit::MonthDayNano => SqlType::INTERVAL_MONTH,
+ IntervalUnit::YearMonth => SqlType::INTERVAL_YEAR_MONTH,
+ },
+ )),
DataType::Binary => Ok(DataTypeMap::new(
DataType::Binary,
PythonType::Bytes,
@@ -197,12 +209,16 @@ impl DataTypeMap {
DataType::Dictionary(_, _) => Err(py_datafusion_err(DataFusionError::NotImplemented(
format!("{:?}", arrow_type),
))),
- DataType::Decimal128(_, _) => Err(py_datafusion_err(DataFusionError::NotImplemented(
- format!("{:?}", arrow_type),
- ))),
- DataType::Decimal256(_, _) => Err(py_datafusion_err(DataFusionError::NotImplemented(
- format!("{:?}", arrow_type),
- ))),
+ DataType::Decimal128(precision, scale) => Ok(DataTypeMap::new(
+ DataType::Decimal128(*precision, *scale),
+ PythonType::Float,
+ SqlType::DECIMAL,
+ )),
+ DataType::Decimal256(precision, scale) => Ok(DataTypeMap::new(
+ DataType::Decimal256(*precision, *scale),
+ PythonType::Float,
+ SqlType::DECIMAL,
+ )),
DataType::Map(_, _) => Err(py_datafusion_err(DataFusionError::NotImplemented(
format!("{:?}", arrow_type),
))),
@@ -211,6 +227,69 @@ impl DataTypeMap {
)),
}
}
+
+ /// Generate the `DataTypeMap` from a `ScalarValue` instance
+ pub fn map_from_scalar_value(scalar_val: &ScalarValue) -> Result<DataTypeMap, PyErr> {
+ DataTypeMap::map_from_arrow_type(&DataTypeMap::map_from_scalar_to_arrow(scalar_val)?)
+ }
+
+ /// Maps a `ScalarValue` to an Arrow `DataType`
+ pub fn map_from_scalar_to_arrow(scalar_val: &ScalarValue) -> Result<DataType, PyErr> {
+ match scalar_val {
+ ScalarValue::Boolean(_) => Ok(DataType::Boolean),
+ ScalarValue::Float32(_) => Ok(DataType::Float32),
+ ScalarValue::Float64(_) => Ok(DataType::Float64),
+ ScalarValue::Decimal128(_, precision, scale) => {
+ Ok(DataType::Decimal128(*precision, *scale))
+ }
+ ScalarValue::Dictionary(data_type, scalar_type) => {
+ // Call this function again to map the dictionary scalar_value to an Arrow type
+ Ok(DataType::Dictionary(
+ Box::new(*data_type.clone()),
+ Box::new(DataTypeMap::map_from_scalar_to_arrow(scalar_type)?),
+ ))
+ }
+ ScalarValue::Int8(_) => Ok(DataType::Int8),
+ ScalarValue::Int16(_) => Ok(DataType::Int16),
+ ScalarValue::Int32(_) => Ok(DataType::Int32),
+ ScalarValue::Int64(_) => Ok(DataType::Int64),
+ ScalarValue::UInt8(_) => Ok(DataType::UInt8),
+ ScalarValue::UInt16(_) => Ok(DataType::UInt16),
+ ScalarValue::UInt32(_) => Ok(DataType::UInt32),
+ ScalarValue::UInt64(_) => Ok(DataType::UInt64),
+ ScalarValue::Utf8(_) => Ok(DataType::Utf8),
+ ScalarValue::LargeUtf8(_) => Ok(DataType::LargeUtf8),
+ ScalarValue::Binary(_) => Ok(DataType::Binary),
+ ScalarValue::LargeBinary(_) => Ok(DataType::LargeBinary),
+ ScalarValue::Date32(_) => Ok(DataType::Date32),
+ ScalarValue::Date64(_) => Ok(DataType::Date64),
+ ScalarValue::Time32Second(_) => Ok(DataType::Time32(TimeUnit::Second)),
+ ScalarValue::Time32Millisecond(_) => Ok(DataType::Time32(TimeUnit::Millisecond)),
+ ScalarValue::Time64Microsecond(_) => Ok(DataType::Time64(TimeUnit::Microsecond)),
+ ScalarValue::Time64Nanosecond(_) => Ok(DataType::Time64(TimeUnit::Nanosecond)),
+ ScalarValue::Null => Ok(DataType::Null),
+ ScalarValue::TimestampSecond(_, tz) => {
+ Ok(DataType::Timestamp(TimeUnit::Second, tz.to_owned()))
+ }
+ ScalarValue::TimestampMillisecond(_, tz) => {
+ Ok(DataType::Timestamp(TimeUnit::Millisecond, tz.to_owned()))
+ }
+ ScalarValue::TimestampMicrosecond(_, tz) => {
+ Ok(DataType::Timestamp(TimeUnit::Microsecond, tz.to_owned()))
+ }
+ ScalarValue::TimestampNanosecond(_, tz) => {
+ Ok(DataType::Timestamp(TimeUnit::Nanosecond, tz.to_owned()))
+ }
+ ScalarValue::IntervalYearMonth(..) => Ok(DataType::Interval(IntervalUnit::YearMonth)),
+ ScalarValue::IntervalDayTime(..) => Ok(DataType::Interval(IntervalUnit::DayTime)),
+ ScalarValue::IntervalMonthDayNano(..) => {
+ Ok(DataType::Interval(IntervalUnit::MonthDayNano))
+ }
+ ScalarValue::List(_val, field_ref) => Ok(DataType::List(field_ref.to_owned())),
+ ScalarValue::Struct(_, fields) => Ok(DataType::Struct(fields.to_owned())),
+ ScalarValue::FixedSizeBinary(size, _) => Ok(DataType::FixedSizeBinary(*size)),
+ }
+ }
}
#[pymethods]
diff --git a/src/expr.rs b/src/expr.rs
index 4ada4c1..c002b32 100644
--- a/src/expr.rs
+++ b/src/expr.rs
@@ -15,19 +15,26 @@
// specific language governing permissions and limitations
// under the License.
+use datafusion_common::DFField;
+use datafusion_expr::expr::{AggregateFunction, Sort, WindowFunction};
+use datafusion_expr::utils::exprlist_to_fields;
use pyo3::{basic::CompareOp, prelude::*};
use std::convert::{From, Into};
use datafusion::arrow::datatypes::DataType;
use datafusion::arrow::pyarrow::PyArrowType;
-use datafusion_expr::{col, lit, Cast, Expr, GetIndexedField};
+use datafusion_expr::{
+ col, lit, Between, BinaryExpr, Case, Cast, Expr, GetIndexedField, Like, LogicalPlan, Operator,
+ TryCast,
+};
-use crate::common::data_type::RexType;
-use crate::errors::py_runtime_err;
+use crate::common::data_type::{DataTypeMap, RexType};
+use crate::errors::{py_runtime_err, py_type_err, DataFusionError};
use crate::expr::aggregate_expr::PyAggregateFunction;
use crate::expr::binary_expr::PyBinaryExpr;
use crate::expr::column::PyColumn;
use crate::expr::literal::PyLiteral;
+use crate::sql::logical::PyLogicalPlan;
use datafusion::scalar::ScalarValue;
use self::alias::PyAlias;
@@ -274,11 +281,296 @@ impl PyExpr {
Expr::ScalarSubquery(..) => RexType::ScalarSubquery,
})
}
+
+ /// Given the current `Expr` return the DataTypeMap which represents the
+ /// PythonType, Arrow DataType, and SqlType Enum which represents
+ pub fn types(&self) -> PyResult<DataTypeMap> {
+ Self::_types(&self.expr)
+ }
+
+ /// Extracts the Expr value into a PyObject that can be shared with Python
+ pub fn python_value(&self, py: Python) -> PyResult<PyObject> {
+ match &self.expr {
+ Expr::Literal(scalar_value) => Ok(match scalar_value {
+ ScalarValue::Null => todo!(),
+ ScalarValue::Boolean(v) => v.into_py(py),
+ ScalarValue::Float32(v) => v.into_py(py),
+ ScalarValue::Float64(v) => v.into_py(py),
+ ScalarValue::Decimal128(_, _, _) => todo!(),
+ ScalarValue::Int8(v) => v.into_py(py),
+ ScalarValue::Int16(v) => v.into_py(py),
+ ScalarValue::Int32(v) => v.into_py(py),
+ ScalarValue::Int64(v) => v.into_py(py),
+ ScalarValue::UInt8(v) => v.into_py(py),
+ ScalarValue::UInt16(v) => v.into_py(py),
+ ScalarValue::UInt32(v) => v.into_py(py),
+ ScalarValue::UInt64(v) => v.into_py(py),
+ ScalarValue::Utf8(v) => v.clone().into_py(py),
+ ScalarValue::LargeUtf8(v) => v.clone().into_py(py),
+ ScalarValue::Binary(v) => v.clone().into_py(py),
+ ScalarValue::FixedSizeBinary(_, _) => todo!(),
+ ScalarValue::LargeBinary(v) => v.clone().into_py(py),
+ ScalarValue::List(_, _) => todo!(),
+ ScalarValue::Date32(v) => v.into_py(py),
+ ScalarValue::Date64(v) => v.into_py(py),
+ ScalarValue::Time32Second(v) => v.into_py(py),
+ ScalarValue::Time32Millisecond(v) => v.into_py(py),
+ ScalarValue::Time64Microsecond(v) => v.into_py(py),
+ ScalarValue::Time64Nanosecond(v) => v.into_py(py),
+ ScalarValue::TimestampSecond(_, _) => todo!(),
+ ScalarValue::TimestampMillisecond(_, _) => todo!(),
+ ScalarValue::TimestampMicrosecond(_, _) => todo!(),
+ ScalarValue::TimestampNanosecond(_, _) => todo!(),
+ ScalarValue::IntervalYearMonth(v) => v.into_py(py),
+ ScalarValue::IntervalDayTime(v) => v.into_py(py),
+ ScalarValue::IntervalMonthDayNano(v) => v.into_py(py),
+ ScalarValue::Struct(_, _) => todo!(),
+ ScalarValue::Dictionary(_, _) => todo!(),
+ }),
+ _ => Err(py_type_err(format!(
+ "Non Expr::Literal encountered in types: {:?}",
+ &self.expr
+ ))),
+ }
+ }
+
+ /// Row expressions, Rex(s), operate on the concept of operands. Different variants of Expressions, Expr(s),
+ /// store those operands in different datastructures. This function examines the Expr variant and returns
+ /// the operands to the calling logic as a Vec of PyExpr instances.
+ pub fn rex_call_operands(&self) -> PyResult<Vec<PyExpr>> {
+ match &self.expr {
+ // Expr variants that are themselves the operand to return
+ Expr::Column(..) | Expr::ScalarVariable(..) | Expr::Literal(..) => {
+ Ok(vec![PyExpr::from(self.expr.clone())])
+ }
+
+ // Expr(s) that house the Expr instance to return in their bounded params
+ Expr::Alias(expr, ..)
+ | Expr::Not(expr)
+ | Expr::IsNull(expr)
+ | Expr::IsNotNull(expr)
+ | Expr::IsTrue(expr)
+ | Expr::IsFalse(expr)
+ | Expr::IsUnknown(expr)
+ | Expr::IsNotTrue(expr)
+ | Expr::IsNotFalse(expr)
+ | Expr::IsNotUnknown(expr)
+ | Expr::Negative(expr)
+ | Expr::GetIndexedField(GetIndexedField { expr, .. })
+ | Expr::Cast(Cast { expr, .. })
+ | Expr::TryCast(TryCast { expr, .. })
+ | Expr::Sort(Sort { expr, .. })
+ | Expr::InSubquery { expr, .. } => Ok(vec![PyExpr::from(*expr.clone())]),
+
+ // Expr variants containing a collection of Expr(s) for operands
+ Expr::AggregateFunction(AggregateFunction { args, .. })
+ | Expr::AggregateUDF { args, .. }
+ | Expr::ScalarFunction { args, .. }
+ | Expr::ScalarUDF { args, .. }
+ | Expr::WindowFunction(WindowFunction { args, .. }) => {
+ Ok(args.iter().map(|arg| PyExpr::from(arg.clone())).collect())
+ }
+
+ // Expr(s) that require more specific processing
+ Expr::Case(Case {
+ expr,
+ when_then_expr,
+ else_expr,
+ }) => {
+ let mut operands: Vec<PyExpr> = Vec::new();
+
+ if let Some(e) = expr {
+ operands.push(PyExpr::from(*e.clone()));
+ };
+
+ for (when, then) in when_then_expr {
+ operands.push(PyExpr::from(*when.clone()));
+ operands.push(PyExpr::from(*then.clone()));
+ }
+
+ if let Some(e) = else_expr {
+ operands.push(PyExpr::from(*e.clone()));
+ };
+
+ Ok(operands)
+ }
+ Expr::InList { expr, list, .. } => {
+ let mut operands: Vec<PyExpr> = vec![PyExpr::from(*expr.clone())];
+ for list_elem in list {
+ operands.push(PyExpr::from(list_elem.clone()));
+ }
+
+ Ok(operands)
+ }
+ Expr::BinaryExpr(BinaryExpr { left, right, .. }) => Ok(vec![
+ PyExpr::from(*left.clone()),
+ PyExpr::from(*right.clone()),
+ ]),
+ Expr::Like(Like { expr, pattern, .. }) => Ok(vec![
+ PyExpr::from(*expr.clone()),
+ PyExpr::from(*pattern.clone()),
+ ]),
+ Expr::ILike(Like { expr, pattern, .. }) => Ok(vec![
+ PyExpr::from(*expr.clone()),
+ PyExpr::from(*pattern.clone()),
+ ]),
+ Expr::SimilarTo(Like { expr, pattern, .. }) => Ok(vec![
+ PyExpr::from(*expr.clone()),
+ PyExpr::from(*pattern.clone()),
+ ]),
+ Expr::Between(Between {
+ expr,
+ negated: _,
+ low,
+ high,
+ }) => Ok(vec![
+ PyExpr::from(*expr.clone()),
+ PyExpr::from(*low.clone()),
+ PyExpr::from(*high.clone()),
+ ]),
+
+ // Currently un-support/implemented Expr types for Rex Call operations
+ Expr::GroupingSet(..)
+ | Expr::OuterReferenceColumn(_, _)
+ | Expr::Wildcard
+ | Expr::QualifiedWildcard { .. }
+ | Expr::ScalarSubquery(..)
+ | Expr::Placeholder { .. }
+ | Expr::Exists { .. } => Err(py_runtime_err(format!(
+ "Unimplemented Expr type: {}",
+ self.expr
+ ))),
+ }
+ }
+
+ /// Extracts the operator associated with a RexType::Call
+ pub fn rex_call_operator(&self) -> PyResult<String> {
+ Ok(match &self.expr {
+ Expr::BinaryExpr(BinaryExpr {
+ left: _,
+ op,
+ right: _,
+ }) => format!("{op}"),
+ Expr::ScalarFunction { fun, args: _ } => format!("{fun}"),
+ Expr::ScalarUDF { fun, .. } => fun.name.clone(),
+ Expr::Cast { .. } => "cast".to_string(),
+ Expr::Between { .. } => "between".to_string(),
+ Expr::Case { .. } => "case".to_string(),
+ Expr::IsNull(..) => "is null".to_string(),
+ Expr::IsNotNull(..) => "is not null".to_string(),
+ Expr::IsTrue(_) => "is true".to_string(),
+ Expr::IsFalse(_) => "is false".to_string(),
+ Expr::IsUnknown(_) => "is unknown".to_string(),
+ Expr::IsNotTrue(_) => "is not true".to_string(),
+ Expr::IsNotFalse(_) => "is not false".to_string(),
+ Expr::IsNotUnknown(_) => "is not unknown".to_string(),
+ Expr::InList { .. } => "in list".to_string(),
+ Expr::Negative(..) => "negative".to_string(),
+ Expr::Not(..) => "not".to_string(),
+ Expr::Like(Like { negated, .. }) => {
+ if *negated {
+ "not like".to_string()
+ } else {
+ "like".to_string()
+ }
+ }
+ Expr::ILike(Like { negated, .. }) => {
+ if *negated {
+ "not ilike".to_string()
+ } else {
+ "ilike".to_string()
+ }
+ }
+ Expr::SimilarTo(Like { negated, .. }) => {
+ if *negated {
+ "not similar to".to_string()
+ } else {
+ "similar to".to_string()
+ }
+ }
+ _ => {
+ return Err(py_type_err(format!(
+ "Catch all triggered in get_operator_name: {:?}",
+ &self.expr
+ )))
+ }
+ })
+ }
+
+ pub fn column_name(&self, plan: PyLogicalPlan) -> PyResult<String> {
+ self._column_name(&plan.plan()).map_err(py_runtime_err)
+ }
+}
+
+impl PyExpr {
+ pub fn _column_name(&self, plan: &LogicalPlan) -> Result<String, DataFusionError> {
+ let field = Self::expr_to_field(&self.expr, plan)?;
+ Ok(field.qualified_column().flat_name())
+ }
+
+ /// Create a [DFField] representing an [Expr], given an input [LogicalPlan] to resolve against
+ pub fn expr_to_field(
+ expr: &Expr,
+ input_plan: &LogicalPlan,
+ ) -> Result<DFField, DataFusionError> {
+ match expr {
+ Expr::Sort(Sort { expr, .. }) => {
+ // DataFusion does not support create_name for sort expressions (since they never
+ // appear in projections) so we just delegate to the contained expression instead
+ Self::expr_to_field(expr, input_plan)
+ }
+ _ => {
+ let fields =
+ exprlist_to_fields(&[expr.clone()], input_plan).map_err(PyErr::from)?;
+ Ok(fields[0].clone())
+ }
+ }
+ }
+
+ fn _types(expr: &Expr) -> PyResult<DataTypeMap> {
+ match expr {
+ Expr::BinaryExpr(BinaryExpr {
+ left: _,
+ op,
+ right: _,
+ }) => match op {
+ Operator::Eq
+ | Operator::NotEq
+ | Operator::Lt
+ | Operator::LtEq
+ | Operator::Gt
+ | Operator::GtEq
+ | Operator::And
+ | Operator::Or
+ | Operator::IsDistinctFrom
+ | Operator::IsNotDistinctFrom
+ | Operator::RegexMatch
+ | Operator::RegexIMatch
+ | Operator::RegexNotMatch
+ | Operator::RegexNotIMatch => DataTypeMap::map_from_arrow_type(&DataType::Boolean),
+ Operator::Plus | Operator::Minus | Operator::Multiply | Operator::Modulo => {
+ DataTypeMap::map_from_arrow_type(&DataType::Int64)
+ }
+ Operator::Divide => DataTypeMap::map_from_arrow_type(&DataType::Float64),
+ Operator::StringConcat => DataTypeMap::map_from_arrow_type(&DataType::Utf8),
+ Operator::BitwiseShiftLeft
+ | Operator::BitwiseShiftRight
+ | Operator::BitwiseXor
+ | Operator::BitwiseAnd
+ | Operator::BitwiseOr => DataTypeMap::map_from_arrow_type(&DataType::Binary),
+ },
+ Expr::Cast(Cast { expr: _, data_type }) => DataTypeMap::map_from_arrow_type(data_type),
+ Expr::Literal(scalar_value) => DataTypeMap::map_from_scalar_value(scalar_value),
+ _ => Err(py_type_err(format!(
+ "Non Expr::Literal encountered in types: {:?}",
+ expr
+ ))),
+ }
+ }
}
/// Initializes the `expr` module to match the pattern of `datafusion-expr` https://docs.rs/datafusion-expr/latest/datafusion_expr/
pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
- // expressions
m.add_class::<PyExpr>()?;
m.add_class::<PyColumn>()?;
m.add_class::<PyLiteral>()?;
diff --git a/src/expr/projection.rs b/src/expr/projection.rs
index f5ba12d..b329661 100644
--- a/src/expr/projection.rs
+++ b/src/expr/projection.rs
@@ -16,6 +16,7 @@
// under the License.
use datafusion_expr::logical_plan::Projection;
+use datafusion_expr::Expr;
use pyo3::prelude::*;
use std::fmt::{self, Display, Formatter};
@@ -27,7 +28,7 @@ use crate::sql::logical::PyLogicalPlan;
#[pyclass(name = "Projection", module = "datafusion.expr", subclass)]
#[derive(Clone)]
pub struct PyProjection {
- projection: Projection,
+ pub projection: Projection,
}
impl PyProjection {
@@ -92,6 +93,21 @@ impl PyProjection {
}
}
+impl PyProjection {
+ /// Projection: Gets the names of the fields that should be projected
+ pub fn projected_expressions(local_expr: &PyExpr) -> Vec<PyExpr> {
+ let mut projs: Vec<PyExpr> = Vec::new();
+ match &local_expr.expr {
+ Expr::Alias(expr, _name) => {
+ let py_expr: PyExpr = PyExpr::from(*expr.clone());
+ projs.extend_from_slice(Self::projected_expressions(&py_expr).as_slice());
+ }
+ _ => projs.push(local_expr.clone()),
+ }
+ projs
+ }
+}
+
impl LogicalNode for PyProjection {
fn inputs(&self) -> Vec<PyLogicalPlan> {
vec![PyLogicalPlan::from((*self.projection.input).clone())]
diff --git a/src/sql/logical.rs b/src/sql/logical.rs
index a75315d..07a3f65 100644
--- a/src/sql/logical.rs
+++ b/src/sql/logical.rs
@@ -20,6 +20,7 @@ use std::sync::Arc;
use crate::errors::py_unsupported_variant_err;
use crate::expr::aggregate::PyAggregate;
use crate::expr::analyze::PyAnalyze;
+use crate::expr::distinct::PyDistinct;
use crate::expr::empty_relation::PyEmptyRelation;
use crate::expr::explain::PyExplain;
use crate::expr::extension::PyExtension;
@@ -62,6 +63,7 @@ impl PyLogicalPlan {
LogicalPlan::EmptyRelation(plan) => PyEmptyRelation::from(plan.clone()).to_variant(py),
LogicalPlan::Explain(plan) => PyExplain::from(plan.clone()).to_variant(py),
LogicalPlan::Extension(plan) => PyExtension::from(plan.clone()).to_variant(py),
+ LogicalPlan::Distinct(plan) => PyDistinct::from(plan.clone()).to_variant(py),
LogicalPlan::Filter(plan) => PyFilter::from(plan.clone()).to_variant(py),
LogicalPlan::Limit(plan) => PyLimit::from(plan.clone()).to_variant(py),
LogicalPlan::Projection(plan) => PyProjection::from(plan.clone()).to_variant(py),