You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by tu...@apache.org on 2022/11/03 08:33:24 UTC

[arrow-datafusion] branch master updated: Support Dictionary in InListExpr (#4070)

This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 61429f839 Support Dictionary in InListExpr (#4070)
61429f839 is described below

commit 61429f839eb07bf50f36147d5b1d065194a45114
Author: Raphael Taylor-Davies <17...@users.noreply.github.com>
AuthorDate: Thu Nov 3 21:33:18 2022 +1300

    Support Dictionary in InListExpr (#4070)
    
    * Support dictionary in InList (#3936)
    
    * Update datafusion-cli
---
 datafusion-cli/Cargo.lock                          |  1 +
 datafusion/core/tests/sql/predicates.rs            | 99 +++++++++++++++++++++-
 datafusion/physical-expr/Cargo.toml                |  1 +
 .../physical-expr/src/expressions/in_list.rs       | 30 +++++--
 4 files changed, 120 insertions(+), 11 deletions(-)

diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock
index 899ea2af2..89f4cc78a 100644
--- a/datafusion-cli/Cargo.lock
+++ b/datafusion-cli/Cargo.lock
@@ -656,6 +656,7 @@ dependencies = [
  "itertools",
  "lazy_static",
  "md-5",
+ "num-traits",
  "ordered-float 3.3.0",
  "paste",
  "rand",
diff --git a/datafusion/core/tests/sql/predicates.rs b/datafusion/core/tests/sql/predicates.rs
index 3eea94300..21d46f7c8 100644
--- a/datafusion/core/tests/sql/predicates.rs
+++ b/datafusion/core/tests/sql/predicates.rs
@@ -428,8 +428,101 @@ async fn csv_in_set_test() -> Result<()> {
 }
 
 #[tokio::test]
-#[ignore]
-// https://github.com/apache/arrow-datafusion/issues/3936
+async fn in_list_string_dictionaries() -> Result<()> {
+    // let input = vec![Some("foo"), Some("bar"), None, Some("fazzz")]
+    let input = vec![Some("foo"), Some("bar"), Some("fazzz")]
+        .into_iter()
+        .collect::<DictionaryArray<Int32Type>>();
+
+    let batch = RecordBatch::try_from_iter(vec![("c1", Arc::new(input) as _)]).unwrap();
+
+    let ctx = SessionContext::new();
+    ctx.register_batch("test", batch)?;
+
+    let sql = "SELECT * FROM test WHERE c1 IN ('Bar')";
+    let actual = execute_to_batches(&ctx, sql).await;
+    let expected = vec!["++", "++"];
+    assert_batches_eq!(expected, &actual);
+
+    let sql = "SELECT * FROM test WHERE c1 IN ('foo')";
+    let actual = execute_to_batches(&ctx, sql).await;
+    let expected = vec!["+-----+", "| c1  |", "+-----+", "| foo |", "+-----+"];
+    assert_batches_eq!(expected, &actual);
+
+    let sql = "SELECT * FROM test WHERE c1 IN ('bar', 'foo')";
+    let actual = execute_to_batches(&ctx, sql).await;
+    let expected = vec![
+        "+-----+", "| c1  |", "+-----+", "| foo |", "| bar |", "+-----+",
+    ];
+    assert_batches_eq!(expected, &actual);
+
+    let sql = "SELECT * FROM test WHERE c1 IN ('Bar', 'foo')";
+    let actual = execute_to_batches(&ctx, sql).await;
+    let expected = vec!["+-----+", "| c1  |", "+-----+", "| foo |", "+-----+"];
+    assert_batches_eq!(expected, &actual);
+
+    let sql = "SELECT * FROM test WHERE c1 IN ('foo', 'Bar', 'fazzz')";
+    let actual = execute_to_batches(&ctx, sql).await;
+    let expected = vec![
+        "+-------+",
+        "| c1    |",
+        "+-------+",
+        "| foo   |",
+        "| fazzz |",
+        "+-------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+    Ok(())
+}
+
+#[tokio::test]
+async fn in_list_string_dictionaries_with_null() -> Result<()> {
+    let input = vec![Some("foo"), Some("bar"), None, Some("fazzz")]
+        .into_iter()
+        .collect::<DictionaryArray<Int32Type>>();
+
+    let batch = RecordBatch::try_from_iter(vec![("c1", Arc::new(input) as _)]).unwrap();
+
+    let ctx = SessionContext::new();
+    ctx.register_batch("test", batch)?;
+
+    let sql = "SELECT * FROM test WHERE c1 IN ('Bar')";
+    let actual = execute_to_batches(&ctx, sql).await;
+    let expected = vec!["++", "++"];
+    assert_batches_eq!(expected, &actual);
+
+    let sql = "SELECT * FROM test WHERE c1 IN ('foo')";
+    let actual = execute_to_batches(&ctx, sql).await;
+    let expected = vec!["+-----+", "| c1  |", "+-----+", "| foo |", "+-----+"];
+    assert_batches_eq!(expected, &actual);
+
+    let sql = "SELECT * FROM test WHERE c1 IN ('bar', 'foo')";
+    let actual = execute_to_batches(&ctx, sql).await;
+    let expected = vec![
+        "+-----+", "| c1  |", "+-----+", "| foo |", "| bar |", "+-----+",
+    ];
+    assert_batches_eq!(expected, &actual);
+
+    let sql = "SELECT * FROM test WHERE c1 IN ('Bar', 'foo')";
+    let actual = execute_to_batches(&ctx, sql).await;
+    let expected = vec!["+-----+", "| c1  |", "+-----+", "| foo |", "+-----+"];
+    assert_batches_eq!(expected, &actual);
+
+    let sql = "SELECT * FROM test WHERE c1 IN ('foo', 'Bar', 'fazzz')";
+    let actual = execute_to_batches(&ctx, sql).await;
+    let expected = vec![
+        "+-------+",
+        "| c1    |",
+        "+-------+",
+        "| foo   |",
+        "| fazzz |",
+        "+-------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+    Ok(())
+}
+
+#[tokio::test]
 async fn in_set_string_dictionaries() -> Result<()> {
     let input = vec![Some("foo"), Some("bar"), None, Some("fazzz")]
         .into_iter()
@@ -440,7 +533,7 @@ async fn in_set_string_dictionaries() -> Result<()> {
     let ctx = SessionContext::new();
     ctx.register_batch("test", batch)?;
 
-    let sql = "SELECT * FROM test WHERE c1 IN ('foo', 'Bar', 'fazz')";
+    let sql = "SELECT * FROM test WHERE c1 IN ('foo', 'Bar', 'fazzz')";
     let actual = execute_to_batches(&ctx, sql).await;
     let expected = vec![
         "+-------+",
diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml
index 6fc6f4176..0b2d72dd3 100644
--- a/datafusion/physical-expr/Cargo.toml
+++ b/datafusion/physical-expr/Cargo.toml
@@ -54,6 +54,7 @@ hashbrown = { version = "0.12", features = ["raw"] }
 itertools = { version = "0.10", features = ["use_std"] }
 lazy_static = { version = "^1.4.0" }
 md-5 = { version = "^0.10.0", optional = true }
+num-traits = { version = "0.2", default-features = false }
 ordered-float = "3.0"
 paste = "^1.0"
 rand = "0.8"
diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs
index 9406b42ee..26503bec7 100644
--- a/datafusion/physical-expr/src/expressions/in_list.rs
+++ b/datafusion/physical-expr/src/expressions/in_list.rs
@@ -27,10 +27,11 @@ use crate::physical_expr::down_cast_any_ref;
 use crate::utils::expr_list_eq_any_order;
 use crate::PhysicalExpr;
 use arrow::array::*;
+use arrow::compute::take;
 use arrow::datatypes::*;
-use arrow::downcast_primitive_array;
 use arrow::record_batch::RecordBatch;
 use arrow::util::bit_iterator::BitIndexIterator;
+use arrow::{downcast_dictionary_array, downcast_primitive_array};
 use datafusion_common::{DataFusionError, Result, ScalarValue};
 use datafusion_expr::ColumnarValue;
 use hashbrown::hash_map::RawEntryMut;
@@ -57,7 +58,7 @@ impl Debug for InListExpr {
 
 /// A type-erased container of array elements
 trait Set: Send + Sync {
-    fn contains(&self, v: &dyn Array, negated: bool) -> BooleanArray;
+    fn contains(&self, v: &dyn Array, negated: bool) -> Result<BooleanArray>;
 }
 
 struct ArrayHashSet {
@@ -92,13 +93,22 @@ where
     for<'a> &'a T: ArrayAccessor,
     for<'a> <&'a T as ArrayAccessor>::Item: PartialEq + HashValue,
 {
-    fn contains(&self, v: &dyn Array, negated: bool) -> BooleanArray {
+    fn contains(&self, v: &dyn Array, negated: bool) -> Result<BooleanArray> {
+        downcast_dictionary_array! {
+            v => {
+                let values_contains = self.contains(v.values().as_ref(), negated)?;
+                let result = take(&values_contains, v.keys(), None)?;
+                return Ok(BooleanArray::from(result.data().clone()))
+            }
+            _ => {}
+        }
+
         let v = v.as_any().downcast_ref::<T>().unwrap();
         let in_data = self.array.data();
         let in_array = &self.array;
         let has_nulls = in_data.null_count() != 0;
 
-        ArrayIter::new(v)
+        Ok(ArrayIter::new(v)
             .map(|v| {
                 v.and_then(|v| {
                     let hash = v.hash_one(&self.hash_set.state);
@@ -116,7 +126,7 @@ where
                     }
                 })
             })
-            .collect()
+            .collect())
     }
 }
 
@@ -188,10 +198,12 @@ fn make_set(array: &dyn Array) -> Result<Box<dyn Set>> {
             let array = as_generic_binary_array::<i64>(array);
             Box::new(ArraySet::new(array, make_hash_set(array)))
         }
+        DataType::Dictionary(_, _) => unreachable!("dictionary should have been flattened"),
         d => return Err(DataFusionError::NotImplemented(format!("DataType::{} not supported in InList", d)))
     })
 }
 
+/// Evaluates the list of expressions into an array, flattening any dictionaries
 fn evaluate_list(
     list: &[Arc<dyn PhysicalExpr>],
     batch: &RecordBatch,
@@ -203,6 +215,8 @@ fn evaluate_list(
                 ColumnarValue::Array(_) => Err(DataFusionError::Execution(
                     "InList expression must evaluate to a scalar".to_string(),
                 )),
+                // Flatten dictionary values
+                ColumnarValue::Scalar(ScalarValue::Dictionary(_, v)) => Ok(*v),
                 ColumnarValue::Scalar(s) => Ok(s),
             })
         })
@@ -286,10 +300,10 @@ impl PhysicalExpr for InListExpr {
     fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
         let value = self.expr.evaluate(batch)?.into_array(1);
         let r = match &self.static_filter {
-            Some(f) => f.contains(value.as_ref(), self.negated),
+            Some(f) => f.contains(value.as_ref(), self.negated)?,
             None => {
                 let list = evaluate_list(&self.list, batch)?;
-                make_set(list.as_ref())?.contains(value.as_ref(), self.negated)
+                make_set(list.as_ref())?.contains(value.as_ref(), self.negated)?
             }
         };
         Ok(ColumnarValue::Array(Arc::new(r)))
@@ -947,7 +961,7 @@ mod tests {
         let result = try_cast_static_filter_to_set(&phy_exprs, &schema).unwrap();
 
         let array = Int64Array::from(vec![1, 2, 3, 4]);
-        let r = result.contains(&array, false);
+        let r = result.contains(&array, false).unwrap();
         assert_eq!(r, BooleanArray::from(vec![true, true, true, false]));
 
         try_cast_static_filter_to_set(&phy_exprs, &schema).unwrap();