You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by li...@apache.org on 2022/07/04 02:07:53 UTC
[arrow-datafusion] branch master updated: InList: fix bug for comparing with Null in the list using the set optimization (#2809)
This is an automated email from the ASF dual-hosted git repository.
liukun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 57f47ab92 InList: fix bug for comparing with Null in the list using the set optimization (#2809)
57f47ab92 is described below
commit 57f47ab9230a9a12b3244191dcf1623f8b69fd61
Author: Kun Liu <li...@apache.org>
AuthorDate: Mon Jul 4 10:07:47 2022 +0800
InList: fix bug for comparing with Null in the list using the set optimization (#2809)
* inlist: remove check path for UTF8::(None) for NULL value
* fix bug: inlist set for null case
* address comments
---
.../physical-expr/src/expressions/in_list.rs | 568 +++++++++++++++------
1 file changed, 411 insertions(+), 157 deletions(-)
diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs
index 392f382c6..c9a3e419a 100644
--- a/datafusion/physical-expr/src/expressions/in_list.rs
+++ b/datafusion/physical-expr/src/expressions/in_list.rs
@@ -27,7 +27,6 @@ use arrow::array::{
Int64Array, Int8Array, OffsetSizeTrait, UInt16Array, UInt32Array, UInt64Array,
UInt8Array,
};
-use arrow::datatypes::ArrowPrimitiveType;
use arrow::{
datatypes::{DataType, Schema},
record_batch::RecordBatch,
@@ -37,7 +36,10 @@ use crate::{expressions, PhysicalExpr};
use arrow::array::*;
use arrow::buffer::{Buffer, MutableBuffer};
use datafusion_common::ScalarValue;
-use datafusion_common::ScalarValue::Decimal128;
+use datafusion_common::ScalarValue::{
+ Boolean, Decimal128, Int16, Int32, Int64, Int8, LargeUtf8, UInt16, UInt32, UInt64,
+ UInt8, Utf8,
+};
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::ColumnarValue;
@@ -119,33 +121,7 @@ macro_rules! make_contains {
})
.collect::<Vec<_>>();
- Ok(ColumnarValue::Array(Arc::new(
- array
- .iter()
- .map(|x| {
- let contains = x.map(|x| values.contains(&x));
- match contains {
- Some(true) => {
- if $NEGATED {
- Some(false)
- } else {
- Some(true)
- }
- }
- Some(false) => {
- if contains_null {
- None
- } else if $NEGATED {
- Some(true)
- } else {
- Some(false)
- }
- }
- None => None,
- }
- })
- .collect::<BooleanArray>(),
- )))
+ collection_contains_check!(array, values, $NEGATED, contains_null)
}};
}
@@ -170,86 +146,114 @@ macro_rules! make_contains_primitive {
})
.collect::<Vec<_>>();
- if $NEGATED {
+ Ok(collection_contains_check!(array, values, $NEGATED, contains_null))
+ }};
+}
+
+macro_rules! set_contains_for_float {
+ ($ARRAY:expr, $SET_VALUES:expr, $SCALAR_VALUE:ident, $NEGATED:expr, $PHY_TYPE:ty) => {{
+ let contains_null = $SET_VALUES.iter().any(|s| s.is_null());
+ let bool_array = if $NEGATED {
+ // Not in
if contains_null {
- Ok(ColumnarValue::Array(Arc::new(
- array
- .iter()
- .map(|x| match x.map(|v| !values.contains(&v)) {
+ $ARRAY
+ .iter()
+ .map(|vop| {
+ match vop.map(|v| !$SET_VALUES.contains(&v.try_into().unwrap())) {
Some(true) => None,
x => x,
- })
- .collect::<BooleanArray>(),
- )))
+ }
+ })
+ .collect::<BooleanArray>()
} else {
- Ok(ColumnarValue::Array(Arc::new(
- not_in_list_primitive(array, &values)?,
- )))
+ $ARRAY
+ .iter()
+ .map(|vop| vop.map(|v| !$SET_VALUES.contains(&v.try_into().unwrap())))
+ .collect::<BooleanArray>()
}
} else {
+ // In
if contains_null {
- Ok(ColumnarValue::Array(Arc::new(
- array
- .iter()
- .map(|x| match x.map(|v| values.contains(&v)) {
+ $ARRAY
+ .iter()
+ .map(|vop| {
+ match vop.map(|v| $SET_VALUES.contains(&v.try_into().unwrap())) {
Some(false) => None,
x => x,
- })
- .collect::<BooleanArray>(),
- )))
+ }
+ })
+ .collect::<BooleanArray>()
} else {
- Ok(ColumnarValue::Array(Arc::new(in_list_primitive(
- array, &values,
- )?)))
+ $ARRAY
+ .iter()
+ .map(|vop| vop.map(|v| $SET_VALUES.contains(&v.try_into().unwrap())))
+ .collect::<BooleanArray>()
}
- }
+ };
+ ColumnarValue::Array(Arc::new(bool_array))
+ }};
+}
+
+macro_rules! set_contains_for_primitive {
+ ($ARRAY:expr, $SET_VALUES:expr, $SCALAR_VALUE:ident, $NEGATED:expr, $PHY_TYPE:ty) => {{
+ let contains_null = $SET_VALUES.iter().any(|s| s.is_null());
+ let native_array = $SET_VALUES
+ .iter()
+ .flat_map(|v| match v {
+ $SCALAR_VALUE(value) => *value,
+ datatype => {
+ unreachable!(
+ "InList can't reach other data type {} for {}.",
+ datatype, v
+ )
+ }
+ })
+ .collect::<Vec<_>>();
+ let native_set: HashSet<$PHY_TYPE> = HashSet::from_iter(native_array);
+
+ collection_contains_check!($ARRAY, native_set, $NEGATED, contains_null)
}};
}
-macro_rules! set_contains_with_negated {
- ($ARRAY:expr, $LIST_VALUES:expr, $NEGATED:expr) => {{
- if $NEGATED {
- return Ok(ColumnarValue::Array(Arc::new(
+macro_rules! collection_contains_check {
+ ($ARRAY:expr, $VALUES:expr, $NEGATED:expr, $CONTAINS_NULL:expr) => {{
+ let bool_array = if $NEGATED {
+ // Not in
+ if $CONTAINS_NULL {
$ARRAY
.iter()
- .map(|x| x.map(|v| !$LIST_VALUES.contains(&v.try_into().unwrap())))
- .collect::<BooleanArray>(),
- )));
+ .map(|vop| match vop.map(|v| !$VALUES.contains(&v)) {
+ Some(true) => None,
+ x => x,
+ })
+ .collect::<BooleanArray>()
+ } else {
+ $ARRAY
+ .iter()
+ .map(|vop| vop.map(|v| !$VALUES.contains(&v)))
+ .collect::<BooleanArray>()
+ }
} else {
- return Ok(ColumnarValue::Array(Arc::new(
+ // In
+ if $CONTAINS_NULL {
$ARRAY
.iter()
- .map(|x| x.map(|v| $LIST_VALUES.contains(&v.try_into().unwrap())))
- .collect::<BooleanArray>(),
- )));
- }
+ .map(|vop| match vop.map(|v| $VALUES.contains(&v)) {
+ Some(false) => None,
+ x => x,
+ })
+ .collect::<BooleanArray>()
+ } else {
+ $ARRAY
+ .iter()
+ .map(|vop| vop.map(|v| $VALUES.contains(&v)))
+ .collect::<BooleanArray>()
+ }
+ };
+ ColumnarValue::Array(Arc::new(bool_array))
}};
}
-// whether each value on the left (can be null) is contained in the non-null list
-fn in_list_primitive<T: ArrowPrimitiveType>(
- array: &PrimitiveArray<T>,
- values: &[<T as ArrowPrimitiveType>::Native],
-) -> Result<BooleanArray> {
- compare_op_scalar!(
- array,
- values,
- |x, v: &[<T as ArrowPrimitiveType>::Native]| v.contains(&x)
- )
-}
-
-// whether each value on the left (can be null) is contained in the non-null list
-fn not_in_list_primitive<T: ArrowPrimitiveType>(
- array: &PrimitiveArray<T>,
- values: &[<T as ArrowPrimitiveType>::Native],
-) -> Result<BooleanArray> {
- compare_op_scalar!(
- array,
- values,
- |x, v: &[<T as ArrowPrimitiveType>::Native]| !v.contains(&x)
- )
-}
-
// whether each value on the left (can be null) is contained in the non-null list
fn in_list_utf8<OffsetSize: OffsetSizeTrait>(
array: &GenericStringArray<OffsetSize>,
@@ -305,7 +309,7 @@ fn make_list_contains_decimal(
array: &DecimalArray,
list: Vec<ColumnarValue>,
negated: bool,
-) -> BooleanArray {
+) -> ColumnarValue {
let contains_null = list
.iter()
.any(|v| matches!(v, ColumnarValue::Scalar(s) if s.is_null()));
@@ -325,32 +329,14 @@ fn make_list_contains_decimal(
})
.collect::<Vec<_>>();
- if !negated {
- // In
- array
- .iter()
- .map(|v| v.map(|v128| values.contains(&v128)))
- .collect::<BooleanArray>()
- } else {
- // Not in
- if contains_null {
- // If the expr is NOT IN and the list contains NULL value
- // All the result must be NONE
- BooleanArray::from(vec![None; array.len()])
- } else {
- array
- .iter()
- .map(|v| v.map(|v128| !values.contains(&v128)))
- .collect::<BooleanArray>()
- }
- }
+ collection_contains_check!(array, values, negated, contains_null)
}
fn make_set_contains_decimal(
array: &DecimalArray,
set: &HashSet<ScalarValue>,
negated: bool,
-) -> BooleanArray {
+) -> ColumnarValue {
let contains_null = set.iter().any(|v| v.is_null());
let native_array = set
.iter()
@@ -363,25 +349,28 @@ fn make_set_contains_decimal(
.collect::<Vec<_>>();
let native_set: HashSet<i128> = HashSet::from_iter(native_array);
- if !negated {
- // In
- array
- .iter()
- .map(|v| v.map(|v128| native_set.contains(&v128)))
- .collect::<BooleanArray>()
- } else {
- // Not in
- if contains_null {
- // If the expr is NOT IN and the list contains NULL value
- // All the result must be NONE
- BooleanArray::from(vec![None; array.len()])
- } else {
- array
- .iter()
- .map(|v| v.map(|v128| !native_set.contains(&v128)))
- .collect::<BooleanArray>()
- }
- }
+ collection_contains_check!(array, native_set, negated, contains_null)
+}
+
+fn set_contains_utf8<OffsetSize: OffsetSizeTrait>(
+ array: &GenericStringArray<OffsetSize>,
+ set: &HashSet<ScalarValue>,
+ negated: bool,
+) -> ColumnarValue {
+ let contains_null = set.iter().any(|v| v.is_null());
+ let native_array = set
+ .iter()
+ .flat_map(|v| match v {
+ Utf8(v) => v.as_deref(),
+ LargeUtf8(v) => v.as_deref(),
+ datatype => {
+ unreachable!("InList can't reach other data type {} for {}.", datatype, v)
+ }
+ })
+ .collect::<Vec<_>>();
+ let native_set: HashSet<&str> = HashSet::from_iter(native_array);
+
+ collection_contains_check!(array, native_set, negated, contains_null)
}
impl InListExpr {
@@ -532,66 +521,131 @@ impl PhysicalExpr for InListExpr {
match value_data_type {
DataType::Boolean => {
let array = array.as_any().downcast_ref::<BooleanArray>().unwrap();
- set_contains_with_negated!(array, set, self.negated)
+ Ok(set_contains_for_primitive!(
+ array,
+ set,
+ Boolean,
+ self.negated,
+ bool
+ ))
}
DataType::Int8 => {
let array = array.as_any().downcast_ref::<Int8Array>().unwrap();
- set_contains_with_negated!(array, set, self.negated)
+ Ok(set_contains_for_primitive!(
+ array,
+ set,
+ Int8,
+ self.negated,
+ i8
+ ))
}
DataType::Int16 => {
let array = array.as_any().downcast_ref::<Int16Array>().unwrap();
- set_contains_with_negated!(array, set, self.negated)
+ Ok(set_contains_for_primitive!(
+ array,
+ set,
+ Int16,
+ self.negated,
+ i16
+ ))
}
DataType::Int32 => {
let array = array.as_any().downcast_ref::<Int32Array>().unwrap();
- set_contains_with_negated!(array, set, self.negated)
+ Ok(set_contains_for_primitive!(
+ array,
+ set,
+ Int32,
+ self.negated,
+ i32
+ ))
}
DataType::Int64 => {
let array = array.as_any().downcast_ref::<Int64Array>().unwrap();
- set_contains_with_negated!(array, set, self.negated)
+ Ok(set_contains_for_primitive!(
+ array,
+ set,
+ Int64,
+ self.negated,
+ i64
+ ))
}
DataType::UInt8 => {
let array = array.as_any().downcast_ref::<UInt8Array>().unwrap();
- set_contains_with_negated!(array, set, self.negated)
+ Ok(set_contains_for_primitive!(
+ array,
+ set,
+ UInt8,
+ self.negated,
+ u8
+ ))
}
DataType::UInt16 => {
let array = array.as_any().downcast_ref::<UInt16Array>().unwrap();
- set_contains_with_negated!(array, set, self.negated)
+ Ok(set_contains_for_primitive!(
+ array,
+ set,
+ UInt16,
+ self.negated,
+ u16
+ ))
}
DataType::UInt32 => {
let array = array.as_any().downcast_ref::<UInt32Array>().unwrap();
- set_contains_with_negated!(array, set, self.negated)
+ Ok(set_contains_for_primitive!(
+ array,
+ set,
+ UInt32,
+ self.negated,
+ u32
+ ))
}
DataType::UInt64 => {
let array = array.as_any().downcast_ref::<UInt64Array>().unwrap();
- set_contains_with_negated!(array, set, self.negated)
+ Ok(set_contains_for_primitive!(
+ array,
+ set,
+ UInt64,
+ self.negated,
+ u64
+ ))
}
DataType::Float32 => {
let array = array.as_any().downcast_ref::<Float32Array>().unwrap();
- set_contains_with_negated!(array, set, self.negated)
+ Ok(set_contains_for_float!(
+ array,
+ set,
+ Float32,
+ self.negated,
+ f32
+ ))
}
DataType::Float64 => {
let array = array.as_any().downcast_ref::<Float64Array>().unwrap();
- set_contains_with_negated!(array, set, self.negated)
+ Ok(set_contains_for_float!(
+ array,
+ set,
+ Float64,
+ self.negated,
+ f64
+ ))
}
DataType::Utf8 => {
let array = array
.as_any()
.downcast_ref::<GenericStringArray<i32>>()
.unwrap();
- set_contains_with_negated!(array, set, self.negated)
+ Ok(set_contains_utf8(array, set, self.negated))
}
DataType::LargeUtf8 => {
let array = array
.as_any()
.downcast_ref::<GenericStringArray<i64>>()
.unwrap();
- set_contains_with_negated!(array, set, self.negated)
+ Ok(set_contains_utf8(array, set, self.negated))
}
DataType::Decimal(_, _) => {
let array = array.as_any().downcast_ref::<DecimalArray>().unwrap();
- let result = make_set_contains_decimal(array, set, self.negated);
- Ok(ColumnarValue::Array(Arc::new(result)))
+ Ok(make_set_contains_decimal(array, set, self.negated))
}
datatype => Result::Err(DataFusionError::NotImplemented(format!(
"InSet does not support datatype {:?}.",
@@ -701,15 +755,13 @@ impl PhysicalExpr for InListExpr {
UInt8Array
)
}
- DataType::Boolean => {
- make_contains!(
- array,
- list_values,
- self.negated,
- Boolean,
- BooleanArray
- )
- }
+ DataType::Boolean => Ok(make_contains!(
+ array,
+ list_values,
+ self.negated,
+ Boolean,
+ BooleanArray
+ )),
DataType::Utf8 => {
self.compare_utf8::<i32>(array, list_values, self.negated)
}
@@ -723,12 +775,11 @@ impl PhysicalExpr for InListExpr {
DataType::Decimal(_, _) => {
let decimal_array =
array.as_any().downcast_ref::<DecimalArray>().unwrap();
- let result = make_list_contains_decimal(
+ Ok(make_list_contains_decimal(
decimal_array,
list_values,
self.negated,
- );
- Ok(ColumnarValue::Array(Arc::new(result)))
+ ))
}
datatype => Result::Err(DataFusionError::NotImplemented(format!(
"InList does not support datatype {:?}.",
@@ -1022,9 +1073,24 @@ mod tests {
);
// expression: "a in (200,NULL), the data type of list is INT32 AND NULL
- // TODO support: NULL data type to decimal in arrow-rs
- // let list = vec![lit(100i32), lit(ScalarValue::Null)];
- // in_list!(batch, list, &false, vec![Some(true), None, Some(false)], col_a.clone(), &schema);
+ let list = vec![lit(ScalarValue::Int32(Some(100))), lit(ScalarValue::Null)];
+ in_list!(
+ batch,
+ list.clone(),
+ &false,
+ vec![Some(true), None, None],
+ col_a.clone(),
+ &schema
+ );
+ // expression: "a not in (200,NULL), the data type of list is INT32 AND NULL
+ in_list!(
+ batch,
+ list,
+ &true,
+ vec![Some(false), None, None],
+ col_a.clone(),
+ &schema
+ );
// expression: "a in (200.5, 100), the data type of list is FLOAT32 and INT32
let list = vec![lit(200.50f32), lit(100i32)];
@@ -1072,4 +1138,192 @@ mod tests {
Ok(())
}
+
+ #[test]
+ fn in_list_set_bool() -> Result<()> {
+ let schema = Schema::new(vec![Field::new("a", DataType::Boolean, true)]);
+ let a = BooleanArray::from(vec![Some(true), None, Some(false)]);
+ let col_a = col("a", &schema)?;
+ let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
+
+ // expression: "a in (true,null,true.....)"
+ let mut list = vec![
+ lit(ScalarValue::Boolean(Some(true))),
+ lit(ScalarValue::Boolean(None)),
+ ];
+ for _ in 0..OPTIMIZER_INSET_THRESHOLD {
+ list.push(lit(ScalarValue::Boolean(Some(true))));
+ }
+ in_list!(
+ batch,
+ list.clone(),
+ &false,
+ vec![Some(true), None, None],
+ col_a.clone(),
+ &schema
+ );
+ in_list!(
+ batch,
+ list,
+ &true,
+ vec![Some(false), None, None],
+ col_a.clone(),
+ &schema
+ );
+ Ok(())
+ }
+
+ #[test]
+ fn in_list_set_int64() -> Result<()> {
+ let schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]);
+ let a = Int64Array::from(vec![Some(0), Some(2), None]);
+ let col_a = col("a", &schema)?;
+ let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
+
+ // expression: "a in (0,NULL,3,4....)"
+ let mut list = vec![
+ lit(ScalarValue::Int64(Some(0))),
+ lit(ScalarValue::Int64(None)),
+ lit(ScalarValue::Int64(Some(3))),
+ ];
+ for v in 4..(OPTIMIZER_INSET_THRESHOLD + 4) {
+ list.push(lit(ScalarValue::Int64(Some(v as i64))));
+ }
+
+ in_list!(
+ batch,
+ list.clone(),
+ &false,
+ vec![Some(true), None, None],
+ col_a.clone(),
+ &schema
+ );
+
+ in_list!(
+ batch,
+ list.clone(),
+ &true,
+ vec![Some(false), None, None],
+ col_a.clone(),
+ &schema
+ );
+
+ Ok(())
+ }
+
+ #[test]
+ fn in_list_set_float64() -> Result<()> {
+ let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]);
+ let a = Float64Array::from(vec![Some(0.0), Some(2.0), None]);
+ let col_a = col("a", &schema)?;
+ let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
+
+ // expression: "a in (0.0,NULL,3.0,4.0 ....)"
+ let mut list = vec![
+ lit(ScalarValue::Float64(Some(0.0))),
+ lit(ScalarValue::Float64(None)),
+ lit(ScalarValue::Float64(Some(3.0))),
+ ];
+ for v in 4..(OPTIMIZER_INSET_THRESHOLD + 4) {
+ list.push(lit(ScalarValue::Float64(Some(v as f64))));
+ }
+
+ in_list!(
+ batch,
+ list.clone(),
+ &false,
+ vec![Some(true), None, None],
+ col_a.clone(),
+ &schema
+ );
+
+ in_list!(
+ batch,
+ list.clone(),
+ &true,
+ vec![Some(false), None, None],
+ col_a.clone(),
+ &schema
+ );
+
+ Ok(())
+ }
+
+ #[test]
+ fn in_list_set_utf8() -> Result<()> {
+ let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
+ let a = StringArray::from(vec![Some("a"), Some("b"), None]);
+ let col_a = col("a", &schema)?;
+ let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?;
+
+ // expression: "a in ("a", NULL, "4c", "5c", ....)"
+ let mut list = vec![
+ lit(ScalarValue::Utf8(Some("a".to_string()))),
+ lit(ScalarValue::Utf8(None)),
+ ];
+ for v in 4..(OPTIMIZER_INSET_THRESHOLD + 4) {
+ let value = v.to_string() + "c";
+ list.push(lit(ScalarValue::Utf8(Some(value))));
+ }
+ in_list!(
+ batch,
+ list.clone(),
+ &false,
+ vec![Some(true), None, None],
+ col_a.clone(),
+ &schema
+ );
+
+ in_list!(
+ batch,
+ list.clone(),
+ &true,
+ vec![Some(false), None, None],
+ col_a.clone(),
+ &schema
+ );
+
+ Ok(())
+ }
+
+ #[test]
+ fn in_list_set_decimal() -> Result<()> {
+ let schema = Schema::new(vec![Field::new("a", DataType::Decimal(13, 4), true)]);
+ let array = vec![Some(100_0000_i128), Some(200_5000_i128), None]
+ .into_iter()
+ .collect::<DecimalArray>();
+ let array = array.with_precision_and_scale(13, 4).unwrap();
+ let col_a = col("a", &schema)?;
+ let batch =
+ RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(array)])?;
+
+ // expression: "a in (100.0000, Null, 100.0004, 100.0005...)
+ let mut list = vec![
+ lit(ScalarValue::Decimal128(Some(100_0000_i128), 13, 4)),
+ lit(ScalarValue::Decimal128(None, 13, 4)),
+ ];
+ for v in 4..(OPTIMIZER_INSET_THRESHOLD + 4) {
+ let value = 100_0000_i128 + v as i128;
+ list.push(lit(ScalarValue::Decimal128(Some(value), 13, 4)));
+ }
+
+ in_list!(
+ batch,
+ list.clone(),
+ &false,
+ vec![Some(true), None, None],
+ col_a.clone(),
+ &schema
+ );
+
+ in_list!(
+ batch,
+ list,
+ &true,
+ vec![Some(false), None, None],
+ col_a.clone(),
+ &schema
+ );
+ Ok(())
+ }
}