You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by GitBox <gi...@apache.org> on 2022/05/11 15:08:04 UTC

[GitHub] [arrow-datafusion] comphead opened a new pull request, #2516: Sum refactor draft

comphead opened a new pull request, #2516:
URL: https://github.com/apache/arrow-datafusion/pull/2516

   # Which issue does this PR close?
   
   <!--
   We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123.
   -->
   
   Closes #2355.
   
    # Rationale for this change
   <!--
    Why are you proposing this change? If this is already explained clearly in the issue then this section is not needed.
    Explaining clearly why changes are proposed helps reviewers understand your changes and offer better suggestions for fixes.  
   -->
   
   The current sum function looks a bit boilerplate, this is attempt to refactor it.
   
   # What changes are included in this PR?
   <!--
   There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR.
   -->
   
   # Are there any user-facing changes?
   <!--
   If there are user-facing changes then we may require documentation to be updated before approving the PR.
   -->
   
   <!--
   If there are any breaking changes to public APIs, please add the `api change` label.
   -->
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow-datafusion] comphead commented on a diff in pull request #2516: Sum refactor draft

Posted by GitBox <gi...@apache.org>.
comphead commented on code in PR #2516:
URL: https://github.com/apache/arrow-datafusion/pull/2516#discussion_r879159675


##########
datafusion/physical-expr/src/aggregate/sum.rs:
##########
@@ -262,98 +249,83 @@ fn sum_decimal_with_diff_scale(
     }
 }
 
