You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by GitBox <gi...@apache.org> on 2022/10/05 21:12:36 UTC

[GitHub] [arrow-datafusion] alamb commented on a diff in pull request #3705: Remove type coercions from ScalarValue and aggregation function code

alamb commented on code in PR #3705:
URL: https://github.com/apache/arrow-datafusion/pull/3705#discussion_r988356383


##########
datafusion/physical-expr/src/aggregate/min_max.rs:
##########
@@ -296,41 +290,18 @@ fn max_batch(values: &ArrayRef) -> Result<ScalarValue> {
         _ => min_max_batch!(values, max),
     })
 }
-macro_rules! typed_min_max_decimal {
-    ($VALUE:expr, $DELTA:expr, $PRECISION:expr, $SCALE:expr, $SCALAR:ident, $OP:ident) => {{
-        ScalarValue::$SCALAR(
-            match ($VALUE, $DELTA) {
-                (None, None) => None,
-                (Some(a), None) => Some(a.clone()),
-                (None, Some(b)) => Some(b.clone()),
-                (Some(a), Some(b)) => Some((*a).$OP(*b)),
-            },
-            $PRECISION.clone(),
-            $SCALE.clone(),
-        )
-    }};
-}
 
 // min/max of two non-string scalar values.
 macro_rules! typed_min_max {
-    ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident) => {{
-        ScalarValue::$SCALAR(match ($VALUE, $DELTA) {
-            (None, None) => None,
-            (Some(a), None) => Some(a.clone()),
-            (None, Some(b)) => Some(b.clone()),
-            (Some(a), Some(b)) => Some((*a).$OP(*b)),
-        })
-    }};
-
-    ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident, $TZ:expr) => {{
+    ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident $(, $EXTRA_ARGS:ident)*) => {{

Review Comment:
   these are drive by cleanups to reduce duplication in the macros, right?
   
   (BTW if you submit these as individual free standing PRs you might find the reviews are faster -- finding enough contiguous time to review large PR changes can be challenging at times)



##########
datafusion/common/src/scalar.rs:
##########
@@ -312,155 +312,186 @@ impl Eq for ScalarValue {}
 // TODO implement this in arrow-rs with simd
 // https://github.com/apache/arrow-rs/issues/1010
 macro_rules! decimal_op {
-    ($LHS:expr, $RHS:expr, $PRECISION:expr, $LHS_SCALE:expr, $RHS_SCALE:expr, $OPERATION:tt ) => {{
-    let (difference, side) = if $LHS_SCALE > $RHS_SCALE {
-        ($LHS_SCALE - $RHS_SCALE, true)
-    } else {
-        ($RHS_SCALE - $LHS_SCALE, false)
-    };
-    let scale = max($LHS_SCALE, $RHS_SCALE);
-    match ($LHS, $RHS, difference) {
-        (None, None, _) => ScalarValue::Decimal128(None, $PRECISION, scale),
-        (None, Some(rhs_value), 0) => ScalarValue::Decimal128(Some((0 as i128) $OPERATION rhs_value), $PRECISION, scale),
-        (None, Some(rhs_value), _) => {
-            let mut new_value = ((0 as i128) $OPERATION rhs_value);
-            if side {
-                new_value *= 10_i128.pow((difference) as u32)
-            };
-            ScalarValue::Decimal128(Some(new_value), $PRECISION, scale)
-        }
-        (Some(lhs_value), None, 0) => ScalarValue::Decimal128(Some(lhs_value $OPERATION (0 as i128)), $PRECISION, scale),
-        (Some(lhs_value), None, _) => {
-            let mut new_value = (lhs_value $OPERATION (0 as i128));
-            if !!!side {
-                new_value *= 10_i128.pow((difference) as u32)
+    ($LHS:expr, $RHS:expr, $PRECISION:expr, $LHS_SCALE:expr, $RHS_SCALE:expr, $OPERATION:tt) => {{
+        let (difference, side) = if $LHS_SCALE > $RHS_SCALE {
+            ($LHS_SCALE - $RHS_SCALE, true)
+        } else {
+            ($RHS_SCALE - $LHS_SCALE, false)
+        };
+        let scale = max($LHS_SCALE, $RHS_SCALE);
+        Ok(match ($LHS, $RHS, difference) {
+            (None, None, _) => ScalarValue::Decimal128(None, $PRECISION, scale),
+            (lhs, None, 0) => ScalarValue::Decimal128(*lhs, $PRECISION, scale),
+            (Some(lhs_value), None, _) => {
+                let mut new_value = *lhs_value;
+                if !side {
+                    new_value *= 10_i128.pow(difference as u32)
+                }
+                ScalarValue::Decimal128(Some(new_value), $PRECISION, scale)
             }
-            ScalarValue::Decimal128(Some(new_value), $PRECISION, scale)
-        }
-        (Some(lhs_value), Some(rhs_value), 0) => {
-            ScalarValue::Decimal128(Some(lhs_value $OPERATION rhs_value), $PRECISION, scale)
-        }
-        (Some(lhs_value), Some(rhs_value), _) => {
-            let new_value = if side {
-                rhs_value * 10_i128.pow((difference) as u32) $OPERATION lhs_value
-            } else {
-                lhs_value * 10_i128.pow((difference) as u32) $OPERATION rhs_value
-            };
-            ScalarValue::Decimal128(Some(new_value), $PRECISION, scale)
-        }
-    }}
+            (None, Some(rhs_value), 0) => {
+                let value = decimal_right!(*rhs_value, $OPERATION);
+                ScalarValue::Decimal128(Some(value), $PRECISION, scale)
+            }
+            (None, Some(rhs_value), _) => {
+                let mut new_value = decimal_right!(*rhs_value, $OPERATION);
+                if side {
+                    new_value *= 10_i128.pow(difference as u32)
+                };
+                ScalarValue::Decimal128(Some(new_value), $PRECISION, scale)
+            }
+            (Some(lhs_value), Some(rhs_value), 0) => {
+                decimal_binary_op!(lhs_value, rhs_value, $OPERATION, $PRECISION, scale)
+            }
+            (Some(lhs_value), Some(rhs_value), _) => {
+                let (left_arg, right_arg) = if side {
+                    (*lhs_value, rhs_value * 10_i128.pow(difference as u32))
+                } else {
+                    (lhs_value * 10_i128.pow(difference as u32), *rhs_value)
+                };
+                decimal_binary_op!(left_arg, right_arg, $OPERATION, $PRECISION, scale)
+            }
+        })
+    }};
+}
 
-    }
+macro_rules! decimal_binary_op {
+    ($LHS:expr, $RHS:expr, $OPERATION:tt, $PRECISION:expr, $SCALE:expr) => {
+        // TODO: This simple implementation loses precision for calculations like
+        //       multiplication and division. Improve this implementation for such
+        //       operations.
+        ScalarValue::Decimal128(Some($LHS $OPERATION $RHS), $PRECISION, $SCALE)
+    };
 }
 
-// Returns the result of applying operation to two scalar values, including coercion into $TYPE.
-macro_rules! typed_op {
-    ($LEFT:expr, $RIGHT:expr, $SCALAR:ident, $TYPE:ident, $OPERATION:tt) => {
-        Some(ScalarValue::$SCALAR(match ($LEFT, $RIGHT) {
-            (None, None) => None,
-            (Some(a), None) => Some((*a as $TYPE) $OPERATION (0 as $TYPE)),
-            (None, Some(b)) => Some((0 as $TYPE) $OPERATION (*b as $TYPE)),
-            (Some(a), Some(b)) => Some((*a as $TYPE) $OPERATION (*b as $TYPE)),
-        }))
+macro_rules! decimal_right {
+    ($TERM:expr, +) => {
+        $TERM
+    };
+    ($TERM:expr, *) => {
+        $TERM
+    };
+    ($TERM:expr, -) => {
+        -$TERM
+    };
+    ($TERM:expr, /) => {
+        Err(DataFusionError::NotImplemented(format!(
+            "Decimal reciprocation not yet supported",
+        )))
     };
 }
 
-macro_rules! impl_common_symmetric_cases_op {
-    ($LHS:expr, $RHS:expr, $OPERATION:tt, [$([$L_TYPE:ident, $R_TYPE:ident, $O_TYPE:ident, $O_PRIM:ident]),+]) => {
-        match ($LHS, $RHS) {
-            $(
-                (ScalarValue::$L_TYPE(lhs), ScalarValue::$R_TYPE(rhs)) => {
-                    typed_op!(lhs, rhs, $O_TYPE, $O_PRIM, $OPERATION)
-                }
-                (ScalarValue::$R_TYPE(lhs), ScalarValue::$L_TYPE(rhs)) => {
-                    typed_op!(lhs, rhs, $O_TYPE, $O_PRIM, $OPERATION)
-                }
-            )+
-            _ => None
+// Returns the result of applying operation to two scalar values.
+macro_rules! primitive_op {
+    ($LEFT:expr, $RIGHT:expr, $SCALAR:ident, $OPERATION:tt) => {
+        match ($LEFT, $RIGHT) {
+            (lhs, None) => Ok(ScalarValue::$SCALAR(*lhs)),
+            #[allow(unused_variables)]
+            (None, Some(b)) => { primitive_right!(*b, $OPERATION, $SCALAR) },
+            (Some(a), Some(b)) => Ok(ScalarValue::$SCALAR(Some(*a $OPERATION *b))),
         }
-    }
+    };
 }
 
-macro_rules! impl_common_cases_op {
+macro_rules! primitive_right {
+    ($TERM:expr, +, $SCALAR:ident) => {
+        Ok(ScalarValue::$SCALAR(Some($TERM)))
+    };
+    ($TERM:expr, *, $SCALAR:ident) => {
+        Ok(ScalarValue::$SCALAR(Some($TERM)))
+    };
+    ($TERM:expr, -, UInt64) => {
+        unsigned_subtraction_error!("UInt64")
+    };
+    ($TERM:expr, -, UInt32) => {
+        unsigned_subtraction_error!("UInt32")
+    };
+    ($TERM:expr, -, UInt16) => {
+        unsigned_subtraction_error!("UInt16")
+    };
+    ($TERM:expr, -, UInt8) => {
+        unsigned_subtraction_error!("UInt8")
+    };
+    ($TERM:expr, -, $SCALAR:ident) => {
+        Ok(ScalarValue::$SCALAR(Some(-$TERM)))
+    };
+    ($TERM:expr, /, Float64) => {
+        Ok(ScalarValue::$SCALAR(Some($TERM.recip())))
+    };
+    ($TERM:expr, /, Float32) => {
+        Ok(ScalarValue::$SCALAR(Some($TERM.recip())))
+    };
+    ($TERM:expr, /, $SCALAR:ident) => {
+        Err(DataFusionError::Internal(format!(
+            "Can not divide an uninitialized value to a non-floating point value",
+        )))
+    };
+}
+
+macro_rules! unsigned_subtraction_error {
+    ($SCALAR:expr) => {{
+        let msg = format!(
+            "Can not subtract a {} value from an uninitialized value",
+            $SCALAR
+        );
+        Err(DataFusionError::Internal(msg))
+    }};
+}
+
+macro_rules! impl_op {
     ($LHS:expr, $RHS:expr, $OPERATION:tt) => {
         match ($LHS, $RHS) {
             (
                 ScalarValue::Decimal128(v1, p1, s1),
                 ScalarValue::Decimal128(v2, p2, s2),
             ) => {
-                let max_precision = *p1.max(p2);
-                Some(decimal_op!(v1, v2, max_precision, *s1, *s2, $OPERATION))
+                decimal_op!(v1, v2, *p1.max(p2), *s1, *s2, $OPERATION)
             }
             (ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => {
-                typed_op!(lhs, rhs, Float64, f64, $OPERATION)
+                primitive_op!(lhs, rhs, Float64, $OPERATION)
             }
             (ScalarValue::Float32(lhs), ScalarValue::Float32(rhs)) => {
-                typed_op!(lhs, rhs, Float32, f32, $OPERATION)
+                primitive_op!(lhs, rhs, Float32, $OPERATION)
             }
             (ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => {
-                typed_op!(lhs, rhs, UInt64, u64, $OPERATION)
+                primitive_op!(lhs, rhs, UInt64, $OPERATION)
             }
             (ScalarValue::Int64(lhs), ScalarValue::Int64(rhs)) => {
-                typed_op!(lhs, rhs, Int64, i64, $OPERATION)
+                primitive_op!(lhs, rhs, Int64, $OPERATION)
             }
             (ScalarValue::UInt32(lhs), ScalarValue::UInt32(rhs)) => {
-                typed_op!(lhs, rhs, UInt32, u32, $OPERATION)
+                primitive_op!(lhs, rhs, UInt32, $OPERATION)
             }
             (ScalarValue::Int32(lhs), ScalarValue::Int32(rhs)) => {
-                typed_op!(lhs, rhs, Int32, i32, $OPERATION)
+                primitive_op!(lhs, rhs, Int32, $OPERATION)
             }
             (ScalarValue::UInt16(lhs), ScalarValue::UInt16(rhs)) => {
-                typed_op!(lhs, rhs, UInt16, u16, $OPERATION)
+                primitive_op!(lhs, rhs, UInt16, $OPERATION)
             }
             (ScalarValue::Int16(lhs), ScalarValue::Int16(rhs)) => {
-                typed_op!(lhs, rhs, Int16, i16, $OPERATION)
+                primitive_op!(lhs, rhs, Int16, $OPERATION)
             }
             (ScalarValue::UInt8(lhs), ScalarValue::UInt8(rhs)) => {
-                typed_op!(lhs, rhs, UInt8, u8, $OPERATION)
+                primitive_op!(lhs, rhs, UInt8, $OPERATION)
             }
             (ScalarValue::Int8(lhs), ScalarValue::Int8(rhs)) => {
-                typed_op!(lhs, rhs, Int8, i8, $OPERATION)
+                primitive_op!(lhs, rhs, Int8, $OPERATION)
+            }
+            _ => {
+                impl_distinct_cases_op!($LHS, $RHS, $OPERATION)
             }
-            _ => impl_common_symmetric_cases_op!(

Review Comment:
   👍 



##########
datafusion/physical-expr/src/aggregate/sum.rs:
##########
@@ -206,87 +205,37 @@ pub(crate) fn sum_batch(values: &ArrayRef, sum_type: &DataType) -> Result<Scalar
 macro_rules! sum_row {
     ($INDEX:ident, $ACC:ident, $DELTA:expr, $TYPE:ident) => {{
         paste::item! {
-            match $DELTA {
-                None => {}
-                Some(v) => $ACC.[<add_ $TYPE>]($INDEX, *v as $TYPE)
+            if let Some(v) = $DELTA {
+                $ACC.[<add_ $TYPE>]($INDEX, *v)
             }
         }
     }};
 }
 
 pub(crate) fn add_to_row(
-    dt: &DataType,
     index: usize,
     accessor: &mut RowAccessor,
     s: &ScalarValue,
 ) -> Result<()> {
-    match (dt, s) {
-        // float64 coerces everything to f64
-        (DataType::Float64, ScalarValue::Float64(rhs)) => {
-            sum_row!(index, accessor, rhs, f64)
-        }
-        (DataType::Float64, ScalarValue::Float32(rhs)) => {
-            sum_row!(index, accessor, rhs, f64)
-        }
-        (DataType::Float64, ScalarValue::Int64(rhs)) => {
-            sum_row!(index, accessor, rhs, f64)
-        }
-        (DataType::Float64, ScalarValue::Int32(rhs)) => {
-            sum_row!(index, accessor, rhs, f64)
-        }
-        (DataType::Float64, ScalarValue::Int16(rhs)) => {
-            sum_row!(index, accessor, rhs, f64)
-        }
-        (DataType::Float64, ScalarValue::Int8(rhs)) => {
-            sum_row!(index, accessor, rhs, f64)
-        }
-        (DataType::Float64, ScalarValue::UInt64(rhs)) => {
-            sum_row!(index, accessor, rhs, f64)
-        }
-        (DataType::Float64, ScalarValue::UInt32(rhs)) => {
-            sum_row!(index, accessor, rhs, f64)
-        }
-        (DataType::Float64, ScalarValue::UInt16(rhs)) => {
-            sum_row!(index, accessor, rhs, f64)
-        }
-        (DataType::Float64, ScalarValue::UInt8(rhs)) => {
+    match s {

Review Comment:
   I think this is a good change -- to use the type of the input to the type of the accumulator. 



##########
datafusion/physical-expr/src/aggregate/min_max.rs:
##########
@@ -154,16 +154,10 @@ macro_rules! typed_min_max_batch_string {
 
 // Statically-typed version of min/max(array) -> ScalarValue for non-string types.
 macro_rules! typed_min_max_batch {
-    ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident) => {{
-        let array = downcast_value!($VALUES, $ARRAYTYPE);
-        let value = compute::$OP(array);
-        ScalarValue::$SCALAR(value)
-    }};
-
-    ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident, $TZ:expr) => {{
+    ($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident $(, $EXTRA_ARGS:ident)*) => {{

Review Comment:
   Macro 🧙 



##########
datafusion/common/src/scalar.rs:
##########
@@ -938,11 +960,13 @@ impl ScalarValue {
     }
 
     pub fn is_unsigned(&self) -> bool {
-        let value_type = self.get_datatype();
-        value_type == DataType::UInt64
-            || value_type == DataType::UInt32
-            || value_type == DataType::UInt16
-            || value_type == DataType::UInt8
+        matches!(

Review Comment:
   👍 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org