You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by GitBox <gi...@apache.org> on 2022/05/12 20:56:02 UTC

[GitHub] [arrow-datafusion] alamb commented on a diff in pull request #2516: Sum refactor draft

alamb commented on code in PR #2516:
URL: https://github.com/apache/arrow-datafusion/pull/2516#discussion_r871802742


##########
datafusion/physical-expr/src/aggregate/sum.rs:
##########
@@ -262,98 +249,83 @@ fn sum_decimal_with_diff_scale(
     }
 }
 
+macro_rules! downcast_arg {
+    ($ARG:expr, $NAME:expr, $ARRAY_TYPE:ident) => {{
+        $ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| {
+            DataFusionError::Internal(format!(
+                "could not cast {} to {}",
+                $NAME,
+                type_name::<$ARRAY_TYPE>()
+            ))
+        })?
+    }};
+}
+
+macro_rules! union_arrays {
+    ($LHS: expr, $RHS: expr, $DTYPE: expr, $ARR_DTYPE: ident, $NAME: expr) => {{
+        let lhs_casted = &cast(&$LHS.to_array(), $DTYPE)?;
+        let rhs_casted = &cast(&$RHS.to_array(), $DTYPE)?;
+        let lhs_prim_array = downcast_arg!(lhs_casted, $NAME, $ARR_DTYPE);
+        let rhs_prim_array = downcast_arg!(rhs_casted, $NAME, $ARR_DTYPE);
+
+        let chained = lhs_prim_array
+            .iter()
+            .chain(rhs_prim_array.iter())
+            .collect::<$ARR_DTYPE>();
+
+        Arc::new(chained)
+    }};
+}
+
 pub(crate) fn sum(lhs: &ScalarValue, rhs: &ScalarValue) -> Result<ScalarValue> {
-    Ok(match (lhs, rhs) {
-        (ScalarValue::Decimal128(v1, p1, s1), ScalarValue::Decimal128(v2, p2, s2)) => {
+    let result = match (lhs.get_datatype(), rhs.get_datatype()) {
+        (DataType::Decimal(p1, s1), DataType::Decimal(p2, s2)) => {
             let max_precision = p1.max(p2);
-            if s1.eq(s2) {
-                // s1 = s2
-                sum_decimal(v1, v2, max_precision, s1)
-            } else if s1.gt(s2) {
-                // s1 > s2
-                sum_decimal_with_diff_scale(v1, v2, max_precision, s1, s2)
-            } else {
-                // s1 < s2
-                sum_decimal_with_diff_scale(v2, v1, max_precision, s2, s1)
+
+            match (lhs, rhs) {
+                (
+                    ScalarValue::Decimal128(v1, _, _),
+                    ScalarValue::Decimal128(v2, _, _),
+                ) => {
+                    Ok(if s1.eq(&s2) {
+                        // s1 = s2
+                        sum_decimal(v1, v2, &max_precision, &s1)
+                    } else if s1.gt(&s2) {
+                        // s1 > s2
+                        sum_decimal_with_diff_scale(v1, v2, &max_precision, &s1, &s2)
+                    } else {
+                        // s1 < s2
+                        sum_decimal_with_diff_scale(v2, v1, &max_precision, &s2, &s1)
+                    })
+                }
+                _ => Err(DataFusionError::Internal(
+                    "Internal state error on sum decimals ".to_string(),
+                )),
             }
         }
-        // float64 coerces everything to f64
-        (ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Float32(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int64(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int32(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int16(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int8(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
+        (DataType::Float64, _) | (_, DataType::Float64) => {
+            let data: ArrayRef =
+                union_arrays!(lhs, rhs, &DataType::Float64, Float64Array, "f64");
+            sum_batch(&data, &arrow::datatypes::DataType::Float64)
         }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt64(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt32(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt16(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt8(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        // float32 has no cast
-        (ScalarValue::Float32(lhs), ScalarValue::Float32(rhs)) => {
-            typed_sum!(lhs, rhs, Float32, f32)
-        }
-        // u64 coerces u* to u64
-        (ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => {
-            typed_sum!(lhs, rhs, UInt64, u64)
+        (DataType::Float32, _) | (_, DataType::Float32) => {
+            let data: ArrayRef =
+                union_arrays!(lhs, rhs, &DataType::Float32, Float32Array, "f32");

Review Comment:
   this is an interesting idea, but I suspect the performance will be fairly low (as it creates arrays for each value 🤔 )
   
   I wonder if we could move the `sum` logic into `scalar.rs` and instead add some sort of coertion logic 
   
   Not sure.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org