+macro_rules! downcast_arg {
+    ($ARG:expr, $NAME:expr, $ARRAY_TYPE:ident) => {{
+        $ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| {
+            DataFusionError::Internal(format!(
+                "could not cast {} to {}",
+                $NAME,
+                type_name::<$ARRAY_TYPE>()
+            ))
+        })?
+    }};
+}
+
+macro_rules! union_arrays {
+    ($LHS: expr, $RHS: expr, $DTYPE: expr, $ARR_DTYPE: ident, $NAME: expr) => {{
+        let lhs_casted = &cast(&$LHS.to_array(), $DTYPE)?;
+        let rhs_casted = &cast(&$RHS.to_array(), $DTYPE)?;
+        let lhs_prim_array = downcast_arg!(lhs_casted, $NAME, $ARR_DTYPE);
+        let rhs_prim_array = downcast_arg!(rhs_casted, $NAME, $ARR_DTYPE);
+
+        let chained = lhs_prim_array
+            .iter()
+            .chain(rhs_prim_array.iter())
+            .collect::<$ARR_DTYPE>();
+
+        Arc::new(chained)
+    }};
+}
+
 pub(crate) fn sum(lhs: &ScalarValue, rhs: &ScalarValue) -> Result<ScalarValue> {
-    Ok(match (lhs, rhs) {
-        (ScalarValue::Decimal128(v1, p1, s1), ScalarValue::Decimal128(v2, p2, s2)) => {
+    let result = match (lhs.get_datatype(), rhs.get_datatype()) {
+        (DataType::Decimal(p1, s1), DataType::Decimal(p2, s2)) => {
             let max_precision = p1.max(p2);
-            if s1.eq(s2) {
-                // s1 = s2
-                sum_decimal(v1, v2, max_precision, s1)
-            } else if s1.gt(s2) {
-                // s1 > s2
-                sum_decimal_with_diff_scale(v1, v2, max_precision, s1, s2)
-            } else {
-                // s1 < s2
-                sum_decimal_with_diff_scale(v2, v1, max_precision, s2, s1)
+
+            match (lhs, rhs) {
+                (
+                    ScalarValue::Decimal128(v1, _, _),
+                    ScalarValue::Decimal128(v2, _, _),
+                ) => {
+                    Ok(if s1.eq(&s2) {
+                        // s1 = s2
+                        sum_decimal(v1, v2, &max_precision, &s1)
+                    } else if s1.gt(&s2) {
+                        // s1 > s2
+                        sum_decimal_with_diff_scale(v1, v2, &max_precision, &s1, &s2)
+                    } else {
+                        // s1 < s2
+                        sum_decimal_with_diff_scale(v2, v1, &max_precision, &s2, &s1)
+                    })
+                }
+                _ => Err(DataFusionError::Internal(
+                    "Internal state error on sum decimals ".to_string(),
+                )),
             }
         }
-        // float64 coerces everything to f64
-        (ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Float32(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int64(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int32(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int16(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int8(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
+        (DataType::Float64, _) | (_, DataType::Float64) => {
+            let data: ArrayRef =
+                union_arrays!(lhs, rhs, &DataType::Float64, Float64Array, "f64");
+            sum_batch(&data, &arrow::datatypes::DataType::Float64)
         }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt64(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt32(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt16(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt8(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        // float32 has no cast
-        (ScalarValue::Float32(lhs), ScalarValue::Float32(rhs)) => {
-            typed_sum!(lhs, rhs, Float32, f32)
-        }
-        // u64 coerces u* to u64
-        (ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => {
-            typed_sum!(lhs, rhs, UInt64, u64)
+        (DataType::Float32, _) | (_, DataType::Float32) => {
+            let data: ArrayRef =
+                union_arrays!(lhs, rhs, &DataType::Float32, Float32Array, "f32");

Review Comment:
   Thanks @waynexia for your response.
   
   I think we already have similar coercion in `type_coercion.rs`
   ```
   /// Returns the data types that each argument must be coerced to match
   /// `signature`.
   ///
   /// See the module level documentation for more detail on coercion.
   pub fn data_types(
       current_types: &[DataType],
       signature: &Signature,
   ) -> Result<Vec<DataType>> {
   ```
   
   Imho `AggregateFunction::Sum => sum_return_type(&coerced_data_types[0]),` already does the proposed solution.
   
   I'm still afraid the problem is not in the result type coercion but how to do operation with underlying values using correct datatypes.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow-datafusion] comphead commented on a diff in pull request #2516: Sum refactor draft

Posted by GitBox <gi...@apache.org>.
comphead commented on code in PR #2516:
URL: https://github.com/apache/arrow-datafusion/pull/2516#discussion_r878836083


##########
datafusion/physical-expr/src/aggregate/sum.rs:
##########
@@ -262,98 +249,83 @@ fn sum_decimal_with_diff_scale(
     }
 }
 
+macro_rules! downcast_arg {
+    ($ARG:expr, $NAME:expr, $ARRAY_TYPE:ident) => {{
+        $ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| {
+            DataFusionError::Internal(format!(
+                "could not cast {} to {}",
+                $NAME,
+                type_name::<$ARRAY_TYPE>()
+            ))
+        })?
+    }};
+}
+
+macro_rules! union_arrays {
+    ($LHS: expr, $RHS: expr, $DTYPE: expr, $ARR_DTYPE: ident, $NAME: expr) => {{
+        let lhs_casted = &cast(&$LHS.to_array(), $DTYPE)?;
+        let rhs_casted = &cast(&$RHS.to_array(), $DTYPE)?;
+        let lhs_prim_array = downcast_arg!(lhs_casted, $NAME, $ARR_DTYPE);
+        let rhs_prim_array = downcast_arg!(rhs_casted, $NAME, $ARR_DTYPE);
+
+        let chained = lhs_prim_array
+            .iter()
+            .chain(rhs_prim_array.iter())
+            .collect::<$ARR_DTYPE>();
+
+        Arc::new(chained)
+    }};
+}
+
 pub(crate) fn sum(lhs: &ScalarValue, rhs: &ScalarValue) -> Result<ScalarValue> {
-    Ok(match (lhs, rhs) {
-        (ScalarValue::Decimal128(v1, p1, s1), ScalarValue::Decimal128(v2, p2, s2)) => {
+    let result = match (lhs.get_datatype(), rhs.get_datatype()) {
+        (DataType::Decimal(p1, s1), DataType::Decimal(p2, s2)) => {
             let max_precision = p1.max(p2);
-            if s1.eq(s2) {
-                // s1 = s2
-                sum_decimal(v1, v2, max_precision, s1)
-            } else if s1.gt(s2) {
-                // s1 > s2
-                sum_decimal_with_diff_scale(v1, v2, max_precision, s1, s2)
-            } else {
-                // s1 < s2
-                sum_decimal_with_diff_scale(v2, v1, max_precision, s2, s1)
+
+            match (lhs, rhs) {
+                (
+                    ScalarValue::Decimal128(v1, _, _),
+                    ScalarValue::Decimal128(v2, _, _),
+                ) => {
+                    Ok(if s1.eq(&s2) {
+                        // s1 = s2
+                        sum_decimal(v1, v2, &max_precision, &s1)
+                    } else if s1.gt(&s2) {
+                        // s1 > s2
+                        sum_decimal_with_diff_scale(v1, v2, &max_precision, &s1, &s2)
+                    } else {
+                        // s1 < s2
+                        sum_decimal_with_diff_scale(v2, v1, &max_precision, &s2, &s1)
+                    })
+                }
+                _ => Err(DataFusionError::Internal(
+                    "Internal state error on sum decimals ".to_string(),
+                )),
             }
         }
-        // float64 coerces everything to f64
-        (ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Float32(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int64(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int32(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int16(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int8(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
+        (DataType::Float64, _) | (_, DataType::Float64) => {
+            let data: ArrayRef =
+                union_arrays!(lhs, rhs, &DataType::Float64, Float64Array, "f64");
+            sum_batch(&data, &arrow::datatypes::DataType::Float64)
         }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt64(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt32(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt16(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt8(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        // float32 has no cast
-        (ScalarValue::Float32(lhs), ScalarValue::Float32(rhs)) => {
-            typed_sum!(lhs, rhs, Float32, f32)
-        }
-        // u64 coerces u* to u64
-        (ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => {
-            typed_sum!(lhs, rhs, UInt64, u64)
+        (DataType::Float32, _) | (_, DataType::Float32) => {
+            let data: ArrayRef =
+                union_arrays!(lhs, rhs, &DataType::Float32, Float32Array, "f32");

Review Comment:
   https://play.rust-lang.org/?version=stable&mode=debug&edition=2021&gist=a513ccd564230b3815d635c515236f5f



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow-datafusion] alamb closed pull request #2516: Sum refactor draft

Posted by GitBox <gi...@apache.org>.
alamb closed pull request #2516: Sum refactor draft
URL: https://github.com/apache/arrow-datafusion/pull/2516


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow-datafusion] comphead commented on a diff in pull request #2516: Sum refactor draft

Posted by GitBox <gi...@apache.org>.
comphead commented on code in PR #2516:
URL: https://github.com/apache/arrow-datafusion/pull/2516#discussion_r874950755


##########
datafusion/physical-expr/src/aggregate/sum.rs:
##########
@@ -262,98 +249,83 @@ fn sum_decimal_with_diff_scale(
     }
 }
 
+macro_rules! downcast_arg {
+    ($ARG:expr, $NAME:expr, $ARRAY_TYPE:ident) => {{
+        $ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| {
+            DataFusionError::Internal(format!(
+                "could not cast {} to {}",
+                $NAME,
+                type_name::<$ARRAY_TYPE>()
+            ))
+        })?
+    }};
+}
+
+macro_rules! union_arrays {
+    ($LHS: expr, $RHS: expr, $DTYPE: expr, $ARR_DTYPE: ident, $NAME: expr) => {{
+        let lhs_casted = &cast(&$LHS.to_array(), $DTYPE)?;
+        let rhs_casted = &cast(&$RHS.to_array(), $DTYPE)?;
+        let lhs_prim_array = downcast_arg!(lhs_casted, $NAME, $ARR_DTYPE);
+        let rhs_prim_array = downcast_arg!(rhs_casted, $NAME, $ARR_DTYPE);
+
+        let chained = lhs_prim_array
+            .iter()
+            .chain(rhs_prim_array.iter())
+            .collect::<$ARR_DTYPE>();
+
+        Arc::new(chained)
+    }};
+}
+
 pub(crate) fn sum(lhs: &ScalarValue, rhs: &ScalarValue) -> Result<ScalarValue> {
-    Ok(match (lhs, rhs) {
-        (ScalarValue::Decimal128(v1, p1, s1), ScalarValue::Decimal128(v2, p2, s2)) => {
+    let result = match (lhs.get_datatype(), rhs.get_datatype()) {
+        (DataType::Decimal(p1, s1), DataType::Decimal(p2, s2)) => {
             let max_precision = p1.max(p2);
-            if s1.eq(s2) {
-                // s1 = s2
-                sum_decimal(v1, v2, max_precision, s1)
-            } else if s1.gt(s2) {
-                // s1 > s2
-                sum_decimal_with_diff_scale(v1, v2, max_precision, s1, s2)
-            } else {
-                // s1 < s2
-                sum_decimal_with_diff_scale(v2, v1, max_precision, s2, s1)
+
+            match (lhs, rhs) {
+                (
+                    ScalarValue::Decimal128(v1, _, _),
+                    ScalarValue::Decimal128(v2, _, _),
+                ) => {
+                    Ok(if s1.eq(&s2) {
+                        // s1 = s2
+                        sum_decimal(v1, v2, &max_precision, &s1)
+                    } else if s1.gt(&s2) {
+                        // s1 > s2
+                        sum_decimal_with_diff_scale(v1, v2, &max_precision, &s1, &s2)
+                    } else {
+                        // s1 < s2
+                        sum_decimal_with_diff_scale(v2, v1, &max_precision, &s2, &s1)
+                    })
+                }
+                _ => Err(DataFusionError::Internal(
+                    "Internal state error on sum decimals ".to_string(),
+                )),
             }
         }
-        // float64 coerces everything to f64
-        (ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Float32(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int64(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int32(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int16(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int8(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
+        (DataType::Float64, _) | (_, DataType::Float64) => {
+            let data: ArrayRef =
+                union_arrays!(lhs, rhs, &DataType::Float64, Float64Array, "f64");
+            sum_batch(&data, &arrow::datatypes::DataType::Float64)
         }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt64(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt32(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt16(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt8(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        // float32 has no cast
-        (ScalarValue::Float32(lhs), ScalarValue::Float32(rhs)) => {
-            typed_sum!(lhs, rhs, Float32, f32)
-        }
-        // u64 coerces u* to u64
-        (ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => {
-            typed_sum!(lhs, rhs, UInt64, u64)
+        (DataType::Float32, _) | (_, DataType::Float32) => {
+            let data: ArrayRef =
+                union_arrays!(lhs, rhs, &DataType::Float32, Float32Array, "f32");

Review Comment:
   @alamb 
   
   I believe we come to next options:
   - implementing the add function(in `scalar.rs` or any other place) leads to boilerplate code, because there is no concise way how to get ScalarValue underlying value.
   - `arrow::compute::sum` allows to get rid of boilerplate but it requires array as input which may hit the performance, it requires wrap up the value into array and then cast it.  
   - tricky options like always cast values to f64, sum, and then rely on result casting from `sum_return_type`
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow-datafusion] andygrove commented on pull request #2516: Sum refactor draft

Posted by GitBox <gi...@apache.org>.
andygrove commented on PR #2516:
URL: https://github.com/apache/arrow-datafusion/pull/2516#issuecomment-1125105056

   @comphead The AMD64 bulid issue can be fixed by merging latest from master


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow-datafusion] comphead commented on a diff in pull request #2516: Sum refactor draft

Posted by GitBox <gi...@apache.org>.
comphead commented on code in PR #2516:
URL: https://github.com/apache/arrow-datafusion/pull/2516#discussion_r873778359


##########
datafusion/physical-expr/src/aggregate/sum.rs:
##########
@@ -262,98 +249,83 @@ fn sum_decimal_with_diff_scale(
     }
 }
 
+macro_rules! downcast_arg {
+    ($ARG:expr, $NAME:expr, $ARRAY_TYPE:ident) => {{
+        $ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| {
+            DataFusionError::Internal(format!(
+                "could not cast {} to {}",
+                $NAME,
+                type_name::<$ARRAY_TYPE>()
+            ))
+        })?
+    }};
+}
+
+macro_rules! union_arrays {
+    ($LHS: expr, $RHS: expr, $DTYPE: expr, $ARR_DTYPE: ident, $NAME: expr) => {{
+        let lhs_casted = &cast(&$LHS.to_array(), $DTYPE)?;
+        let rhs_casted = &cast(&$RHS.to_array(), $DTYPE)?;
+        let lhs_prim_array = downcast_arg!(lhs_casted, $NAME, $ARR_DTYPE);
+        let rhs_prim_array = downcast_arg!(rhs_casted, $NAME, $ARR_DTYPE);
+
+        let chained = lhs_prim_array
+            .iter()
+            .chain(rhs_prim_array.iter())
+            .collect::<$ARR_DTYPE>();
+
+        Arc::new(chained)
+    }};
+}
+
 pub(crate) fn sum(lhs: &ScalarValue, rhs: &ScalarValue) -> Result<ScalarValue> {
-    Ok(match (lhs, rhs) {
-        (ScalarValue::Decimal128(v1, p1, s1), ScalarValue::Decimal128(v2, p2, s2)) => {
+    let result = match (lhs.get_datatype(), rhs.get_datatype()) {
+        (DataType::Decimal(p1, s1), DataType::Decimal(p2, s2)) => {
             let max_precision = p1.max(p2);
-            if s1.eq(s2) {
-                // s1 = s2
-                sum_decimal(v1, v2, max_precision, s1)
-            } else if s1.gt(s2) {
-                // s1 > s2
-                sum_decimal_with_diff_scale(v1, v2, max_precision, s1, s2)
-            } else {
-                // s1 < s2
-                sum_decimal_with_diff_scale(v2, v1, max_precision, s2, s1)
+
+            match (lhs, rhs) {
+                (
+                    ScalarValue::Decimal128(v1, _, _),
+                    ScalarValue::Decimal128(v2, _, _),
+                ) => {
+                    Ok(if s1.eq(&s2) {
+                        // s1 = s2
+                        sum_decimal(v1, v2, &max_precision, &s1)
+                    } else if s1.gt(&s2) {
+                        // s1 > s2
+                        sum_decimal_with_diff_scale(v1, v2, &max_precision, &s1, &s2)
+                    } else {
+                        // s1 < s2
+                        sum_decimal_with_diff_scale(v2, v1, &max_precision, &s2, &s1)
+                    })
+                }
+                _ => Err(DataFusionError::Internal(
+                    "Internal state error on sum decimals ".to_string(),
+                )),
             }
         }
-        // float64 coerces everything to f64
-        (ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Float32(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int64(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int32(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int16(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int8(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
+        (DataType::Float64, _) | (_, DataType::Float64) => {
+            let data: ArrayRef =
+                union_arrays!(lhs, rhs, &DataType::Float64, Float64Array, "f64");
+            sum_batch(&data, &arrow::datatypes::DataType::Float64)
         }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt64(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt32(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt16(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt8(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        // float32 has no cast
-        (ScalarValue::Float32(lhs), ScalarValue::Float32(rhs)) => {
-            typed_sum!(lhs, rhs, Float32, f32)
-        }
-        // u64 coerces u* to u64
-        (ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => {
-            typed_sum!(lhs, rhs, UInt64, u64)
+        (DataType::Float32, _) | (_, DataType::Float32) => {
+            let data: ArrayRef =
+                union_arrays!(lhs, rhs, &DataType::Float32, Float32Array, "f32");

Review Comment:
   okay, I found its already implemented coercion for sum result type
   
   ```
           AggregateFunction::Sum => sum_return_type(&coerced_data_types[0]),
   ```
   
   I'll check, probably coercing the resulting sum will allow to get rid of coercion in the `sum.rs`
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow-datafusion] comphead commented on a diff in pull request #2516: Sum refactor draft

Posted by GitBox <gi...@apache.org>.
comphead commented on code in PR #2516:
URL: https://github.com/apache/arrow-datafusion/pull/2516#discussion_r879159675


##########
datafusion/physical-expr/src/aggregate/sum.rs:
##########
@@ -262,98 +249,83 @@ fn sum_decimal_with_diff_scale(
     }
 }
 
+macro_rules! downcast_arg {
+    ($ARG:expr, $NAME:expr, $ARRAY_TYPE:ident) => {{
+        $ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| {
+            DataFusionError::Internal(format!(
+                "could not cast {} to {}",
+                $NAME,
+                type_name::<$ARRAY_TYPE>()
+            ))
+        })?
+    }};
+}
+
+macro_rules! union_arrays {
+    ($LHS: expr, $RHS: expr, $DTYPE: expr, $ARR_DTYPE: ident, $NAME: expr) => {{
+        let lhs_casted = &cast(&$LHS.to_array(), $DTYPE)?;
+        let rhs_casted = &cast(&$RHS.to_array(), $DTYPE)?;
+        let lhs_prim_array = downcast_arg!(lhs_casted, $NAME, $ARR_DTYPE);
+        let rhs_prim_array = downcast_arg!(rhs_casted, $NAME, $ARR_DTYPE);
+
+        let chained = lhs_prim_array
+            .iter()
+            .chain(rhs_prim_array.iter())
+            .collect::<$ARR_DTYPE>();
+
+        Arc::new(chained)
+    }};
+}
+
 pub(crate) fn sum(lhs: &ScalarValue, rhs: &ScalarValue) -> Result<ScalarValue> {
-    Ok(match (lhs, rhs) {
-        (ScalarValue::Decimal128(v1, p1, s1), ScalarValue::Decimal128(v2, p2, s2)) => {
+    let result = match (lhs.get_datatype(), rhs.get_datatype()) {
+        (DataType::Decimal(p1, s1), DataType::Decimal(p2, s2)) => {
             let max_precision = p1.max(p2);
-            if s1.eq(s2) {
-                // s1 = s2
-                sum_decimal(v1, v2, max_precision, s1)
-            } else if s1.gt(s2) {
-                // s1 > s2
-                sum_decimal_with_diff_scale(v1, v2, max_precision, s1, s2)
-            } else {
-                // s1 < s2
-                sum_decimal_with_diff_scale(v2, v1, max_precision, s2, s1)
+
+            match (lhs, rhs) {
+                (
+                    ScalarValue::Decimal128(v1, _, _),
+                    ScalarValue::Decimal128(v2, _, _),
+                ) => {
+                    Ok(if s1.eq(&s2) {
+                        // s1 = s2
+                        sum_decimal(v1, v2, &max_precision, &s1)
+                    } else if s1.gt(&s2) {
+                        // s1 > s2
+                        sum_decimal_with_diff_scale(v1, v2, &max_precision, &s1, &s2)
+                    } else {
+                        // s1 < s2
+                        sum_decimal_with_diff_scale(v2, v1, &max_precision, &s2, &s1)
+                    })
+                }
+                _ => Err(DataFusionError::Internal(
+                    "Internal state error on sum decimals ".to_string(),
+                )),
             }
         }
-        // float64 coerces everything to f64
-        (ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Float32(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int64(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int32(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int16(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int8(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
+        (DataType::Float64, _) | (_, DataType::Float64) => {
+            let data: ArrayRef =
+                union_arrays!(lhs, rhs, &DataType::Float64, Float64Array, "f64");
+            sum_batch(&data, &arrow::datatypes::DataType::Float64)
         }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt64(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt32(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt16(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt8(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        // float32 has no cast
-        (ScalarValue::Float32(lhs), ScalarValue::Float32(rhs)) => {
-            typed_sum!(lhs, rhs, Float32, f32)
-        }
-        // u64 coerces u* to u64
-        (ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => {
-            typed_sum!(lhs, rhs, UInt64, u64)
+        (DataType::Float32, _) | (_, DataType::Float32) => {
+            let data: ArrayRef =
+                union_arrays!(lhs, rhs, &DataType::Float32, Float32Array, "f32");

Review Comment:
   Thanks @waynexia for your response.
   
   I think we already have similar coercion in `type_coercion.rs`
   ```
   /// Returns the data types that each argument must be coerced to match
   /// `signature`.
   ///
   /// See the module level documentation for more detail on coercion.
   pub fn data_types(
       current_types: &[DataType],
       signature: &Signature,
   ) -> Result<Vec<DataType>> {
   ```
   
   Imho `AggregateFunction::Sum => sum_return_type(&coerced_data_types[0]),` already does the proposed solution.
   
   I'm still afraid the problem is not in the result type coercion but how to do operation with underlying values using correct datatypes without boilerplate.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow-datafusion] comphead commented on a diff in pull request #2516: Sum refactor draft

Posted by GitBox <gi...@apache.org>.
comphead commented on code in PR #2516:
URL: https://github.com/apache/arrow-datafusion/pull/2516#discussion_r878835961


##########
datafusion/physical-expr/src/aggregate/sum.rs:
##########
@@ -262,98 +249,83 @@ fn sum_decimal_with_diff_scale(
     }
 }
 
+macro_rules! downcast_arg {
+    ($ARG:expr, $NAME:expr, $ARRAY_TYPE:ident) => {{
+        $ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| {
+            DataFusionError::Internal(format!(
+                "could not cast {} to {}",
+                $NAME,
+                type_name::<$ARRAY_TYPE>()
+            ))
+        })?
+    }};
+}
+
+macro_rules! union_arrays {
+    ($LHS: expr, $RHS: expr, $DTYPE: expr, $ARR_DTYPE: ident, $NAME: expr) => {{
+        let lhs_casted = &cast(&$LHS.to_array(), $DTYPE)?;
+        let rhs_casted = &cast(&$RHS.to_array(), $DTYPE)?;
+        let lhs_prim_array = downcast_arg!(lhs_casted, $NAME, $ARR_DTYPE);
+        let rhs_prim_array = downcast_arg!(rhs_casted, $NAME, $ARR_DTYPE);
+
+        let chained = lhs_prim_array
+            .iter()
+            .chain(rhs_prim_array.iter())
+            .collect::<$ARR_DTYPE>();
+
+        Arc::new(chained)
+    }};
+}
+
 pub(crate) fn sum(lhs: &ScalarValue, rhs: &ScalarValue) -> Result<ScalarValue> {
-    Ok(match (lhs, rhs) {
-        (ScalarValue::Decimal128(v1, p1, s1), ScalarValue::Decimal128(v2, p2, s2)) => {
+    let result = match (lhs.get_datatype(), rhs.get_datatype()) {
+        (DataType::Decimal(p1, s1), DataType::Decimal(p2, s2)) => {
             let max_precision = p1.max(p2);
-            if s1.eq(s2) {
-                // s1 = s2
-                sum_decimal(v1, v2, max_precision, s1)
-            } else if s1.gt(s2) {
-                // s1 > s2
-                sum_decimal_with_diff_scale(v1, v2, max_precision, s1, s2)
-            } else {
-                // s1 < s2
-                sum_decimal_with_diff_scale(v2, v1, max_precision, s2, s1)
+
+            match (lhs, rhs) {
+                (
+                    ScalarValue::Decimal128(v1, _, _),
+                    ScalarValue::Decimal128(v2, _, _),
+                ) => {
+                    Ok(if s1.eq(&s2) {
+                        // s1 = s2
+                        sum_decimal(v1, v2, &max_precision, &s1)
+                    } else if s1.gt(&s2) {
+                        // s1 > s2
+                        sum_decimal_with_diff_scale(v1, v2, &max_precision, &s1, &s2)
+                    } else {
+                        // s1 < s2
+                        sum_decimal_with_diff_scale(v2, v1, &max_precision, &s2, &s1)
+                    })
+                }
+                _ => Err(DataFusionError::Internal(
+                    "Internal state error on sum decimals ".to_string(),
+                )),
             }
         }
-        // float64 coerces everything to f64
-        (ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Float32(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int64(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int32(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int16(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int8(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
+        (DataType::Float64, _) | (_, DataType::Float64) => {
+            let data: ArrayRef =
+                union_arrays!(lhs, rhs, &DataType::Float64, Float64Array, "f64");
+            sum_batch(&data, &arrow::datatypes::DataType::Float64)
         }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt64(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt32(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt16(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt8(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        // float32 has no cast
-        (ScalarValue::Float32(lhs), ScalarValue::Float32(rhs)) => {
-            typed_sum!(lhs, rhs, Float32, f32)
-        }
-        // u64 coerces u* to u64
-        (ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => {
-            typed_sum!(lhs, rhs, UInt64, u64)
+        (DataType::Float32, _) | (_, DataType::Float32) => {
+            let data: ArrayRef =
+                union_arrays!(lhs, rhs, &DataType::Float32, Float32Array, "f32");

Review Comment:
   I came up to idea like that, it should put the boilerplate once. And then use external functions over values.
   
   ```
   use std::ops::Sub;
   use std::ops::Add;
   use std::fmt::Debug;
   use std::any::Any;
   
   #[derive(Debug, Copy, Clone)]
   enum DataType {
       Int32(Option<i32>),
       Float64(Option<f64>)
   }
   
   // boilerplate comes here
   macro_rules! op {
       ($ARG1: expr, $ARG2: expr, $FUNC: block) => {{
           let res = match ($ARG1, $ARG2) {
               (DataType::Int32(Some(v1)), DataType::Float64(Some(v2))) => 
               DataType::Float64(Some($FUNC(v1 as f64, v2 as f64))),
               _ => panic!("123")
           };
           
           res
       }};
   }
   
   fn sum<T:Add<Output = T> + Copy> (num1: T, num2: T) -> T {
       return num1 + num2;
   }
   
   fn minus<T:Sub<Output = T> + Copy> (num1: T, num2: T) -> T {
       return num1 - num2;
   }
   
   fn main() {
       let i_32 = DataType::Int32(Some(1));
       let f_64 = DataType::Float64(Some(2.0));
   
       let s = op!(i_32, f_64, {sum});
       dbg!(&s);
       
       let m = op!(i_32, f_64, {minus});
       dbg!(&m);
   }
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow-datafusion] alamb commented on a diff in pull request #2516: Sum refactor draft

Posted by GitBox <gi...@apache.org>.
alamb commented on code in PR #2516:
URL: https://github.com/apache/arrow-datafusion/pull/2516#discussion_r873628173


##########
datafusion/physical-expr/src/aggregate/sum.rs:
##########
@@ -262,98 +249,83 @@ fn sum_decimal_with_diff_scale(
     }
 }
 
+macro_rules! downcast_arg {
+    ($ARG:expr, $NAME:expr, $ARRAY_TYPE:ident) => {{
+        $ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| {
+            DataFusionError::Internal(format!(
+                "could not cast {} to {}",
+                $NAME,
+                type_name::<$ARRAY_TYPE>()
+            ))
+        })?
+    }};
+}
+
+macro_rules! union_arrays {
+    ($LHS: expr, $RHS: expr, $DTYPE: expr, $ARR_DTYPE: ident, $NAME: expr) => {{
+        let lhs_casted = &cast(&$LHS.to_array(), $DTYPE)?;
+        let rhs_casted = &cast(&$RHS.to_array(), $DTYPE)?;
+        let lhs_prim_array = downcast_arg!(lhs_casted, $NAME, $ARR_DTYPE);
+        let rhs_prim_array = downcast_arg!(rhs_casted, $NAME, $ARR_DTYPE);
+
+        let chained = lhs_prim_array
+            .iter()
+            .chain(rhs_prim_array.iter())
+            .collect::<$ARR_DTYPE>();
+
+        Arc::new(chained)
+    }};
+}
+
 pub(crate) fn sum(lhs: &ScalarValue, rhs: &ScalarValue) -> Result<ScalarValue> {
-    Ok(match (lhs, rhs) {
-        (ScalarValue::Decimal128(v1, p1, s1), ScalarValue::Decimal128(v2, p2, s2)) => {
+    let result = match (lhs.get_datatype(), rhs.get_datatype()) {
+        (DataType::Decimal(p1, s1), DataType::Decimal(p2, s2)) => {
             let max_precision = p1.max(p2);
-            if s1.eq(s2) {
-                // s1 = s2
-                sum_decimal(v1, v2, max_precision, s1)
-            } else if s1.gt(s2) {
-                // s1 > s2
-                sum_decimal_with_diff_scale(v1, v2, max_precision, s1, s2)
-            } else {
-                // s1 < s2
-                sum_decimal_with_diff_scale(v2, v1, max_precision, s2, s1)
+
+            match (lhs, rhs) {
+                (
+                    ScalarValue::Decimal128(v1, _, _),
+                    ScalarValue::Decimal128(v2, _, _),
+                ) => {
+                    Ok(if s1.eq(&s2) {
+                        // s1 = s2
+                        sum_decimal(v1, v2, &max_precision, &s1)
+                    } else if s1.gt(&s2) {
+                        // s1 > s2
+                        sum_decimal_with_diff_scale(v1, v2, &max_precision, &s1, &s2)
+                    } else {
+                        // s1 < s2
+                        sum_decimal_with_diff_scale(v2, v1, &max_precision, &s2, &s1)
+                    })
+                }
+                _ => Err(DataFusionError::Internal(
+                    "Internal state error on sum decimals ".to_string(),
+                )),
             }
         }
-        // float64 coerces everything to f64
-        (ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Float32(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int64(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int32(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int16(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int8(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
+        (DataType::Float64, _) | (_, DataType::Float64) => {
+            let data: ArrayRef =
+                union_arrays!(lhs, rhs, &DataType::Float64, Float64Array, "f64");
+            sum_batch(&data, &arrow::datatypes::DataType::Float64)
         }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt64(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt32(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt16(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt8(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        // float32 has no cast
-        (ScalarValue::Float32(lhs), ScalarValue::Float32(rhs)) => {
-            typed_sum!(lhs, rhs, Float32, f32)
-        }
-        // u64 coerces u* to u64
-        (ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => {
-            typed_sum!(lhs, rhs, UInt64, u64)
+        (DataType::Float32, _) | (_, DataType::Float32) => {
+            let data: ArrayRef =
+                union_arrays!(lhs, rhs, &DataType::Float32, Float32Array, "f32");

Review Comment:
   I am torn -- basically I worry that adding the common math functions to `ScalarValue` will result in them being used more (and they are very slow). See more discussion on  https://github.com/apache/arrow-datafusion/pull/1525#pullrequestreview-847147545
   
   However, if the alternative is a bunch of replicated code over the codebase, consolidating that all into ScalarValue seems like a much better outcome



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow-datafusion] alamb commented on pull request #2516: Sum refactor draft

Posted by GitBox <gi...@apache.org>.
alamb commented on PR #2516:
URL: https://github.com/apache/arrow-datafusion/pull/2516#issuecomment-1382717562

   This PR is more than 6 month old, so closing it down for now to clean up the PR list. Please reopen if this is a mistake and you plan to work on it more 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow-datafusion] alamb commented on a diff in pull request #2516: Sum refactor draft

Posted by GitBox <gi...@apache.org>.
alamb commented on code in PR #2516:
URL: https://github.com/apache/arrow-datafusion/pull/2516#discussion_r874153124


##########
datafusion/physical-expr/src/aggregate/sum.rs:
##########
@@ -262,98 +249,83 @@ fn sum_decimal_with_diff_scale(
     }
 }
 
+macro_rules! downcast_arg {
+    ($ARG:expr, $NAME:expr, $ARRAY_TYPE:ident) => {{
+        $ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| {
+            DataFusionError::Internal(format!(
+                "could not cast {} to {}",
+                $NAME,
+                type_name::<$ARRAY_TYPE>()
+            ))
+        })?
+    }};
+}
+
+macro_rules! union_arrays {
+    ($LHS: expr, $RHS: expr, $DTYPE: expr, $ARR_DTYPE: ident, $NAME: expr) => {{
+        let lhs_casted = &cast(&$LHS.to_array(), $DTYPE)?;
+        let rhs_casted = &cast(&$RHS.to_array(), $DTYPE)?;
+        let lhs_prim_array = downcast_arg!(lhs_casted, $NAME, $ARR_DTYPE);
+        let rhs_prim_array = downcast_arg!(rhs_casted, $NAME, $ARR_DTYPE);
+
+        let chained = lhs_prim_array
+            .iter()
+            .chain(rhs_prim_array.iter())
+            .collect::<$ARR_DTYPE>();
+
+        Arc::new(chained)
+    }};
+}
+
 pub(crate) fn sum(lhs: &ScalarValue, rhs: &ScalarValue) -> Result<ScalarValue> {
-    Ok(match (lhs, rhs) {
-        (ScalarValue::Decimal128(v1, p1, s1), ScalarValue::Decimal128(v2, p2, s2)) => {
+    let result = match (lhs.get_datatype(), rhs.get_datatype()) {
+        (DataType::Decimal(p1, s1), DataType::Decimal(p2, s2)) => {
             let max_precision = p1.max(p2);
-            if s1.eq(s2) {
-                // s1 = s2
-                sum_decimal(v1, v2, max_precision, s1)
-            } else if s1.gt(s2) {
-                // s1 > s2
-                sum_decimal_with_diff_scale(v1, v2, max_precision, s1, s2)
-            } else {
-                // s1 < s2
-                sum_decimal_with_diff_scale(v2, v1, max_precision, s2, s1)
+
+            match (lhs, rhs) {
+                (
+                    ScalarValue::Decimal128(v1, _, _),
+                    ScalarValue::Decimal128(v2, _, _),
+                ) => {
+                    Ok(if s1.eq(&s2) {
+                        // s1 = s2
+                        sum_decimal(v1, v2, &max_precision, &s1)
+                    } else if s1.gt(&s2) {
+                        // s1 > s2
+                        sum_decimal_with_diff_scale(v1, v2, &max_precision, &s1, &s2)
+                    } else {
+                        // s1 < s2
+                        sum_decimal_with_diff_scale(v2, v1, &max_precision, &s2, &s1)
+                    })
+                }
+                _ => Err(DataFusionError::Internal(
+                    "Internal state error on sum decimals ".to_string(),
+                )),
             }
         }
-        // float64 coerces everything to f64
-        (ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Float32(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int64(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int32(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int16(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int8(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
+        (DataType::Float64, _) | (_, DataType::Float64) => {
+            let data: ArrayRef =
+                union_arrays!(lhs, rhs, &DataType::Float64, Float64Array, "f64");
+            sum_batch(&data, &arrow::datatypes::DataType::Float64)
         }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt64(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt32(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt16(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt8(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        // float32 has no cast
-        (ScalarValue::Float32(lhs), ScalarValue::Float32(rhs)) => {
-            typed_sum!(lhs, rhs, Float32, f32)
-        }
-        // u64 coerces u* to u64
-        (ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => {
-            typed_sum!(lhs, rhs, UInt64, u64)
+        (DataType::Float32, _) | (_, DataType::Float32) => {
+            let data: ArrayRef =
+                union_arrays!(lhs, rhs, &DataType::Float32, Float32Array, "f32");

Review Comment:
   thanks@



##########
datafusion/physical-expr/src/aggregate/sum.rs:
##########
@@ -262,98 +249,83 @@ fn sum_decimal_with_diff_scale(
     }
 }
 
+macro_rules! downcast_arg {
+    ($ARG:expr, $NAME:expr, $ARRAY_TYPE:ident) => {{
+        $ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| {
+            DataFusionError::Internal(format!(
+                "could not cast {} to {}",
+                $NAME,
+                type_name::<$ARRAY_TYPE>()
+            ))
+        })?
+    }};
+}
+
+macro_rules! union_arrays {
+    ($LHS: expr, $RHS: expr, $DTYPE: expr, $ARR_DTYPE: ident, $NAME: expr) => {{
+        let lhs_casted = &cast(&$LHS.to_array(), $DTYPE)?;
+        let rhs_casted = &cast(&$RHS.to_array(), $DTYPE)?;
+        let lhs_prim_array = downcast_arg!(lhs_casted, $NAME, $ARR_DTYPE);
+        let rhs_prim_array = downcast_arg!(rhs_casted, $NAME, $ARR_DTYPE);
+
+        let chained = lhs_prim_array
+            .iter()
+            .chain(rhs_prim_array.iter())
+            .collect::<$ARR_DTYPE>();
+
+        Arc::new(chained)
+    }};
+}
+
 pub(crate) fn sum(lhs: &ScalarValue, rhs: &ScalarValue) -> Result<ScalarValue> {
-    Ok(match (lhs, rhs) {
-        (ScalarValue::Decimal128(v1, p1, s1), ScalarValue::Decimal128(v2, p2, s2)) => {
+    let result = match (lhs.get_datatype(), rhs.get_datatype()) {
+        (DataType::Decimal(p1, s1), DataType::Decimal(p2, s2)) => {
             let max_precision = p1.max(p2);
-            if s1.eq(s2) {
-                // s1 = s2
-                sum_decimal(v1, v2, max_precision, s1)
-            } else if s1.gt(s2) {
-                // s1 > s2
-                sum_decimal_with_diff_scale(v1, v2, max_precision, s1, s2)
-            } else {
-                // s1 < s2
-                sum_decimal_with_diff_scale(v2, v1, max_precision, s2, s1)
+
+            match (lhs, rhs) {
+                (
+                    ScalarValue::Decimal128(v1, _, _),
+                    ScalarValue::Decimal128(v2, _, _),
+                ) => {
+                    Ok(if s1.eq(&s2) {
+                        // s1 = s2
+                        sum_decimal(v1, v2, &max_precision, &s1)
+                    } else if s1.gt(&s2) {
+                        // s1 > s2
+                        sum_decimal_with_diff_scale(v1, v2, &max_precision, &s1, &s2)
+                    } else {
+                        // s1 < s2
+                        sum_decimal_with_diff_scale(v2, v1, &max_precision, &s2, &s1)
+                    })
+                }
+                _ => Err(DataFusionError::Internal(
+                    "Internal state error on sum decimals ".to_string(),
+                )),
             }
         }
-        // float64 coerces everything to f64
-        (ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Float32(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int64(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int32(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int16(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int8(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
+        (DataType::Float64, _) | (_, DataType::Float64) => {
+            let data: ArrayRef =
+                union_arrays!(lhs, rhs, &DataType::Float64, Float64Array, "f64");
+            sum_batch(&data, &arrow::datatypes::DataType::Float64)
         }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt64(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt32(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt16(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt8(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        // float32 has no cast
-        (ScalarValue::Float32(lhs), ScalarValue::Float32(rhs)) => {
-            typed_sum!(lhs, rhs, Float32, f32)
-        }
-        // u64 coerces u* to u64
-        (ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => {
-            typed_sum!(lhs, rhs, UInt64, u64)
+        (DataType::Float32, _) | (_, DataType::Float32) => {
+            let data: ArrayRef =
+                union_arrays!(lhs, rhs, &DataType::Float32, Float32Array, "f32");

Review Comment:
   thanks!



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow-datafusion] waynexia commented on a diff in pull request #2516: Sum refactor draft

Posted by GitBox <gi...@apache.org>.
waynexia commented on code in PR #2516:
URL: https://github.com/apache/arrow-datafusion/pull/2516#discussion_r879107316


##########
datafusion/physical-expr/src/aggregate/sum.rs:
##########
@@ -262,98 +249,83 @@ fn sum_decimal_with_diff_scale(
     }
 }
 
+macro_rules! downcast_arg {
+    ($ARG:expr, $NAME:expr, $ARRAY_TYPE:ident) => {{
+        $ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| {
+            DataFusionError::Internal(format!(
+                "could not cast {} to {}",
+                $NAME,
+                type_name::<$ARRAY_TYPE>()
+            ))
+        })?
+    }};
+}
+
+macro_rules! union_arrays {
+    ($LHS: expr, $RHS: expr, $DTYPE: expr, $ARR_DTYPE: ident, $NAME: expr) => {{
+        let lhs_casted = &cast(&$LHS.to_array(), $DTYPE)?;
+        let rhs_casted = &cast(&$RHS.to_array(), $DTYPE)?;
+        let lhs_prim_array = downcast_arg!(lhs_casted, $NAME, $ARR_DTYPE);
+        let rhs_prim_array = downcast_arg!(rhs_casted, $NAME, $ARR_DTYPE);
+
+        let chained = lhs_prim_array
+            .iter()
+            .chain(rhs_prim_array.iter())
+            .collect::<$ARR_DTYPE>();
+
+        Arc::new(chained)
+    }};
+}
+
 pub(crate) fn sum(lhs: &ScalarValue, rhs: &ScalarValue) -> Result<ScalarValue> {
-    Ok(match (lhs, rhs) {
-        (ScalarValue::Decimal128(v1, p1, s1), ScalarValue::Decimal128(v2, p2, s2)) => {
+    let result = match (lhs.get_datatype(), rhs.get_datatype()) {
+        (DataType::Decimal(p1, s1), DataType::Decimal(p2, s2)) => {
             let max_precision = p1.max(p2);
-            if s1.eq(s2) {
-                // s1 = s2
-                sum_decimal(v1, v2, max_precision, s1)
-            } else if s1.gt(s2) {
-                // s1 > s2
-                sum_decimal_with_diff_scale(v1, v2, max_precision, s1, s2)
-            } else {
-                // s1 < s2
-                sum_decimal_with_diff_scale(v2, v1, max_precision, s2, s1)
+
+            match (lhs, rhs) {
+                (
+                    ScalarValue::Decimal128(v1, _, _),
+                    ScalarValue::Decimal128(v2, _, _),
+                ) => {
+                    Ok(if s1.eq(&s2) {
+                        // s1 = s2
+                        sum_decimal(v1, v2, &max_precision, &s1)
+                    } else if s1.gt(&s2) {
+                        // s1 > s2
+                        sum_decimal_with_diff_scale(v1, v2, &max_precision, &s1, &s2)
+                    } else {
+                        // s1 < s2
+                        sum_decimal_with_diff_scale(v2, v1, &max_precision, &s2, &s1)
+                    })
+                }
+                _ => Err(DataFusionError::Internal(
+                    "Internal state error on sum decimals ".to_string(),
+                )),
             }
         }
-        // float64 coerces everything to f64
-        (ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Float32(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int64(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int32(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int16(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int8(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
+        (DataType::Float64, _) | (_, DataType::Float64) => {
+            let data: ArrayRef =
+                union_arrays!(lhs, rhs, &DataType::Float64, Float64Array, "f64");
+            sum_batch(&data, &arrow::datatypes::DataType::Float64)
         }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt64(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt32(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt16(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt8(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        // float32 has no cast
-        (ScalarValue::Float32(lhs), ScalarValue::Float32(rhs)) => {
-            typed_sum!(lhs, rhs, Float32, f32)
-        }
-        // u64 coerces u* to u64
-        (ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => {
-            typed_sum!(lhs, rhs, UInt64, u64)
+        (DataType::Float32, _) | (_, DataType::Float32) => {
+            let data: ArrayRef =
+                union_arrays!(lhs, rhs, &DataType::Float32, Float32Array, "f32");

Review Comment:
   What about just cut all this coercions logic? I've investigate all the occurrences of `sum()`, it's only used to accumulate aggregator state in `sum`, `sum_distinct` and `average` where the operand's type of `sum()` is deterministic. And `sum()` is an internal function (`pub(crate)`), API change of this function is acceptable.
   
   I try to remove all the match arms that with different operand types and only fail two cases (`sum_distinct_i32_with_nulls` and `sum_distinct_u32_with_nulls`). I think this is acceptable. And I find the min/max calculator already applied this.
   
   About how to achieve calculate operator over different types, I think we can extract our coercion rule to something like
   ```rust
   fn coercion(lhs: DataType, rhs: DataType) -> DataType {}
   ```
   And cast both operands to the result type before calculation.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow-datafusion] alamb commented on a diff in pull request #2516: Sum refactor draft

Posted by GitBox <gi...@apache.org>.
alamb commented on code in PR #2516:
URL: https://github.com/apache/arrow-datafusion/pull/2516#discussion_r876912019


##########
datafusion/physical-expr/src/aggregate/sum.rs:
##########
@@ -262,98 +249,83 @@ fn sum_decimal_with_diff_scale(
     }
 }
 
+macro_rules! downcast_arg {
+    ($ARG:expr, $NAME:expr, $ARRAY_TYPE:ident) => {{
+        $ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| {
+            DataFusionError::Internal(format!(
+                "could not cast {} to {}",
+                $NAME,
+                type_name::<$ARRAY_TYPE>()
+            ))
+        })?
+    }};
+}
+
+macro_rules! union_arrays {
+    ($LHS: expr, $RHS: expr, $DTYPE: expr, $ARR_DTYPE: ident, $NAME: expr) => {{
+        let lhs_casted = &cast(&$LHS.to_array(), $DTYPE)?;
+        let rhs_casted = &cast(&$RHS.to_array(), $DTYPE)?;
+        let lhs_prim_array = downcast_arg!(lhs_casted, $NAME, $ARR_DTYPE);
+        let rhs_prim_array = downcast_arg!(rhs_casted, $NAME, $ARR_DTYPE);
+
+        let chained = lhs_prim_array
+            .iter()
+            .chain(rhs_prim_array.iter())
+            .collect::<$ARR_DTYPE>();
+
+        Arc::new(chained)
+    }};
+}
+
 pub(crate) fn sum(lhs: &ScalarValue, rhs: &ScalarValue) -> Result<ScalarValue> {
-    Ok(match (lhs, rhs) {
-        (ScalarValue::Decimal128(v1, p1, s1), ScalarValue::Decimal128(v2, p2, s2)) => {
+    let result = match (lhs.get_datatype(), rhs.get_datatype()) {
+        (DataType::Decimal(p1, s1), DataType::Decimal(p2, s2)) => {
             let max_precision = p1.max(p2);
-            if s1.eq(s2) {
-                // s1 = s2
-                sum_decimal(v1, v2, max_precision, s1)
-            } else if s1.gt(s2) {
-                // s1 > s2
-                sum_decimal_with_diff_scale(v1, v2, max_precision, s1, s2)
-            } else {
-                // s1 < s2
-                sum_decimal_with_diff_scale(v2, v1, max_precision, s2, s1)
+
+            match (lhs, rhs) {
+                (
+                    ScalarValue::Decimal128(v1, _, _),
+                    ScalarValue::Decimal128(v2, _, _),
+                ) => {
+                    Ok(if s1.eq(&s2) {
+                        // s1 = s2
+                        sum_decimal(v1, v2, &max_precision, &s1)
+                    } else if s1.gt(&s2) {
+                        // s1 > s2
+                        sum_decimal_with_diff_scale(v1, v2, &max_precision, &s1, &s2)
+                    } else {
+                        // s1 < s2
+                        sum_decimal_with_diff_scale(v2, v1, &max_precision, &s2, &s1)
+                    })
+                }
+                _ => Err(DataFusionError::Internal(
+                    "Internal state error on sum decimals ".to_string(),
+                )),
             }
         }
-        // float64 coerces everything to f64
-        (ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Float32(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int64(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int32(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int16(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int8(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
+        (DataType::Float64, _) | (_, DataType::Float64) => {
+            let data: ArrayRef =
+                union_arrays!(lhs, rhs, &DataType::Float64, Float64Array, "f64");
+            sum_batch(&data, &arrow::datatypes::DataType::Float64)
         }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt64(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt32(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt16(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt8(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        // float32 has no cast
-        (ScalarValue::Float32(lhs), ScalarValue::Float32(rhs)) => {
-            typed_sum!(lhs, rhs, Float32, f32)
-        }
-        // u64 coerces u* to u64
-        (ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => {
-            typed_sum!(lhs, rhs, UInt64, u64)
+        (DataType::Float32, _) | (_, DataType::Float32) => {
+            let data: ArrayRef =
+                union_arrays!(lhs, rhs, &DataType::Float32, Float32Array, "f32");

Review Comment:
   I think we had a bunch of discussion on a related topic (how to handle constants and arrays similarly) on https://github.com/apache/arrow-datafusion/issues/1248
   
   Maybe that will provide some insight
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow-datafusion] alamb commented on a diff in pull request #2516: Sum refactor draft

Posted by GitBox <gi...@apache.org>.
alamb commented on code in PR #2516:
URL: https://github.com/apache/arrow-datafusion/pull/2516#discussion_r875922067


##########
datafusion/physical-expr/src/aggregate/sum.rs:
##########
@@ -262,98 +249,83 @@ fn sum_decimal_with_diff_scale(
     }
 }
 
+macro_rules! downcast_arg {
+    ($ARG:expr, $NAME:expr, $ARRAY_TYPE:ident) => {{
+        $ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| {
+            DataFusionError::Internal(format!(
+                "could not cast {} to {}",
+                $NAME,
+                type_name::<$ARRAY_TYPE>()
+            ))
+        })?
+    }};
+}
+
+macro_rules! union_arrays {
+    ($LHS: expr, $RHS: expr, $DTYPE: expr, $ARR_DTYPE: ident, $NAME: expr) => {{
+        let lhs_casted = &cast(&$LHS.to_array(), $DTYPE)?;
+        let rhs_casted = &cast(&$RHS.to_array(), $DTYPE)?;
+        let lhs_prim_array = downcast_arg!(lhs_casted, $NAME, $ARR_DTYPE);
+        let rhs_prim_array = downcast_arg!(rhs_casted, $NAME, $ARR_DTYPE);
+
+        let chained = lhs_prim_array
+            .iter()
+            .chain(rhs_prim_array.iter())
+            .collect::<$ARR_DTYPE>();
+
+        Arc::new(chained)
+    }};
+}
+
 pub(crate) fn sum(lhs: &ScalarValue, rhs: &ScalarValue) -> Result<ScalarValue> {
-    Ok(match (lhs, rhs) {
-        (ScalarValue::Decimal128(v1, p1, s1), ScalarValue::Decimal128(v2, p2, s2)) => {
+    let result = match (lhs.get_datatype(), rhs.get_datatype()) {
+        (DataType::Decimal(p1, s1), DataType::Decimal(p2, s2)) => {
             let max_precision = p1.max(p2);
-            if s1.eq(s2) {
-                // s1 = s2
-                sum_decimal(v1, v2, max_precision, s1)
-            } else if s1.gt(s2) {
-                // s1 > s2
-                sum_decimal_with_diff_scale(v1, v2, max_precision, s1, s2)
-            } else {
-                // s1 < s2
-                sum_decimal_with_diff_scale(v2, v1, max_precision, s2, s1)
+
+            match (lhs, rhs) {
+                (
+                    ScalarValue::Decimal128(v1, _, _),
+                    ScalarValue::Decimal128(v2, _, _),
+                ) => {
+                    Ok(if s1.eq(&s2) {
+                        // s1 = s2
+                        sum_decimal(v1, v2, &max_precision, &s1)
+                    } else if s1.gt(&s2) {
+                        // s1 > s2
+                        sum_decimal_with_diff_scale(v1, v2, &max_precision, &s1, &s2)
+                    } else {
+                        // s1 < s2
+                        sum_decimal_with_diff_scale(v2, v1, &max_precision, &s2, &s1)
+                    })
+                }
+                _ => Err(DataFusionError::Internal(
+                    "Internal state error on sum decimals ".to_string(),
+                )),
             }
         }
-        // float64 coerces everything to f64
-        (ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Float32(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int64(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int32(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int16(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int8(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
+        (DataType::Float64, _) | (_, DataType::Float64) => {
+            let data: ArrayRef =
+                union_arrays!(lhs, rhs, &DataType::Float64, Float64Array, "f64");
+            sum_batch(&data, &arrow::datatypes::DataType::Float64)
         }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt64(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt32(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt16(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt8(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        // float32 has no cast
-        (ScalarValue::Float32(lhs), ScalarValue::Float32(rhs)) => {
-            typed_sum!(lhs, rhs, Float32, f32)
-        }
-        // u64 coerces u* to u64
-        (ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => {
-            typed_sum!(lhs, rhs, UInt64, u64)
+        (DataType::Float32, _) | (_, DataType::Float32) => {
+            let data: ArrayRef =
+                union_arrays!(lhs, rhs, &DataType::Float32, Float32Array, "f32");

Review Comment:
   > implementing the add function(in scalar.rs or any other place) leads to boilerplate code, because there is no concise way how to get ScalarValue underlying value.
   
   I think this is the best we can do -- I don't (really) mind boiler plate code, but there seems to be too many copies of almost the same boiler plate code.
   
   I agree using `arrow::compute::sum` is likely not the right thing to do
   
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow-datafusion] comphead commented on a diff in pull request #2516: Sum refactor draft

Posted by GitBox <gi...@apache.org>.
comphead commented on code in PR #2516:
URL: https://github.com/apache/arrow-datafusion/pull/2516#discussion_r873131180


##########
datafusion/physical-expr/src/aggregate/sum.rs:
##########
@@ -262,98 +249,83 @@ fn sum_decimal_with_diff_scale(
     }
 }
 
+macro_rules! downcast_arg {
+    ($ARG:expr, $NAME:expr, $ARRAY_TYPE:ident) => {{
+        $ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| {
+            DataFusionError::Internal(format!(
+                "could not cast {} to {}",
+                $NAME,
+                type_name::<$ARRAY_TYPE>()
+            ))
+        })?
+    }};
+}
+
+macro_rules! union_arrays {
+    ($LHS: expr, $RHS: expr, $DTYPE: expr, $ARR_DTYPE: ident, $NAME: expr) => {{
+        let lhs_casted = &cast(&$LHS.to_array(), $DTYPE)?;
+        let rhs_casted = &cast(&$RHS.to_array(), $DTYPE)?;
+        let lhs_prim_array = downcast_arg!(lhs_casted, $NAME, $ARR_DTYPE);
+        let rhs_prim_array = downcast_arg!(rhs_casted, $NAME, $ARR_DTYPE);
+
+        let chained = lhs_prim_array
+            .iter()
+            .chain(rhs_prim_array.iter())
+            .collect::<$ARR_DTYPE>();
+
+        Arc::new(chained)
+    }};
+}
+
 pub(crate) fn sum(lhs: &ScalarValue, rhs: &ScalarValue) -> Result<ScalarValue> {
-    Ok(match (lhs, rhs) {
-        (ScalarValue::Decimal128(v1, p1, s1), ScalarValue::Decimal128(v2, p2, s2)) => {
+    let result = match (lhs.get_datatype(), rhs.get_datatype()) {
+        (DataType::Decimal(p1, s1), DataType::Decimal(p2, s2)) => {
             let max_precision = p1.max(p2);
-            if s1.eq(s2) {
-                // s1 = s2
-                sum_decimal(v1, v2, max_precision, s1)
-            } else if s1.gt(s2) {
-                // s1 > s2
-                sum_decimal_with_diff_scale(v1, v2, max_precision, s1, s2)
-            } else {
-                // s1 < s2
-                sum_decimal_with_diff_scale(v2, v1, max_precision, s2, s1)
+
+            match (lhs, rhs) {
+                (
+                    ScalarValue::Decimal128(v1, _, _),
+                    ScalarValue::Decimal128(v2, _, _),
+                ) => {
+                    Ok(if s1.eq(&s2) {
+                        // s1 = s2
+                        sum_decimal(v1, v2, &max_precision, &s1)
+                    } else if s1.gt(&s2) {
+                        // s1 > s2
+                        sum_decimal_with_diff_scale(v1, v2, &max_precision, &s1, &s2)
+                    } else {
+                        // s1 < s2
+                        sum_decimal_with_diff_scale(v2, v1, &max_precision, &s2, &s1)
+                    })
+                }
+                _ => Err(DataFusionError::Internal(
+                    "Internal state error on sum decimals ".to_string(),
+                )),
             }
         }
-        // float64 coerces everything to f64
-        (ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Float32(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int64(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int32(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int16(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int8(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
+        (DataType::Float64, _) | (_, DataType::Float64) => {
+            let data: ArrayRef =
+                union_arrays!(lhs, rhs, &DataType::Float64, Float64Array, "f64");
+            sum_batch(&data, &arrow::datatypes::DataType::Float64)
         }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt64(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt32(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt16(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt8(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        // float32 has no cast
-        (ScalarValue::Float32(lhs), ScalarValue::Float32(rhs)) => {
-            typed_sum!(lhs, rhs, Float32, f32)
-        }
-        // u64 coerces u* to u64
-        (ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => {
-            typed_sum!(lhs, rhs, UInt64, u64)
+        (DataType::Float32, _) | (_, DataType::Float32) => {
+            let data: ArrayRef =
+                union_arrays!(lhs, rhs, &DataType::Float32, Float32Array, "f32");

Review Comment:
   Went through those PRs. Also currently we have negate function on scalar value level. 
   ```
       pub fn arithmetic_negate(&self) -> Self {
   ```
   Probably we can extend this trend and add other common math functions to Scalar Value too like sum, minus, multiply, divide?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow-datafusion] alamb commented on a diff in pull request #2516: Sum refactor draft

Posted by GitBox <gi...@apache.org>.
alamb commented on code in PR #2516:
URL: https://github.com/apache/arrow-datafusion/pull/2516#discussion_r872653499


##########
datafusion/physical-expr/src/aggregate/sum.rs:
##########
@@ -262,98 +249,83 @@ fn sum_decimal_with_diff_scale(
     }
 }
 
+macro_rules! downcast_arg {
+    ($ARG:expr, $NAME:expr, $ARRAY_TYPE:ident) => {{
+        $ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| {
+            DataFusionError::Internal(format!(
+                "could not cast {} to {}",
+                $NAME,
+                type_name::<$ARRAY_TYPE>()
+            ))
+        })?
+    }};
+}
+
+macro_rules! union_arrays {
+    ($LHS: expr, $RHS: expr, $DTYPE: expr, $ARR_DTYPE: ident, $NAME: expr) => {{
+        let lhs_casted = &cast(&$LHS.to_array(), $DTYPE)?;
+        let rhs_casted = &cast(&$RHS.to_array(), $DTYPE)?;
+        let lhs_prim_array = downcast_arg!(lhs_casted, $NAME, $ARR_DTYPE);
+        let rhs_prim_array = downcast_arg!(rhs_casted, $NAME, $ARR_DTYPE);
+
+        let chained = lhs_prim_array
+            .iter()
+            .chain(rhs_prim_array.iter())
+            .collect::<$ARR_DTYPE>();
+
+        Arc::new(chained)
+    }};
+}
+
 pub(crate) fn sum(lhs: &ScalarValue, rhs: &ScalarValue) -> Result<ScalarValue> {
-    Ok(match (lhs, rhs) {
-        (ScalarValue::Decimal128(v1, p1, s1), ScalarValue::Decimal128(v2, p2, s2)) => {
+    let result = match (lhs.get_datatype(), rhs.get_datatype()) {
+        (DataType::Decimal(p1, s1), DataType::Decimal(p2, s2)) => {
             let max_precision = p1.max(p2);
-            if s1.eq(s2) {
-                // s1 = s2
-                sum_decimal(v1, v2, max_precision, s1)
-            } else if s1.gt(s2) {
-                // s1 > s2
-                sum_decimal_with_diff_scale(v1, v2, max_precision, s1, s2)
-            } else {
-                // s1 < s2
-                sum_decimal_with_diff_scale(v2, v1, max_precision, s2, s1)
+
+            match (lhs, rhs) {
+                (
+                    ScalarValue::Decimal128(v1, _, _),
+                    ScalarValue::Decimal128(v2, _, _),
+                ) => {
+                    Ok(if s1.eq(&s2) {
+                        // s1 = s2
+                        sum_decimal(v1, v2, &max_precision, &s1)
+                    } else if s1.gt(&s2) {
+                        // s1 > s2
+                        sum_decimal_with_diff_scale(v1, v2, &max_precision, &s1, &s2)
+                    } else {
+                        // s1 < s2
+                        sum_decimal_with_diff_scale(v2, v1, &max_precision, &s2, &s1)
+                    })
+                }
+                _ => Err(DataFusionError::Internal(
+                    "Internal state error on sum decimals ".to_string(),
+                )),
             }
         }
-        // float64 coerces everything to f64
-        (ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Float32(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int64(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int32(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int16(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int8(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
+        (DataType::Float64, _) | (_, DataType::Float64) => {
+            let data: ArrayRef =
+                union_arrays!(lhs, rhs, &DataType::Float64, Float64Array, "f64");
+            sum_batch(&data, &arrow::datatypes::DataType::Float64)
         }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt64(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt32(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt16(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt8(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        // float32 has no cast
-        (ScalarValue::Float32(lhs), ScalarValue::Float32(rhs)) => {
-            typed_sum!(lhs, rhs, Float32, f32)
-        }
-        // u64 coerces u* to u64
-        (ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => {
-            typed_sum!(lhs, rhs, UInt64, u64)
+        (DataType::Float32, _) | (_, DataType::Float32) => {
+            let data: ArrayRef =
+                union_arrays!(lhs, rhs, &DataType::Float32, Float32Array, "f32");

Review Comment:
   Something like that -- yes -- given that is effectively what this code is doing. 
   
   @realno  added just such code here https://github.com/apache/arrow-datafusion/pull/1525 though I was trying to avoid it https://github.com/apache/arrow-datafusion/pull/1550
   
   However, if it is truly required then having one copy of it seems better than several 🤔 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow-datafusion] alamb commented on a diff in pull request #2516: Sum refactor draft

Posted by GitBox <gi...@apache.org>.
alamb commented on code in PR #2516:
URL: https://github.com/apache/arrow-datafusion/pull/2516#discussion_r871802742


##########
datafusion/physical-expr/src/aggregate/sum.rs:
##########
@@ -262,98 +249,83 @@ fn sum_decimal_with_diff_scale(
     }
 }
 
+macro_rules! downcast_arg {
+    ($ARG:expr, $NAME:expr, $ARRAY_TYPE:ident) => {{
+        $ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| {
+            DataFusionError::Internal(format!(
+                "could not cast {} to {}",
+                $NAME,
+                type_name::<$ARRAY_TYPE>()
+            ))
+        })?
+    }};
+}
+
+macro_rules! union_arrays {
+    ($LHS: expr, $RHS: expr, $DTYPE: expr, $ARR_DTYPE: ident, $NAME: expr) => {{
+        let lhs_casted = &cast(&$LHS.to_array(), $DTYPE)?;
+        let rhs_casted = &cast(&$RHS.to_array(), $DTYPE)?;
+        let lhs_prim_array = downcast_arg!(lhs_casted, $NAME, $ARR_DTYPE);
+        let rhs_prim_array = downcast_arg!(rhs_casted, $NAME, $ARR_DTYPE);
+
+        let chained = lhs_prim_array
+            .iter()
+            .chain(rhs_prim_array.iter())
+            .collect::<$ARR_DTYPE>();
+
+        Arc::new(chained)
+    }};
+}
+
 pub(crate) fn sum(lhs: &ScalarValue, rhs: &ScalarValue) -> Result<ScalarValue> {
-    Ok(match (lhs, rhs) {
-        (ScalarValue::Decimal128(v1, p1, s1), ScalarValue::Decimal128(v2, p2, s2)) => {
+    let result = match (lhs.get_datatype(), rhs.get_datatype()) {
+        (DataType::Decimal(p1, s1), DataType::Decimal(p2, s2)) => {
             let max_precision = p1.max(p2);
-            if s1.eq(s2) {
-                // s1 = s2
-                sum_decimal(v1, v2, max_precision, s1)
-            } else if s1.gt(s2) {
-                // s1 > s2
-                sum_decimal_with_diff_scale(v1, v2, max_precision, s1, s2)
-            } else {
-                // s1 < s2
-                sum_decimal_with_diff_scale(v2, v1, max_precision, s2, s1)
+
+            match (lhs, rhs) {
+                (
+                    ScalarValue::Decimal128(v1, _, _),
+                    ScalarValue::Decimal128(v2, _, _),
+                ) => {
+                    Ok(if s1.eq(&s2) {
+                        // s1 = s2
+                        sum_decimal(v1, v2, &max_precision, &s1)
+                    } else if s1.gt(&s2) {
+                        // s1 > s2
+                        sum_decimal_with_diff_scale(v1, v2, &max_precision, &s1, &s2)
+                    } else {
+                        // s1 < s2
+                        sum_decimal_with_diff_scale(v2, v1, &max_precision, &s2, &s1)
+                    })
+                }
+                _ => Err(DataFusionError::Internal(
+                    "Internal state error on sum decimals ".to_string(),
+                )),
             }
         }
-        // float64 coerces everything to f64
-        (ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Float32(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int64(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int32(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int16(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int8(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
+        (DataType::Float64, _) | (_, DataType::Float64) => {
+            let data: ArrayRef =
+                union_arrays!(lhs, rhs, &DataType::Float64, Float64Array, "f64");
+            sum_batch(&data, &arrow::datatypes::DataType::Float64)
         }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt64(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt32(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt16(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt8(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        // float32 has no cast
-        (ScalarValue::Float32(lhs), ScalarValue::Float32(rhs)) => {
-            typed_sum!(lhs, rhs, Float32, f32)
-        }
-        // u64 coerces u* to u64
-        (ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => {
-            typed_sum!(lhs, rhs, UInt64, u64)
+        (DataType::Float32, _) | (_, DataType::Float32) => {
+            let data: ArrayRef =
+                union_arrays!(lhs, rhs, &DataType::Float32, Float32Array, "f32");

Review Comment:
   this is an interesting idea, but I suspect the performance will be fairly low (as it creates arrays for each value 🤔 )
   
   I wonder if we could move the `sum` logic into `scalar.rs` and instead add some sort of coertion logic 
   
   Not sure.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow-datafusion] comphead commented on a diff in pull request #2516: Sum refactor draft

Posted by GitBox <gi...@apache.org>.
comphead commented on code in PR #2516:
URL: https://github.com/apache/arrow-datafusion/pull/2516#discussion_r872215794


##########
datafusion/physical-expr/src/aggregate/sum.rs:
##########
@@ -262,98 +249,83 @@ fn sum_decimal_with_diff_scale(
     }
 }
 
+macro_rules! downcast_arg {
+    ($ARG:expr, $NAME:expr, $ARRAY_TYPE:ident) => {{
+        $ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| {
+            DataFusionError::Internal(format!(
+                "could not cast {} to {}",
+                $NAME,
+                type_name::<$ARRAY_TYPE>()
+            ))
+        })?
+    }};
+}
+
+macro_rules! union_arrays {
+    ($LHS: expr, $RHS: expr, $DTYPE: expr, $ARR_DTYPE: ident, $NAME: expr) => {{
+        let lhs_casted = &cast(&$LHS.to_array(), $DTYPE)?;
+        let rhs_casted = &cast(&$RHS.to_array(), $DTYPE)?;
+        let lhs_prim_array = downcast_arg!(lhs_casted, $NAME, $ARR_DTYPE);
+        let rhs_prim_array = downcast_arg!(rhs_casted, $NAME, $ARR_DTYPE);
+
+        let chained = lhs_prim_array
+            .iter()
+            .chain(rhs_prim_array.iter())
+            .collect::<$ARR_DTYPE>();
+
+        Arc::new(chained)
+    }};
+}
+
 pub(crate) fn sum(lhs: &ScalarValue, rhs: &ScalarValue) -> Result<ScalarValue> {
-    Ok(match (lhs, rhs) {
-        (ScalarValue::Decimal128(v1, p1, s1), ScalarValue::Decimal128(v2, p2, s2)) => {
+    let result = match (lhs.get_datatype(), rhs.get_datatype()) {
+        (DataType::Decimal(p1, s1), DataType::Decimal(p2, s2)) => {
             let max_precision = p1.max(p2);
-            if s1.eq(s2) {
-                // s1 = s2
-                sum_decimal(v1, v2, max_precision, s1)
-            } else if s1.gt(s2) {
-                // s1 > s2
-                sum_decimal_with_diff_scale(v1, v2, max_precision, s1, s2)
-            } else {
-                // s1 < s2
-                sum_decimal_with_diff_scale(v2, v1, max_precision, s2, s1)
+
+            match (lhs, rhs) {
+                (
+                    ScalarValue::Decimal128(v1, _, _),
+                    ScalarValue::Decimal128(v2, _, _),
+                ) => {
+                    Ok(if s1.eq(&s2) {
+                        // s1 = s2
+                        sum_decimal(v1, v2, &max_precision, &s1)
+                    } else if s1.gt(&s2) {
+                        // s1 > s2
+                        sum_decimal_with_diff_scale(v1, v2, &max_precision, &s1, &s2)
+                    } else {
+                        // s1 < s2
+                        sum_decimal_with_diff_scale(v2, v1, &max_precision, &s2, &s1)
+                    })
+                }
+                _ => Err(DataFusionError::Internal(
+                    "Internal state error on sum decimals ".to_string(),
+                )),
             }
         }
-        // float64 coerces everything to f64
-        (ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Float32(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int64(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int32(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int16(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int8(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
+        (DataType::Float64, _) | (_, DataType::Float64) => {
+            let data: ArrayRef =
+                union_arrays!(lhs, rhs, &DataType::Float64, Float64Array, "f64");
+            sum_batch(&data, &arrow::datatypes::DataType::Float64)
         }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt64(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt32(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt16(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt8(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        // float32 has no cast
-        (ScalarValue::Float32(lhs), ScalarValue::Float32(rhs)) => {
-            typed_sum!(lhs, rhs, Float32, f32)
-        }
-        // u64 coerces u* to u64
-        (ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => {
-            typed_sum!(lhs, rhs, UInt64, u64)
+        (DataType::Float32, _) | (_, DataType::Float32) => {
+            let data: ArrayRef =
+                union_arrays!(lhs, rhs, &DataType::Float32, Float32Array, "f32");

Review Comment:
   do you mean add a sum function to scalar value struct?
   
   like 
   ```
    sum(self, other: ScalarValue) -> ScalarValue
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [arrow-datafusion] comphead commented on a diff in pull request #2516: Sum refactor draft

Posted by GitBox <gi...@apache.org>.
comphead commented on code in PR #2516:
URL: https://github.com/apache/arrow-datafusion/pull/2516#discussion_r876757215


##########
datafusion/physical-expr/src/aggregate/sum.rs:
##########
@@ -262,98 +249,83 @@ fn sum_decimal_with_diff_scale(
     }
 }
 
+macro_rules! downcast_arg {
+    ($ARG:expr, $NAME:expr, $ARRAY_TYPE:ident) => {{
+        $ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| {
+            DataFusionError::Internal(format!(
+                "could not cast {} to {}",
+                $NAME,
+                type_name::<$ARRAY_TYPE>()
+            ))
+        })?
+    }};
+}
+
+macro_rules! union_arrays {
+    ($LHS: expr, $RHS: expr, $DTYPE: expr, $ARR_DTYPE: ident, $NAME: expr) => {{
+        let lhs_casted = &cast(&$LHS.to_array(), $DTYPE)?;
+        let rhs_casted = &cast(&$RHS.to_array(), $DTYPE)?;
+        let lhs_prim_array = downcast_arg!(lhs_casted, $NAME, $ARR_DTYPE);
+        let rhs_prim_array = downcast_arg!(rhs_casted, $NAME, $ARR_DTYPE);
+
+        let chained = lhs_prim_array
+            .iter()
+            .chain(rhs_prim_array.iter())
+            .collect::<$ARR_DTYPE>();
+
+        Arc::new(chained)
+    }};
+}
+
 pub(crate) fn sum(lhs: &ScalarValue, rhs: &ScalarValue) -> Result<ScalarValue> {
-    Ok(match (lhs, rhs) {
-        (ScalarValue::Decimal128(v1, p1, s1), ScalarValue::Decimal128(v2, p2, s2)) => {
+    let result = match (lhs.get_datatype(), rhs.get_datatype()) {
+        (DataType::Decimal(p1, s1), DataType::Decimal(p2, s2)) => {
             let max_precision = p1.max(p2);
-            if s1.eq(s2) {
-                // s1 = s2
-                sum_decimal(v1, v2, max_precision, s1)
-            } else if s1.gt(s2) {
-                // s1 > s2
-                sum_decimal_with_diff_scale(v1, v2, max_precision, s1, s2)
-            } else {
-                // s1 < s2
-                sum_decimal_with_diff_scale(v2, v1, max_precision, s2, s1)
+
+            match (lhs, rhs) {
+                (
+                    ScalarValue::Decimal128(v1, _, _),
+                    ScalarValue::Decimal128(v2, _, _),
+                ) => {
+                    Ok(if s1.eq(&s2) {
+                        // s1 = s2
+                        sum_decimal(v1, v2, &max_precision, &s1)
+                    } else if s1.gt(&s2) {
+                        // s1 > s2
+                        sum_decimal_with_diff_scale(v1, v2, &max_precision, &s1, &s2)
+                    } else {
+                        // s1 < s2
+                        sum_decimal_with_diff_scale(v2, v1, &max_precision, &s2, &s1)
+                    })
+                }
+                _ => Err(DataFusionError::Internal(
+                    "Internal state error on sum decimals ".to_string(),
+                )),
             }
         }
-        // float64 coerces everything to f64
-        (ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Float32(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int64(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int32(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int16(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::Int8(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
+        (DataType::Float64, _) | (_, DataType::Float64) => {
+            let data: ArrayRef =
+                union_arrays!(lhs, rhs, &DataType::Float64, Float64Array, "f64");
+            sum_batch(&data, &arrow::datatypes::DataType::Float64)
         }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt64(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt32(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt16(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        (ScalarValue::Float64(lhs), ScalarValue::UInt8(rhs)) => {
-            typed_sum!(lhs, rhs, Float64, f64)
-        }
-        // float32 has no cast
-        (ScalarValue::Float32(lhs), ScalarValue::Float32(rhs)) => {
-            typed_sum!(lhs, rhs, Float32, f32)
-        }
-        // u64 coerces u* to u64
-        (ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => {
-            typed_sum!(lhs, rhs, UInt64, u64)
+        (DataType::Float32, _) | (_, DataType::Float32) => {
+            let data: ArrayRef =
+                union_arrays!(lhs, rhs, &DataType::Float32, Float32Array, "f32");

Review Comment:
   Right, I'm still thinking how we can do that. I checked arrow-rs but they dont have the same model with enum + option.
   Even if we implement once the add function with boilerplate code, and then we need to implement minus function or any other function requiring direct actions with scalarValue underlying values, we will have to add boilerplate code again. The reason being we can't access underlying value without bunch of matches....
   
   I think to make things easier we need to implement a function to return inner value somehow. Kind of 
   `scalarValue.unwrap()`
   
   @alamb perhaps you have better thoughts?
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org