You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by GitBox <gi...@apache.org> on 2022/03/23 20:37:05 UTC

[GitHub] [arrow-datafusion] alamb commented on a change in pull request #2031: #2004 approx percentile with weight

alamb commented on a change in pull request #2031:
URL: https://github.com/apache/arrow-datafusion/pull/2031#discussion_r833685712



##########
File path: datafusion-physical-expr/src/coercion_rule/aggregate_rule.rs
##########
@@ -152,6 +152,27 @@ pub fn coerce_types(
             }
             Ok(input_types.to_vec())
         }
+        AggregateFunction::ApproxPercentileContWithWeight => {
+            if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) {
+                return Err(DataFusionError::Plan(format!(
+                    "The function {:?} does not support inputs of type {:?}.",
+                    agg_fun, input_types[0]
+                )));
+            }
+            if !is_approx_percentile_cont_supported_arg_type(&input_types[1]) {
+                return Err(DataFusionError::Plan(format!(
+                    "The weight argument for {:?} does not support inputs of type {:?}.",
+                    agg_fun, input_types[0]

Review comment:
       ```suggestion
                       agg_fun, input_types[1]
   ```

##########
File path: datafusion/tests/sql/aggregates.rs
##########
@@ -476,6 +476,59 @@ async fn csv_query_approx_percentile_cont() -> Result<()> {
     Ok(())
 }
 
+#[tokio::test]
+async fn csv_query_approx_percentile_cont_with_weight() -> Result<()> {
+    let mut ctx = SessionContext::new();
+    register_aggregate_csv(&mut ctx).await?;
+
+    // compare approx_percentile_cont and approx_percentile_cont_with_weight
+    let sql = "SELECT c1, approx_percentile_cont(c3, 0.95) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1";
+    let actual = execute_to_batches(&ctx, sql).await;
+    let expected = vec![
+        "+----+--------+",
+        "| c1 | c3_p95 |",
+        "+----+--------+",
+        "| a  | 73     |",
+        "| b  | 68     |",
+        "| c  | 122    |",
+        "| d  | 124    |",
+        "| e  | 115    |",
+        "+----+--------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+
+    let sql = "SELECT c1, approx_percentile_cont_with_weight(c3, 1, 0.95) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1";
+    let actual = execute_to_batches(&ctx, sql).await;
+    let expected = vec![
+        "+----+--------+",
+        "| c1 | c3_p95 |",
+        "+----+--------+",
+        "| a  | 73     |",
+        "| b  | 68     |",
+        "| c  | 122    |",
+        "| d  | 124    |",
+        "| e  | 115    |",
+        "+----+--------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+
+    let sql = "SELECT c1, approx_percentile_cont_with_weight(c3, c2, 0.95) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1";
+    let actual = execute_to_batches(&ctx, sql).await;
+    let expected = vec![
+        "+----+--------+",

Review comment:
       these values are different -- I don't honestly know if they are correct or not 🤷 

##########
File path: datafusion-physical-expr/src/expressions/approx_percentile_cont.rs
##########
@@ -194,75 +204,125 @@ pub struct ApproxPercentileAccumulator {
 impl ApproxPercentileAccumulator {
     pub fn new(percentile: f64, return_type: DataType) -> Self {
         Self {
-            digest: TDigest::new(100),
+            digest: TDigest::new(DEFAULT_MAX_SIZE),
             percentile,
             return_type,
         }
     }
-}
 
-impl Accumulator for ApproxPercentileAccumulator {
-    fn state(&self) -> Result<Vec<ScalarValue>> {
-        Ok(self.digest.to_scalar_state())
+    pub(crate) fn merge_digests(&mut self, digests: &[TDigest]) {
+        self.digest = TDigest::merge_digests(digests);
     }
 
-    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
-        debug_assert_eq!(
-            values.len(),
-            1,
-            "invalid number of values in batch percentile update"
-        );
-        let values = &values[0];
-
-        self.digest = match values.data_type() {
+    pub(crate) fn convert_to_ordered_float(
+        values: &ArrayRef,
+    ) -> Result<Vec<OrderedFloat<f64>>> {
+        match values.data_type() {
             DataType::Float64 => {
                 let array = values.as_any().downcast_ref::<Float64Array>().unwrap();
-                self.digest.merge_unsorted(array.values().iter().cloned())?
+                Ok(array
+                    .values()
+                    .iter()
+                    .filter_map(|v| v.try_as_f64().transpose())
+                    .collect::<Result<Vec<_>>>()?)

Review comment:
       Yeah, it is tough because the type of the various branches are different.
   
   You could could save the copy by doing something like letting the caller provide a function that gets invoked for each element (untested):
   
   ```rust
       pub(crate) fn convert_to_ordered_float(
           values: &ArrayRef,
           f: impl FnMut(Option<OrderedFloat<f64>))
       ) ->  Result<()> {
   ```
   
   And then call `f()` on each element;
   
   ```rust
   ...
               DataType::Float32 => {
                   let array = values.as_any().downcast_ref::<Float32Array>().unwrap();
                   array
                       .values()
                       .iter()
                       .try_for_each(|v| {
                         f(v.try_as_f64()?)
                        })
               }
   ...
   ```

##########
File path: datafusion/tests/sql/aggregates.rs
##########
@@ -476,6 +476,59 @@ async fn csv_query_approx_percentile_cont() -> Result<()> {
     Ok(())
 }
 
+#[tokio::test]
+async fn csv_query_approx_percentile_cont_with_weight() -> Result<()> {
+    let mut ctx = SessionContext::new();
+    register_aggregate_csv(&mut ctx).await?;
+
+    // compare approx_percentile_cont and approx_percentile_cont_with_weight
+    let sql = "SELECT c1, approx_percentile_cont(c3, 0.95) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1";
+    let actual = execute_to_batches(&ctx, sql).await;
+    let expected = vec![
+        "+----+--------+",
+        "| c1 | c3_p95 |",
+        "+----+--------+",
+        "| a  | 73     |",
+        "| b  | 68     |",
+        "| c  | 122    |",
+        "| d  | 124    |",
+        "| e  | 115    |",
+        "+----+--------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+
+    let sql = "SELECT c1, approx_percentile_cont_with_weight(c3, 1, 0.95) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1";
+    let actual = execute_to_batches(&ctx, sql).await;
+    let expected = vec![
+        "+----+--------+",
+        "| c1 | c3_p95 |",
+        "+----+--------+",
+        "| a  | 73     |",
+        "| b  | 68     |",
+        "| c  | 122    |",
+        "| d  | 124    |",
+        "| e  | 115    |",
+        "+----+--------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+
+    let sql = "SELECT c1, approx_percentile_cont_with_weight(c3, c2, 0.95) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1";
+    let actual = execute_to_batches(&ctx, sql).await;
+    let expected = vec![
+        "+----+--------+",
+        "| c1 | c3_p95 |",
+        "+----+--------+",
+        "| a  | 74     |",
+        "| b  | 68     |",
+        "| c  | 123    |",
+        "| d  | 124    |",
+        "| e  | 115    |",
+        "+----+--------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+    Ok(())
+}
+

Review comment:
       I wonder if it is worth one or two error cases here (e.g. try and invoke this function on a `StringArray`)

##########
File path: datafusion-physical-expr/src/expressions/approx_percentile_cont.rs
##########
@@ -194,75 +204,125 @@ pub struct ApproxPercentileAccumulator {
 impl ApproxPercentileAccumulator {
     pub fn new(percentile: f64, return_type: DataType) -> Self {
         Self {
-            digest: TDigest::new(100),
+            digest: TDigest::new(DEFAULT_MAX_SIZE),

Review comment:
       changing to a symbolic constant `DEFAULT_MAX_SIZE` is a nice improvement 👍 

##########
File path: datafusion/tests/sql/aggregates.rs
##########
@@ -476,6 +476,59 @@ async fn csv_query_approx_percentile_cont() -> Result<()> {
     Ok(())
 }
 
+#[tokio::test]
+async fn csv_query_approx_percentile_cont_with_weight() -> Result<()> {
+    let mut ctx = SessionContext::new();
+    register_aggregate_csv(&mut ctx).await?;
+
+    // compare approx_percentile_cont and approx_percentile_cont_with_weight
+    let sql = "SELECT c1, approx_percentile_cont(c3, 0.95) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1";
+    let actual = execute_to_batches(&ctx, sql).await;
+    let expected = vec![
+        "+----+--------+",
+        "| c1 | c3_p95 |",
+        "+----+--------+",
+        "| a  | 73     |",
+        "| b  | 68     |",
+        "| c  | 122    |",
+        "| d  | 124    |",
+        "| e  | 115    |",
+        "+----+--------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+
+    let sql = "SELECT c1, approx_percentile_cont_with_weight(c3, 1, 0.95) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1";
+    let actual = execute_to_batches(&ctx, sql).await;
+    let expected = vec![
+        "+----+--------+",
+        "| c1 | c3_p95 |",
+        "+----+--------+",
+        "| a  | 73     |",
+        "| b  | 68     |",
+        "| c  | 122    |",
+        "| d  | 124    |",
+        "| e  | 115    |",
+        "+----+--------+",
+    ];

Review comment:
       these values seem to be the same as the ones with `approx_percentile_cont(c3, 0.95)` which I think is the point of the test. Perhaps we could encode that into the test:
   
   ```suggestion
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org