You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by tu...@apache.org on 2022/11/02 20:29:41 UTC
[arrow-datafusion] branch master updated: Simplify InListExpr ~20-70% Faster (#4057)
This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 12875292a Simplify InListExpr ~20-70% Faster (#4057)
12875292a is described below
commit 12875292a290feb174b5cd97d123c06bebcf8179
Author: Raphael Taylor-Davies <17...@users.noreply.github.com>
AuthorDate: Thu Nov 3 09:29:35 2022 +1300
Simplify InListExpr ~20-70% Faster (#4057)
* Simplify InList expression
* Simplify
* Hash floats as integers
* Fix tests
* Format
* Update datafusion-cli lockfile
* Sort Cargo.toml
* Update datafusion/physical-expr/src/expressions/in_list.rs
Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
---
datafusion-cli/Cargo.lock | 6 +-
datafusion/core/Cargo.toml | 4 +-
datafusion/core/src/physical_plan/mod.rs | 5 +-
datafusion/core/src/physical_plan/planner.rs | 49 +-
datafusion/physical-expr/Cargo.toml | 3 +
.../physical-expr/src/expressions/in_list.rs | 1277 +++-----------------
.../src}/hash_utils.rs | 10 +-
datafusion/physical-expr/src/lib.rs | 1 +
8 files changed, 215 insertions(+), 1140 deletions(-)
diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock
index a693295e9..899ea2af2 100644
--- a/datafusion-cli/Cargo.lock
+++ b/datafusion-cli/Cargo.lock
@@ -547,8 +547,6 @@ version = "13.0.0"
dependencies = [
"ahash 0.8.0",
"arrow",
- "arrow-buffer",
- "arrow-schema",
"async-compression",
"async-trait",
"bytes",
@@ -563,7 +561,6 @@ dependencies = [
"flate2",
"futures",
"glob",
- "half",
"hashbrown",
"itertools",
"lazy_static",
@@ -646,12 +643,15 @@ version = "13.0.0"
dependencies = [
"ahash 0.8.0",
"arrow",
+ "arrow-buffer",
+ "arrow-schema",
"blake2",
"blake3",
"chrono",
"datafusion-common",
"datafusion-expr",
"datafusion-row",
+ "half",
"hashbrown",
"itertools",
"lazy_static",
diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml
index 2764273ce..4a8e48960 100644
--- a/datafusion/core/Cargo.toml
+++ b/datafusion/core/Cargo.toml
@@ -57,8 +57,7 @@ unicode_expressions = ["datafusion-physical-expr/regex_expressions", "datafusion
ahash = { version = "0.8", default-features = false, features = ["runtime-rng"] }
apache-avro = { version = "0.14", optional = true }
arrow = { version = "25.0.0", features = ["prettyprint"] }
-arrow-buffer = "25.0.0"
-arrow-schema = "25.0.0"
+
async-compression = { version = "0.3.14", features = ["bzip2", "gzip", "futures-io", "tokio"] }
async-trait = "0.1.41"
bytes = "1.1"
@@ -74,7 +73,6 @@ datafusion-sql = { path = "../sql", version = "13.0.0" }
flate2 = "1.0.24"
futures = "0.3"
glob = "0.3.0"
-half = { version = "2.1", default-features = false }
hashbrown = { version = "0.12", features = ["raw"] }
itertools = "0.10"
lazy_static = { version = "^1.4.0" }
diff --git a/datafusion/core/src/physical_plan/mod.rs b/datafusion/core/src/physical_plan/mod.rs
index 9e36c3ec8..55b46c991 100644
--- a/datafusion/core/src/physical_plan/mod.rs
+++ b/datafusion/core/src/physical_plan/mod.rs
@@ -525,7 +525,6 @@ pub mod empty;
pub mod explain;
pub mod file_format;
pub mod filter;
-pub mod hash_utils;
pub mod joins;
pub mod limit;
pub mod memory;
@@ -541,4 +540,6 @@ pub mod values;
pub mod windows;
use crate::execution::context::TaskContext;
-pub use datafusion_physical_expr::{expressions, functions, type_coercion, udf};
+pub use datafusion_physical_expr::{
+ expressions, functions, hash_utils, type_coercion, udf,
+};
diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs
index a389bb65b..085677785 100644
--- a/datafusion/core/src/physical_plan/planner.rs
+++ b/datafusion/core/src/physical_plan/planner.rs
@@ -2035,8 +2035,9 @@ mod tests {
.build()?;
let execution_plan = plan(&logical_plan).await?;
// verify that the plan correctly adds cast from Int64(1) to Utf8, and the const will be evaluated.
- let expected = "expr: [(InListExpr { expr: Column { name: \"c1\", index: 0 }, list: [Literal { value: Utf8(\"a\") }, Literal { value: Utf8(\"1\") }], negated: false, set: None }";
- assert!(format!("{:?}", execution_plan).contains(expected));
+ let expected = "expr: [(InListExpr { expr: Column { name: \"c1\", index: 0 }, list: [Literal { value: Utf8(\"a\") }, Literal { value: Utf8(\"1\") }], negated: false }";
+ let actual = format!("{:?}", execution_plan);
+ assert!(actual.contains(expected), "{}", actual);
Ok(())
}
@@ -2068,50 +2069,6 @@ mod tests {
lit(struct_literal)
}
- #[tokio::test]
- async fn in_set_test() -> Result<()> {
- // OPTIMIZER_INSET_THRESHOLD = 10
- // expression: "a in ('a', 1, 2, ..30)"
- let mut list = vec![Expr::Literal(ScalarValue::Utf8(Some("a".to_string())))];
- for i in 1..31 {
- list.push(Expr::Literal(ScalarValue::Int64(Some(i))));
- }
- let logical_plan = test_csv_scan()
- .await?
- .filter(col("c12").lt(lit(0.05)))?
- .project(vec![col("c1").in_list(list, false)])?
- .build()?;
- let execution_plan = plan(&logical_plan).await?;
- let expected = "expr: [(InListExpr { expr: Column { name: \"c1\", index: 0 }, list: [Literal { value: Utf8(\"a\") }, Literal { value: Utf8(\"1\") }, Literal { value: Utf8(\"2\") },";
- assert!(format!("{:?}", execution_plan).contains(expected));
- let expected =
- "Literal { value: Utf8(\"30\") }], negated: false, set: Some(InSet { set: ";
- assert!(format!("{:?}", execution_plan).contains(expected));
- Ok(())
- }
-
- #[tokio::test]
- async fn in_set_null_test() -> Result<()> {
- // test NULL
- let mut list = vec![Expr::Literal(ScalarValue::Int64(None))];
- for i in 1..31 {
- list.push(Expr::Literal(ScalarValue::Int64(Some(i))));
- }
-
- let logical_plan = test_csv_scan()
- .await?
- .filter(col("c12").lt(lit(0.05)))?
- .project(vec![col("c1").in_list(list, false)])?
- .build()?;
- let execution_plan = plan(&logical_plan).await?;
- let expected = "expr: [(InListExpr { expr: Column { name: \"c1\", index: 0 }, list: [Literal { value: Utf8(NULL) }, Literal { value: Utf8(\"1\") }, Literal { value: Utf8(\"2\") }";
- assert!(format!("{:?}", execution_plan).contains(expected));
- let expected =
- "Literal { value: Utf8(\"30\") }], negated: false, set: Some(InSet";
- assert!(format!("{:?}", execution_plan).contains(expected));
- Ok(())
- }
-
#[tokio::test]
async fn hash_agg_input_schema() -> Result<()> {
let logical_plan = test_csv_scan_with_name("aggregate_test_100")
diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml
index 12cb7311b..6fc6f4176 100644
--- a/datafusion/physical-expr/Cargo.toml
+++ b/datafusion/physical-expr/Cargo.toml
@@ -41,12 +41,15 @@ unicode_expressions = ["unicode-segmentation"]
[dependencies]
ahash = { version = "0.8", default-features = false, features = ["runtime-rng"] }
arrow = { version = "25.0.0", features = ["prettyprint"] }
+arrow-buffer = "25.0.0"
+arrow-schema = "25.0.0"
blake2 = { version = "^0.10.2", optional = true }
blake3 = { version = "1.0", optional = true }
chrono = { version = "0.4.22", default-features = false }
datafusion-common = { path = "../common", version = "13.0.0" }
datafusion-expr = { path = "../expr", version = "13.0.0" }
datafusion-row = { path = "../row", version = "13.0.0" }
+half = { version = "2.1", default-features = false }
hashbrown = { version = "0.12", features = ["raw"] }
itertools = { version = "0.10", features = ["use_std"] }
lazy_static = { version = "^1.4.0" }
diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs
index 783b898e1..9406b42ee 100644
--- a/datafusion/physical-expr/src/expressions/in_list.rs
+++ b/datafusion/physical-expr/src/expressions/in_list.rs
@@ -17,50 +17,31 @@
//! InList expression
+use ahash::RandomState;
use std::any::Any;
-use std::collections::HashSet;
use std::fmt::Debug;
use std::sync::Arc;
-use arrow::array::GenericStringArray;
-use arrow::array::{
- ArrayRef, BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array,
- Int64Array, Int8Array, OffsetSizeTrait, TimestampMicrosecondArray,
- TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray,
- UInt16Array, UInt32Array, UInt64Array, UInt8Array,
-};
-use arrow::{
- datatypes::{DataType, Schema},
- record_batch::RecordBatch,
-};
-
+use crate::hash_utils::HashValue;
use crate::physical_expr::down_cast_any_ref;
use crate::utils::expr_list_eq_any_order;
use crate::PhysicalExpr;
use arrow::array::*;
-use arrow::datatypes::TimeUnit;
-use datafusion_common::cast::as_date32_array;
-use datafusion_common::ScalarValue;
-use datafusion_common::ScalarValue::{
- Binary, Boolean, Date32, Date64, Decimal128, Int16, Int32, Int64, Int8, LargeBinary,
- LargeUtf8, TimestampMicrosecond, TimestampMillisecond, TimestampNanosecond,
- TimestampSecond, UInt16, UInt32, UInt64, UInt8, Utf8,
-};
-use datafusion_common::{DataFusionError, Result};
+use arrow::datatypes::*;
+use arrow::downcast_primitive_array;
+use arrow::record_batch::RecordBatch;
+use arrow::util::bit_iterator::BitIndexIterator;
+use datafusion_common::{DataFusionError, Result, ScalarValue};
use datafusion_expr::ColumnarValue;
-
-/// Size at which to use a Set rather than Vec for `IN` / `NOT IN`
-/// Value chosen by the benchmark at
-/// https://github.com/apache/arrow-datafusion/pull/2156#discussion_r845198369
-/// TODO: add switch codeGen in In_List
-static OPTIMIZER_INSET_THRESHOLD: usize = 30;
+use hashbrown::hash_map::RawEntryMut;
+use hashbrown::HashMap;
/// InList
pub struct InListExpr {
expr: Arc<dyn PhysicalExpr>,
list: Vec<Arc<dyn PhysicalExpr>>,
negated: bool,
- set: Option<InSet>,
+ static_filter: Option<Box<dyn Set>>,
input_schema: Schema,
}
@@ -70,320 +51,172 @@ impl Debug for InListExpr {
.field("expr", &self.expr)
.field("list", &self.list)
.field("negated", &self.negated)
- .field("set", &self.set)
.finish()
}
}
-/// InSet
-#[derive(Debug, PartialEq, Eq)]
-pub struct InSet {
- // TODO: optimization: In the `IN` or `NOT IN` we don't need to consider the NULL value
- // The data type is same, we can use set: HashSet<T>
- set: HashSet<ScalarValue>,
+/// A type-erased container of array elements
+trait Set: Send + Sync {
+ fn contains(&self, v: &dyn Array, negated: bool) -> BooleanArray;
}
-impl InSet {
- pub fn new(set: HashSet<ScalarValue>) -> Self {
- Self { set }
- }
-
- pub fn get_set(&self) -> &HashSet<ScalarValue> {
- &self.set
- }
+struct ArrayHashSet {
+ state: RandomState,
+ /// Used to provide a lookup from value to in list index
+ ///
+ /// Note: usize::hash is not used, instead the raw entry
+ /// API is used to store entries w.r.t their value
+ map: HashMap<usize, (), ()>,
}
-macro_rules! make_contains {
- ($ARRAY:expr, $LIST_VALUES:expr, $NEGATED:expr, $SCALAR_VALUE:ident, $ARRAY_TYPE:ident) => {{
- let array = $ARRAY.as_any().downcast_ref::<$ARRAY_TYPE>().unwrap();
-
- let contains_null = $LIST_VALUES
- .iter()
- .any(|v| matches!(v, ColumnarValue::Scalar(s) if s.is_null()));
- let values = $LIST_VALUES
- .iter()
- .flat_map(|expr| match expr {
- ColumnarValue::Scalar(s) => match s {
- ScalarValue::$SCALAR_VALUE(Some(v)) => Some(*v),
- ScalarValue::$SCALAR_VALUE(None) => None,
- datatype => unreachable!("InList can't reach other data type {} for {}.", datatype, s),
- },
- ColumnarValue::Array(_) => {
- unimplemented!("InList does not yet support nested columns.")
- }
- })
- .collect::<Vec<_>>();
-
- collection_contains_check!(array, values, $NEGATED, contains_null)
- }};
+struct ArraySet<T> {
+ array: T,
+ hash_set: ArrayHashSet,
}
-macro_rules! make_contains_primitive {
- ($ARRAY:expr, $LIST_VALUES:expr, $NEGATED:expr, $SCALAR_VALUE:ident, $ARRAY_TYPE:ident) => {{
- let array = $ARRAY.as_any().downcast_ref::<$ARRAY_TYPE>().unwrap();
-
- let contains_null = $LIST_VALUES
- .iter()
- .any(|v| matches!(v, ColumnarValue::Scalar(s) if s.is_null()));
- let values = $LIST_VALUES
- .iter()
- .flat_map(|expr| match expr {
- ColumnarValue::Scalar(s) => match s {
- ScalarValue::$SCALAR_VALUE(Some(v), ..) => Some(*v),
- ScalarValue::$SCALAR_VALUE(None, ..) => None,
- datatype => unreachable!("InList can't reach other data type {} for {}.", datatype, s),
- },
- ColumnarValue::Array(_) => {
- unimplemented!("InList does not yet support nested columns.")
- }
- })
- .collect::<Vec<_>>();
-
- Ok(collection_contains_check!(array, values, $NEGATED, contains_null))
- }};
-}
-
-macro_rules! set_contains_for_float {
- ($ARRAY:expr, $SET_VALUES:expr, $SCALAR_VALUE:ident, $NEGATED:expr) => {{
- let contains_null = $SET_VALUES.iter().any(|s| s.is_null());
- let bool_array = if $NEGATED {
- // Not in
- if contains_null {
- $ARRAY
- .iter()
- .map(|vop| {
- match vop.map(|v| !$SET_VALUES.contains(&v.try_into().unwrap())) {
- Some(true) => None,
- x => x,
- }
- })
- .collect::<BooleanArray>()
- } else {
- $ARRAY
- .iter()
- .map(|vop| vop.map(|v| !$SET_VALUES.contains(&v.try_into().unwrap())))
- .collect::<BooleanArray>()
- }
- } else {
- // In
- if contains_null {
- $ARRAY
- .iter()
- .map(|vop| {
- match vop.map(|v| $SET_VALUES.contains(&v.try_into().unwrap())) {
- Some(false) => None,
- x => x,
- }
- })
- .collect::<BooleanArray>()
- } else {
- $ARRAY
- .iter()
- .map(|vop| vop.map(|v| $SET_VALUES.contains(&v.try_into().unwrap())))
- .collect::<BooleanArray>()
- }
- };
- ColumnarValue::Array(Arc::new(bool_array))
- }};
+impl<T> ArraySet<T>
+where
+ T: Array + From<ArrayData>,
+{
+ fn new(array: &T, hash_set: ArrayHashSet) -> Self {
+ Self {
+ array: T::from(array.data().clone()),
+ hash_set,
+ }
+ }
}
-macro_rules! set_contains_for_primitive {
- ($ARRAY:expr, $SET_VALUES:expr, $SCALAR_VALUE:ident, $NEGATED:expr) => {{
- let contains_null = $SET_VALUES.iter().any(|s| s.is_null());
- let native_set = $SET_VALUES
- .iter()
- .flat_map(|v| match v {
- $SCALAR_VALUE(value, ..) => *value,
- datatype => {
- unreachable!(
- "InList can't reach other data type {} for {}.",
- datatype, v
- )
- }
+impl<T> Set for ArraySet<T>
+where
+ T: Array + 'static,
+ for<'a> &'a T: ArrayAccessor,
+ for<'a> <&'a T as ArrayAccessor>::Item: PartialEq + HashValue,
+{
+ fn contains(&self, v: &dyn Array, negated: bool) -> BooleanArray {
+ let v = v.as_any().downcast_ref::<T>().unwrap();
+ let in_data = self.array.data();
+ let in_array = &self.array;
+ let has_nulls = in_data.null_count() != 0;
+
+ ArrayIter::new(v)
+ .map(|v| {
+ v.and_then(|v| {
+ let hash = v.hash_one(&self.hash_set.state);
+ let contains = self
+ .hash_set
+ .map
+ .raw_entry()
+ .from_hash(hash, |idx| in_array.value(*idx) == v)
+ .is_some();
+
+ match contains {
+ true => Some(!negated),
+ false if has_nulls => None,
+ false => Some(negated),
+ }
+ })
})
- .collect::<HashSet<_>>();
-
- collection_contains_check!($ARRAY, native_set, $NEGATED, contains_null)
- }};
-}
-
-macro_rules! collection_contains_check {
- ($ARRAY:expr, $VALUES:expr, $NEGATED:expr, $CONTAINS_NULL:expr) => {{
- let bool_array = if $NEGATED {
- // Not in
- if $CONTAINS_NULL {
- $ARRAY
- .iter()
- .map(|vop| match vop.map(|v| !$VALUES.contains(&v)) {
- Some(true) => None,
- x => x,
- })
- .collect::<BooleanArray>()
- } else {
- $ARRAY
- .iter()
- .map(|vop| vop.map(|v| !$VALUES.contains(&v)))
- .collect::<BooleanArray>()
- }
- } else {
- // In
- if $CONTAINS_NULL {
- $ARRAY
- .iter()
- .map(|vop| match vop.map(|v| $VALUES.contains(&v)) {
- Some(false) => None,
- x => x,
- })
- .collect::<BooleanArray>()
- } else {
- $ARRAY
- .iter()
- .map(|vop| vop.map(|v| $VALUES.contains(&v)))
- .collect::<BooleanArray>()
- }
- };
- ColumnarValue::Array(Arc::new(bool_array))
- }};
-}
-
-macro_rules! collection_contains_check_decimal {
- ($ARRAY:expr, $VALUES:expr, $NEGATED:expr, $CONTAINS_NULL:expr) => {{
- let bool_array = if $NEGATED {
- // Not in
- if $CONTAINS_NULL {
- $ARRAY
- .iter()
- .map(|vop| match vop.map(|v| !$VALUES.contains(&v)) {
- Some(true) => None,
- x => x,
- })
- .collect::<BooleanArray>()
- } else {
- $ARRAY
- .iter()
- .map(|vop| vop.map(|v| !$VALUES.contains(&v)))
- .collect::<BooleanArray>()
- }
- } else {
- // In
- if $CONTAINS_NULL {
- $ARRAY
- .iter()
- .map(|vop| match vop.map(|v| $VALUES.contains(&v)) {
- Some(false) => None,
- x => x,
- })
- .collect::<BooleanArray>()
- } else {
- $ARRAY
- .iter()
- .map(|vop| vop.map(|v| $VALUES.contains(&v)))
- .collect::<BooleanArray>()
- }
- };
- ColumnarValue::Array(Arc::new(bool_array))
- }};
+ .collect()
+ }
}
-// try evaluate all list exprs and check if the exprs are constants or not
-fn try_cast_static_filter_to_set(
- list: &[Arc<dyn PhysicalExpr>],
- schema: &Schema,
-) -> Result<HashSet<ScalarValue>> {
- let batch = RecordBatch::new_empty(Arc::new(schema.to_owned()));
- list.iter()
- .map(|expr| match expr.evaluate(&batch) {
- Ok(ColumnarValue::Array(_)) => Err(DataFusionError::NotImplemented(
- "InList doesn't support to evaluate the array result".to_string(),
- )),
- Ok(ColumnarValue::Scalar(s)) => Ok(s),
- Err(e) => Err(e),
- })
- .collect::<Result<HashSet<_>>>()
-}
+/// Computes an [`ArrayHashSet`] for the provided [`Array`] if there are nulls present
+/// or there are more than [`OPTIMIZER_INSET_THRESHOLD`] values
+///
+/// Note: This is split into a separate function as higher-rank trait bounds currently
+/// cause type inference to misbehave
+fn make_hash_set<T>(array: T) -> ArrayHashSet
+where
+ T: ArrayAccessor,
+ T::Item: PartialEq + HashValue,
+{
+ let data = array.data();
+
+ let state = RandomState::new();
+ let mut map: HashMap<usize, (), ()> =
+ HashMap::with_capacity_and_hasher(data.len(), ());
+
+ let insert_value = |idx| {
+ let value = array.value(idx);
+ let hash = value.hash_one(&state);
+ if let RawEntryMut::Vacant(v) = map
+ .raw_entry_mut()
+ .from_hash(hash, |x| array.value(*x) == value)
+ {
+ v.insert_with_hasher(hash, idx, (), |x| array.value(*x).hash_one(&state));
+ }
+ };
-fn make_list_contains_decimal(
- array: &Decimal128Array,
- list: Vec<ColumnarValue>,
- negated: bool,
-) -> ColumnarValue {
- let contains_null = list
- .iter()
- .any(|v| matches!(v, ColumnarValue::Scalar(s) if s.is_null()));
- let values = list
- .iter()
- .flat_map(|v| match v {
- ColumnarValue::Scalar(s) => match s {
- Decimal128(v128op, _, _) => *v128op,
- datatype => unreachable!(
- "InList can't reach other data type {} for {}.",
- datatype, s
- ),
- },
- ColumnarValue::Array(_) => {
- unimplemented!("InList does not yet support nested columns.")
- }
- })
- .collect::<Vec<_>>();
+ match data.null_buffer() {
+ Some(buffer) => BitIndexIterator::new(buffer.as_ref(), data.offset(), data.len())
+ .for_each(insert_value),
+ None => (0..data.len()).for_each(insert_value),
+ }
- collection_contains_check_decimal!(array, values, negated, contains_null)
+ ArrayHashSet { state, map }
}
-fn make_set_contains_decimal(
- array: &Decimal128Array,
- set: &HashSet<ScalarValue>,
- negated: bool,
-) -> ColumnarValue {
- let contains_null = set.iter().any(|v| v.is_null());
- let native_set = set
- .iter()
- .flat_map(|v| match v {
- Decimal128(v128op, _, _) => *v128op,
- datatype => {
- unreachable!("InList can't reach other data type {} for {}.", datatype, v)
- }
- })
- .collect::<HashSet<_>>();
-
- collection_contains_check_decimal!(array, native_set, negated, contains_null)
+/// Creates a `Box<dyn Set>` for the given list of `IN` expressions and `batch`
+fn make_set(array: &dyn Array) -> Result<Box<dyn Set>> {
+ Ok(downcast_primitive_array! {
+ array => Box::new(ArraySet::new(array, make_hash_set(array))),
+ DataType::Boolean => {
+ let array = as_boolean_array(array);
+ Box::new(ArraySet::new(array, make_hash_set(array)))
+ },
+ DataType::Decimal128(_, _) => {
+ let array = as_primitive_array::<Decimal128Type>(array);
+ Box::new(ArraySet::new(array, make_hash_set(array)))
+ }
+ DataType::Decimal256(_, _) => {
+ let array = as_primitive_array::<Decimal256Type>(array);
+ Box::new(ArraySet::new(array, make_hash_set(array)))
+ }
+ DataType::Utf8 => {
+ let array = as_string_array(array);
+ Box::new(ArraySet::new(array, make_hash_set(array)))
+ }
+ DataType::LargeUtf8 => {
+ let array = as_largestring_array(array);
+ Box::new(ArraySet::new(array, make_hash_set(array)))
+ }
+ DataType::Binary => {
+ let array = as_generic_binary_array::<i32>(array);
+ Box::new(ArraySet::new(array, make_hash_set(array)))
+ }
+ DataType::LargeBinary => {
+ let array = as_generic_binary_array::<i64>(array);
+ Box::new(ArraySet::new(array, make_hash_set(array)))
+ }
+ d => return Err(DataFusionError::NotImplemented(format!("DataType::{} not supported in InList", d)))
+ })
}
-fn set_contains_utf8<OffsetSize: OffsetSizeTrait>(
- array: &GenericStringArray<OffsetSize>,
- set: &HashSet<ScalarValue>,
- negated: bool,
-) -> ColumnarValue {
- let contains_null = set.iter().any(|v| v.is_null());
- let native_set = set
+fn evaluate_list(
+ list: &[Arc<dyn PhysicalExpr>],
+ batch: &RecordBatch,
+) -> Result<ArrayRef> {
+ let scalars = list
.iter()
- .flat_map(|v| match v {
- Utf8(v) | LargeUtf8(v) => v.as_deref(),
- datatype => {
- unreachable!("InList can't reach other data type {} for {}.", datatype, v)
- }
+ .map(|expr| {
+ expr.evaluate(batch).and_then(|r| match r {
+ ColumnarValue::Array(_) => Err(DataFusionError::Execution(
+ "InList expression must evaluate to a scalar".to_string(),
+ )),
+ ColumnarValue::Scalar(s) => Ok(s),
+ })
})
- .collect::<HashSet<_>>();
+ .collect::<Result<Vec<_>>>()?;
- collection_contains_check!(array, native_set, negated, contains_null)
+ ScalarValue::iter_to_array(scalars)
}
-fn set_contains_binary<OffsetSize: OffsetSizeTrait>(
- array: &GenericBinaryArray<OffsetSize>,
- set: &HashSet<ScalarValue>,
- negated: bool,
-) -> ColumnarValue {
- let contains_null = set.iter().any(|v| v.is_null());
- let native_set = set
- .iter()
- .flat_map(|v| match v {
- Binary(v) | LargeBinary(v) => v.as_deref(),
- datatype => {
- unreachable!("InList can't reach other data type {} for {}.", datatype, v)
- }
- })
- .collect::<HashSet<_>>();
-
- collection_contains_check!(array, native_set, negated, contains_null)
+fn try_cast_static_filter_to_set(
+ list: &[Arc<dyn PhysicalExpr>],
+ schema: &Schema,
+) -> Result<Box<dyn Set>> {
+ let batch = RecordBatch::new_empty(Arc::new(schema.clone()));
+ make_set(evaluate_list(list, &batch)?.as_ref())
}
impl InListExpr {
@@ -394,22 +227,12 @@ impl InListExpr {
negated: bool,
schema: &Schema,
) -> Self {
- if list.len() > OPTIMIZER_INSET_THRESHOLD {
- if let Ok(set) = try_cast_static_filter_to_set(&list, schema) {
- return Self {
- expr,
- set: Some(InSet::new(set)),
- list,
- negated,
- input_schema: schema.clone(),
- };
- }
- }
+ let static_filter = try_cast_static_filter_to_set(&list, schema).ok();
Self {
expr,
list,
negated,
- set: None,
+ static_filter,
input_schema: schema.clone(),
}
}
@@ -428,95 +251,17 @@ impl InListExpr {
pub fn negated(&self) -> bool {
self.negated
}
-
- /// Compare for specific utf8 types
- #[allow(clippy::unnecessary_wraps)]
- fn compare_utf8<T: OffsetSizeTrait>(
- &self,
- array: ArrayRef,
- list_values: Vec<ColumnarValue>,
- negated: bool,
- ) -> Result<ColumnarValue> {
- let array = array
- .as_any()
- .downcast_ref::<GenericStringArray<T>>()
- .unwrap();
-
- let contains_null = list_values
- .iter()
- .any(|v| matches!(v, ColumnarValue::Scalar(s) if s.is_null()));
- let values = list_values
- .iter()
- .flat_map(|expr| match expr {
- ColumnarValue::Scalar(s) => match s {
- ScalarValue::Utf8(Some(v)) => Some(v.as_str()),
- ScalarValue::Utf8(None) => None,
- ScalarValue::LargeUtf8(Some(v)) => Some(v.as_str()),
- ScalarValue::LargeUtf8(None) => None,
- datatype => unimplemented!("Unexpected type {} for InList", datatype),
- },
- ColumnarValue::Array(_) => {
- unimplemented!("InList does not yet support nested columns.")
- }
- })
- .collect::<Vec<&str>>();
-
- Ok(collection_contains_check!(
- array,
- values,
- negated,
- contains_null
- ))
- }
-
- fn compare_binary<T: OffsetSizeTrait>(
- &self,
- array: ArrayRef,
- list_values: Vec<ColumnarValue>,
- negated: bool,
- ) -> Result<ColumnarValue> {
- let array = array
- .as_any()
- .downcast_ref::<GenericBinaryArray<T>>()
- .unwrap();
-
- let contains_null = list_values
- .iter()
- .any(|v| matches!(v, ColumnarValue::Scalar(s) if s.is_null()));
- let values = list_values
- .iter()
- .flat_map(|expr| match expr {
- ColumnarValue::Scalar(s) => match s {
- ScalarValue::Binary(Some(v)) | ScalarValue::LargeBinary(Some(v)) => {
- Some(v.as_slice())
- }
- ScalarValue::Binary(None) | ScalarValue::LargeBinary(None) => None,
- datatype => unimplemented!("Unexpected type {} for InList", datatype),
- },
- ColumnarValue::Array(_) => {
- unimplemented!("InList does not yet support nested columns.")
- }
- })
- .collect::<Vec<&[u8]>>();
-
- Ok(collection_contains_check!(
- array,
- values,
- negated,
- contains_null
- ))
- }
}
impl std::fmt::Display for InListExpr {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
if self.negated {
- if self.set.is_some() {
+ if self.static_filter.is_some() {
write!(f, "{} NOT IN (SET) ({:?})", self.expr, self.list)
} else {
write!(f, "{} NOT IN ({:?})", self.expr, self.list)
}
- } else if self.set.is_some() {
+ } else if self.static_filter.is_some() {
write!(f, "Use {} IN (SET) ({:?})", self.expr, self.list)
} else {
write!(f, "{} IN ({:?})", self.expr, self.list)
@@ -539,382 +284,15 @@ impl PhysicalExpr for InListExpr {
}
fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
- let value = self.expr.evaluate(batch)?;
- let value_data_type = value.data_type();
-
- if let Some(in_set) = &self.set {
- let array = match value {
- ColumnarValue::Array(array) => array,
- ColumnarValue::Scalar(scalar) => scalar.to_array(),
- };
- let set = in_set.get_set();
- match value_data_type {
- DataType::Boolean => {
- let array = array.as_any().downcast_ref::<BooleanArray>().unwrap();
- Ok(set_contains_for_primitive!(
- array,
- set,
- Boolean,
- self.negated
- ))
- }
- DataType::Int8 => {
- let array = array.as_any().downcast_ref::<Int8Array>().unwrap();
- Ok(set_contains_for_primitive!(array, set, Int8, self.negated))
- }
- DataType::Int16 => {
- let array = array.as_any().downcast_ref::<Int16Array>().unwrap();
- Ok(set_contains_for_primitive!(array, set, Int16, self.negated))
- }
- DataType::Int32 => {
- let array = array.as_any().downcast_ref::<Int32Array>().unwrap();
- Ok(set_contains_for_primitive!(array, set, Int32, self.negated))
- }
- DataType::Int64 => {
- let array = array.as_any().downcast_ref::<Int64Array>().unwrap();
- Ok(set_contains_for_primitive!(array, set, Int64, self.negated))
- }
- DataType::UInt8 => {
- let array = array.as_any().downcast_ref::<UInt8Array>().unwrap();
- Ok(set_contains_for_primitive!(array, set, UInt8, self.negated))
- }
- DataType::UInt16 => {
- let array = array.as_any().downcast_ref::<UInt16Array>().unwrap();
- Ok(set_contains_for_primitive!(
- array,
- set,
- UInt16,
- self.negated
- ))
- }
- DataType::UInt32 => {
- let array = array.as_any().downcast_ref::<UInt32Array>().unwrap();
- Ok(set_contains_for_primitive!(
- array,
- set,
- UInt32,
- self.negated
- ))
- }
- DataType::UInt64 => {
- let array = array.as_any().downcast_ref::<UInt64Array>().unwrap();
- Ok(set_contains_for_primitive!(
- array,
- set,
- UInt64,
- self.negated
- ))
- }
- DataType::Date32 => {
- let array = as_date32_array(&array)?;
- Ok(set_contains_for_primitive!(
- array,
- set,
- Date32,
- self.negated
- ))
- }
- DataType::Date64 => {
- let array = array.as_any().downcast_ref::<Date64Array>().unwrap();
- Ok(set_contains_for_primitive!(
- array,
- set,
- Date64,
- self.negated
- ))
- }
- DataType::Float32 => {
- let array = array.as_any().downcast_ref::<Float32Array>().unwrap();
- Ok(set_contains_for_float!(array, set, Float32, self.negated))
- }
- DataType::Float64 => {
- let array = array.as_any().downcast_ref::<Float64Array>().unwrap();
- Ok(set_contains_for_float!(array, set, Float64, self.negated))
- }
- DataType::Utf8 => {
- let array = array
- .as_any()
- .downcast_ref::<GenericStringArray<i32>>()
- .unwrap();
- Ok(set_contains_utf8(array, set, self.negated))
- }
- DataType::LargeUtf8 => {
- let array = array
- .as_any()
- .downcast_ref::<GenericStringArray<i64>>()
- .unwrap();
- Ok(set_contains_utf8(array, set, self.negated))
- }
- DataType::Binary => {
- let array = array
- .as_any()
- .downcast_ref::<GenericBinaryArray<i32>>()
- .unwrap();
- Ok(set_contains_binary(array, set, self.negated))
- }
- DataType::LargeBinary => {
- let array = array
- .as_any()
- .downcast_ref::<GenericBinaryArray<i64>>()
- .unwrap();
- Ok(set_contains_binary(array, set, self.negated))
- }
- DataType::Decimal128(_, _) => {
- let array = array.as_any().downcast_ref::<Decimal128Array>().unwrap();
- Ok(make_set_contains_decimal(array, set, self.negated))
- }
- DataType::Timestamp(unit, _) => match unit {
- TimeUnit::Second => {
- let array = array
- .as_any()
- .downcast_ref::<TimestampSecondArray>()
- .unwrap();
- Ok(set_contains_for_primitive!(
- array,
- set,
- TimestampSecond,
- self.negated
- ))
- }
- TimeUnit::Millisecond => {
- let array = array
- .as_any()
- .downcast_ref::<TimestampMillisecondArray>()
- .unwrap();
- Ok(set_contains_for_primitive!(
- array,
- set,
- TimestampMillisecond,
- self.negated
- ))
- }
- TimeUnit::Microsecond => {
- let array = array
- .as_any()
- .downcast_ref::<TimestampMicrosecondArray>()
- .unwrap();
- Ok(set_contains_for_primitive!(
- array,
- set,
- TimestampMicrosecond,
- self.negated
- ))
- }
- TimeUnit::Nanosecond => {
- let array = array
- .as_any()
- .downcast_ref::<TimestampNanosecondArray>()
- .unwrap();
- Ok(set_contains_for_primitive!(
- array,
- set,
- TimestampNanosecond,
- self.negated
- ))
- }
- },
- datatype => Result::Err(DataFusionError::NotImplemented(format!(
- "InSet does not support datatype {:?}.",
- datatype
- ))),
+ let value = self.expr.evaluate(batch)?.into_array(1);
+ let r = match &self.static_filter {
+ Some(f) => f.contains(value.as_ref(), self.negated),
+ None => {
+ let list = evaluate_list(&self.list, batch)?;
+ make_set(list.as_ref())?.contains(value.as_ref(), self.negated)
}
- } else {
- let list_values = self
- .list
- .iter()
- .map(|expr| expr.evaluate(batch))
- .collect::<Result<Vec<_>>>()?;
-
- let array = match value {
- ColumnarValue::Array(array) => array,
- ColumnarValue::Scalar(scalar) => scalar.to_array(),
- };
-
- match value_data_type {
- DataType::Float32 => {
- make_contains_primitive!(
- array,
- list_values,
- self.negated,
- Float32,
- Float32Array
- )
- }
- DataType::Float64 => {
- make_contains_primitive!(
- array,
- list_values,
- self.negated,
- Float64,
- Float64Array
- )
- }
- DataType::Int16 => {
- make_contains_primitive!(
- array,
- list_values,
- self.negated,
- Int16,
- Int16Array
- )
- }
- DataType::Int32 => {
- make_contains_primitive!(
- array,
- list_values,
- self.negated,
- Int32,
- Int32Array
- )
- }
- DataType::Int64 => {
- make_contains_primitive!(
- array,
- list_values,
- self.negated,
- Int64,
- Int64Array
- )
- }
- DataType::Int8 => {
- make_contains_primitive!(
- array,
- list_values,
- self.negated,
- Int8,
- Int8Array
- )
- }
- DataType::UInt16 => {
- make_contains_primitive!(
- array,
- list_values,
- self.negated,
- UInt16,
- UInt16Array
- )
- }
- DataType::UInt32 => {
- make_contains_primitive!(
- array,
- list_values,
- self.negated,
- UInt32,
- UInt32Array
- )
- }
- DataType::UInt64 => {
- make_contains_primitive!(
- array,
- list_values,
- self.negated,
- UInt64,
- UInt64Array
- )
- }
- DataType::UInt8 => {
- make_contains_primitive!(
- array,
- list_values,
- self.negated,
- UInt8,
- UInt8Array
- )
- }
- DataType::Date32 => {
- make_contains_primitive!(
- array,
- list_values,
- self.negated,
- Date32,
- Date32Array
- )
- }
- DataType::Date64 => {
- make_contains_primitive!(
- array,
- list_values,
- self.negated,
- Date64,
- Date64Array
- )
- }
- DataType::Boolean => Ok(make_contains!(
- array,
- list_values,
- self.negated,
- Boolean,
- BooleanArray
- )),
- DataType::Utf8 => {
- self.compare_utf8::<i32>(array, list_values, self.negated)
- }
- DataType::LargeUtf8 => {
- self.compare_utf8::<i64>(array, list_values, self.negated)
- }
- DataType::Binary => {
- self.compare_binary::<i32>(array, list_values, self.negated)
- }
- DataType::LargeBinary => {
- self.compare_binary::<i64>(array, list_values, self.negated)
- }
- DataType::Null => {
- let null_array = new_null_array(&DataType::Boolean, array.len());
- Ok(ColumnarValue::Array(Arc::new(null_array)))
- }
- DataType::Decimal128(_, _) => {
- let decimal_array =
- array.as_any().downcast_ref::<Decimal128Array>().unwrap();
- Ok(make_list_contains_decimal(
- decimal_array,
- list_values,
- self.negated,
- ))
- }
- DataType::Timestamp(unit, _) => match unit {
- TimeUnit::Second => {
- make_contains_primitive!(
- array,
- list_values,
- self.negated,
- TimestampSecond,
- TimestampSecondArray
- )
- }
- TimeUnit::Millisecond => {
- make_contains_primitive!(
- array,
- list_values,
- self.negated,
- TimestampMillisecond,
- TimestampMillisecondArray
- )
- }
- TimeUnit::Microsecond => {
- make_contains_primitive!(
- array,
- list_values,
- self.negated,
- TimestampMicrosecond,
- TimestampMicrosecondArray
- )
- }
- TimeUnit::Nanosecond => {
- make_contains_primitive!(
- array,
- list_values,
- self.negated,
- TimestampNanosecond,
- TimestampNanosecondArray
- )
- }
- },
- datatype => Result::Err(DataFusionError::NotImplemented(format!(
- "InList does not support datatype {:?}.",
- datatype
- ))),
- }
- }
+ };
+ Ok(ColumnarValue::Array(Arc::new(r)))
}
fn children(&self) -> Vec<Arc<dyn PhysicalExpr>> {
@@ -945,7 +323,6 @@ impl PartialEq<dyn Any> for InListExpr {
self.expr.eq(&x.expr)
&& expr_list_eq_any_order(&self.list, &x.list)
&& self.negated == x.negated
- && self.set == x.set
})
.unwrap_or(false)
}
@@ -1326,7 +703,10 @@ mod tests {
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
// expression: "a in (0, 1)"
- let list = vec![lit(Date64(Some(0))), lit(Date64(Some(1)))];
+ let list = vec![
+ lit(ScalarValue::Date64(Some(0))),
+ lit(ScalarValue::Date64(Some(1))),
+ ];
in_list!(
batch,
list,
@@ -1337,7 +717,10 @@ mod tests {
);
// expression: "a not in (0, 1)"
- let list = vec![lit(Date64(Some(0))), lit(Date64(Some(1)))];
+ let list = vec![
+ lit(ScalarValue::Date64(Some(0))),
+ lit(ScalarValue::Date64(Some(1))),
+ ];
in_list!(
batch,
list,
@@ -1349,8 +732,8 @@ mod tests {
// expression: "a in (0, 1, NULL)"
let list = vec![
- lit(Date64(Some(0))),
- lit(Date64(Some(1))),
+ lit(ScalarValue::Date64(Some(0))),
+ lit(ScalarValue::Date64(Some(1))),
lit(ScalarValue::Null),
];
in_list!(
@@ -1364,8 +747,8 @@ mod tests {
// expression: "a not in (0, 1, NULL)"
let list = vec![
- lit(Date64(Some(0))),
- lit(Date64(Some(1))),
+ lit(ScalarValue::Date64(Some(0))),
+ lit(ScalarValue::Date64(Some(1))),
lit(ScalarValue::Null),
];
in_list!(
@@ -1388,7 +771,10 @@ mod tests {
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
// expression: "a in (0, 1)"
- let list = vec![lit(Date32(Some(0))), lit(Date32(Some(1)))];
+ let list = vec![
+ lit(ScalarValue::Date32(Some(0))),
+ lit(ScalarValue::Date32(Some(1))),
+ ];
in_list!(
batch,
list,
@@ -1399,7 +785,10 @@ mod tests {
);
// expression: "a not in (0, 1)"
- let list = vec![lit(Date32(Some(0))), lit(Date32(Some(1)))];
+ let list = vec![
+ lit(ScalarValue::Date32(Some(0))),
+ lit(ScalarValue::Date32(Some(1))),
+ ];
in_list!(
batch,
list,
@@ -1411,8 +800,8 @@ mod tests {
// expression: "a in (0, 1, NULL)"
let list = vec![
- lit(Date32(Some(0))),
- lit(Date32(Some(1))),
+ lit(ScalarValue::Date32(Some(0))),
+ lit(ScalarValue::Date32(Some(1))),
lit(ScalarValue::Null),
];
in_list!(
@@ -1426,8 +815,8 @@ mod tests {
// expression: "a not in (0, 1, NULL)"
let list = vec![
- lit(Date32(Some(0))),
- lit(Date32(Some(1))),
+ lit(ScalarValue::Date32(Some(0))),
+ lit(ScalarValue::Date32(Some(1))),
lit(ScalarValue::Null),
];
in_list!(
@@ -1543,237 +932,12 @@ mod tests {
Ok(())
}
- #[test]
- fn in_list_set_bool() -> Result<()> {
- let schema = Schema::new(vec![Field::new("a", DataType::Boolean, true)]);
- let a = BooleanArray::from(vec![Some(true), None, Some(false)]);
- let col_a = col("a", &schema)?;
- let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
-
- // expression: "a in (true,null,true.....)"
- let mut list = vec![
- lit(ScalarValue::Boolean(Some(true))),
- lit(ScalarValue::Boolean(None)),
- ];
- for _ in 0..OPTIMIZER_INSET_THRESHOLD {
- list.push(lit(ScalarValue::Boolean(Some(true))));
- }
- in_list!(
- batch,
- list.clone(),
- &false,
- vec![Some(true), None, None],
- col_a.clone(),
- &schema
- );
- in_list!(
- batch,
- list,
- &true,
- vec![Some(false), None, None],
- col_a.clone(),
- &schema
- );
- Ok(())
- }
-
- #[test]
- fn in_list_set_int64() -> Result<()> {
- let schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]);
- let a = Int64Array::from(vec![Some(0), Some(2), None]);
- let col_a = col("a", &schema)?;
- let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
-
- // expression: "a in (0,NULL,3,4....)"
- let mut list = vec![
- lit(ScalarValue::Int64(Some(0))),
- lit(ScalarValue::Int64(None)),
- lit(ScalarValue::Int64(Some(3))),
- ];
- for v in 4..(OPTIMIZER_INSET_THRESHOLD + 4) {
- list.push(lit(ScalarValue::Int64(Some(v as i64))));
- }
-
- in_list!(
- batch,
- list.clone(),
- &false,
- vec![Some(true), None, None],
- col_a.clone(),
- &schema
- );
-
- in_list!(
- batch,
- list.clone(),
- &true,
- vec![Some(false), None, None],
- col_a.clone(),
- &schema
- );
-
- Ok(())
- }
-
- #[test]
- fn in_list_set_float64() -> Result<()> {
- let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]);
- let a = Float64Array::from(vec![Some(0.0), Some(2.0), None]);
- let col_a = col("a", &schema)?;
- let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
-
- // expression: "a in (0.0,NULL,3.0,4.0 ....)"
- let mut list = vec![
- lit(ScalarValue::Float64(Some(0.0))),
- lit(ScalarValue::Float64(None)),
- lit(ScalarValue::Float64(Some(3.0))),
- ];
- for v in 4..(OPTIMIZER_INSET_THRESHOLD + 4) {
- list.push(lit(ScalarValue::Float64(Some(v as f64))));
- }
-
- in_list!(
- batch,
- list.clone(),
- &false,
- vec![Some(true), None, None],
- col_a.clone(),
- &schema
- );
-
- in_list!(
- batch,
- list.clone(),
- &true,
- vec![Some(false), None, None],
- col_a.clone(),
- &schema
- );
-
- Ok(())
- }
-
- #[test]
- fn in_list_set_utf8() -> Result<()> {
- let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
- let a = StringArray::from(vec![Some("a"), Some("b"), None]);
- let col_a = col("a", &schema)?;
- let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
-
- // expression: "a in ("a", NULL, "4c", "5c", ....)"
- let mut list = vec![
- lit(ScalarValue::Utf8(Some("a".to_string()))),
- lit(ScalarValue::Utf8(None)),
- ];
- for v in 4..(OPTIMIZER_INSET_THRESHOLD + 4) {
- let value = v.to_string() + "c";
- list.push(lit(ScalarValue::Utf8(Some(value))));
- }
- in_list!(
- batch,
- list.clone(),
- &false,
- vec![Some(true), None, None],
- col_a.clone(),
- &schema
- );
-
- in_list!(
- batch,
- list.clone(),
- &true,
- vec![Some(false), None, None],
- col_a.clone(),
- &schema
- );
-
- Ok(())
- }
-
- #[test]
- fn in_list_set_binary() -> Result<()> {
- let schema = Schema::new(vec![Field::new("a", DataType::Binary, true)]);
- let a = BinaryArray::from(vec![
- Some([1, 2, 3].as_slice()),
- Some([3, 2, 1].as_slice()),
- None,
- ]);
- let col_a = col("a", &schema)?;
- let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
-
- let mut list = vec![lit([1, 2, 3].as_slice()), lit(ScalarValue::Binary(None))];
- for v in 0..OPTIMIZER_INSET_THRESHOLD {
- list.push(lit([v as u8].as_slice()));
- }
-
- in_list!(
- batch,
- list.clone(),
- &false,
- vec![Some(true), None, None],
- col_a.clone(),
- &schema
- );
-
- in_list!(
- batch,
- list.clone(),
- &true,
- vec![Some(false), None, None],
- col_a.clone(),
- &schema
- );
-
- Ok(())
- }
-
- #[test]
- fn in_list_set_decimal() -> Result<()> {
- let schema =
- Schema::new(vec![Field::new("a", DataType::Decimal128(13, 4), true)]);
- let array = vec![Some(100_0000_i128), Some(200_5000_i128), None]
- .into_iter()
- .collect::<Decimal128Array>();
- let array = array.with_precision_and_scale(13, 4).unwrap();
- let col_a = col("a", &schema)?;
- let batch =
- RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(array)])?;
-
- // expression: "a in (100.0000, Null, 100.0004, 100.0005...)
- let mut list = vec![
- lit(ScalarValue::Decimal128(Some(100_0000_i128), 13, 4)),
- lit(ScalarValue::Decimal128(None, 13, 4)),
- ];
- for v in 4..(OPTIMIZER_INSET_THRESHOLD + 4) {
- let value = 100_0000_i128 + v as i128;
- list.push(lit(ScalarValue::Decimal128(Some(value), 13, 4)));
- }
-
- in_list!(
- batch,
- list.clone(),
- &false,
- vec![Some(true), None, None],
- col_a.clone(),
- &schema
- );
-
- in_list!(
- batch,
- list,
- &true,
- vec![Some(false), None, None],
- col_a.clone(),
- &schema
- );
- Ok(())
- }
-
#[test]
fn test_cast_static_filter_to_set() -> Result<()> {
// random schema
let schema =
Schema::new(vec![Field::new("a", DataType::Decimal128(13, 4), true)]);
+
// list of phy expr
let mut phy_exprs = vec![
lit(1i64),
@@ -1782,25 +946,28 @@ mod tests {
];
let result = try_cast_static_filter_to_set(&phy_exprs, &schema).unwrap();
- assert!(result.contains(&1i64.try_into().unwrap()));
- assert!(result.contains(&2i64.try_into().unwrap()));
- assert!(result.contains(&3i64.try_into().unwrap()));
+ let array = Int64Array::from(vec![1, 2, 3, 4]);
+ let r = result.contains(&array, false);
+ assert_eq!(r, BooleanArray::from(vec![true, true, true, false]));
- assert!(try_cast_static_filter_to_set(&phy_exprs, &schema).is_ok());
+ try_cast_static_filter_to_set(&phy_exprs, &schema).unwrap();
// cast(cast(lit())), but the cast to the same data type, one case will be ignored
phy_exprs.push(expressions::cast(
expressions::cast(lit(2i32), &schema, DataType::Int64)?,
&schema,
DataType::Int64,
)?);
- assert!(try_cast_static_filter_to_set(&phy_exprs, &schema).is_ok());
+ try_cast_static_filter_to_set(&phy_exprs, &schema).unwrap();
+
+ phy_exprs.clear();
+
// case(cast(lit())), the cast to the diff data type
phy_exprs.push(expressions::cast(
expressions::cast(lit(2i32), &schema, DataType::Int64)?,
&schema,
DataType::Int32,
)?);
- assert!(try_cast_static_filter_to_set(&phy_exprs, &schema).is_ok());
+ try_cast_static_filter_to_set(&phy_exprs, &schema).unwrap();
// column
phy_exprs.push(expressions::col("a", &schema)?);
@@ -1809,58 +976,6 @@ mod tests {
Ok(())
}
- #[test]
- fn in_list_set_timestamp() -> Result<()> {
- let schema = Schema::new(vec![Field::new(
- "a",
- DataType::Timestamp(TimeUnit::Microsecond, None),
- true,
- )]);
- let a = TimestampMicrosecondArray::from(vec![
- Some(1388588401000000000),
- Some(1288588501000000000),
- None,
- ]);
- let col_a = col("a", &schema)?;
- let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
-
- let mut list = vec![
- lit(ScalarValue::TimestampMicrosecond(
- Some(1388588401000000000),
- None,
- )),
- lit(ScalarValue::TimestampMicrosecond(None, None)),
- lit(ScalarValue::TimestampMicrosecond(
- Some(1388588401000000001),
- None,
- )),
- ];
- let start_ts = 1388588401000000001;
- for v in start_ts..(start_ts + OPTIMIZER_INSET_THRESHOLD + 4) {
- list.push(lit(ScalarValue::TimestampMicrosecond(Some(v as i64), None)));
- }
-
- in_list!(
- batch,
- list.clone(),
- &false,
- vec![Some(true), None, None],
- col_a.clone(),
- &schema
- );
-
- in_list!(
- batch,
- list.clone(),
- &true,
- vec![Some(false), None, None],
- col_a.clone(),
- &schema
- );
-
- Ok(())
- }
-
#[test]
fn in_list_timestamp() -> Result<()> {
let schema = Schema::new(vec![Field::new(
diff --git a/datafusion/core/src/physical_plan/hash_utils.rs b/datafusion/physical-expr/src/hash_utils.rs
similarity index 98%
rename from datafusion/core/src/physical_plan/hash_utils.rs
rename to datafusion/physical-expr/src/hash_utils.rs
index 75f5bab46..4b1cb23d0 100644
--- a/datafusion/core/src/physical_plan/hash_utils.rs
+++ b/datafusion/physical-expr/src/hash_utils.rs
@@ -17,12 +17,12 @@
//! Functionality used both on logical and physical plans
-use crate::error::{DataFusionError, Result};
use ahash::RandomState;
use arrow::array::*;
use arrow::datatypes::*;
use arrow::{downcast_dictionary_array, downcast_primitive_array};
use arrow_buffer::i256;
+use datafusion_common::{DataFusionError, Result};
use std::sync::Arc;
// Combines two hashes into one hash
@@ -45,7 +45,7 @@ fn hash_null(random_state: &RandomState, hashes_buffer: &'_ mut [u64], mul_col:
}
}
-trait HashValue {
+pub(crate) trait HashValue {
fn hash_one(&self, state: &RandomState) -> u64;
}
@@ -68,15 +68,15 @@ hash_value!(i8, i16, i32, i64, i128, i256, u8, u16, u32, u64);
hash_value!(bool, str, [u8]);
macro_rules! hash_float_value {
- ($($t:ty),+) => {
+ ($(($t:ty, $i:ty)),+) => {
$(impl HashValue for $t {
fn hash_one(&self, state: &RandomState) -> u64 {
- state.hash_one(self.to_le_bytes())
+ state.hash_one(<$i>::from_ne_bytes(self.to_ne_bytes()))
}
})+
};
}
-hash_float_value!(half::f16, f32, f64);
+hash_float_value!((half::f16, u16), (f32, u32), (f64, u64));
fn hash_array<T>(
array: T,
diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs
index 9087c2c2d..d2b899dca 100644
--- a/datafusion/physical-expr/src/lib.rs
+++ b/datafusion/physical-expr/src/lib.rs
@@ -24,6 +24,7 @@ pub mod datetime_expressions;
pub mod execution_props;
pub mod expressions;
pub mod functions;
+pub mod hash_utils;
pub mod math_expressions;
mod physical_expr;
pub mod planner;