You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by jo...@apache.org on 2020/12/07 06:01:53 UTC
[arrow] branch master updated: ARROW-10821 [Rust][Datafusion]
support negative expression
This is an automated email from the ASF dual-hosted git repository.
jorgecarleitao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new e1c1e05 ARROW-10821 [Rust][Datafusion] support negative expression
e1c1e05 is described below
commit e1c1e054ff8a00353a975e5139277573921eac0a
Author: Qingping Hou <da...@gmail.com>
AuthorDate: Mon Dec 7 07:00:46 2020 +0100
ARROW-10821 [Rust][Datafusion] support negative expression
To support queries like `SELECT c3 FROM aggregate_test_100 WHERE -c4 > 0`:
* add negate compute kernel in arrow
* add negative expression in datafusion
* support negative and positive operators in datafusion's sql planner
Closes #8846 from houqp/qp_negative
Authored-by: Qingping Hou <da...@gmail.com>
Signed-off-by: Jorge C. Leitao <jo...@gmail.com>
---
rust/arrow/src/array/cast.rs | 1 +
rust/arrow/src/array/mod.rs | 4 +-
rust/arrow/src/compute/kernels/arithmetic.rs | 88 +++++++++++++++++--
rust/arrow/src/datatypes.rs | 75 ++++++++++++++++
rust/datafusion/src/logical_plan/expr.rs | 9 ++
rust/datafusion/src/optimizer/utils.rs | 6 +-
rust/datafusion/src/physical_plan/expressions.rs | 105 +++++++++++++++++++++--
rust/datafusion/src/physical_plan/planner.rs | 4 +
rust/datafusion/src/scalar.rs | 19 ++++
rust/datafusion/src/sql/planner.rs | 49 ++++++++---
rust/datafusion/tests/sql.rs | 11 +++
11 files changed, 346 insertions(+), 25 deletions(-)
diff --git a/rust/arrow/src/array/cast.rs b/rust/arrow/src/array/cast.rs
index 56e5d3a..a0ef7e2 100644
--- a/rust/arrow/src/array/cast.rs
+++ b/rust/arrow/src/array/cast.rs
@@ -59,5 +59,6 @@ macro_rules! array_downcast_fn {
}
array_downcast_fn!(as_string_array, StringArray);
+array_downcast_fn!(as_largestring_array, LargeStringArray);
array_downcast_fn!(as_boolean_array, BooleanArray);
array_downcast_fn!(as_null_array, NullArray);
diff --git a/rust/arrow/src/array/mod.rs b/rust/arrow/src/array/mod.rs
index fb0b302..cb1c13e 100644
--- a/rust/arrow/src/array/mod.rs
+++ b/rust/arrow/src/array/mod.rs
@@ -270,8 +270,8 @@ pub use self::ord::{build_compare, DynComparator};
// --------------------- Array downcast helper functions ---------------------
pub use self::cast::{
- as_boolean_array, as_dictionary_array, as_null_array, as_primitive_array,
- as_string_array,
+ as_boolean_array, as_dictionary_array, as_largestring_array, as_null_array,
+ as_primitive_array, as_string_array,
};
// ------------------------------ C Data Interface ---------------------------
diff --git a/rust/arrow/src/compute/kernels/arithmetic.rs b/rust/arrow/src/compute/kernels/arithmetic.rs
index fe1bda5..e0bd37d 100644
--- a/rust/arrow/src/compute/kernels/arithmetic.rs
+++ b/rust/arrow/src/compute/kernels/arithmetic.rs
@@ -24,7 +24,7 @@
#[cfg(feature = "simd")]
use std::mem;
-use std::ops::{Add, Div, Mul, Sub};
+use std::ops::{Add, Div, Mul, Neg, Sub};
#[cfg(feature = "simd")]
use std::slice::from_raw_parts_mut;
use std::sync::Arc;
@@ -44,6 +44,72 @@ use crate::datatypes::ToByteSlice;
use crate::error::{ArrowError, Result};
use crate::{array::*, util::bit_util};
+/// Helper function to perform math lambda function on values from single array of signed numeric
+/// type. If value is null then the output value is also null, so `-null` is `null`.
+pub fn signed_unary_math_op<T, F>(
+ array: &PrimitiveArray<T>,
+ op: F,
+) -> Result<PrimitiveArray<T>>
+where
+ T: datatypes::ArrowSignedNumericType,
+ T::Native: Neg<Output = T::Native>,
+ F: Fn(T::Native) -> T::Native,
+{
+ let values = (0..array.len())
+ .map(|i| op(array.value(i)))
+ .collect::<Vec<T::Native>>();
+
+ let data = ArrayData::new(
+ T::DATA_TYPE,
+ array.len(),
+ None,
+ array.data_ref().null_buffer().cloned(),
+ 0,
+ vec![Buffer::from(values.to_byte_slice())],
+ vec![],
+ );
+ Ok(PrimitiveArray::<T>::from(Arc::new(data)))
+}
+
+/// SIMD vectorized version of `signed_unary_math_op` above.
+#[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))]
+fn simd_signed_unary_math_op<T, F>(
+ array: &PrimitiveArray<T>,
+ op: F,
+) -> Result<PrimitiveArray<T>>
+where
+ T: datatypes::ArrowSignedNumericType,
+ F: Fn(T::SignedSimd) -> T::SignedSimd,
+{
+ let lanes = T::lanes();
+ let buffer_size = array.len() * mem::size_of::<T::Native>();
+ let mut result = MutableBuffer::new(buffer_size).with_bitset(buffer_size, false);
+
+ for i in (0..array.len()).step_by(lanes) {
+ let simd_result =
+ T::signed_unary_op(T::load_signed(array.value_slice(i, lanes)), &op);
+
+ let result_slice: &mut [T::Native] = unsafe {
+ from_raw_parts_mut(
+ (result.data_mut().as_mut_ptr() as *mut T::Native).add(i),
+ lanes,
+ )
+ };
+ T::write_signed(simd_result, result_slice);
+ }
+
+ let data = ArrayData::new(
+ T::DATA_TYPE,
+ array.len(),
+ None,
+ array.data_ref().null_buffer().cloned(),
+ 0,
+ vec![result.freeze()],
+ vec![],
+ );
+ Ok(PrimitiveArray::<T>::from(Arc::new(data)))
+}
+
/// Helper function to perform math lambda function on values from two arrays. If either
/// left or right value is null then the output value is also null, so `1 + null` is
/// `null`.
@@ -159,10 +225,6 @@ fn simd_math_op<T, F>(
) -> Result<PrimitiveArray<T>>
where
T: datatypes::ArrowNumericType,
- T::Simd: Add<Output = T::Simd>
- + Sub<Output = T::Simd>
- + Mul<Output = T::Simd>
- + Div<Output = T::Simd>,
F: Fn(T::Simd, T::Simd) -> T::Simd,
{
if left.len() != right.len() {
@@ -304,6 +366,22 @@ where
math_op(left, right, |a, b| a - b)
}
+/// Perform `-` operation on an array. If value is null then the result is also null.
+pub fn negate<T>(array: &PrimitiveArray<T>) -> Result<PrimitiveArray<T>>
+where
+ T: datatypes::ArrowSignedNumericType,
+ T::Native: Neg<Output = T::Native>,
+{
+ #[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))]
+ return simd_signed_unary_math_op(array, |x| -x);
+
+ #[cfg(any(
+ not(any(target_arch = "x86", target_arch = "x86_64")),
+ not(feature = "simd")
+ ))]
+ return signed_unary_math_op(array, |x| -x);
+}
+
/// Perform `left * right` operation on two arrays. If either left or right value is null
/// then the result is also null.
pub fn multiply<T>(
diff --git a/rust/arrow/src/datatypes.rs b/rust/arrow/src/datatypes.rs
index c4b0109..e1e99b9 100644
--- a/rust/arrow/src/datatypes.rs
+++ b/rust/arrow/src/datatypes.rs
@@ -26,6 +26,7 @@ use std::collections::HashMap;
use std::default::Default;
use std::fmt;
use std::mem::size_of;
+use std::ops::Neg;
#[cfg(feature = "simd")]
use std::ops::{Add, Div, Mul, Sub};
use std::slice::from_raw_parts;
@@ -810,6 +811,80 @@ make_numeric_type!(DurationMillisecondType, i64, i64x8, m64x8);
make_numeric_type!(DurationMicrosecondType, i64, i64x8, m64x8);
make_numeric_type!(DurationNanosecondType, i64, i64x8, m64x8);
+/// A subtype of primitive type that represents signed numeric values.
+///
+/// SIMD operations are defined in this trait if available on the target system.
+#[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))]
+pub trait ArrowSignedNumericType: ArrowNumericType
+where
+ Self::SignedSimd: Neg<Output = Self::SignedSimd>,
+{
+ /// Defines the SIMD type that should be used for this numeric type
+ type SignedSimd;
+
+ /// Loads a slice of signed numeric type into a SIMD register
+ fn load_signed(slice: &[Self::Native]) -> Self::SignedSimd;
+
+ /// Performs a SIMD unary operation on signed numeric type
+ fn signed_unary_op<F: Fn(Self::SignedSimd) -> Self::SignedSimd>(
+ a: Self::SignedSimd,
+ op: F,
+ ) -> Self::SignedSimd;
+
+ /// Writes a signed SIMD result back to a slice
+ fn write_signed(simd_result: Self::SignedSimd, slice: &mut [Self::Native]);
+}
+
+#[cfg(any(
+ not(any(target_arch = "x86", target_arch = "x86_64")),
+ not(feature = "simd")
+))]
+pub trait ArrowSignedNumericType: ArrowNumericType
+where
+ Self::Native: Neg<Output = Self::Native>,
+{
+}
+
+macro_rules! make_signed_numeric_type {
+ ($impl_ty:ty, $simd_ty:ident) => {
+ #[cfg(all(any(target_arch = "x86", target_arch = "x86_64"), feature = "simd"))]
+ impl ArrowSignedNumericType for $impl_ty {
+ type SignedSimd = $simd_ty;
+
+ #[inline]
+ fn load_signed(slice: &[Self::Native]) -> Self::SignedSimd {
+ unsafe { Self::SignedSimd::from_slice_unaligned_unchecked(slice) }
+ }
+
+ #[inline]
+ fn signed_unary_op<F: Fn(Self::SignedSimd) -> Self::SignedSimd>(
+ a: Self::SignedSimd,
+ op: F,
+ ) -> Self::SignedSimd {
+ op(a)
+ }
+
+ #[inline]
+ fn write_signed(simd_result: Self::SignedSimd, slice: &mut [Self::Native]) {
+ unsafe { simd_result.write_to_slice_unaligned_unchecked(slice) };
+ }
+ }
+
+ #[cfg(any(
+ not(any(target_arch = "x86", target_arch = "x86_64")),
+ not(feature = "simd")
+ ))]
+ impl ArrowSignedNumericType for $impl_ty {}
+ };
+}
+
+make_signed_numeric_type!(Int8Type, i8x64);
+make_signed_numeric_type!(Int16Type, i16x32);
+make_signed_numeric_type!(Int32Type, i32x16);
+make_signed_numeric_type!(Int64Type, i64x8);
+make_signed_numeric_type!(Float32Type, f32x16);
+make_signed_numeric_type!(Float64Type, f64x8);
+
/// A subtype of primitive type that represents temporal values.
pub trait ArrowTemporalType: ArrowPrimitiveType {}
diff --git a/rust/datafusion/src/logical_plan/expr.rs b/rust/datafusion/src/logical_plan/expr.rs
index 78b2e60..43edf62 100644
--- a/rust/datafusion/src/logical_plan/expr.rs
+++ b/rust/datafusion/src/logical_plan/expr.rs
@@ -76,6 +76,8 @@ pub enum Expr {
IsNotNull(Box<Expr>),
/// Whether an expression is Null. This expression is never null.
IsNull(Box<Expr>),
+ /// arithmetic negation of an expression, the operand must be of a signed numeric data type
+ Negative(Box<Expr>),
/// The CASE expression is similar to a series of nested if/else and there are two forms that
/// can be used. The first form consists of a series of boolean "when" expressions with
/// corresponding "then" expressions, and an optional "else" expression.
@@ -196,6 +198,7 @@ impl Expr {
Ok((fun.return_type)(&data_types)?.as_ref().clone())
}
Expr::Not(_) => Ok(DataType::Boolean),
+ Expr::Negative(expr) => expr.get_type(schema),
Expr::IsNull(_) => Ok(DataType::Boolean),
Expr::IsNotNull(_) => Ok(DataType::Boolean),
Expr::BinaryExpr {
@@ -250,6 +253,7 @@ impl Expr {
Expr::AggregateFunction { .. } => Ok(true),
Expr::AggregateUDF { .. } => Ok(true),
Expr::Not(expr) => expr.nullable(input_schema),
+ Expr::Negative(expr) => expr.nullable(input_schema),
Expr::IsNull(_) => Ok(false),
Expr::IsNotNull(_) => Ok(false),
Expr::BinaryExpr {
@@ -729,6 +733,7 @@ impl fmt::Debug for Expr {
write!(f, "CAST({:?} AS {:?})", expr, data_type)
}
Expr::Not(expr) => write!(f, "NOT {:?}", expr),
+ Expr::Negative(expr) => write!(f, "(- {:?})", expr),
Expr::IsNull(expr) => write!(f, "{:?} IS NULL", expr),
Expr::IsNotNull(expr) => write!(f, "{:?} IS NOT NULL", expr),
Expr::BinaryExpr { left, op, right } => {
@@ -826,6 +831,10 @@ fn create_name(e: &Expr, input_schema: &Schema) -> Result<String> {
let expr = create_name(expr, input_schema)?;
Ok(format!("NOT {}", expr))
}
+ Expr::Negative(expr) => {
+ let expr = create_name(expr, input_schema)?;
+ Ok(format!("(- {})", expr))
+ }
Expr::IsNull(expr) => {
let expr = create_name(expr, input_schema)?;
Ok(format!("{} IS NULL", expr))
diff --git a/rust/datafusion/src/optimizer/utils.rs b/rust/datafusion/src/optimizer/utils.rs
index 7b269be..ade195a 100644
--- a/rust/datafusion/src/optimizer/utils.rs
+++ b/rust/datafusion/src/optimizer/utils.rs
@@ -60,6 +60,7 @@ pub fn expr_to_column_names(expr: &Expr, accum: &mut HashSet<String>) -> Result<
Ok(())
}
Expr::Not(e) => expr_to_column_names(e, accum),
+ Expr::Negative(e) => expr_to_column_names(e, accum),
Expr::IsNull(e) => expr_to_column_names(e, accum),
Expr::IsNotNull(e) => expr_to_column_names(e, accum),
Expr::BinaryExpr { left, right, .. } => {
@@ -277,6 +278,7 @@ pub fn expr_sub_expressions(expr: &Expr) -> Result<Vec<Expr>> {
Expr::Literal(_) => Ok(vec![]),
Expr::ScalarVariable(_) => Ok(vec![]),
Expr::Not(expr) => Ok(vec![expr.as_ref().to_owned()]),
+ Expr::Negative(expr) => Ok(vec![expr.as_ref().to_owned()]),
Expr::Sort { expr, .. } => Ok(vec![expr.as_ref().to_owned()]),
Expr::Wildcard { .. } => Err(DataFusionError::Internal(
"Wildcard expressions are not valid in a logical query plan".to_owned(),
@@ -284,7 +286,8 @@ pub fn expr_sub_expressions(expr: &Expr) -> Result<Vec<Expr>> {
}
}
-/// returns a new expression where the expressions in expr are replaced by the ones in `expr`.
+/// returns a new expression where the expressions in `expr` are replaced by the ones in
+/// `expressions`.
/// This is used in conjunction with ``expr_expressions`` to re-write expressions.
pub fn rewrite_expression(expr: &Expr, expressions: &Vec<Expr>) -> Result<Expr> {
match expr {
@@ -356,6 +359,7 @@ pub fn rewrite_expression(expr: &Expr, expressions: &Vec<Expr>) -> Result<Expr>
Ok(Expr::Alias(Box::new(expressions[0].clone()), alias.clone()))
}
Expr::Not(_) => Ok(Expr::Not(Box::new(expressions[0].clone()))),
+ Expr::Negative(_) => Ok(Expr::Negative(Box::new(expressions[0].clone()))),
Expr::Column(_) => Ok(expr.clone()),
Expr::Literal(_) => Ok(expr.clone()),
Expr::ScalarVariable(_) => Ok(expr.clone()),
diff --git a/rust/datafusion/src/physical_plan/expressions.rs b/rust/datafusion/src/physical_plan/expressions.rs
index d82c63b..9718f46 100644
--- a/rust/datafusion/src/physical_plan/expressions.rs
+++ b/rust/datafusion/src/physical_plan/expressions.rs
@@ -29,7 +29,7 @@ use crate::scalar::ScalarValue;
use arrow::array::{self, Array, BooleanBuilder, LargeStringArray};
use arrow::compute;
use arrow::compute::kernels;
-use arrow::compute::kernels::arithmetic::{add, divide, multiply, subtract};
+use arrow::compute::kernels::arithmetic::{add, divide, multiply, negate, subtract};
use arrow::compute::kernels::boolean::{and, nullif, or};
use arrow::compute::kernels::comparison::{eq, gt, gt_eq, lt, lt_eq, neq};
use arrow::compute::kernels::comparison::{
@@ -1009,8 +1009,9 @@ macro_rules! compute_op_scalar {
}};
}
-/// Invoke a compute kernel on a pair of arrays
+/// Invoke a compute kernel on array(s)
macro_rules! compute_op {
+ // invoke binary operator
($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{
let ll = $LEFT
.as_any()
@@ -1022,6 +1023,14 @@ macro_rules! compute_op {
.expect("compute_op failed to downcast array");
Ok(Arc::new($OP(&ll, &rr)?))
}};
+ // invoke unary operator
+ ($OPERAND:expr, $OP:ident, $DT:ident) => {{
+ let operand = $OPERAND
+ .as_any()
+ .downcast_ref::<$DT>()
+ .expect("compute_op failed to downcast array");
+ Ok(Arc::new($OP(&operand)?))
+ }};
}
macro_rules! binary_string_array_op_scalar {
@@ -1682,6 +1691,82 @@ pub fn not(
}
}
+/// Negative expression
+#[derive(Debug)]
+pub struct NegativeExpr {
+ arg: Arc<dyn PhysicalExpr>,
+}
+
+impl NegativeExpr {
+ /// Create new not expression
+ pub fn new(arg: Arc<dyn PhysicalExpr>) -> Self {
+ Self { arg }
+ }
+}
+
+impl fmt::Display for NegativeExpr {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ write!(f, "(- {})", self.arg)
+ }
+}
+
+impl PhysicalExpr for NegativeExpr {
+ fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
+ self.arg.data_type(input_schema)
+ }
+
+ fn nullable(&self, input_schema: &Schema) -> Result<bool> {
+ self.arg.nullable(input_schema)
+ }
+
+ fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
+ let arg = self.arg.evaluate(batch)?;
+ match arg {
+ ColumnarValue::Array(array) => {
+ let result: Result<ArrayRef> = match array.data_type() {
+ DataType::Int8 => compute_op!(array, negate, Int8Array),
+ DataType::Int16 => compute_op!(array, negate, Int16Array),
+ DataType::Int32 => compute_op!(array, negate, Int32Array),
+ DataType::Int64 => compute_op!(array, negate, Int64Array),
+ DataType::Float32 => compute_op!(array, negate, Float32Array),
+ DataType::Float64 => compute_op!(array, negate, Float64Array),
+ _ => Err(DataFusionError::Internal(format!(
+ "(- '{:?}') can't be evaluated because the expression's type is {:?}, not signed numeric",
+ self,
+ array.data_type(),
+ ))),
+ };
+ result.map(|a| ColumnarValue::Array(a))
+ }
+ ColumnarValue::Scalar(scalar) => {
+ Ok(ColumnarValue::Scalar(scalar.arithmetic_negate()))
+ }
+ }
+ }
+}
+
+/// Creates a unary expression NEGATIVE
+///
+/// # Errors
+///
+/// This function errors when the argument's type is not signed numeric
+pub fn negative(
+ arg: Arc<dyn PhysicalExpr>,
+ input_schema: &Schema,
+) -> Result<Arc<dyn PhysicalExpr>> {
+ let data_type = arg.data_type(input_schema)?;
+ if !is_signed_numeric(&data_type) {
+ Err(DataFusionError::Internal(
+ format!(
+ "(- '{:?}') can't be evaluated because the expression's type is {:?}, not signed numeric",
+ arg, data_type,
+ ),
+ ))
+ } else {
+ Ok(Arc::new(NegativeExpr::new(arg)))
+ }
+}
+
/// IS NULL expression
#[derive(Debug)]
pub struct IsNullExpr {
@@ -2189,16 +2274,26 @@ pub struct CastExpr {
cast_type: DataType,
}
-/// Determine if a DataType is numeric or not
-pub fn is_numeric(dt: &DataType) -> bool {
+/// Determine if a DataType is signed numeric or not
+pub fn is_signed_numeric(dt: &DataType) -> bool {
match dt {
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => true,
- DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => true,
DataType::Float16 | DataType::Float32 | DataType::Float64 => true,
_ => false,
}
}
+/// Determine if a DataType is numeric or not
+pub fn is_numeric(dt: &DataType) -> bool {
+ is_signed_numeric(dt)
+ || match dt {
+ DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => {
+ true
+ }
+ _ => false,
+ }
+}
+
impl fmt::Display for CastExpr {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "CAST({} AS {:?})", self.expr, self.cast_type)
diff --git a/rust/datafusion/src/physical_plan/planner.rs b/rust/datafusion/src/physical_plan/planner.rs
index 20e1a29..4060383 100644
--- a/rust/datafusion/src/physical_plan/planner.rs
+++ b/rust/datafusion/src/physical_plan/planner.rs
@@ -504,6 +504,10 @@ impl DefaultPhysicalPlanner {
self.create_physical_expr(expr, input_schema, ctx_state)?,
input_schema,
),
+ Expr::Negative(expr) => expressions::negative(
+ self.create_physical_expr(expr, input_schema, ctx_state)?,
+ input_schema,
+ ),
Expr::IsNull(expr) => expressions::is_null(self.create_physical_expr(
expr,
input_schema,
diff --git a/rust/datafusion/src/scalar.rs b/rust/datafusion/src/scalar.rs
index 06309ab..6baf802 100644
--- a/rust/datafusion/src/scalar.rs
+++ b/rust/datafusion/src/scalar.rs
@@ -142,6 +142,25 @@ impl ScalarValue {
}
}
+ /// Calculate arithmetic negation for a scalar value
+ pub fn arithmetic_negate(&self) -> Self {
+ match self {
+ ScalarValue::Boolean(None)
+ | ScalarValue::Int8(None)
+ | ScalarValue::Int16(None)
+ | ScalarValue::Int32(None)
+ | ScalarValue::Int64(None)
+ | ScalarValue::Float32(None) => self.clone(),
+ ScalarValue::Float64(Some(v)) => ScalarValue::Float64(Some(-v)),
+ ScalarValue::Float32(Some(v)) => ScalarValue::Float32(Some(-v)),
+ ScalarValue::Int8(Some(v)) => ScalarValue::Int8(Some(-v)),
+ ScalarValue::Int16(Some(v)) => ScalarValue::Int16(Some(-v)),
+ ScalarValue::Int32(Some(v)) => ScalarValue::Int32(Some(-v)),
+ ScalarValue::Int64(Some(v)) => ScalarValue::Int64(Some(-v)),
+ _ => panic!("Cannot run arithmetic negate on scala value: {:?}", self),
+ }
+ }
+
/// whether this value is null or not.
pub fn is_null(&self) -> bool {
match *self {
diff --git a/rust/datafusion/src/sql/planner.rs b/rust/datafusion/src/sql/planner.rs
index c727765..7ea81a0 100644
--- a/rust/datafusion/src/sql/planner.rs
+++ b/rust/datafusion/src/sql/planner.rs
@@ -638,22 +638,29 @@ impl<'a, S: SchemaProvider> SqlToRel<'a, S> {
Ok(Expr::IsNotNull(Box::new(self.sql_to_rex(expr, schema)?)))
}
- SQLExpr::UnaryOp { ref op, ref expr } => match (op, expr.as_ref()) {
- (UnaryOperator::Not, _) => {
+ SQLExpr::UnaryOp { ref op, ref expr } => match op {
+ UnaryOperator::Not => {
Ok(Expr::Not(Box::new(self.sql_to_rex(expr, schema)?)))
}
- (UnaryOperator::Minus, SQLExpr::Value(Value::Number(n))) =>
- // Parse negative numbers properly
- {
- match n.parse::<i64>() {
- Ok(n) => Ok(lit(-n)),
- Err(_) => Ok(lit(-n.parse::<f64>().unwrap())),
+ UnaryOperator::Plus => Ok(self.sql_to_rex(expr, schema)?),
+ UnaryOperator::Minus => {
+ match expr.as_ref() {
+ // optimization: if it's a number literal, we applly the negative operator
+ // here directly to calculate the new literal.
+ SQLExpr::Value(Value::Number(n)) => match n.parse::<i64>() {
+ Ok(n) => Ok(lit(-n)),
+ Err(_) => Ok(lit(-n
+ .parse::<f64>()
+ .map_err(|_e| {
+ DataFusionError::Internal(format!(
+ "negative operator can be only applied to integer and float operands, got: {}",
+ n))
+ })?)),
+ },
+ // not a literal, apply negative operator on expression
+ _ => Ok(Expr::Negative(Box::new(self.sql_to_rex(expr, schema)?))),
}
}
- _ => Err(DataFusionError::Internal(
- "SQL binary operator cannot be interpreted as a unary operator"
- .to_string(),
- )),
},
SQLExpr::BinaryOp {
@@ -1139,6 +1146,24 @@ mod tests {
}
#[test]
+ fn select_where_with_negative_operator() {
+ let sql = "SELECT c3 FROM aggregate_test_100 WHERE c3 > -0.1 AND -c4 > 0";
+ let expected = "Projection: #c3\
+ \n Filter: #c3 Gt Float64(-0.1) And (- #c4) Gt Int64(0)\
+ \n TableScan: aggregate_test_100 projection=None";
+ quick_test(sql, expected);
+ }
+
+ #[test]
+ fn select_where_with_positive_operator() {
+ let sql = "SELECT c3 FROM aggregate_test_100 WHERE c3 > +0.1 AND +c4 > 0";
+ let expected = "Projection: #c3\
+ \n Filter: #c3 Gt Float64(0.1) And #c4 Gt Int64(0)\
+ \n TableScan: aggregate_test_100 projection=None";
+ quick_test(sql, expected);
+ }
+
+ #[test]
fn select_order_by() {
let sql = "SELECT id FROM person ORDER BY id";
let expected = "Sort: #id ASC NULLS FIRST\
diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs
index 63e60ee..bf23b9c 100644
--- a/rust/datafusion/tests/sql.rs
+++ b/rust/datafusion/tests/sql.rs
@@ -281,6 +281,17 @@ async fn csv_query_with_predicate() -> Result<()> {
}
#[tokio::test]
+async fn csv_query_with_negative_predicate() -> Result<()> {
+ let mut ctx = ExecutionContext::new();
+ register_aggregate_csv(&mut ctx)?;
+ let sql = "SELECT c1, c4 FROM aggregate_test_100 WHERE c3 < -55 AND -c4 > 30000";
+ let actual = execute(&mut ctx, sql).await;
+ let expected = vec![vec!["e", "-31500"], vec!["c", "-30187"]];
+ assert_eq!(expected, actual);
+ Ok(())
+}
+
+#[tokio::test]
async fn csv_query_with_negated_predicate() -> Result<()> {
let mut ctx = ExecutionContext::new();
register_aggregate_csv(&mut ctx)?;