You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ag...@apache.org on 2020/08/19 17:14:17 UTC

[arrow] branch master updated: ARROW-9778: [Rust] [DataFusion] Implement Expr.nullable() and make consistent between logical and physical plans

This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 01f06cf  ARROW-9778: [Rust] [DataFusion] Implement Expr.nullable() and make consistent between logical and physical plans
01f06cf is described below

commit 01f06cf27ef973f7e19127e8102178635a75a525
Author: Andy Grove <an...@gmail.com>
AuthorDate: Wed Aug 19 11:13:12 2020 -0600

    ARROW-9778: [Rust] [DataFusion] Implement Expr.nullable() and make consistent between logical and physical plans
    
    - Implements `Expr.nullable()`
    - Updates some PhysicalExpr nullable methods to be consistent
    - Updates `sql.rs` integration test to assert that schema is same between logical and physical plans
    
    Closes #8005 from andygrove/nullability
    
    Lead-authored-by: Andy Grove <an...@gmail.com>
    Co-authored-by: Andy Grove <an...@users.noreply.github.com>
    Signed-off-by: Andy Grove <an...@gmail.com>
---
 .../src/execution/physical_plan/expressions.rs     | 30 ++++++++++++--
 .../src/execution/physical_plan/hash_aggregate.rs  | 12 +++++-
 rust/datafusion/src/execution/physical_plan/mod.rs |  4 +-
 rust/datafusion/src/logicalplan.rs                 | 48 +++++++++++++++++-----
 rust/datafusion/tests/sql.rs                       |  7 ++++
 5 files changed, 85 insertions(+), 16 deletions(-)

diff --git a/rust/datafusion/src/execution/physical_plan/expressions.rs b/rust/datafusion/src/execution/physical_plan/expressions.rs
index 2891552..faeaade 100644
--- a/rust/datafusion/src/execution/physical_plan/expressions.rs
+++ b/rust/datafusion/src/execution/physical_plan/expressions.rs
@@ -125,6 +125,11 @@ impl AggregateExpr for Sum {
         }
     }
 
+    fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
+        // null should be returned if no rows are aggregated
+        Ok(true)
+    }
+
     fn evaluate_input(&self, batch: &RecordBatch) -> Result<ArrayRef> {
         self.expr.evaluate(batch)
     }
@@ -321,6 +326,11 @@ impl AggregateExpr for Avg {
         }
     }
 
+    fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
+        // null should be returned if no rows are aggregated
+        Ok(true)
+    }
+
     fn evaluate_input(&self, batch: &RecordBatch) -> Result<ArrayRef> {
         self.expr.evaluate(batch)
     }
@@ -422,6 +432,11 @@ impl AggregateExpr for Max {
         self.expr.data_type(input_schema)
     }
 
+    fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
+        // null should be returned if no rows are aggregated
+        Ok(true)
+    }
+
     fn evaluate_input(&self, batch: &RecordBatch) -> Result<ArrayRef> {
         self.expr.evaluate(batch)
     }
@@ -606,6 +621,11 @@ impl AggregateExpr for Min {
         self.expr.data_type(input_schema)
     }
 
+    fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
+        // null should be returned if no rows are aggregated
+        Ok(true)
+    }
+
     fn evaluate_input(&self, batch: &RecordBatch) -> Result<ArrayRef> {
         self.expr.evaluate(batch)
     }
@@ -791,6 +811,11 @@ impl AggregateExpr for Count {
         Ok(DataType::UInt64)
     }
 
+    fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
+        // null should be returned if no rows are aggregated
+        Ok(true)
+    }
+
     fn evaluate_input(&self, batch: &RecordBatch) -> Result<ArrayRef> {
         self.expr.evaluate(batch)
     }
@@ -990,9 +1015,8 @@ impl PhysicalExpr for BinaryExpr {
         })
     }
 
