You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by jo...@apache.org on 2020/11/25 16:18:25 UTC

[arrow] branch master updated: ARROW-10689: [Rust] [DataFusion] Add SQL support for CASE WHEN

This is an automated email from the ASF dual-hosted git repository.

jorgecarleitao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 47b2dd5  ARROW-10689: [Rust] [DataFusion] Add SQL support for CASE WHEN
47b2dd5 is described below

commit 47b2dd57cdb7098b3c533fb25055df66e1b9c3d0
Author: Andy Grove <an...@gmail.com>
AuthorDate: Wed Nov 25 17:17:04 2020 +0100

    ARROW-10689: [Rust] [DataFusion] Add SQL support for CASE WHEN
    
    This follows on from https://github.com/apache/arrow/pull/8746 and adds SQL support for CASE WHEN.
    
    Closes #8749 from andygrove/case-when-sql
    
    Authored-by: Andy Grove <an...@gmail.com>
    Signed-off-by: Jorge C. Leitao <jo...@gmail.com>
---
 rust/datafusion/src/logical_plan/expr.rs         | 18 ++++++
 rust/datafusion/src/physical_plan/expressions.rs | 12 +++-
 rust/datafusion/src/sql/planner.rs               | 36 ++++++++++++
 rust/datafusion/tests/sql.rs                     | 73 ++++++++++++++++++++++++
 4 files changed, 136 insertions(+), 3 deletions(-)

diff --git a/rust/datafusion/src/logical_plan/expr.rs b/rust/datafusion/src/logical_plan/expr.rs
index 07e5f34..509285b 100644
--- a/rust/datafusion/src/logical_plan/expr.rs
+++ b/rust/datafusion/src/logical_plan/expr.rs
@@ -800,6 +800,24 @@ fn create_name(e: &Expr, input_schema: &Schema) -> Result<String> {
             let right = create_name(right, input_schema)?;
             Ok(format!("{} {:?} {}", left, op, right))
         }
