You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ag...@apache.org on 2023/05/10 12:54:00 UTC

[arrow-datafusion-python] branch main updated: Expand Expr to include RexType basic support (#378)

This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-python.git


The following commit(s) were added to refs/heads/main by this push:
     new 433dbca  Expand Expr to include RexType basic support (#378)
433dbca is described below

commit 433dbca009a25a78e698d2c703ef9f1e4177dae0
Author: Jeremy Dyer <jd...@gmail.com>
AuthorDate: Wed May 10 08:53:55 2023 -0400

    Expand Expr to include RexType basic support (#378)
    
    * Make expr member of PyExpr public
    
    * Add RexType to Expr
    
    * Add utility functions for mapping ScalarValue instances to DataTypeMap instances
    
    * Add function to get python_value from Expr instance
    
    * Fix syntax problems
    
    * Add function to get the operands for a Rex::Call
    
    * Add function to get operator for RexType::Call
    
    * expand types function to include variant support for BinaryExpr
    
    * Add variant coverage for Decimal128 and Decimal256
    
    * add function for getting the column name of an Expr from a LogicalPlan
    
    * Make PyProjection::projection member public
    
    * Add projected_expressions to projection node
    
    * Adjust function signature
    
    * Add Distinct variant to to_variant function in PyLogicalPlan
    
    * Fill in variants for DataType::Timestamp
    
    * Address syntax issues
    
    * Refactor types() function to extend support for CAST
    
    * Update CAST variant handling
    
    * Cargo fmt
    
    * Cargo clippy
    
    * Coverage for INTERVAL in DataType
    
    * More cargo fmt changes
---
 src/common/data_type.rs | 119 +++++++++++++++----
 src/expr.rs             | 300 +++++++++++++++++++++++++++++++++++++++++++++++-
 src/expr/projection.rs  |  18 ++-
 src/sql/logical.rs      |   2 +
 4 files changed, 414 insertions(+), 25 deletions(-)

diff --git a/src/common/data_type.rs b/src/common/data_type.rs
index d55a0e8..622e1aa 100644
--- a/src/common/data_type.rs
+++ b/src/common/data_type.rs
@@ -15,8 +15,8 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use datafusion::arrow::datatypes::DataType;
-use datafusion_common::DataFusionError;
+use datafusion::arrow::datatypes::{DataType, IntervalUnit, TimeUnit};
+use datafusion_common::{DataFusionError, ScalarValue};
 use pyo3::prelude::*;
 
 use crate::errors::py_datafusion_err;
@@ -130,9 +130,11 @@ impl DataTypeMap {
                 PythonType::Float,
                 SqlType::FLOAT,
             )),
-            DataType::Timestamp(_, _) => Err(py_datafusion_err(DataFusionError::NotImplemented(
-                format!("{:?}", arrow_type),
-            ))),
+            DataType::Timestamp(unit, tz) => Ok(DataTypeMap::new(
+                DataType::Timestamp(unit.clone(), tz.clone()),
+                PythonType::Datetime,
+                SqlType::DATE,
+            )),
             DataType::Date32 => Ok(DataTypeMap::new(
                 DataType::Date32,
                 PythonType::Datetime,
@@ -143,18 +145,28 @@ impl DataTypeMap {
                 PythonType::Datetime,
                 SqlType::DATE,
             )),
-            DataType::Time32(_) => Err(py_datafusion_err(DataFusionError::NotImplemented(
-                format!("{:?}", arrow_type),
-            ))),
-            DataType::Time64(_) => Err(py_datafusion_err(DataFusionError::NotImplemented(
-                format!("{:?}", arrow_type),
-            ))),
+            DataType::Time32(unit) => Ok(DataTypeMap::new(
+                DataType::Time32(unit.clone()),
+                PythonType::Datetime,
+                SqlType::DATE,
+            )),
+            DataType::Time64(unit) => Ok(DataTypeMap::new(
+                DataType::Time64(unit.clone()),
+                PythonType::Datetime,
+                SqlType::DATE,
+            )),
             DataType::Duration(_) => Err(py_datafusion_err(DataFusionError::NotImplemented(
                 format!("{:?}", arrow_type),
             ))),
-            DataType::Interval(_) => Err(py_datafusion_err(DataFusionError::NotImplemented(
-                format!("{:?}", arrow_type),
-            ))),
+            DataType::Interval(interval_unit) => Ok(DataTypeMap::new(
+                DataType::Interval(interval_unit.clone()),
+                PythonType::Datetime,
+                match interval_unit {
+                    IntervalUnit::DayTime => SqlType::INTERVAL_DAY,
+                    IntervalUnit::MonthDayNano => SqlType::INTERVAL_MONTH,
+                    IntervalUnit::YearMonth => SqlType::INTERVAL_YEAR_MONTH,
+                },
+            )),
             DataType::Binary => Ok(DataTypeMap::new(
                 DataType::Binary,
                 PythonType::Bytes,
@@ -197,12 +209,16 @@ impl DataTypeMap {
             DataType::Dictionary(_, _) => Err(py_datafusion_err(DataFusionError::NotImplemented(
                 format!("{:?}", arrow_type),
             ))),
-            DataType::Decimal128(_, _) => Err(py_datafusion_err(DataFusionError::NotImplemented(
-                format!("{:?}", arrow_type),
-            ))),
-            DataType::Decimal256(_, _) => Err(py_datafusion_err(DataFusionError::NotImplemented(
-                format!("{:?}", arrow_type),
-            ))),
+            DataType::Decimal128(precision, scale) => Ok(DataTypeMap::new(
+                DataType::Decimal128(*precision, *scale),
+                PythonType::Float,
+                SqlType::DECIMAL,
+            )),
+            DataType::Decimal256(precision, scale) => Ok(DataTypeMap::new(
+                DataType::Decimal256(*precision, *scale),
+                PythonType::Float,
+                SqlType::DECIMAL,
+            )),
             DataType::Map(_, _) => Err(py_datafusion_err(DataFusionError::NotImplemented(
                 format!("{:?}", arrow_type),
             ))),
@@ -211,6 +227,69 @@ impl DataTypeMap {
             )),
         }
     }
+
+    /// Generate the `DataTypeMap` from a `ScalarValue` instance
+    pub fn map_from_scalar_value(scalar_val: &ScalarValue) -> Result<DataTypeMap, PyErr> {
+        DataTypeMap::map_from_arrow_type(&DataTypeMap::map_from_scalar_to_arrow(scalar_val)?)
+    }
+
+    /// Maps a `ScalarValue` to an Arrow `DataType`
+    pub fn map_from_scalar_to_arrow(scalar_val: &ScalarValue) -> Result<DataType, PyErr> {
+        match scalar_val {
+            ScalarValue::Boolean(_) => Ok(DataType::Boolean),
+            ScalarValue::Float32(_) => Ok(DataType::Float32),
+            ScalarValue::Float64(_) => Ok(DataType::Float64),
+            ScalarValue::Decimal128(_, precision, scale) => {
+                Ok(DataType::Decimal128(*precision, *scale))
+            }
+            ScalarValue::Dictionary(data_type, scalar_type) => {
+                // Call this function again to map the dictionary scalar_value to an Arrow type
+                Ok(DataType::Dictionary(
+                    Box::new(*data_type.clone()),
+                    Box::new(DataTypeMap::map_from_scalar_to_arrow(scalar_type)?),
+                ))
+            }
+            ScalarValue::Int8(_) => Ok(DataType::Int8),
+            ScalarValue::Int16(_) => Ok(DataType::Int16),
+            ScalarValue::Int32(_) => Ok(DataType::Int32),
+            ScalarValue::Int64(_) => Ok(DataType::Int64),
+            ScalarValue::UInt8(_) => Ok(DataType::UInt8),
+            ScalarValue::UInt16(_) => Ok(DataType::UInt16),
+            ScalarValue::UInt32(_) => Ok(DataType::UInt32),
+            ScalarValue::UInt64(_) => Ok(DataType::UInt64),
+            ScalarValue::Utf8(_) => Ok(DataType::Utf8),
+            ScalarValue::LargeUtf8(_) => Ok(DataType::LargeUtf8),
+            ScalarValue::Binary(_) => Ok(DataType::Binary),
+            ScalarValue::LargeBinary(_) => Ok(DataType::LargeBinary),
+            ScalarValue::Date32(_) => Ok(DataType::Date32),
+            ScalarValue::Date64(_) => Ok(DataType::Date64),
+            ScalarValue::Time32Second(_) => Ok(DataType::Time32(TimeUnit::Second)),
+            ScalarValue::Time32Millisecond(_) => Ok(DataType::Time32(TimeUnit::Millisecond)),
+            ScalarValue::Time64Microsecond(_) => Ok(DataType::Time64(TimeUnit::Microsecond)),
+            ScalarValue::Time64Nanosecond(_) => Ok(DataType::Time64(TimeUnit::Nanosecond)),
+            ScalarValue::Null => Ok(DataType::Null),
+            ScalarValue::TimestampSecond(_, tz) => {
+                Ok(DataType::Timestamp(TimeUnit::Second, tz.to_owned()))
+            }
+            ScalarValue::TimestampMillisecond(_, tz) => {
+                Ok(DataType::Timestamp(TimeUnit::Millisecond, tz.to_owned()))
+            }
+            ScalarValue::TimestampMicrosecond(_, tz) => {
+                Ok(DataType::Timestamp(TimeUnit::Microsecond, tz.to_owned()))
+            }
+            ScalarValue::TimestampNanosecond(_, tz) => {
+                Ok(DataType::Timestamp(TimeUnit::Nanosecond, tz.to_owned()))
+            }
+            ScalarValue::IntervalYearMonth(..) => Ok(DataType::Interval(IntervalUnit::YearMonth)),
+            ScalarValue::IntervalDayTime(..) => Ok(DataType::Interval(IntervalUnit::DayTime)),
+            ScalarValue::IntervalMonthDayNano(..) => {
+                Ok(DataType::Interval(IntervalUnit::MonthDayNano))
+            }
+            ScalarValue::List(_val, field_ref) => Ok(DataType::List(field_ref.to_owned())),
+            ScalarValue::Struct(_, fields) => Ok(DataType::Struct(fields.to_owned())),
+            ScalarValue::FixedSizeBinary(size, _) => Ok(DataType::FixedSizeBinary(*size)),
+        }
+    }
 }
 
 #[pymethods]
diff --git a/src/expr.rs b/src/expr.rs
index 4ada4c1..c002b32 100644
--- a/src/expr.rs
+++ b/src/expr.rs
@@ -15,19 +15,26 @@
 // specific language governing permissions and limitations
 // under the License.
 
+use datafusion_common::DFField;
+use datafusion_expr::expr::{AggregateFunction, Sort, WindowFunction};
+use datafusion_expr::utils::exprlist_to_fields;
 use pyo3::{basic::CompareOp, prelude::*};
 use std::convert::{From, Into};
 
 use datafusion::arrow::datatypes::DataType;
 use datafusion::arrow::pyarrow::PyArrowType;
-use datafusion_expr::{col, lit, Cast, Expr, GetIndexedField};
+use datafusion_expr::{
+    col, lit, Between, BinaryExpr, Case, Cast, Expr, GetIndexedField, Like, LogicalPlan, Operator,
+    TryCast,
+};
 
-use crate::common::data_type::RexType;
-use crate::errors::py_runtime_err;
+use crate::common::data_type::{DataTypeMap, RexType};
+use crate::errors::{py_runtime_err, py_type_err, DataFusionError};
 use crate::expr::aggregate_expr::PyAggregateFunction;
 use crate::expr::binary_expr::PyBinaryExpr;
 use crate::expr::column::PyColumn;
 use crate::expr::literal::PyLiteral;
+use crate::sql::logical::PyLogicalPlan;
 use datafusion::scalar::ScalarValue;
 
 use self::alias::PyAlias;
@@ -274,11 +281,296 @@ impl PyExpr {
             Expr::ScalarSubquery(..) => RexType::ScalarSubquery,
         })
     }