-    fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
-        // this is not correct
-        Ok(false)
+    fn nullable(&self, input_schema: &Schema) -> Result<bool> {
+        Ok(self.left.nullable(input_schema)? || self.right.nullable(input_schema)?)
     }
 
     fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef> {
diff --git a/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs b/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs
index d736639..19038d6 100644
--- a/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs
+++ b/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs
@@ -64,10 +64,18 @@ impl HashAggregateExec {
 
         let mut fields = Vec::with_capacity(group_expr.len() + aggr_expr.len());
         for (expr, name) in &group_expr {
-            fields.push(Field::new(name, expr.data_type(&input_schema)?, true))
+            fields.push(Field::new(
+                name,
+                expr.data_type(&input_schema)?,
+                expr.nullable(&input_schema)?,
+            ))
         }
         for (expr, name) in &aggr_expr {
-            fields.push(Field::new(&name, expr.data_type(&input_schema)?, true))
+            fields.push(Field::new(
+                &name,
+                expr.data_type(&input_schema)?,
+                expr.nullable(&input_schema)?,
+            ))
         }
         let schema = Arc::new(Schema::new(fields));
 
diff --git a/rust/datafusion/src/execution/physical_plan/mod.rs b/rust/datafusion/src/execution/physical_plan/mod.rs
index 98d2a91..191d256 100644
--- a/rust/datafusion/src/execution/physical_plan/mod.rs
+++ b/rust/datafusion/src/execution/physical_plan/mod.rs
@@ -63,7 +63,7 @@ pub trait Partition: Send + Sync + Debug {
 pub trait PhysicalExpr: Send + Sync + Display + Debug {
     /// Get the data type of this expression, given the schema of the input
     fn data_type(&self, input_schema: &Schema) -> Result<DataType>;
-    /// Decide whehter this expression is nullable, given the schema of the input
+    /// Determine whether this expression is nullable, given the schema of the input
     fn nullable(&self, input_schema: &Schema) -> Result<bool>;
     /// Evaluate an expression against a RecordBatch
     fn evaluate(&self, batch: &RecordBatch) -> Result<ArrayRef>;
@@ -73,6 +73,8 @@ pub trait PhysicalExpr: Send + Sync + Display + Debug {
 pub trait AggregateExpr: Send + Sync + Debug {
     /// Get the data type of this expression, given the schema of the input
     fn data_type(&self, input_schema: &Schema) -> Result<DataType>;
+    /// Determine whether this expression is nullable, given the schema of the input
+    fn nullable(&self, input_schema: &Schema) -> Result<bool>;
     /// Evaluate the expression being aggregated
     fn evaluate_input(&self, batch: &RecordBatch) -> Result<ArrayRef>;
     /// Create an accumulator for this aggregate expression
diff --git a/rust/datafusion/src/logicalplan.rs b/rust/datafusion/src/logicalplan.rs
index a2e1979..314283f 100644
--- a/rust/datafusion/src/logicalplan.rs
+++ b/rust/datafusion/src/logicalplan.rs
@@ -264,18 +264,9 @@ fn create_name(e: &Expr, input_schema: &Schema) -> Result<String> {
     }
 }
 
-/// Returns the datatype of the expression given the input schema
-// note: the physical plan derived from an expression must match the datatype on this function.
-pub fn expr_to_field(e: &Expr, input_schema: &Schema) -> Result<Field> {
-    let data_type = e.get_type(input_schema)?;
-    Ok(Field::new(&e.name(input_schema)?, data_type, true))
-}
-
 /// Create field meta-data from an expression, for use in a result set schema
 pub fn exprlist_to_fields(expr: &[Expr], input_schema: &Schema) -> Result<Vec<Field>> {
-    expr.iter()
-        .map(|e| expr_to_field(e, input_schema))
-        .collect()
+    expr.iter().map(|e| e.to_field(input_schema)).collect()
 }
 
 /// Relation expression
@@ -419,6 +410,34 @@ impl Expr {
         }
     }
 
+    /// return true if this expression might produce null values
+    pub fn nullable(&self, input_schema: &Schema) -> Result<bool> {
+        match self {
+            Expr::Alias(expr, _) => expr.nullable(input_schema),
+            Expr::Column(name) => Ok(input_schema.field_with_name(name)?.is_nullable()),
+            Expr::Literal(value) => match value {
+                ScalarValue::Null => Ok(true),
+                _ => Ok(false),
+            },
+            Expr::Cast { expr, .. } => expr.nullable(input_schema),
+            Expr::ScalarFunction { .. } => Ok(true),
+            Expr::AggregateFunction { .. } => Ok(true),
+            Expr::Not(expr) => expr.nullable(input_schema),
+            Expr::IsNull(_) => Ok(false),
+            Expr::IsNotNull(_) => Ok(false),
+            Expr::BinaryExpr {
+                ref left,
+                ref right,
+                ..
+            } => Ok(left.nullable(input_schema)? || right.nullable(input_schema)?),
+            Expr::Sort { ref expr, .. } => expr.nullable(input_schema),
+            Expr::Nested(e) => e.nullable(input_schema),
+            Expr::Wildcard => Err(ExecutionError::General(
+                "Wildcard expressions are not valid in a logical query plan".to_owned(),
+            )),
+        }
+    }
+
     /// Return the name of this expression
     ///
     /// This represents how a column with this expression is named when no alias is chosen
@@ -426,6 +445,15 @@ impl Expr {
         create_name(self, input_schema)
     }
 
+    /// Create a Field representing this expression
+    pub fn to_field(&self, input_schema: &Schema) -> Result<Field> {
+        Ok(Field::new(
+            &self.name(input_schema)?,
+            self.get_type(input_schema)?,
+            self.nullable(input_schema)?,
+        ))
+    }
+
     /// Perform a type cast on the expression value.
     ///
     /// Will `Err` if the type cast cannot be performed.
diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs
index 2f5bc9d..4ae90cd 100644
--- a/rust/datafusion/tests/sql.rs
+++ b/rust/datafusion/tests/sql.rs
@@ -536,9 +536,16 @@ fn register_alltypes_parquet(ctx: &mut ExecutionContext) {
 /// Execute query and return result set as tab delimited string
 fn execute(ctx: &mut ExecutionContext, sql: &str) -> Vec<String> {
     let plan = ctx.create_logical_plan(&sql).unwrap();
+    let logical_schema = plan.schema().clone();
     let plan = ctx.optimize(&plan).unwrap();
+    let optimized_logical_schema = plan.schema().clone();
     let plan = ctx.create_physical_plan(&plan).unwrap();
+    let physical_schema = plan.schema().clone();
     let results = ctx.collect(plan.as_ref()).unwrap();
+
+    assert_eq!(logical_schema.as_ref(), optimized_logical_schema.as_ref());
+    assert_eq!(logical_schema.as_ref(), physical_schema.as_ref());
+
     result_str(&results)
 }