+        Expr::Case {
+            expr,
+            when_then_expr,
+            else_expr,
+        } => {
+            let mut name = "CASE ".to_string();
+            if let Some(e) = expr {
+                name += &format!("{:?} ", e).to_string();
+            }
+            for (w, t) in when_then_expr {
+                name += &format!("WHEN {:?} THEN {:?} ", w, t).to_string();
+            }
+            if let Some(e) = else_expr {
+                name += &format!("ELSE {:?} ", e).to_string();
+            }
+            name += "END";
+            Ok(name)
+        }
         Expr::Cast { expr, data_type } => {
             let expr = create_name(expr, input_schema)?;
             Ok(format!("CAST({} AS {:?})", expr, data_type))
diff --git a/rust/datafusion/src/physical_plan/expressions.rs b/rust/datafusion/src/physical_plan/expressions.rs
index f642da7..8a1e3f0 100644
--- a/rust/datafusion/src/physical_plan/expressions.rs
+++ b/rust/datafusion/src/physical_plan/expressions.rs
@@ -1790,7 +1790,11 @@ macro_rules! if_then_else {
         let mut builder = <$BUILDER_TYPE>::new($BOOLS.len());
         for i in 0..$BOOLS.len() {
             if $BOOLS.is_null(i) {
-                builder.append_null()?;
+                if false_values.is_null(i) {
+                    builder.append_null()?;
+                } else {
+                    builder.append_value(false_values.value(i))?;
+                }
             } else if $BOOLS.value(i) {
                 if true_values.is_null(i) {
                     builder.append_null()?;
@@ -3403,7 +3407,8 @@ mod tests {
             .downcast_ref::<Int32Array>()
             .expect("failed to downcast to Int32Array");
 
-        let expected = &Int32Array::from(vec![Some(123), Some(999), None, Some(456)]);
+        let expected =
+            &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(456)]);
 
         assert_eq!(expected, result);
 
@@ -3472,7 +3477,8 @@ mod tests {
             .downcast_ref::<Int32Array>()
             .expect("failed to downcast to Int32Array");
 
-        let expected = &Int32Array::from(vec![Some(123), Some(999), None, Some(456)]);
+        let expected =
+            &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(456)]);
 
         assert_eq!(expected, result);
 
diff --git a/rust/datafusion/src/sql/planner.rs b/rust/datafusion/src/sql/planner.rs
index 002b53c..f5d2f88 100644
--- a/rust/datafusion/src/sql/planner.rs
+++ b/rust/datafusion/src/sql/planner.rs
@@ -447,6 +447,42 @@ impl<'a, S: SchemaProvider> SqlToRel<'a, S> {
 
             SQLExpr::Wildcard => Ok(Expr::Wildcard),
 
+            SQLExpr::Case {
+                operand,
+                conditions,
+                results,
+                else_result,
+            } => {
+                let expr = if let Some(e) = operand {
+                    Some(Box::new(self.sql_to_rex(e, schema)?))
+                } else {
+                    None
+                };
+                let when_expr = conditions
+                    .iter()
+                    .map(|e| self.sql_to_rex(e, schema))
+                    .collect::<Result<Vec<_>>>()?;
+                let then_expr = results
+                    .iter()
+                    .map(|e| self.sql_to_rex(e, schema))
+                    .collect::<Result<Vec<_>>>()?;
+                let else_expr = if let Some(e) = else_result {
+                    Some(Box::new(self.sql_to_rex(e, schema)?))
+                } else {
+                    None
+                };
+
+                Ok(Expr::Case {
+                    expr,
+                    when_then_expr: when_expr
+                        .iter()
+                        .zip(then_expr.iter())
+                        .map(|(w, t)| (Box::new(w.to_owned()), Box::new(t.to_owned())))
+                        .collect(),
+                    else_expr,
+                })
+            }
+
             SQLExpr::Cast {
                 ref expr,
                 ref data_type,
diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs
index bbf8934..13459f9 100644
--- a/rust/datafusion/tests/sql.rs
+++ b/rust/datafusion/tests/sql.rs
@@ -939,6 +939,79 @@ async fn csv_query_count_one() {
 }
 
 #[tokio::test]
+async fn case_when() -> Result<()> {
+    let mut ctx = create_case_context()?;
+    let sql = "SELECT \
+        CASE WHEN c1 = 'a' THEN 1 \
+             WHEN c1 = 'b' THEN 2 \
+             END \
+        FROM t1";
+    let actual = execute(&mut ctx, sql).await;
+    let expected = vec![vec!["1"], vec!["2"], vec!["NULL"], vec!["NULL"]];
+    assert_eq!(expected, actual);
+    Ok(())
+}
+
+#[tokio::test]
+async fn case_when_else() -> Result<()> {
+    let mut ctx = create_case_context()?;
+    let sql = "SELECT \
+        CASE WHEN c1 = 'a' THEN 1 \
+             WHEN c1 = 'b' THEN 2 \
+             ELSE 999 END \
+        FROM t1";
+    let actual = execute(&mut ctx, sql).await;
+    let expected = vec![vec!["1"], vec!["2"], vec!["999"], vec!["999"]];
+    assert_eq!(expected, actual);
+    Ok(())
+}
+
+#[tokio::test]
+async fn case_when_with_base_expr() -> Result<()> {
+    let mut ctx = create_case_context()?;
+    let sql = "SELECT \
+        CASE c1 WHEN 'a' THEN 1 \
+             WHEN 'b' THEN 2 \
+             END \
+        FROM t1";
+    let actual = execute(&mut ctx, sql).await;
+    let expected = vec![vec!["1"], vec!["2"], vec!["NULL"], vec!["NULL"]];
+    assert_eq!(expected, actual);
+    Ok(())
+}
+
+#[tokio::test]
+async fn case_when_else_with_base_expr() -> Result<()> {
+    let mut ctx = create_case_context()?;
+    let sql = "SELECT \
+        CASE c1 WHEN 'a' THEN 1 \
+             WHEN 'b' THEN 2 \
+             ELSE 999 END \
+        FROM t1";
+    let actual = execute(&mut ctx, sql).await;
+    let expected = vec![vec!["1"], vec!["2"], vec!["999"], vec!["999"]];
+    assert_eq!(expected, actual);
+    Ok(())
+}
+
+fn create_case_context() -> Result<ExecutionContext> {
+    let mut ctx = ExecutionContext::new();
+    let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Utf8, true)]));
+    let data = RecordBatch::try_new(
+        schema.clone(),
+        vec![Arc::new(StringArray::from(vec![
+            Some("a"),
+            Some("b"),
+            Some("c"),
+            None,
+        ]))],
+    )?;
+    let table = MemTable::new(schema, vec![vec![data]])?;
+    ctx.register_table("t1", Box::new(table));
+    Ok(ctx)
+}
+
+#[tokio::test]
 async fn csv_explain() {
     let mut ctx = ExecutionContext::new();
     register_aggregate_csv_by_sql(&mut ctx).await;