You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by tu...@apache.org on 2022/11/03 08:33:24 UTC
[arrow-datafusion] branch master updated: Support Dictionary in InListExpr (#4070)
This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 61429f839 Support Dictionary in InListExpr (#4070)
61429f839 is described below
commit 61429f839eb07bf50f36147d5b1d065194a45114
Author: Raphael Taylor-Davies <17...@users.noreply.github.com>
AuthorDate: Thu Nov 3 21:33:18 2022 +1300
Support Dictionary in InListExpr (#4070)
* Support dictionary in InList (#3936)
* Update datafusion-cli
---
datafusion-cli/Cargo.lock | 1 +
datafusion/core/tests/sql/predicates.rs | 99 +++++++++++++++++++++-
datafusion/physical-expr/Cargo.toml | 1 +
.../physical-expr/src/expressions/in_list.rs | 30 +++++--
4 files changed, 120 insertions(+), 11 deletions(-)
diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock
index 899ea2af2..89f4cc78a 100644
--- a/datafusion-cli/Cargo.lock
+++ b/datafusion-cli/Cargo.lock
@@ -656,6 +656,7 @@ dependencies = [
"itertools",
"lazy_static",
"md-5",
+ "num-traits",
"ordered-float 3.3.0",
"paste",
"rand",
diff --git a/datafusion/core/tests/sql/predicates.rs b/datafusion/core/tests/sql/predicates.rs
index 3eea94300..21d46f7c8 100644
--- a/datafusion/core/tests/sql/predicates.rs
+++ b/datafusion/core/tests/sql/predicates.rs
@@ -428,8 +428,101 @@ async fn csv_in_set_test() -> Result<()> {
}
#[tokio::test]
-#[ignore]
-// https://github.com/apache/arrow-datafusion/issues/3936
+async fn in_list_string_dictionaries() -> Result<()> {
+ // let input = vec![Some("foo"), Some("bar"), None, Some("fazzz")]
+ let input = vec![Some("foo"), Some("bar"), Some("fazzz")]
+ .into_iter()
+ .collect::<DictionaryArray<Int32Type>>();
+
+ let batch = RecordBatch::try_from_iter(vec![("c1", Arc::new(input) as _)]).unwrap();
+
+ let ctx = SessionContext::new();
+ ctx.register_batch("test", batch)?;
+
+ let sql = "SELECT * FROM test WHERE c1 IN ('Bar')";
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec!["++", "++"];
+ assert_batches_eq!(expected, &actual);
+
+ let sql = "SELECT * FROM test WHERE c1 IN ('foo')";
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec!["+-----+", "| c1 |", "+-----+", "| foo |", "+-----+"];
+ assert_batches_eq!(expected, &actual);
+
+ let sql = "SELECT * FROM test WHERE c1 IN ('bar', 'foo')";
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec![
+ "+-----+", "| c1 |", "+-----+", "| foo |", "| bar |", "+-----+",
+ ];
+ assert_batches_eq!(expected, &actual);
+
+ let sql = "SELECT * FROM test WHERE c1 IN ('Bar', 'foo')";
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec!["+-----+", "| c1 |", "+-----+", "| foo |", "+-----+"];
+ assert_batches_eq!(expected, &actual);
+
+ let sql = "SELECT * FROM test WHERE c1 IN ('foo', 'Bar', 'fazzz')";
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec![
+ "+-------+",
+ "| c1 |",
+ "+-------+",
+ "| foo |",
+ "| fazzz |",
+ "+-------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+ Ok(())
+}
+
+#[tokio::test]
+async fn in_list_string_dictionaries_with_null() -> Result<()> {
+ let input = vec![Some("foo"), Some("bar"), None, Some("fazzz")]
+ .into_iter()
+ .collect::<DictionaryArray<Int32Type>>();
+
+ let batch = RecordBatch::try_from_iter(vec![("c1", Arc::new(input) as _)]).unwrap();
+
+ let ctx = SessionContext::new();
+ ctx.register_batch("test", batch)?;
+
+ let sql = "SELECT * FROM test WHERE c1 IN ('Bar')";
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec!["++", "++"];
+ assert_batches_eq!(expected, &actual);
+
+ let sql = "SELECT * FROM test WHERE c1 IN ('foo')";
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec!["+-----+", "| c1 |", "+-----+", "| foo |", "+-----+"];
+ assert_batches_eq!(expected, &actual);
+
+ let sql = "SELECT * FROM test WHERE c1 IN ('bar', 'foo')";
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec![
+ "+-----+", "| c1 |", "+-----+", "| foo |", "| bar |", "+-----+",
+ ];
+ assert_batches_eq!(expected, &actual);
+
+ let sql = "SELECT * FROM test WHERE c1 IN ('Bar', 'foo')";
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec!["+-----+", "| c1 |", "+-----+", "| foo |", "+-----+"];
+ assert_batches_eq!(expected, &actual);
+
+ let sql = "SELECT * FROM test WHERE c1 IN ('foo', 'Bar', 'fazzz')";
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec![
+ "+-------+",
+ "| c1 |",
+ "+-------+",
+ "| foo |",
+ "| fazzz |",
+ "+-------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+ Ok(())
+}
+
+#[tokio::test]
async fn in_set_string_dictionaries() -> Result<()> {
let input = vec![Some("foo"), Some("bar"), None, Some("fazzz")]
.into_iter()
@@ -440,7 +533,7 @@ async fn in_set_string_dictionaries() -> Result<()> {
let ctx = SessionContext::new();
ctx.register_batch("test", batch)?;
- let sql = "SELECT * FROM test WHERE c1 IN ('foo', 'Bar', 'fazz')";
+ let sql = "SELECT * FROM test WHERE c1 IN ('foo', 'Bar', 'fazzz')";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+-------+",
diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml
index 6fc6f4176..0b2d72dd3 100644
--- a/datafusion/physical-expr/Cargo.toml
+++ b/datafusion/physical-expr/Cargo.toml
@@ -54,6 +54,7 @@ hashbrown = { version = "0.12", features = ["raw"] }
itertools = { version = "0.10", features = ["use_std"] }
lazy_static = { version = "^1.4.0" }
md-5 = { version = "^0.10.0", optional = true }
+num-traits = { version = "0.2", default-features = false }
ordered-float = "3.0"
paste = "^1.0"
rand = "0.8"
diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs
index 9406b42ee..26503bec7 100644
--- a/datafusion/physical-expr/src/expressions/in_list.rs
+++ b/datafusion/physical-expr/src/expressions/in_list.rs
@@ -27,10 +27,11 @@ use crate::physical_expr::down_cast_any_ref;
use crate::utils::expr_list_eq_any_order;
use crate::PhysicalExpr;
use arrow::array::*;
+use arrow::compute::take;
use arrow::datatypes::*;
-use arrow::downcast_primitive_array;
use arrow::record_batch::RecordBatch;
use arrow::util::bit_iterator::BitIndexIterator;
+use arrow::{downcast_dictionary_array, downcast_primitive_array};
use datafusion_common::{DataFusionError, Result, ScalarValue};
use datafusion_expr::ColumnarValue;
use hashbrown::hash_map::RawEntryMut;
@@ -57,7 +58,7 @@ impl Debug for InListExpr {
/// A type-erased container of array elements
trait Set: Send + Sync {
- fn contains(&self, v: &dyn Array, negated: bool) -> BooleanArray;
+ fn contains(&self, v: &dyn Array, negated: bool) -> Result<BooleanArray>;
}
struct ArrayHashSet {
@@ -92,13 +93,22 @@ where
for<'a> &'a T: ArrayAccessor,
for<'a> <&'a T as ArrayAccessor>::Item: PartialEq + HashValue,
{
- fn contains(&self, v: &dyn Array, negated: bool) -> BooleanArray {
+ fn contains(&self, v: &dyn Array, negated: bool) -> Result<BooleanArray> {
+ downcast_dictionary_array! {
+ v => {
+ let values_contains = self.contains(v.values().as_ref(), negated)?;
+ let result = take(&values_contains, v.keys(), None)?;
+ return Ok(BooleanArray::from(result.data().clone()))
+ }
+ _ => {}
+ }
+
let v = v.as_any().downcast_ref::<T>().unwrap();
let in_data = self.array.data();
let in_array = &self.array;
let has_nulls = in_data.null_count() != 0;
- ArrayIter::new(v)
+ Ok(ArrayIter::new(v)
.map(|v| {
v.and_then(|v| {
let hash = v.hash_one(&self.hash_set.state);
@@ -116,7 +126,7 @@ where
}
})
})
- .collect()
+ .collect())
}
}
@@ -188,10 +198,12 @@ fn make_set(array: &dyn Array) -> Result<Box<dyn Set>> {
let array = as_generic_binary_array::<i64>(array);
Box::new(ArraySet::new(array, make_hash_set(array)))
}
+ DataType::Dictionary(_, _) => unreachable!("dictionary should have been flattened"),
d => return Err(DataFusionError::NotImplemented(format!("DataType::{} not supported in InList", d)))
})
}
+/// Evaluates the list of expressions into an array, flattening any dictionaries
fn evaluate_list(
list: &[Arc<dyn PhysicalExpr>],
batch: &RecordBatch,
@@ -203,6 +215,8 @@ fn evaluate_list(
ColumnarValue::Array(_) => Err(DataFusionError::Execution(
"InList expression must evaluate to a scalar".to_string(),
)),
+ // Flatten dictionary values
+ ColumnarValue::Scalar(ScalarValue::Dictionary(_, v)) => Ok(*v),
ColumnarValue::Scalar(s) => Ok(s),
})
})
@@ -286,10 +300,10 @@ impl PhysicalExpr for InListExpr {
fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
let value = self.expr.evaluate(batch)?.into_array(1);
let r = match &self.static_filter {
- Some(f) => f.contains(value.as_ref(), self.negated),
+ Some(f) => f.contains(value.as_ref(), self.negated)?,
None => {
let list = evaluate_list(&self.list, batch)?;
- make_set(list.as_ref())?.contains(value.as_ref(), self.negated)
+ make_set(list.as_ref())?.contains(value.as_ref(), self.negated)?
}
};
Ok(ColumnarValue::Array(Arc::new(r)))
@@ -947,7 +961,7 @@ mod tests {
let result = try_cast_static_filter_to_set(&phy_exprs, &schema).unwrap();
let array = Int64Array::from(vec![1, 2, 3, 4]);
- let r = result.contains(&array, false);
+ let r = result.contains(&array, false).unwrap();
assert_eq!(r, BooleanArray::from(vec![true, true, true, false]));
try_cast_static_filter_to_set(&phy_exprs, &schema).unwrap();