You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/02/03 17:05:41 UTC

[arrow-datafusion] branch master updated: API to get Expr's type and nullability without a `DFSchema` (#1726)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new b2eaee3  API to get Expr's type and nullability without a `DFSchema` (#1726)
b2eaee3 is described below

commit b2eaee385035d78ec89ea3ef8e5772eb9f9018fe
Author: Andrew Lamb <an...@nerdnetworks.org>
AuthorDate: Thu Feb 3 12:04:40 2022 -0500

    API to get Expr's type and nullability without a `DFSchema` (#1726)
    
    * API to get Expr type and nullability without a `DFSchema`
    
    * Add test
    
    * publically export
    
    * Improve docs
---
 datafusion/src/logical_plan/expr.rs | 123 ++++++++++++++++++++++++++++++++----
 datafusion/src/logical_plan/mod.rs  |   4 +-
 2 files changed, 113 insertions(+), 14 deletions(-)

diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs
index 63bf72c..300c751 100644
--- a/datafusion/src/logical_plan/expr.rs
+++ b/datafusion/src/logical_plan/expr.rs
@@ -392,20 +392,59 @@ impl PartialOrd for Expr {
     }
 }
 
+/// Provides schema information needed by [Expr] methods such as
+/// [Expr::nullable] and [Expr::data_type].
+///
+/// Note that this trait is implemented for &[DFSchema] which is
+/// widely used in the DataFusion codebase.
+pub trait ExprSchema {
+    /// Is this column reference nullable?
+    fn nullable(&self, col: &Column) -> Result<bool>;
+
+    /// What is the datatype of this column?
+    fn data_type(&self, col: &Column) -> Result<&DataType>;
+}
+
+// Implement `ExprSchema` for `Arc<DFSchema>`
+impl<P: AsRef<DFSchema>> ExprSchema for P {
+    fn nullable(&self, col: &Column) -> Result<bool> {
+        self.as_ref().nullable(col)
+    }
+
+    fn data_type(&self, col: &Column) -> Result<&DataType> {
+        self.as_ref().data_type(col)
+    }
+}
+
+impl ExprSchema for DFSchema {
+    fn nullable(&self, col: &Column) -> Result<bool> {
+        Ok(self.field_from_column(col)?.is_nullable())
+    }
+
+    fn data_type(&self, col: &Column) -> Result<&DataType> {
+        Ok(self.field_from_column(col)?.data_type())
+    }
+}
+
 impl Expr {
-    /// Returns the [arrow::datatypes::DataType] of the expression based on [arrow::datatypes::Schema].
+    /// Returns the [arrow::datatypes::DataType] of the expression
+    /// based on [ExprSchema]
+    ///
+    /// Note: [DFSchema] implements [ExprSchema].
     ///
     /// # Errors
     ///
-    /// This function errors when it is not possible to compute its [arrow::datatypes::DataType].
-    /// This happens when e.g. the expression refers to a column that does not exist in the schema, or when
-    /// the expression is incorrectly typed (e.g. `[utf8] + [bool]`).
-    pub fn get_type(&self, schema: &DFSchema) -> Result<DataType> {
+    /// This function errors when it is not possible to compute its
+    /// [arrow::datatypes::DataType].  This happens when e.g. the
+    /// expression refers to a column that does not exist in the
+    /// schema, or when the expression is incorrectly typed
+    /// (e.g. `[utf8] + [bool]`).
+    pub fn get_type<S: ExprSchema>(&self, schema: &S) -> Result<DataType> {
         match self {
             Expr::Alias(expr, _) | Expr::Sort { expr, .. } | Expr::Negative(expr) => {
                 expr.get_type(schema)
             }
-            Expr::Column(c) => Ok(schema.field_from_column(c)?.data_type().clone()),
+            Expr::Column(c) => Ok(schema.data_type(c)?.clone()),
             Expr::ScalarVariable(_) => Ok(DataType::Utf8),
             Expr::Literal(l) => Ok(l.get_datatype()),
             Expr::Case { when_then_expr, .. } => when_then_expr[0].1.get_type(schema),
@@ -472,13 +511,16 @@ impl Expr {
         }
     }
 
-    /// Returns the nullability of the expression based on [arrow::datatypes::Schema].
+    /// Returns the nullability of the expression based on [ExprSchema].
+    ///
+    /// Note: [DFSchema] implements [ExprSchema].
     ///
     /// # Errors
     ///
-    /// This function errors when it is not possible to compute its nullability.
-    /// This happens when the expression refers to a column that does not exist in the schema.
-    pub fn nullable(&self, input_schema: &DFSchema) -> Result<bool> {
+    /// This function errors when it is not possible to compute its
+    /// nullability.  This happens when the expression refers to a
+    /// column that does not exist in the schema.
+    pub fn nullable<S: ExprSchema>(&self, input_schema: &S) -> Result<bool> {
         match self {
             Expr::Alias(expr, _)
             | Expr::Not(expr)
@@ -486,7 +528,7 @@ impl Expr {
             | Expr::Sort { expr, .. }
             | Expr::Between { expr, .. }
             | Expr::InList { expr, .. } => expr.nullable(input_schema),
-            Expr::Column(c) => Ok(input_schema.field_from_column(c)?.is_nullable()),
+            Expr::Column(c) => input_schema.nullable(c),
             Expr::Literal(value) => Ok(value.is_null()),
             Expr::Case {
                 when_then_expr,
@@ -561,7 +603,11 @@ impl Expr {
     ///
     /// This function errors when it is impossible to cast the
     /// expression to the target [arrow::datatypes::DataType].
-    pub fn cast_to(self, cast_to_type: &DataType, schema: &DFSchema) -> Result<Expr> {
+    pub fn cast_to<S: ExprSchema>(
+        self,
+        cast_to_type: &DataType,
+        schema: &S,
+    ) -> Result<Expr> {
         // TODO(kszucs): most of the operations do not validate the type correctness
         // like all of the binary expressions below. Perhaps Expr should track the
         // type of the expression?
@@ -2557,4 +2603,57 @@ mod tests {
             combine_filters(&[filter1.clone(), filter2.clone(), filter3.clone()]);
         assert_eq!(result, Some(and(and(filter1, filter2), filter3)));
     }
+
+    #[test]
+    fn expr_schema_nullability() {
+        let expr = col("foo").eq(lit(1));
+        assert!(!expr.nullable(&MockExprSchema::new()).unwrap());
+        assert!(expr
+            .nullable(&MockExprSchema::new().with_nullable(true))
+            .unwrap());
+    }
+
+    #[test]
+    fn expr_schema_data_type() {
+        let expr = col("foo");
+        assert_eq!(
+            DataType::Utf8,
+            expr.get_type(&MockExprSchema::new().with_data_type(DataType::Utf8))
+                .unwrap()
+        );
+    }
+
+    struct MockExprSchema {
+        nullable: bool,
+        data_type: DataType,
+    }
+
+    impl MockExprSchema {
+        fn new() -> Self {
+            Self {
+                nullable: false,
+                data_type: DataType::Null,
+            }
+        }
+
+        fn with_nullable(mut self, nullable: bool) -> Self {
+            self.nullable = nullable;
+            self
+        }
+
+        fn with_data_type(mut self, data_type: DataType) -> Self {
+            self.data_type = data_type;
+            self
+        }
+    }
+
+    impl ExprSchema for MockExprSchema {
+        fn nullable(&self, _col: &Column) -> Result<bool> {
+            Ok(self.nullable)
+        }
+
+        fn data_type(&self, _col: &Column) -> Result<&DataType> {
+            Ok(&self.data_type)
+        }
+    }
 }
diff --git a/datafusion/src/logical_plan/mod.rs b/datafusion/src/logical_plan/mod.rs
index 2571451..22521a1 100644
--- a/datafusion/src/logical_plan/mod.rs
+++ b/datafusion/src/logical_plan/mod.rs
@@ -46,8 +46,8 @@ pub use expr::{
     rewrite_sort_cols_by_aggs, right, round, rpad, rtrim, sha224, sha256, sha384, sha512,
     signum, sin, split_part, sqrt, starts_with, strpos, substr, sum, tan, to_hex,
     translate, trim, trunc, unalias, unnormalize_col, unnormalize_cols, upper, when,
-    Column, Expr, ExprRewriter, ExpressionVisitor, Literal, Recursion, RewriteRecursion,
-    SimplifyInfo,
+    Column, Expr, ExprRewriter, ExprSchema, ExpressionVisitor, Literal, Recursion,
+    RewriteRecursion, SimplifyInfo,
 };
 pub use extension::UserDefinedLogicalNode;
 pub use operators::Operator;