You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by jo...@apache.org on 2020/11/25 16:18:25 UTC
[arrow] branch master updated: ARROW-10689: [Rust] [DataFusion] Add
SQL support for CASE WHEN
This is an automated email from the ASF dual-hosted git repository.
jorgecarleitao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 47b2dd5 ARROW-10689: [Rust] [DataFusion] Add SQL support for CASE WHEN
47b2dd5 is described below
commit 47b2dd57cdb7098b3c533fb25055df66e1b9c3d0
Author: Andy Grove <an...@gmail.com>
AuthorDate: Wed Nov 25 17:17:04 2020 +0100
ARROW-10689: [Rust] [DataFusion] Add SQL support for CASE WHEN
This follows on from https://github.com/apache/arrow/pull/8746 and adds SQL support for CASE WHEN.
Closes #8749 from andygrove/case-when-sql
Authored-by: Andy Grove <an...@gmail.com>
Signed-off-by: Jorge C. Leitao <jo...@gmail.com>
---
rust/datafusion/src/logical_plan/expr.rs | 18 ++++++
rust/datafusion/src/physical_plan/expressions.rs | 12 +++-
rust/datafusion/src/sql/planner.rs | 36 ++++++++++++
rust/datafusion/tests/sql.rs | 73 ++++++++++++++++++++++++
4 files changed, 136 insertions(+), 3 deletions(-)
diff --git a/rust/datafusion/src/logical_plan/expr.rs b/rust/datafusion/src/logical_plan/expr.rs
index 07e5f34..509285b 100644
--- a/rust/datafusion/src/logical_plan/expr.rs
+++ b/rust/datafusion/src/logical_plan/expr.rs
@@ -800,6 +800,24 @@ fn create_name(e: &Expr, input_schema: &Schema) -> Result<String> {
let right = create_name(right, input_schema)?;
Ok(format!("{} {:?} {}", left, op, right))
}
+ Expr::Case {
+ expr,
+ when_then_expr,
+ else_expr,
+ } => {
+ let mut name = "CASE ".to_string();
+ if let Some(e) = expr {
+ name += &format!("{:?} ", e).to_string();
+ }
+ for (w, t) in when_then_expr {
+ name += &format!("WHEN {:?} THEN {:?} ", w, t).to_string();
+ }
+ if let Some(e) = else_expr {
+ name += &format!("ELSE {:?} ", e).to_string();
+ }
+ name += "END";
+ Ok(name)
+ }
Expr::Cast { expr, data_type } => {
let expr = create_name(expr, input_schema)?;
Ok(format!("CAST({} AS {:?})", expr, data_type))
diff --git a/rust/datafusion/src/physical_plan/expressions.rs b/rust/datafusion/src/physical_plan/expressions.rs
index f642da7..8a1e3f0 100644
--- a/rust/datafusion/src/physical_plan/expressions.rs
+++ b/rust/datafusion/src/physical_plan/expressions.rs
@@ -1790,7 +1790,11 @@ macro_rules! if_then_else {
let mut builder = <$BUILDER_TYPE>::new($BOOLS.len());
for i in 0..$BOOLS.len() {
if $BOOLS.is_null(i) {
- builder.append_null()?;
+ if false_values.is_null(i) {
+ builder.append_null()?;
+ } else {
+ builder.append_value(false_values.value(i))?;
+ }
} else if $BOOLS.value(i) {
if true_values.is_null(i) {
builder.append_null()?;
@@ -3403,7 +3407,8 @@ mod tests {
.downcast_ref::<Int32Array>()
.expect("failed to downcast to Int32Array");
- let expected = &Int32Array::from(vec![Some(123), Some(999), None, Some(456)]);
+ let expected =
+ &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(456)]);
assert_eq!(expected, result);
@@ -3472,7 +3477,8 @@ mod tests {
.downcast_ref::<Int32Array>()
.expect("failed to downcast to Int32Array");
- let expected = &Int32Array::from(vec![Some(123), Some(999), None, Some(456)]);
+ let expected =
+ &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(456)]);
assert_eq!(expected, result);
diff --git a/rust/datafusion/src/sql/planner.rs b/rust/datafusion/src/sql/planner.rs
index 002b53c..f5d2f88 100644
--- a/rust/datafusion/src/sql/planner.rs
+++ b/rust/datafusion/src/sql/planner.rs
@@ -447,6 +447,42 @@ impl<'a, S: SchemaProvider> SqlToRel<'a, S> {
SQLExpr::Wildcard => Ok(Expr::Wildcard),
+ SQLExpr::Case {
+ operand,
+ conditions,
+ results,
+ else_result,
+ } => {
+ let expr = if let Some(e) = operand {
+ Some(Box::new(self.sql_to_rex(e, schema)?))
+ } else {
+ None
+ };
+ let when_expr = conditions
+ .iter()
+ .map(|e| self.sql_to_rex(e, schema))
+ .collect::<Result<Vec<_>>>()?;
+ let then_expr = results
+ .iter()
+ .map(|e| self.sql_to_rex(e, schema))
+ .collect::<Result<Vec<_>>>()?;
+ let else_expr = if let Some(e) = else_result {
+ Some(Box::new(self.sql_to_rex(e, schema)?))
+ } else {
+ None
+ };
+
+ Ok(Expr::Case {
+ expr,
+ when_then_expr: when_expr
+ .iter()
+ .zip(then_expr.iter())
+ .map(|(w, t)| (Box::new(w.to_owned()), Box::new(t.to_owned())))
+ .collect(),
+ else_expr,
+ })
+ }
+
SQLExpr::Cast {
ref expr,
ref data_type,
diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs
index bbf8934..13459f9 100644
--- a/rust/datafusion/tests/sql.rs
+++ b/rust/datafusion/tests/sql.rs
@@ -939,6 +939,79 @@ async fn csv_query_count_one() {
}
#[tokio::test]
+async fn case_when() -> Result<()> {
+ let mut ctx = create_case_context()?;
+ let sql = "SELECT \
+ CASE WHEN c1 = 'a' THEN 1 \
+ WHEN c1 = 'b' THEN 2 \
+ END \
+ FROM t1";
+ let actual = execute(&mut ctx, sql).await;
+ let expected = vec![vec!["1"], vec!["2"], vec!["NULL"], vec!["NULL"]];
+ assert_eq!(expected, actual);
+ Ok(())
+}
+
+#[tokio::test]
+async fn case_when_else() -> Result<()> {
+ let mut ctx = create_case_context()?;
+ let sql = "SELECT \
+ CASE WHEN c1 = 'a' THEN 1 \
+ WHEN c1 = 'b' THEN 2 \
+ ELSE 999 END \
+ FROM t1";
+ let actual = execute(&mut ctx, sql).await;
+ let expected = vec![vec!["1"], vec!["2"], vec!["999"], vec!["999"]];
+ assert_eq!(expected, actual);
+ Ok(())
+}
+
+#[tokio::test]
+async fn case_when_with_base_expr() -> Result<()> {
+ let mut ctx = create_case_context()?;
+ let sql = "SELECT \
+ CASE c1 WHEN 'a' THEN 1 \
+ WHEN 'b' THEN 2 \
+ END \
+ FROM t1";
+ let actual = execute(&mut ctx, sql).await;
+ let expected = vec![vec!["1"], vec!["2"], vec!["NULL"], vec!["NULL"]];
+ assert_eq!(expected, actual);
+ Ok(())
+}
+
+#[tokio::test]
+async fn case_when_else_with_base_expr() -> Result<()> {
+ let mut ctx = create_case_context()?;
+ let sql = "SELECT \
+ CASE c1 WHEN 'a' THEN 1 \
+ WHEN 'b' THEN 2 \
+ ELSE 999 END \
+ FROM t1";
+ let actual = execute(&mut ctx, sql).await;
+ let expected = vec![vec!["1"], vec!["2"], vec!["999"], vec!["999"]];
+ assert_eq!(expected, actual);
+ Ok(())
+}
+
+fn create_case_context() -> Result<ExecutionContext> {
+ let mut ctx = ExecutionContext::new();
+ let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Utf8, true)]));
+ let data = RecordBatch::try_new(
+ schema.clone(),
+ vec![Arc::new(StringArray::from(vec![
+ Some("a"),
+ Some("b"),
+ Some("c"),
+ None,
+ ]))],
+ )?;
+ let table = MemTable::new(schema, vec![vec![data]])?;
+ ctx.register_table("t1", Box::new(table));
+ Ok(ctx)
+}
+
+#[tokio::test]
async fn csv_explain() {
let mut ctx = ExecutionContext::new();
register_aggregate_csv_by_sql(&mut ctx).await;