You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/05/06 17:24:50 UTC
[arrow-datafusion] branch master updated: Add proper support for `null` literal by introducing `ScalarValue::Null` (#2364)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new fcc35e88b Add proper support for `null` literal by introducing `ScalarValue::Null` (#2364)
fcc35e88b is described below
commit fcc35e88b1874d72783949a8d3a2c967c182cc7e
Author: DuRipeng <45...@qq.com>
AuthorDate: Sat May 7 01:24:46 2022 +0800
Add proper support for `null` literal by introducing `ScalarValue::Null` (#2364)
* introduce null
* fix fmt
---
datafusion/common/src/scalar.rs | 30 +++++++++++++-
datafusion/core/src/logical_plan/builder.rs | 2 +-
datafusion/core/src/physical_plan/hash_join.rs | 6 ++-
datafusion/core/src/physical_plan/hash_utils.rs | 16 ++++++++
datafusion/core/src/sql/planner.rs | 4 +-
datafusion/core/tests/sql/expr.rs | 46 ++++++++++-----------
datafusion/core/tests/sql/functions.rs | 10 ++---
datafusion/core/tests/sql/joins.rs | 6 ++-
datafusion/core/tests/sql/select.rs | 34 +++++++++++++---
datafusion/expr/src/binary_rule.rs | 22 ++++++++++
datafusion/expr/src/function.rs | 2 +
datafusion/expr/src/type_coercion.rs | 40 ++++++++++++------
datafusion/physical-expr/src/expressions/binary.rs | 47 +++++++++++++++++++++-
.../physical-expr/src/expressions/in_list.rs | 4 ++
datafusion/physical-expr/src/expressions/nullif.rs | 2 +-
15 files changed, 217 insertions(+), 54 deletions(-)
diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs
index 03a59ff6d..4a7bc5337 100644
--- a/datafusion/common/src/scalar.rs
+++ b/datafusion/common/src/scalar.rs
@@ -39,6 +39,8 @@ use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc};
/// This is the single-valued counter-part of arrow’s `Array`.
#[derive(Clone)]
pub enum ScalarValue {
+ /// represents `DataType::Null` (castable to/from any other type)
+ Null,
/// true or false value
Boolean(Option<bool>),
/// 32bit float
@@ -170,6 +172,8 @@ impl PartialEq for ScalarValue {
(IntervalMonthDayNano(_), _) => false,
(Struct(v1, t1), Struct(v2, t2)) => v1.eq(v2) && t1.eq(t2),
(Struct(_, _), _) => false,
+ (Null, Null) => true,
+ (Null, _) => false,
}
}
}
@@ -270,6 +274,8 @@ impl PartialOrd for ScalarValue {
}
}
(Struct(_, _), _) => None,
+ (Null, Null) => Some(Ordering::Equal),
+ (Null, _) => None,
}
}
}
@@ -325,6 +331,8 @@ impl std::hash::Hash for ScalarValue {
v.hash(state);
t.hash(state);
}
+ // stable hash for Null value
+ Null => 1.hash(state),
}
}
}
@@ -594,6 +602,7 @@ impl ScalarValue {
DataType::Interval(IntervalUnit::MonthDayNano)
}
ScalarValue::Struct(_, fields) => DataType::Struct(fields.as_ref().clone()),
+ ScalarValue::Null => DataType::Null,
}
}
@@ -623,7 +632,8 @@ impl ScalarValue {
pub fn is_null(&self) -> bool {
matches!(
*self,
- ScalarValue::Boolean(None)
+ ScalarValue::Null
+ | ScalarValue::Boolean(None)
| ScalarValue::UInt8(None)
| ScalarValue::UInt16(None)
| ScalarValue::UInt32(None)
@@ -836,6 +846,7 @@ impl ScalarValue {
ScalarValue::iter_to_decimal_array(scalars, precision, scale)?;
Arc::new(decimal_array)
}
+ DataType::Null => ScalarValue::iter_to_null_array(scalars),
DataType::Boolean => build_array_primitive!(BooleanArray, Boolean),
DataType::Float32 => build_array_primitive!(Float32Array, Float32),
DataType::Float64 => build_array_primitive!(Float64Array, Float64),
@@ -968,6 +979,17 @@ impl ScalarValue {
Ok(array)
}
+ fn iter_to_null_array(scalars: impl IntoIterator<Item = ScalarValue>) -> ArrayRef {
+ let length =
+ scalars
+ .into_iter()
+ .fold(0usize, |r, element: ScalarValue| match element {
+ ScalarValue::Null => r + 1,
+ _ => unreachable!(),
+ });
+ new_null_array(&DataType::Null, length)
+ }
+
fn iter_to_decimal_array(
scalars: impl IntoIterator<Item = ScalarValue>,
precision: &usize,
@@ -1241,6 +1263,7 @@ impl ScalarValue {
Arc::new(StructArray::from(field_values))
}
},
+ ScalarValue::Null => new_null_array(&DataType::Null, size),
}
}
@@ -1266,6 +1289,7 @@ impl ScalarValue {
}
Ok(match array.data_type() {
+ DataType::Null => ScalarValue::Null,
DataType::Decimal(precision, scale) => {
ScalarValue::get_decimal_value_from_array(array, index, precision, scale)
}
@@ -1522,6 +1546,7 @@ impl ScalarValue {
eq_array_primitive!(array, index, IntervalMonthDayNanoArray, val)
}
ScalarValue::Struct(_, _) => unimplemented!(),
+ ScalarValue::Null => array.data().is_null(index),
}
}
@@ -1743,6 +1768,7 @@ impl TryFrom<&DataType> for ScalarValue {
DataType::Struct(fields) => {
ScalarValue::Struct(None, Box::new(fields.clone()))
}
+ DataType::Null => ScalarValue::Null,
_ => {
return Err(DataFusionError::NotImplemented(format!(
"Can't create a scalar from data_type \"{:?}\"",
@@ -1835,6 +1861,7 @@ impl fmt::Display for ScalarValue {
)?,
None => write!(f, "NULL")?,
},
+ ScalarValue::Null => write!(f, "NULL")?,
};
Ok(())
}
@@ -1902,6 +1929,7 @@ impl fmt::Debug for ScalarValue {
None => write!(f, "Struct(NULL)"),
}
}
+ ScalarValue::Null => write!(f, "NULL"),
}
}
}
diff --git a/datafusion/core/src/logical_plan/builder.rs b/datafusion/core/src/logical_plan/builder.rs
index 1fbb1f5f9..8a0ea6d66 100644
--- a/datafusion/core/src/logical_plan/builder.rs
+++ b/datafusion/core/src/logical_plan/builder.rs
@@ -155,7 +155,7 @@ impl LogicalPlanBuilder {
.iter()
.enumerate()
.map(|(j, expr)| {
- if let Expr::Literal(ScalarValue::Utf8(None)) = expr {
+ if let Expr::Literal(ScalarValue::Null) = expr {
nulls.push((i, j));
Ok(field_types[j].clone())
} else {
diff --git a/datafusion/core/src/physical_plan/hash_join.rs b/datafusion/core/src/physical_plan/hash_join.rs
index 8a4a342c1..ee763241a 100644
--- a/datafusion/core/src/physical_plan/hash_join.rs
+++ b/datafusion/core/src/physical_plan/hash_join.rs
@@ -817,7 +817,11 @@ fn equal_rows(
.iter()
.zip(right_arrays)
.all(|(l, r)| match l.data_type() {
- DataType::Null => true,
+ DataType::Null => {
+ // lhs and rhs are both `DataType::Null`, so the euqal result
+ // is dependent on `null_equals_null`
+ null_equals_null
+ }
DataType::Boolean => {
equal_rows_elem!(BooleanArray, l, r, left, right, null_equals_null)
}
diff --git a/datafusion/core/src/physical_plan/hash_utils.rs b/datafusion/core/src/physical_plan/hash_utils.rs
index 4e503b19e..2ca1fa3df 100644
--- a/datafusion/core/src/physical_plan/hash_utils.rs
+++ b/datafusion/core/src/physical_plan/hash_utils.rs
@@ -39,6 +39,19 @@ fn combine_hashes(l: u64, r: u64) -> u64 {
hash.wrapping_mul(37).wrapping_add(r)
}
+fn hash_null(random_state: &RandomState, hashes_buffer: &'_ mut [u64], mul_col: bool) {
+ if mul_col {
+ hashes_buffer.iter_mut().for_each(|hash| {
+ // stable hash for null value
+ *hash = combine_hashes(i128::get_hash(&1, random_state), *hash);
+ })
+ } else {
+ hashes_buffer.iter_mut().for_each(|hash| {
+ *hash = i128::get_hash(&1, random_state);
+ })
+ }
+}
+
fn hash_decimal128<'a>(
array: &ArrayRef,
random_state: &RandomState,
@@ -284,6 +297,9 @@ pub fn create_hashes<'a>(
for col in arrays {
match col.data_type() {
+ DataType::Null => {
+ hash_null(random_state, hashes_buffer, multi_col);
+ }
DataType::Decimal(_, _) => {
hash_decimal128(col, random_state, hashes_buffer, multi_col);
}
diff --git a/datafusion/core/src/sql/planner.rs b/datafusion/core/src/sql/planner.rs
index fe737d6e8..33915300c 100644
--- a/datafusion/core/src/sql/planner.rs
+++ b/datafusion/core/src/sql/planner.rs
@@ -1531,7 +1531,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
SQLExpr::Value(Value::Number(n, _)) => parse_sql_number(&n),
SQLExpr::Value(Value::SingleQuotedString(s)) => Ok(lit(s)),
SQLExpr::Value(Value::Null) => {
- Ok(Expr::Literal(ScalarValue::Utf8(None)))
+ Ok(Expr::Literal(ScalarValue::Null))
}
SQLExpr::Value(Value::Boolean(n)) => Ok(lit(n)),
SQLExpr::UnaryOp { op, expr } => self.parse_sql_unary_op(
@@ -1569,7 +1569,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
SQLExpr::Value(Value::Number(n, _)) => parse_sql_number(&n),
SQLExpr::Value(Value::SingleQuotedString(ref s)) => Ok(lit(s.clone())),
SQLExpr::Value(Value::Boolean(n)) => Ok(lit(n)),
- SQLExpr::Value(Value::Null) => Ok(Expr::Literal(ScalarValue::Utf8(None))),
+ SQLExpr::Value(Value::Null) => Ok(Expr::Literal(ScalarValue::Null)),
SQLExpr::Extract { field, expr } => Ok(Expr::ScalarFunction {
fun: BuiltinScalarFunction::DatePart,
args: vec![
diff --git a/datafusion/core/tests/sql/expr.rs b/datafusion/core/tests/sql/expr.rs
index e62acc502..1dffc2eb9 100644
--- a/datafusion/core/tests/sql/expr.rs
+++ b/datafusion/core/tests/sql/expr.rs
@@ -121,25 +121,25 @@ async fn case_when_else_with_null_contant() -> Result<()> {
FROM t1";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
- "+----------------------------------------------------------------------------------------------+",
- "| CASE WHEN #t1.c1 = Utf8(\"a\") THEN Int64(1) WHEN Utf8(NULL) THEN Int64(2) ELSE Int64(999) END |",
- "+----------------------------------------------------------------------------------------------+",
- "| 1 |",
- "| 999 |",
- "| 999 |",
- "| 999 |",
- "+----------------------------------------------------------------------------------------------+",
+ "+----------------------------------------------------------------------------------------+",
+ "| CASE WHEN #t1.c1 = Utf8(\"a\") THEN Int64(1) WHEN NULL THEN Int64(2) ELSE Int64(999) END |",
+ "+----------------------------------------------------------------------------------------+",
+ "| 1 |",
+ "| 999 |",
+ "| 999 |",
+ "| 999 |",
+ "+----------------------------------------------------------------------------------------+",
];
assert_batches_eq!(expected, &actual);
let sql = "SELECT CASE WHEN NULL THEN 'foo' ELSE 'bar' END";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
- "+------------------------------------------------------------+",
- "| CASE WHEN Utf8(NULL) THEN Utf8(\"foo\") ELSE Utf8(\"bar\") END |",
- "+------------------------------------------------------------+",
- "| bar |",
- "+------------------------------------------------------------+",
+ "+------------------------------------------------------+",
+ "| CASE WHEN NULL THEN Utf8(\"foo\") ELSE Utf8(\"bar\") END |",
+ "+------------------------------------------------------+",
+ "| bar |",
+ "+------------------------------------------------------+",
];
assert_batches_eq!(expected, &actual);
Ok(())
@@ -347,11 +347,11 @@ async fn test_string_concat_operator() -> Result<()> {
let sql = "SELECT 'aa' || NULL || 'd'";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
- "+---------------------------------------+",
- "| Utf8(\"aa\") || Utf8(NULL) || Utf8(\"d\") |",
- "+---------------------------------------+",
- "| |",
- "+---------------------------------------+",
+ "+---------------------------------+",
+ "| Utf8(\"aa\") || NULL || Utf8(\"d\") |",
+ "+---------------------------------+",
+ "| |",
+ "+---------------------------------+",
];
assert_batches_eq!(expected, &actual);
@@ -387,11 +387,11 @@ async fn test_not_expressions() -> Result<()> {
let sql = "SELECT null, not(null)";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
- "+------------+----------------+",
- "| Utf8(NULL) | NOT Utf8(NULL) |",
- "+------------+----------------+",
- "| | |",
- "+------------+----------------+",
+ "+------+----------+",
+ "| NULL | NOT NULL |",
+ "+------+----------+",
+ "| | |",
+ "+------+----------+",
];
assert_batches_eq!(expected, &actual);
diff --git a/datafusion/core/tests/sql/functions.rs b/datafusion/core/tests/sql/functions.rs
index 857781aa3..396bd1194 100644
--- a/datafusion/core/tests/sql/functions.rs
+++ b/datafusion/core/tests/sql/functions.rs
@@ -176,11 +176,11 @@ async fn coalesce_static_value_with_null() -> Result<()> {
let sql = "SELECT COALESCE(NULL, 'test')";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
- "+-----------------------------------+",
- "| coalesce(Utf8(NULL),Utf8(\"test\")) |",
- "+-----------------------------------+",
- "| test |",
- "+-----------------------------------+",
+ "+-----------------------------+",
+ "| coalesce(NULL,Utf8(\"test\")) |",
+ "+-----------------------------+",
+ "| test |",
+ "+-----------------------------+",
];
assert_batches_eq!(expected, &actual);
Ok(())
diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs
index aaa8adac5..312b687a6 100644
--- a/datafusion/core/tests/sql/joins.rs
+++ b/datafusion/core/tests/sql/joins.rs
@@ -829,7 +829,11 @@ async fn inner_join_nulls() {
let sql = "SELECT * FROM (SELECT null AS id1) t1
INNER JOIN (SELECT null AS id2) t2 ON id1 = id2";
- let expected = vec!["++", "++"];
+ #[rustfmt::skip]
+ let expected = vec![
+ "++",
+ "++",
+ ];
let ctx = create_join_context_qualified().unwrap();
let actual = execute_to_batches(&ctx, sql).await;
diff --git a/datafusion/core/tests/sql/select.rs b/datafusion/core/tests/sql/select.rs
index 747a9e05a..4ab3a83be 100644
--- a/datafusion/core/tests/sql/select.rs
+++ b/datafusion/core/tests/sql/select.rs
@@ -398,15 +398,37 @@ async fn select_distinct_from() {
1 IS NOT DISTINCT FROM CAST(NULL as INT) as c,
1 IS NOT DISTINCT FROM 1 as d,
NULL IS DISTINCT FROM NULL as e,
- NULL IS NOT DISTINCT FROM NULL as f
+ NULL IS NOT DISTINCT FROM NULL as f,
+ NULL is DISTINCT FROM 1 as g,
+ NULL is NOT DISTINCT FROM 1 as h
";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
- "+------+-------+-------+------+-------+------+",
- "| a | b | c | d | e | f |",
- "+------+-------+-------+------+-------+------+",
- "| true | false | false | true | false | true |",
- "+------+-------+-------+------+-------+------+",
+ "+------+-------+-------+------+-------+------+------+-------+",
+ "| a | b | c | d | e | f | g | h |",
+ "+------+-------+-------+------+-------+------+------+-------+",
+ "| true | false | false | true | false | true | true | false |",
+ "+------+-------+-------+------+-------+------+------+-------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+
+ let sql = "select
+ NULL IS DISTINCT FROM NULL as a,
+ NULL IS NOT DISTINCT FROM NULL as b,
+ NULL is DISTINCT FROM 1 as c,
+ NULL is NOT DISTINCT FROM 1 as d,
+ 1 IS DISTINCT FROM CAST(NULL as INT) as e,
+ 1 IS DISTINCT FROM 1 as f,
+ 1 IS NOT DISTINCT FROM CAST(NULL as INT) as g,
+ 1 IS NOT DISTINCT FROM 1 as h
+ ";
+ let actual = execute_to_batches(&ctx, sql).await;
+ let expected = vec![
+ "+-------+------+------+-------+------+-------+-------+------+",
+ "| a | b | c | d | e | f | g | h |",
+ "+-------+------+------+-------+------+-------+-------+------+",
+ "| false | true | true | false | true | false | false | true |",
+ "+-------+------+------+-------+------+-------+-------+------+",
];
assert_batches_eq!(expected, &actual);
}
diff --git a/datafusion/expr/src/binary_rule.rs b/datafusion/expr/src/binary_rule.rs
index ad46770f1..63a9712fd 100644
--- a/datafusion/expr/src/binary_rule.rs
+++ b/datafusion/expr/src/binary_rule.rs
@@ -590,8 +590,30 @@ fn eq_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
numerical_coercion(lhs_type, rhs_type)
.or_else(|| dictionary_coercion(lhs_type, rhs_type))
.or_else(|| temporal_coercion(lhs_type, rhs_type))
+ .or_else(|| null_coercion(lhs_type, rhs_type))
}
+/// coercion rules from NULL type. Since NULL can be casted to most of types in arrow,
+/// either lhs or rhs is NULL, if NULL can be casted to type of the other side, the coecion is valid.
+fn null_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
+ match (lhs_type, rhs_type) {
+ (DataType::Null, _) => {
+ if can_cast_types(&DataType::Null, rhs_type) {
+ Some(rhs_type.clone())
+ } else {
+ None
+ }
+ }
+ (_, DataType::Null) => {
+ if can_cast_types(&DataType::Null, lhs_type) {
+ Some(lhs_type.clone())
+ } else {
+ None
+ }
+ }
+ _ => None,
+ }
+}
#[cfg(test)]
mod tests {
use super::*;
diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs
index 385e247bd..d631e0f83 100644
--- a/datafusion/expr/src/function.rs
+++ b/datafusion/expr/src/function.rs
@@ -58,6 +58,7 @@ macro_rules! make_utf8_to_return_type {
Ok(match arg_type {
DataType::LargeUtf8 => $largeUtf8Type,
DataType::Utf8 => $utf8Type,
+ DataType::Null => DataType::Null,
_ => {
// this error is internal as `data_types` should have captured this.
return Err(DataFusionError::Internal(format!(
@@ -209,6 +210,7 @@ pub fn return_type(
DataType::Utf8 => {
DataType::List(Box::new(Field::new("item", DataType::Utf8, true)))
}
+ DataType::Null => DataType::Null,
_ => {
// this error is internal as `data_types` should have captured this.
return Err(DataFusionError::Internal(
diff --git a/datafusion/expr/src/type_coercion.rs b/datafusion/expr/src/type_coercion.rs
index 8cea256f1..33a540d6f 100644
--- a/datafusion/expr/src/type_coercion.rs
+++ b/datafusion/expr/src/type_coercion.rs
@@ -31,7 +31,10 @@
//!
use crate::{Signature, TypeSignature};
-use arrow::datatypes::{DataType, TimeUnit};
+use arrow::{
+ compute::can_cast_types,
+ datatypes::{DataType, TimeUnit},
+};
use datafusion_common::{DataFusionError, Result};
/// Returns the data types that each argument must be coerced to match
@@ -142,25 +145,35 @@ fn maybe_data_types(
/// See the module level documentation for more detail on coercion.
pub fn can_coerce_from(type_into: &DataType, type_from: &DataType) -> bool {
use self::DataType::*;
+ // Null can convert to most of types
match type_into {
- Int8 => matches!(type_from, Int8),
- Int16 => matches!(type_from, Int8 | Int16 | UInt8),
- Int32 => matches!(type_from, Int8 | Int16 | Int32 | UInt8 | UInt16),
+ Int8 => matches!(type_from, Null | Int8),
+ Int16 => matches!(type_from, Null | Int8 | Int16 | UInt8),
+ Int32 => matches!(type_from, Null | Int8 | Int16 | Int32 | UInt8 | UInt16),
Int64 => matches!(
type_from,
- Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32
+ Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32
),
- UInt8 => matches!(type_from, UInt8),
- UInt16 => matches!(type_from, UInt8 | UInt16),
- UInt32 => matches!(type_from, UInt8 | UInt16 | UInt32),
- UInt64 => matches!(type_from, UInt8 | UInt16 | UInt32 | UInt64),
+ UInt8 => matches!(type_from, Null | UInt8),
+ UInt16 => matches!(type_from, Null | UInt8 | UInt16),
+ UInt32 => matches!(type_from, Null | UInt8 | UInt16 | UInt32),
+ UInt64 => matches!(type_from, Null | UInt8 | UInt16 | UInt32 | UInt64),
Float32 => matches!(
type_from,
- Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 | Float32
+ Null | Int8
+ | Int16
+ | Int32
+ | Int64
+ | UInt8
+ | UInt16
+ | UInt32
+ | UInt64
+ | Float32
),
Float64 => matches!(
type_from,
- Int8 | Int16
+ Null | Int8
+ | Int16
| Int32
| Int64
| UInt8
@@ -171,8 +184,11 @@ pub fn can_coerce_from(type_into: &DataType, type_from: &DataType) -> bool {
| Float64
| Decimal(_, _)
),
- Timestamp(TimeUnit::Nanosecond, None) => matches!(type_from, Timestamp(_, None)),
+ Timestamp(TimeUnit::Nanosecond, None) => {
+ matches!(type_from, Null | Timestamp(_, None))
+ }
Utf8 | LargeUtf8 => true,
+ Null => can_cast_types(type_from, type_into),
_ => false,
}
}
diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs
index 6dafb43f9..060f30cb2 100644
--- a/datafusion/physical-expr/src/expressions/binary.rs
+++ b/datafusion/physical-expr/src/expressions/binary.rs
@@ -604,6 +604,20 @@ macro_rules! compute_decimal_op {
}};
}
+macro_rules! compute_null_op {
+ ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{
+ let ll = $LEFT
+ .as_any()
+ .downcast_ref::<$DT>()
+ .expect("compute_op failed to downcast array");
+ let rr = $RIGHT
+ .as_any()
+ .downcast_ref::<$DT>()
+ .expect("compute_op failed to downcast array");
+ Ok(Arc::new(paste::expr! {[<$OP _null>]}(&ll, &rr)?))
+ }};
+}
+
/// Invoke a compute kernel on a pair of binary data arrays
macro_rules! compute_utf8_op {
($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{
@@ -909,6 +923,7 @@ macro_rules! binary_array_op_scalar {
macro_rules! binary_array_op {
($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
match $LEFT.data_type() {
+ DataType::Null => compute_null_op!($LEFT, $RIGHT, $OP, NullArray),
DataType::Decimal(_,_) => compute_decimal_op!($LEFT, $RIGHT, $OP, DecimalArray),
DataType::Int8 => compute_op!($LEFT, $RIGHT, $OP, Int8Array),
DataType::Int16 => compute_op!($LEFT, $RIGHT, $OP, Int16Array),
@@ -1261,7 +1276,16 @@ impl BinaryExpr {
Operator::GtEq => gt_eq_dyn(&left, &right),
Operator::Eq => eq_dyn(&left, &right),
Operator::NotEq => neq_dyn(&left, &right),
- Operator::IsDistinctFrom => binary_array_op!(left, right, is_distinct_from),
+ Operator::IsDistinctFrom => {
+ match (left_data_type, right_data_type) {
+ // exchange lhs and rhs when lhs is Null, since `binary_array_op` is
+ // always try to down cast array according to $LEFT expression.
+ (DataType::Null, _) => {
+ binary_array_op!(right, left, is_distinct_from)
+ }
+ _ => binary_array_op!(left, right, is_distinct_from),
+ }
+ }
Operator::IsNotDistinctFrom => {
binary_array_op!(left, right, is_not_distinct_from)
}
@@ -1336,6 +1360,27 @@ fn is_distinct_from_utf8<OffsetSize: StringOffsetSizeTrait>(
.collect())
}
+fn is_distinct_from_null(left: &NullArray, _right: &NullArray) -> Result<BooleanArray> {
+ let length = left.len();
+ make_boolean_array(length, false)
+}
+
+fn is_not_distinct_from_null(
+ left: &NullArray,
+ _right: &NullArray,
+) -> Result<BooleanArray> {
+ let length = left.len();
+ make_boolean_array(length, true)
+}
+
+pub fn eq_null(left: &NullArray, _right: &NullArray) -> Result<BooleanArray> {
+ Ok((0..left.len()).into_iter().map(|_| None).collect())
+}
+
+fn make_boolean_array(length: usize, value: bool) -> Result<BooleanArray> {
+ Ok((0..length).into_iter().map(|_| Some(value)).collect())
+}
+
fn is_not_distinct_from<T>(
left: &PrimitiveArray<T>,
right: &PrimitiveArray<T>,
diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs
index a6894b938..7094a718d 100644
--- a/datafusion/physical-expr/src/expressions/in_list.rs
+++ b/datafusion/physical-expr/src/expressions/in_list.rs
@@ -629,6 +629,10 @@ impl PhysicalExpr for InListExpr {
DataType::LargeUtf8 => {
self.compare_utf8::<i64>(array, list_values, self.negated)
}
+ DataType::Null => {
+ let null_array = new_null_array(&DataType::Boolean, array.len());
+ Ok(ColumnarValue::Array(Arc::new(null_array)))
+ }
datatype => Result::Err(DataFusionError::NotImplemented(format!(
"InList does not support datatype {:?}.",
datatype
diff --git a/datafusion/physical-expr/src/expressions/nullif.rs b/datafusion/physical-expr/src/expressions/nullif.rs
index 307e3a07f..2d1f3654d 100644
--- a/datafusion/physical-expr/src/expressions/nullif.rs
+++ b/datafusion/physical-expr/src/expressions/nullif.rs
@@ -17,7 +17,7 @@
use std::sync::Arc;
-use crate::expressions::binary::{eq_decimal, eq_decimal_scalar};
+use crate::expressions::binary::{eq_decimal, eq_decimal_scalar, eq_null};
use arrow::array::Array;
use arrow::array::*;
use arrow::compute::kernels::boolean::nullif;