You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/04/13 12:35:11 UTC

[arrow-datafusion] branch master updated: `case when` supports `NULL` constant (#2197)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new d81657de0 `case when` supports `NULL`  constant (#2197)
d81657de0 is described below

commit d81657de04b3fe511cba35e8cabee17de1578117
Author: DuRipeng <45...@qq.com>
AuthorDate: Wed Apr 13 20:35:04 2022 +0800

    `case when` supports `NULL`  constant (#2197)
    
    * case when support  literal
    
    * fmt fix
    
    * code clean
    
    Co-authored-by: duripeng <du...@baidu.com>
---
 datafusion/core/tests/sql/expr.rs                | 34 ++++++++++++++++++++++++
 datafusion/physical-expr/src/expressions/case.rs |  7 +++++
 2 files changed, 41 insertions(+)

diff --git a/datafusion/core/tests/sql/expr.rs b/datafusion/core/tests/sql/expr.rs
index de71764ba..4bb2cafad 100644
--- a/datafusion/core/tests/sql/expr.rs
+++ b/datafusion/core/tests/sql/expr.rs
@@ -109,6 +109,40 @@ async fn case_when_else_with_base_expr() -> Result<()> {
     Ok(())
 }
 
+#[tokio::test]
+async fn case_when_else_with_null_contant() -> Result<()> {
+    let ctx = create_case_context()?;
+    let sql = "SELECT \
+        CASE WHEN c1 = 'a' THEN 1 \
+             WHEN NULL THEN 2 \
+             ELSE 999 END \
+        FROM t1";
+    let actual = execute_to_batches(&ctx, sql).await;
+    let expected = vec![
+        "+----------------------------------------------------------------------------------------------+",
+        "| CASE WHEN #t1.c1 = Utf8(\"a\") THEN Int64(1) WHEN Utf8(NULL) THEN Int64(2) ELSE Int64(999) END |",
+        "+----------------------------------------------------------------------------------------------+",
+        "| 1                                                                                            |",
+        "| 999                                                                                          |",
+        "| 999                                                                                          |",
+        "| 999                                                                                          |",
+        "+----------------------------------------------------------------------------------------------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+
+    let sql = "SELECT CASE WHEN NULL THEN 'foo' ELSE 'bar' END";
+    let actual = execute_to_batches(&ctx, sql).await;
+    let expected = vec![
+        "+------------------------------------------------------------+",
+        "| CASE WHEN Utf8(NULL) THEN Utf8(\"foo\") ELSE Utf8(\"bar\") END |",
+        "+------------------------------------------------------------+",
+        "| bar                                                        |",
+        "+------------------------------------------------------------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+    Ok(())
+}
+
 #[tokio::test]
 async fn query_not() -> Result<()> {
     let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Boolean, true)]));
diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs
index df763ec9a..e7db10d17 100644
--- a/datafusion/physical-expr/src/expressions/case.rs
+++ b/datafusion/physical-expr/src/expressions/case.rs
@@ -178,6 +178,13 @@ impl CaseExpr {
             let when_value = self.when_then_expr[i]
                 .0
                 .evaluate_selection(batch, &remainder)?;
+            // Treat 'NULL' as false value
+            let when_value = match when_value {
+                ColumnarValue::Scalar(value) if value.is_null() => {
+                    continue;
+                }
+                _ => when_value,
+            };
             let when_value = when_value.into_array(batch.num_rows());
             let when_value = when_value
                 .as_ref()