You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ag...@apache.org on 2021/04/25 13:47:56 UTC

[arrow-datafusion] branch master updated: Use arrow eq kernels in CaseWhen (#52)

This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 1f1130e  Use arrow eq kernels in CaseWhen (#52)
1f1130e is described below

commit 1f1130e5c51bba05bd55d0495bbe0d952841a1d7
Author: Daniƫl Heres <da...@gmail.com>
AuthorDate: Sun Apr 25 15:47:50 2021 +0200

    Use arrow eq kernels in CaseWhen (#52)
---
 datafusion/src/physical_plan/expressions/case.rs | 62 +++++++++++++++---------
 1 file changed, 38 insertions(+), 24 deletions(-)

diff --git a/datafusion/src/physical_plan/expressions/case.rs b/datafusion/src/physical_plan/expressions/case.rs
index e8c500e..723438d 100644
--- a/datafusion/src/physical_plan/expressions/case.rs
+++ b/datafusion/src/physical_plan/expressions/case.rs
@@ -17,13 +17,13 @@
 
 use std::{any::Any, sync::Arc};
 
+use crate::error::{DataFusionError, Result};
+use crate::physical_plan::{ColumnarValue, PhysicalExpr};
 use arrow::array::{self, *};
+use arrow::compute::{eq, eq_utf8};
 use arrow::datatypes::{DataType, Schema};
 use arrow::record_batch::RecordBatch;
 
-use crate::error::{DataFusionError, Result};
-use crate::physical_plan::{ColumnarValue, PhysicalExpr};
-
 /// The CASE expression is similar to a series of nested if/else and there are two forms that
 /// can be used. The first form consists of a series of boolean "when" expressions with
 /// corresponding "then" expressions, and an optional "else" expression.
@@ -265,7 +265,7 @@ fn build_null_array(data_type: &DataType, num_rows: usize) -> Result<ArrayRef> {
 }
 
 macro_rules! array_equals {
-    ($TY:ty, $L:expr, $R:expr) => {{
+    ($TY:ty, $L:expr, $R:expr, $eq_fn:expr) => {{
         let when_value = $L
             .as_ref()
             .as_any()
@@ -278,15 +278,7 @@ macro_rules! array_equals {
             .downcast_ref::<$TY>()
             .expect("array_equals downcast failed");
 
-        let mut builder = BooleanBuilder::new(when_value.len());
-        for row in 0..when_value.len() {
-            if when_value.is_valid(row) && base_value.is_valid(row) {
-                builder.append_value(when_value.value(row) == base_value.value(row))?;
-            } else {
-                builder.append_null()?;
-            }
-        }
-        Ok(builder.finish())
+        $eq_fn(when_value, base_value).map_err(DataFusionError::from)
     }};
 }
 
@@ -296,17 +288,39 @@ fn array_equals(
     base_value: ArrayRef,
 ) -> Result<BooleanArray> {
     match data_type {
-        DataType::UInt8 => array_equals!(array::UInt8Array, when_value, base_value),
-        DataType::UInt16 => array_equals!(array::UInt16Array, when_value, base_value),
-        DataType::UInt32 => array_equals!(array::UInt32Array, when_value, base_value),
-        DataType::UInt64 => array_equals!(array::UInt64Array, when_value, base_value),
-        DataType::Int8 => array_equals!(array::Int8Array, when_value, base_value),
-        DataType::Int16 => array_equals!(array::Int16Array, when_value, base_value),
-        DataType::Int32 => array_equals!(array::Int32Array, when_value, base_value),
-        DataType::Int64 => array_equals!(array::Int64Array, when_value, base_value),
-        DataType::Float32 => array_equals!(array::Float32Array, when_value, base_value),
-        DataType::Float64 => array_equals!(array::Float64Array, when_value, base_value),
-        DataType::Utf8 => array_equals!(array::StringArray, when_value, base_value),
+        DataType::UInt8 => {
+            array_equals!(array::UInt8Array, when_value, base_value, eq)
+        }
+        DataType::UInt16 => {
+            array_equals!(array::UInt16Array, when_value, base_value, eq)
+        }
+        DataType::UInt32 => {
+            array_equals!(array::UInt32Array, when_value, base_value, eq)
+        }
+        DataType::UInt64 => {
+            array_equals!(array::UInt64Array, when_value, base_value, eq)
+        }
+        DataType::Int8 => {
+            array_equals!(array::Int8Array, when_value, base_value, eq)
+        }
+        DataType::Int16 => {
+            array_equals!(array::Int16Array, when_value, base_value, eq)
+        }
+        DataType::Int32 => {
+            array_equals!(array::Int32Array, when_value, base_value, eq)
+        }
+        DataType::Int64 => {
+            array_equals!(array::Int64Array, when_value, base_value, eq)
+        }
+        DataType::Float32 => {
+            array_equals!(array::Float32Array, when_value, base_value, eq)
+        }
+        DataType::Float64 => {
+            array_equals!(array::Float64Array, when_value, base_value, eq)
+        }
+        DataType::Utf8 => {
+            array_equals!(array::StringArray, when_value, base_value, eq_utf8)
+        }
         other => Err(DataFusionError::Execution(format!(
             "CASE does not support '{:?}'",
             other