You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/06/30 01:08:54 UTC

[arrow-datafusion] branch master updated: Use specialized dictionary kernels (#1178) (#2808)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 6e0bb8476 Use specialized dictionary kernels (#1178) (#2808)
6e0bb8476 is described below

commit 6e0bb8476d783c1caaf6bf011487c92ae9352f78
Author: Raphael Taylor-Davies <17...@users.noreply.github.com>
AuthorDate: Thu Jun 30 02:08:49 2022 +0100

    Use specialized dictionary kernels (#1178) (#2808)
    
    * Use specialized dictionary kernels (#1178)
    
    * Fix tests
---
 datafusion/expr/src/binary_rule.rs                 | 56 ++++++++++++++--------
 datafusion/physical-expr/src/expressions/binary.rs | 27 +++++++++--
 2 files changed, 57 insertions(+), 26 deletions(-)

diff --git a/datafusion/expr/src/binary_rule.rs b/datafusion/expr/src/binary_rule.rs
index b7b2c57e8..5b404d8a2 100644
--- a/datafusion/expr/src/binary_rule.rs
+++ b/datafusion/expr/src/binary_rule.rs
@@ -155,14 +155,12 @@ pub fn comparison_eq_coercion(
     lhs_type: &DataType,
     rhs_type: &DataType,
 ) -> Option<DataType> {
-    // can't compare dictionaries directly due to
-    // https://github.com/apache/arrow-rs/issues/1201
-    if lhs_type == rhs_type && !is_dictionary(lhs_type) {
+    if lhs_type == rhs_type {
         // same type => equality is possible
         return Some(lhs_type.clone());
     }
     comparison_binary_numeric_coercion(lhs_type, rhs_type)
-        .or_else(|| dictionary_coercion(lhs_type, rhs_type))
+        .or_else(|| dictionary_coercion(lhs_type, rhs_type, true))
         .or_else(|| temporal_coercion(lhs_type, rhs_type))
         .or_else(|| string_coercion(lhs_type, rhs_type))
         .or_else(|| null_coercion(lhs_type, rhs_type))
@@ -173,15 +171,13 @@ fn comparison_order_coercion(
     lhs_type: &DataType,
     rhs_type: &DataType,
 ) -> Option<DataType> {
-    // can't compare dictionaries directly due to
-    // https://github.com/apache/arrow-rs/issues/1201
-    if lhs_type == rhs_type && !is_dictionary(lhs_type) {
+    if lhs_type == rhs_type {
         // same type => all good
         return Some(lhs_type.clone());
     }
     comparison_binary_numeric_coercion(lhs_type, rhs_type)
         .or_else(|| string_coercion(lhs_type, rhs_type))
-        .or_else(|| dictionary_coercion(lhs_type, rhs_type))
+        .or_else(|| dictionary_coercion(lhs_type, rhs_type, true))
         .or_else(|| temporal_coercion(lhs_type, rhs_type))
         .or_else(|| null_coercion(lhs_type, rhs_type))
 }
@@ -448,17 +444,24 @@ fn dictionary_value_coercion(
 /// Coercion rules for Dictionaries: the type that both lhs and rhs
 /// can be casted to for the purpose of a computation.
 ///
-/// It would likely be preferable to cast primitive values to
-/// dictionaries, and thus avoid unpacking dictionary as well as doing
-/// faster comparisons. However, the arrow compute kernels (e.g. eq)
-/// don't have DictionaryArray support yet, so fall back to unpacking
-/// the dictionaries
-fn dictionary_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
+/// Not all operators support dictionaries, if `preserve_dictionaries` is true
+/// dictionaries will be preserved if possible
+fn dictionary_coercion(
+    lhs_type: &DataType,
+    rhs_type: &DataType,
+    preserve_dictionaries: bool,
+) -> Option<DataType> {
     match (lhs_type, rhs_type) {
         (
             DataType::Dictionary(_lhs_index_type, lhs_value_type),
             DataType::Dictionary(_rhs_index_type, rhs_value_type),
         ) => dictionary_value_coercion(lhs_value_type, rhs_value_type),
+        (d @ DataType::Dictionary(_, value_type), other_type)
+        | (other_type, d @ DataType::Dictionary(_, value_type))
+            if preserve_dictionaries && value_type.as_ref() == other_type =>
+        {
+            Some(d.clone())
+        }
         (DataType::Dictionary(_index_type, value_type), _) => {
             dictionary_value_coercion(value_type, rhs_type)
         }
@@ -514,7 +517,7 @@ fn string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType>
 /// This is a union of string coercion rules and dictionary coercion rules
 fn like_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
     string_coercion(lhs_type, rhs_type)
-        .or_else(|| dictionary_coercion(lhs_type, rhs_type))
+        .or_else(|| dictionary_coercion(lhs_type, rhs_type, false))
         .or_else(|| null_coercion(lhs_type, rhs_type))
 }
 
@@ -616,7 +619,7 @@ fn eq_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
         return Some(lhs_type.clone());
     }
     numerical_coercion(lhs_type, rhs_type)
-        .or_else(|| dictionary_coercion(lhs_type, rhs_type))
+        .or_else(|| dictionary_coercion(lhs_type, rhs_type, true))
         .or_else(|| temporal_coercion(lhs_type, rhs_type))
         .or_else(|| null_coercion(lhs_type, rhs_type))
 }
@@ -779,21 +782,32 @@ mod tests {
     fn test_dictionary_type_coersion() {
         use DataType::*;
 
-        // TODO: In the future, this would ideally return Dictionary types and avoid unpacking
         let lhs_type = Dictionary(Box::new(Int8), Box::new(Int32));
         let rhs_type = Dictionary(Box::new(Int8), Box::new(Int16));
-        assert_eq!(dictionary_coercion(&lhs_type, &rhs_type), Some(Int32));
+        assert_eq!(dictionary_coercion(&lhs_type, &rhs_type, true), Some(Int32));
+        assert_eq!(
+            dictionary_coercion(&lhs_type, &rhs_type, false),
+            Some(Int32)
+        );
 
         let lhs_type = Dictionary(Box::new(Int8), Box::new(Utf8));
         let rhs_type = Dictionary(Box::new(Int8), Box::new(Int16));
-        assert_eq!(dictionary_coercion(&lhs_type, &rhs_type), None);
+        assert_eq!(dictionary_coercion(&lhs_type, &rhs_type, true), None);
 
         let lhs_type = Dictionary(Box::new(Int8), Box::new(Utf8));
         let rhs_type = Utf8;
-        assert_eq!(dictionary_coercion(&lhs_type, &rhs_type), Some(Utf8));
+        assert_eq!(dictionary_coercion(&lhs_type, &rhs_type, false), Some(Utf8));
+        assert_eq!(
+            dictionary_coercion(&lhs_type, &rhs_type, true),
+            Some(lhs_type.clone())
+        );
 
         let lhs_type = Utf8;
         let rhs_type = Dictionary(Box::new(Int8), Box::new(Utf8));
-        assert_eq!(dictionary_coercion(&lhs_type, &rhs_type), Some(Utf8));
+        assert_eq!(dictionary_coercion(&lhs_type, &rhs_type, false), Some(Utf8));
+        assert_eq!(
+            dictionary_coercion(&lhs_type, &rhs_type, true),
+            Some(rhs_type.clone())
+        );
     }
 }
diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs
index c0876a722..417306221 100644
--- a/datafusion/physical-expr/src/expressions/binary.rs
+++ b/datafusion/physical-expr/src/expressions/binary.rs
@@ -1009,11 +1009,28 @@ impl PhysicalExpr for BinaryExpr {
         let left_data_type = left_value.data_type();
         let right_data_type = right_value.data_type();
 
-        if left_data_type != right_data_type {
-            return Err(DataFusionError::Internal(format!(
-                "Cannot evaluate binary expression {:?} with types {:?} and {:?}",
-                self.op, left_data_type, right_data_type
-            )));
+        match (&left_value, &left_data_type, &right_value, &right_data_type) {
+            // Types are equal => valid
+            (_, l, _, r) if l == r => {}
+            // Allow comparing a dictionary value with its corresponding scalar value
+            (
+                ColumnarValue::Array(_),
+                DataType::Dictionary(_, dict_t),
+                ColumnarValue::Scalar(_),
+                scalar_t,
+            )
+            | (
+                ColumnarValue::Scalar(_),
+                scalar_t,
+                ColumnarValue::Array(_),
+                DataType::Dictionary(_, dict_t),
+            ) if dict_t.as_ref() == scalar_t => {}
+            _ => {
+                return Err(DataFusionError::Internal(format!(
+                    "Cannot evaluate binary expression {:?} with types {:?} and {:?}",
+                    self.op, left_data_type, right_data_type
+                )));
+            }
         }
 
         // Attempt to use special kernels if one input is scalar and the other is an array