You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by GitBox <gi...@apache.org> on 2020/11/23 17:09:44 UTC

[GitHub] [arrow] jorgecarleitao commented on a change in pull request #8740: ARROW-10679: [Rust] [DataFusion] Implement CASE WHEN physical expression

jorgecarleitao commented on a change in pull request #8740:
URL: https://github.com/apache/arrow/pull/8740#discussion_r528860991



##########
File path: rust/datafusion/src/physical_plan/mod.rs
##########
@@ -126,10 +126,10 @@ impl ColumnarValue {
         }
     }
 
-    fn into_array(self, batch: &RecordBatch) -> ArrayRef {
+    fn into_array(self, num_rows: usize) -> ArrayRef {

Review comment:
       good. I had the same concern when reviewing this the other day.

##########
File path: rust/datafusion/src/physical_plan/expressions.rs
##########
@@ -1697,6 +1700,409 @@ pub fn is_not_null(arg: Arc<dyn PhysicalExpr>) -> Result<Arc<dyn PhysicalExpr>>
     Ok(Arc::new(IsNotNullExpr::new(arg)))
 }
 
+/// The CASE expression is similar to a series of nested if/else and there are two forms that
+/// can be used. The first form consists of a series of boolean "when" expressions with
+/// corresponding "then" expressions, and an optional "else" expression.
+///
+/// CASE WHEN condition THEN result
+///      [WHEN ...]
+///      [ELSE result]
+/// END
+///
+/// The second form uses a base expression and then a series of "when" clauses that match on a
+/// literal value.
+///
+/// CASE expression
+///     WHEN value THEN result
+///     [WHEN ...]
+///     [ELSE result]
+/// END
+#[derive(Debug)]
+struct CaseExpr {
+    /// Optional base expression that can be compared to literal values in the "when" expressions
+    expr: Option<Arc<dyn PhysicalExpr>>,
+    /// One or more when/then expressions
+    when_then_expr: Vec<(Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>)>,
+    /// Optional "else" expression
+    else_expr: Option<Arc<dyn PhysicalExpr>>,
+}
+
+impl fmt::Display for CaseExpr {
+    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+        write!(f, "CASE ")?;
+        if let Some(e) = &self.expr {
+            write!(f, "{} ", e)?;
+        }
+        for (w, t) in &self.when_then_expr {
+            write!(f, "WHEN {} THEN {} ", w, t)?;
+        }
+        if let Some(e) = &self.else_expr {
+            write!(f, "ELSE {} ", e)?;
+        }
+        write!(f, "END")
+    }
+}
+
+impl CaseExpr {
+    pub fn try_new(
+        expr: Option<Arc<dyn PhysicalExpr>>,
+        when_then_expr: &[(Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>)],
+        else_expr: Option<Arc<dyn PhysicalExpr>>,
+    ) -> Result<Self> {
+        if when_then_expr.len() == 0 {
+            Err(DataFusionError::Execution(
+                "There must be at least one WHEN clause".to_string(),
+            ))
+        } else {
+            Ok(Self {
+                expr,
+                when_then_expr: when_then_expr.to_vec(),
+                else_expr,
+            })
+        }
+    }
+}
+
+/// Create a CASE expression
+pub fn case(
+    expr: Option<Arc<dyn PhysicalExpr>>,
+    when_thens: &[(Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>)],
+    else_expr: Option<Arc<dyn PhysicalExpr>>,
+) -> Result<Arc<dyn PhysicalExpr>> {
+    Ok(Arc::new(CaseExpr::try_new(expr, when_thens, else_expr)?))
+}
+
+macro_rules! if_then_else {
+    ($BUILDER_TYPE:ty, $ARRAY_TYPE:ty, $BOOLS:expr, $TRUE:expr, $FALSE:expr) => {{
+        let true_values = $TRUE
+            .as_ref()
+            .as_any()
+            .downcast_ref::<$ARRAY_TYPE>()
+            .expect("true_values downcast failed");
+
+        let false_values = $FALSE
+            .as_ref()
+            .as_any()
+            .downcast_ref::<$ARRAY_TYPE>()
+            .expect("false_values downcast failed");
+
+        let mut builder = <$BUILDER_TYPE>::new($BOOLS.len());
+        for i in 0..$BOOLS.len() {
+            if $BOOLS.is_null(i) {
+                builder.append_null()?;
+            } else if $BOOLS.value(i) {
+                if true_values.is_null(i) {
+                    builder.append_null()?;
+                } else {
+                    builder.append_value(true_values.value(i))?;
+                }
+            } else {
+                if false_values.is_null(i) {
+                    builder.append_null()?;
+                } else {
+                    builder.append_value(false_values.value(i))?;
+                }
+            }
+        }
+        Ok(Arc::new(builder.finish()))
+    }};
+}
+
+fn if_then_else(
+    bools: &BooleanArray,
+    true_values: ArrayRef,
+    false_values: ArrayRef,
+    data_type: &DataType,
+) -> Result<ArrayRef> {
+    match data_type {
+        DataType::UInt8 => if_then_else!(
+            array::UInt8Builder,
+            array::UInt8Array,
+            bools,
+            true_values,
+            false_values
+        ),
+        DataType::UInt16 => if_then_else!(
+            array::UInt16Builder,
+            array::UInt16Array,
+            bools,
+            true_values,
+            false_values
+        ),
+        DataType::UInt32 => if_then_else!(
+            array::UInt32Builder,
+            array::UInt32Array,
+            bools,
+            true_values,
+            false_values
+        ),
+        DataType::UInt64 => if_then_else!(
+            array::UInt64Builder,
+            array::UInt64Array,
+            bools,
+            true_values,
+            false_values
+        ),
+        DataType::Int8 => if_then_else!(
+            array::Int8Builder,
+            array::Int8Array,
+            bools,
+            true_values,
+            false_values
+        ),
+        DataType::Int16 => if_then_else!(
+            array::Int16Builder,
+            array::Int16Array,
+            bools,
+            true_values,
+            false_values
+        ),
+        DataType::Int32 => if_then_else!(
+            array::Int32Builder,
+            array::Int32Array,
+            bools,
+            true_values,
+            false_values
+        ),
+        DataType::Int64 => if_then_else!(
+            array::Int64Builder,
+            array::Int64Array,
+            bools,
+            true_values,
+            false_values
+        ),
+        DataType::Float32 => if_then_else!(
+            array::Float32Builder,
+            array::Float32Array,
+            bools,
+            true_values,
+            false_values
+        ),
+        DataType::Float64 => if_then_else!(
+            array::Float64Builder,
+            array::Float64Array,
+            bools,
+            true_values,
+            false_values
+        ),
+        DataType::Utf8 => if_then_else!(
+            array::StringBuilder,
+            array::StringArray,
+            bools,
+            true_values,
+            false_values
+        ),
+        other => Err(DataFusionError::Execution(format!(
+            "CASE does not support '{:?}'",
+            other
+        ))),
+    }
+}
+
+macro_rules! make_null_array {
+    ($TY:ty, $N:expr) => {{
+        let mut builder = <$TY>::new($N);
+        for _ in 0..$N {
+            builder.append_null()?;
+        }
+        Ok(Arc::new(builder.finish()))
+    }};
+}
+
+fn build_null_array(data_type: &DataType, num_rows: usize) -> Result<ArrayRef> {
+    match data_type {
+        DataType::UInt8 => make_null_array!(array::UInt8Builder, num_rows),
+        DataType::UInt16 => make_null_array!(array::UInt16Builder, num_rows),
+        DataType::UInt32 => make_null_array!(array::UInt32Builder, num_rows),
+        DataType::UInt64 => make_null_array!(array::UInt64Builder, num_rows),
+        DataType::Int8 => make_null_array!(array::Int8Builder, num_rows),
+        DataType::Int16 => make_null_array!(array::Int16Builder, num_rows),
+        DataType::Int32 => make_null_array!(array::Int32Builder, num_rows),
+        DataType::Int64 => make_null_array!(array::Int64Builder, num_rows),
+        DataType::Float32 => make_null_array!(array::Float32Builder, num_rows),
+        DataType::Float64 => make_null_array!(array::Float64Builder, num_rows),
+        DataType::Utf8 => make_null_array!(array::StringBuilder, num_rows),
+        other => Err(DataFusionError::Execution(format!(
+            "CASE does not support '{:?}'",
+            other
+        ))),
+    }
+}
+
+macro_rules! array_equals {
+    ($TY:ty, $L:expr, $R:expr) => {{
+        let when_value = $L
+            .as_ref()
+            .as_any()
+            .downcast_ref::<$TY>()
+            .expect("array_equals downcast failed");
+
+        let base_value = $R
+            .as_ref()
+            .as_any()
+            .downcast_ref::<$TY>()
+            .expect("array_equals downcast failed");
+
+        let mut builder = BooleanBuilder::new(when_value.len());
+        for row in 0..when_value.len() {
+            if when_value.is_valid(row) && base_value.is_valid(row) {
+                builder.append_value(when_value.value(row) == base_value.value(row))?;
+            } else {
+                builder.append_null()?;
+            }
+        }
+        Ok(builder.finish())
+    }};
+}

Review comment:
       Doesn't this exist as `arrow::compute::kernels::comparison::eq`?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org