+
+    /// Given the current `Expr` return the DataTypeMap which represents the
+    /// PythonType, Arrow DataType, and SqlType Enum which represents
+    pub fn types(&self) -> PyResult<DataTypeMap> {
+        Self::_types(&self.expr)
+    }
+
+    /// Extracts the Expr value into a PyObject that can be shared with Python
+    pub fn python_value(&self, py: Python) -> PyResult<PyObject> {
+        match &self.expr {
+            Expr::Literal(scalar_value) => Ok(match scalar_value {
+                ScalarValue::Null => todo!(),
+                ScalarValue::Boolean(v) => v.into_py(py),
+                ScalarValue::Float32(v) => v.into_py(py),
+                ScalarValue::Float64(v) => v.into_py(py),
+                ScalarValue::Decimal128(_, _, _) => todo!(),
+                ScalarValue::Int8(v) => v.into_py(py),
+                ScalarValue::Int16(v) => v.into_py(py),
+                ScalarValue::Int32(v) => v.into_py(py),
+                ScalarValue::Int64(v) => v.into_py(py),
+                ScalarValue::UInt8(v) => v.into_py(py),
+                ScalarValue::UInt16(v) => v.into_py(py),
+                ScalarValue::UInt32(v) => v.into_py(py),
+                ScalarValue::UInt64(v) => v.into_py(py),
+                ScalarValue::Utf8(v) => v.clone().into_py(py),
+                ScalarValue::LargeUtf8(v) => v.clone().into_py(py),
+                ScalarValue::Binary(v) => v.clone().into_py(py),
+                ScalarValue::FixedSizeBinary(_, _) => todo!(),
+                ScalarValue::LargeBinary(v) => v.clone().into_py(py),
+                ScalarValue::List(_, _) => todo!(),
+                ScalarValue::Date32(v) => v.into_py(py),
+                ScalarValue::Date64(v) => v.into_py(py),
+                ScalarValue::Time32Second(v) => v.into_py(py),
+                ScalarValue::Time32Millisecond(v) => v.into_py(py),
+                ScalarValue::Time64Microsecond(v) => v.into_py(py),
+                ScalarValue::Time64Nanosecond(v) => v.into_py(py),
+                ScalarValue::TimestampSecond(_, _) => todo!(),
+                ScalarValue::TimestampMillisecond(_, _) => todo!(),
+                ScalarValue::TimestampMicrosecond(_, _) => todo!(),
+                ScalarValue::TimestampNanosecond(_, _) => todo!(),
+                ScalarValue::IntervalYearMonth(v) => v.into_py(py),
+                ScalarValue::IntervalDayTime(v) => v.into_py(py),
+                ScalarValue::IntervalMonthDayNano(v) => v.into_py(py),
+                ScalarValue::Struct(_, _) => todo!(),
+                ScalarValue::Dictionary(_, _) => todo!(),
+            }),
+            _ => Err(py_type_err(format!(
+                "Non Expr::Literal encountered in types: {:?}",
+                &self.expr
+            ))),
+        }
+    }
+
+    /// Row expressions, Rex(s), operate on the concept of operands. Different variants of Expressions, Expr(s),
+    /// store those operands in different datastructures. This function examines the Expr variant and returns
+    /// the operands to the calling logic as a Vec of PyExpr instances.
+    pub fn rex_call_operands(&self) -> PyResult<Vec<PyExpr>> {
+        match &self.expr {
+            // Expr variants that are themselves the operand to return
+            Expr::Column(..) | Expr::ScalarVariable(..) | Expr::Literal(..) => {
+                Ok(vec![PyExpr::from(self.expr.clone())])
+            }
+
+            // Expr(s) that house the Expr instance to return in their bounded params
+            Expr::Alias(expr, ..)
+            | Expr::Not(expr)
+            | Expr::IsNull(expr)
+            | Expr::IsNotNull(expr)
+            | Expr::IsTrue(expr)
+            | Expr::IsFalse(expr)
+            | Expr::IsUnknown(expr)
+            | Expr::IsNotTrue(expr)
+            | Expr::IsNotFalse(expr)
+            | Expr::IsNotUnknown(expr)
+            | Expr::Negative(expr)
+            | Expr::GetIndexedField(GetIndexedField { expr, .. })
+            | Expr::Cast(Cast { expr, .. })
+            | Expr::TryCast(TryCast { expr, .. })
+            | Expr::Sort(Sort { expr, .. })
+            | Expr::InSubquery { expr, .. } => Ok(vec![PyExpr::from(*expr.clone())]),
+
+            // Expr variants containing a collection of Expr(s) for operands
+            Expr::AggregateFunction(AggregateFunction { args, .. })
+            | Expr::AggregateUDF { args, .. }
+            | Expr::ScalarFunction { args, .. }
+            | Expr::ScalarUDF { args, .. }
+            | Expr::WindowFunction(WindowFunction { args, .. }) => {
+                Ok(args.iter().map(|arg| PyExpr::from(arg.clone())).collect())
+            }
+
+            // Expr(s) that require more specific processing
+            Expr::Case(Case {
+                expr,
+                when_then_expr,
+                else_expr,
+            }) => {
+                let mut operands: Vec<PyExpr> = Vec::new();
+
+                if let Some(e) = expr {
+                    operands.push(PyExpr::from(*e.clone()));
+                };
+
+                for (when, then) in when_then_expr {
+                    operands.push(PyExpr::from(*when.clone()));
+                    operands.push(PyExpr::from(*then.clone()));
+                }
+
+                if let Some(e) = else_expr {
+                    operands.push(PyExpr::from(*e.clone()));
+                };
+
+                Ok(operands)
+            }
+            Expr::InList { expr, list, .. } => {
+                let mut operands: Vec<PyExpr> = vec![PyExpr::from(*expr.clone())];
+                for list_elem in list {
+                    operands.push(PyExpr::from(list_elem.clone()));
+                }
+
+                Ok(operands)
+            }
+            Expr::BinaryExpr(BinaryExpr { left, right, .. }) => Ok(vec![
+                PyExpr::from(*left.clone()),
+                PyExpr::from(*right.clone()),
+            ]),
+            Expr::Like(Like { expr, pattern, .. }) => Ok(vec![
+                PyExpr::from(*expr.clone()),
+                PyExpr::from(*pattern.clone()),
+            ]),
+            Expr::ILike(Like { expr, pattern, .. }) => Ok(vec![
+                PyExpr::from(*expr.clone()),
+                PyExpr::from(*pattern.clone()),
+            ]),
+            Expr::SimilarTo(Like { expr, pattern, .. }) => Ok(vec![
+                PyExpr::from(*expr.clone()),
+                PyExpr::from(*pattern.clone()),
+            ]),
+            Expr::Between(Between {
+                expr,
+                negated: _,
+                low,
+                high,
+            }) => Ok(vec![
+                PyExpr::from(*expr.clone()),
+                PyExpr::from(*low.clone()),
+                PyExpr::from(*high.clone()),
+            ]),
+
+            // Currently un-support/implemented Expr types for Rex Call operations
+            Expr::GroupingSet(..)
+            | Expr::OuterReferenceColumn(_, _)
+            | Expr::Wildcard
+            | Expr::QualifiedWildcard { .. }
+            | Expr::ScalarSubquery(..)
+            | Expr::Placeholder { .. }
+            | Expr::Exists { .. } => Err(py_runtime_err(format!(
+                "Unimplemented Expr type: {}",
+                self.expr
+            ))),
+        }
+    }
+
+    /// Extracts the operator associated with a RexType::Call
+    pub fn rex_call_operator(&self) -> PyResult<String> {
+        Ok(match &self.expr {
+            Expr::BinaryExpr(BinaryExpr {
+                left: _,
+                op,
+                right: _,
+            }) => format!("{op}"),
+            Expr::ScalarFunction { fun, args: _ } => format!("{fun}"),
+            Expr::ScalarUDF { fun, .. } => fun.name.clone(),
+            Expr::Cast { .. } => "cast".to_string(),
+            Expr::Between { .. } => "between".to_string(),
+            Expr::Case { .. } => "case".to_string(),
+            Expr::IsNull(..) => "is null".to_string(),
+            Expr::IsNotNull(..) => "is not null".to_string(),
+            Expr::IsTrue(_) => "is true".to_string(),
+            Expr::IsFalse(_) => "is false".to_string(),
+            Expr::IsUnknown(_) => "is unknown".to_string(),
+            Expr::IsNotTrue(_) => "is not true".to_string(),
+            Expr::IsNotFalse(_) => "is not false".to_string(),
+            Expr::IsNotUnknown(_) => "is not unknown".to_string(),
+            Expr::InList { .. } => "in list".to_string(),
+            Expr::Negative(..) => "negative".to_string(),
+            Expr::Not(..) => "not".to_string(),
+            Expr::Like(Like { negated, .. }) => {
+                if *negated {
+                    "not like".to_string()
+                } else {
+                    "like".to_string()
+                }
+            }
+            Expr::ILike(Like { negated, .. }) => {
+                if *negated {
+                    "not ilike".to_string()
+                } else {
+                    "ilike".to_string()
+                }
+            }
+            Expr::SimilarTo(Like { negated, .. }) => {
+                if *negated {
+                    "not similar to".to_string()
+                } else {
+                    "similar to".to_string()
+                }
+            }
+            _ => {
+                return Err(py_type_err(format!(
+                    "Catch all triggered in get_operator_name: {:?}",
+                    &self.expr
+                )))
+            }
+        })
+    }
+
+    pub fn column_name(&self, plan: PyLogicalPlan) -> PyResult<String> {
+        self._column_name(&plan.plan()).map_err(py_runtime_err)
+    }
+}
+
+impl PyExpr {
+    pub fn _column_name(&self, plan: &LogicalPlan) -> Result<String, DataFusionError> {
+        let field = Self::expr_to_field(&self.expr, plan)?;
+        Ok(field.qualified_column().flat_name())
+    }
+
+    /// Create a [DFField] representing an [Expr], given an input [LogicalPlan] to resolve against
+    pub fn expr_to_field(
+        expr: &Expr,
+        input_plan: &LogicalPlan,
+    ) -> Result<DFField, DataFusionError> {
+        match expr {
+            Expr::Sort(Sort { expr, .. }) => {
+                // DataFusion does not support create_name for sort expressions (since they never
+                // appear in projections) so we just delegate to the contained expression instead
+                Self::expr_to_field(expr, input_plan)
+            }
+            _ => {
+                let fields =
+                    exprlist_to_fields(&[expr.clone()], input_plan).map_err(PyErr::from)?;
+                Ok(fields[0].clone())
+            }
+        }
+    }
+
+    fn _types(expr: &Expr) -> PyResult<DataTypeMap> {
+        match expr {
+            Expr::BinaryExpr(BinaryExpr {
+                left: _,
+                op,
+                right: _,
+            }) => match op {
+                Operator::Eq
+                | Operator::NotEq
+                | Operator::Lt
+                | Operator::LtEq
+                | Operator::Gt
+                | Operator::GtEq
+                | Operator::And
+                | Operator::Or
+                | Operator::IsDistinctFrom
+                | Operator::IsNotDistinctFrom
+                | Operator::RegexMatch
+                | Operator::RegexIMatch
+                | Operator::RegexNotMatch
+                | Operator::RegexNotIMatch => DataTypeMap::map_from_arrow_type(&DataType::Boolean),
+                Operator::Plus | Operator::Minus | Operator::Multiply | Operator::Modulo => {
+                    DataTypeMap::map_from_arrow_type(&DataType::Int64)
+                }
+                Operator::Divide => DataTypeMap::map_from_arrow_type(&DataType::Float64),
+                Operator::StringConcat => DataTypeMap::map_from_arrow_type(&DataType::Utf8),
+                Operator::BitwiseShiftLeft
+                | Operator::BitwiseShiftRight
+                | Operator::BitwiseXor
+                | Operator::BitwiseAnd
+                | Operator::BitwiseOr => DataTypeMap::map_from_arrow_type(&DataType::Binary),
+            },
+            Expr::Cast(Cast { expr: _, data_type }) => DataTypeMap::map_from_arrow_type(data_type),
+            Expr::Literal(scalar_value) => DataTypeMap::map_from_scalar_value(scalar_value),
+            _ => Err(py_type_err(format!(
+                "Non Expr::Literal encountered in types: {:?}",
+                expr
+            ))),
+        }
+    }
 }
 
 /// Initializes the `expr` module to match the pattern of `datafusion-expr` https://docs.rs/datafusion-expr/latest/datafusion_expr/
 pub(crate) fn init_module(m: &PyModule) -> PyResult<()> {
-    // expressions
     m.add_class::<PyExpr>()?;
     m.add_class::<PyColumn>()?;
     m.add_class::<PyLiteral>()?;
diff --git a/src/expr/projection.rs b/src/expr/projection.rs
index f5ba12d..b329661 100644
--- a/src/expr/projection.rs
+++ b/src/expr/projection.rs
@@ -16,6 +16,7 @@
 // under the License.
 
 use datafusion_expr::logical_plan::Projection;
+use datafusion_expr::Expr;
 use pyo3::prelude::*;
 use std::fmt::{self, Display, Formatter};
 
@@ -27,7 +28,7 @@ use crate::sql::logical::PyLogicalPlan;
 #[pyclass(name = "Projection", module = "datafusion.expr", subclass)]
 #[derive(Clone)]
 pub struct PyProjection {
-    projection: Projection,
+    pub projection: Projection,
 }
 
 impl PyProjection {
@@ -92,6 +93,21 @@ impl PyProjection {
     }
 }
 
+impl PyProjection {
+    /// Projection: Gets the names of the fields that should be projected
+    pub fn projected_expressions(local_expr: &PyExpr) -> Vec<PyExpr> {
+        let mut projs: Vec<PyExpr> = Vec::new();
+        match &local_expr.expr {
+            Expr::Alias(expr, _name) => {
+                let py_expr: PyExpr = PyExpr::from(*expr.clone());
+                projs.extend_from_slice(Self::projected_expressions(&py_expr).as_slice());
+            }
+            _ => projs.push(local_expr.clone()),
+        }
+        projs
+    }
+}
+
 impl LogicalNode for PyProjection {
     fn inputs(&self) -> Vec<PyLogicalPlan> {
         vec![PyLogicalPlan::from((*self.projection.input).clone())]
diff --git a/src/sql/logical.rs b/src/sql/logical.rs
index a75315d..07a3f65 100644
--- a/src/sql/logical.rs
+++ b/src/sql/logical.rs
@@ -20,6 +20,7 @@ use std::sync::Arc;
 use crate::errors::py_unsupported_variant_err;
 use crate::expr::aggregate::PyAggregate;
 use crate::expr::analyze::PyAnalyze;
+use crate::expr::distinct::PyDistinct;
 use crate::expr::empty_relation::PyEmptyRelation;
 use crate::expr::explain::PyExplain;
 use crate::expr::extension::PyExtension;
@@ -62,6 +63,7 @@ impl PyLogicalPlan {
             LogicalPlan::EmptyRelation(plan) => PyEmptyRelation::from(plan.clone()).to_variant(py),
             LogicalPlan::Explain(plan) => PyExplain::from(plan.clone()).to_variant(py),
             LogicalPlan::Extension(plan) => PyExtension::from(plan.clone()).to_variant(py),
+            LogicalPlan::Distinct(plan) => PyDistinct::from(plan.clone()).to_variant(py),
             LogicalPlan::Filter(plan) => PyFilter::from(plan.clone()).to_variant(py),
             LogicalPlan::Limit(plan) => PyLimit::from(plan.clone()).to_variant(py),
             LogicalPlan::Projection(plan) => PyProjection::from(plan.clone()).to_variant(py),