You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by yj...@apache.org on 2022/04/18 02:19:30 UTC

[arrow-datafusion] branch master updated: [Ballista] Enable ApproxPercentileWithWeight in Ballista and fill UT (#2192)

This is an automated email from the ASF dual-hosted git repository.

yjshen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new c91efc276 [Ballista] Enable ApproxPercentileWithWeight in Ballista and fill UT  (#2192)
c91efc276 is described below

commit c91efc27658e58264c4f346a5cfdec8810179e90
Author: Yang Jiang <37...@users.noreply.github.com>
AuthorDate: Mon Apr 18 10:19:26 2022 +0800

    [Ballista] Enable ApproxPercentileWithWeight in Ballista and fill UT  (#2192)
    
    * enable ApproxPercentileWithWeight in Ballista
    
    * add ApproxPercentileWithWeight in Ballista proto
---
 ballista/rust/client/src/context.rs                | 223 ++++++++++++++++++++-
 ballista/rust/core/proto/datafusion.proto          |   1 +
 .../rust/core/src/serde/physical_plan/to_proto.rs  |   6 +
 .../approx_percentile_cont_with_weight.rs          |  13 +-
 4 files changed, 226 insertions(+), 17 deletions(-)

diff --git a/ballista/rust/client/src/context.rs b/ballista/rust/client/src/context.rs
index 5899598ba..7dc7ec63b 100644
--- a/ballista/rust/client/src/context.rs
+++ b/ballista/rust/client/src/context.rs
@@ -642,7 +642,6 @@ mod tests {
             BallistaConfigBuilder, BALLISTA_WITH_INFORMATION_SCHEMA,
         };
         use datafusion::arrow::util::pretty::pretty_format_batches;
-        use datafusion::assert_batches_eq;
         let config = BallistaConfigBuilder::default()
             .set(BALLISTA_WITH_INFORMATION_SCHEMA, "true")
             .build()
@@ -696,13 +695,15 @@ mod tests {
 
     #[tokio::test]
     #[cfg(feature = "standalone")]
-    async fn test_percentile_func() {
+    async fn test_aggregate_func() {
         use crate::context::BallistaContext;
         use ballista_core::config::{
             BallistaConfigBuilder, BALLISTA_WITH_INFORMATION_SCHEMA,
         };
+        use datafusion::arrow;
         use datafusion::arrow::util::pretty::pretty_format_batches;
         use datafusion::prelude::ParquetReadOptions;
+
         let config = BallistaConfigBuilder::default()
             .set(BALLISTA_WITH_INFORMATION_SCHEMA, "true")
             .build()
@@ -718,6 +719,199 @@ mod tests {
             )
             .await
             .unwrap();
+
+        let df = context.sql("select min(\"id\") from test").await.unwrap();
+        let res = df.collect().await.unwrap();
+        let expected = vec![
+            "+--------------+",
+            "| MIN(test.id) |",
+            "+--------------+",
+            "| 0            |",
+            "+--------------+",
+        ];
+        assert_result_eq(expected, &*res);
+
+        let df = context.sql("select max(\"id\") from test").await.unwrap();
+        let res = df.collect().await.unwrap();
+        let expected = vec![
+            "+--------------+",
+            "| MAX(test.id) |",
+            "+--------------+",
+            "| 7            |",
+            "+--------------+",
+        ];
+        assert_result_eq(expected, &*res);
+
+        let df = context.sql("select SUM(\"id\") from test").await.unwrap();
+        let res = df.collect().await.unwrap();
+        let expected = vec![
+            "+--------------+",
+            "| SUM(test.id) |",
+            "+--------------+",
+            "| 28           |",
+            "+--------------+",
+        ];
+        assert_result_eq(expected, &*res);
+
+        let df = context.sql("select AVG(\"id\") from test").await.unwrap();
+        let res = df.collect().await.unwrap();
+        let expected = vec![
+            "+--------------+",
+            "| AVG(test.id) |",
+            "+--------------+",
+            "| 3.5          |",
+            "+--------------+",
+        ];
+        assert_result_eq(expected, &*res);
+
+        let df = context.sql("select COUNT(\"id\") from test").await.unwrap();
+        let res = df.collect().await.unwrap();
+        let expected = vec![
+            "+----------------+",
+            "| COUNT(test.id) |",
+            "+----------------+",
+            "| 8              |",
+            "+----------------+",
+        ];
+        assert_result_eq(expected, &*res);
+
+        let df = context
+            .sql("select approx_distinct(\"id\") from test")
+            .await
+            .unwrap();
+        let res = df.collect().await.unwrap();
+        let expected = vec![
+            "+-------------------------+",
+            "| APPROXDISTINCT(test.id) |",
+            "+-------------------------+",
+            "| 8                       |",
+            "+-------------------------+",
+        ];
+        assert_result_eq(expected, &*res);
+
+        let df = context
+            .sql("select ARRAY_AGG(\"id\") from test")
+            .await
+            .unwrap();
+        let res = df.collect().await.unwrap();
+        let expected = vec![
+            "+--------------------------+",
+            "| ARRAYAGG(test.id)        |",
+            "+--------------------------+",
+            "| [4, 5, 6, 7, 2, 3, 0, 1] |",
+            "+--------------------------+",
+        ];
+        assert_result_eq(expected, &*res);
+
+        let df = context.sql("select VAR(\"id\") from test").await.unwrap();
+        let res = df.collect().await.unwrap();
+        let expected = vec![
+            "+-------------------+",
+            "| VARIANCE(test.id) |",
+            "+-------------------+",
+            "| 6.000000000000001 |",
+            "+-------------------+",
+        ];
+        assert_result_eq(expected, &*res);
+
+        let df = context
+            .sql("select VAR_POP(\"id\") from test")
+            .await
+            .unwrap();
+        let res = df.collect().await.unwrap();
+        let expected = vec![
+            "+----------------------+",
+            "| VARIANCEPOP(test.id) |",
+            "+----------------------+",
+            "| 5.250000000000001    |",
+            "+----------------------+",
+        ];
+        assert_result_eq(expected, &*res);
+
+        let df = context
+            .sql("select VAR_SAMP(\"id\") from test")
+            .await
+            .unwrap();
+        let res = df.collect().await.unwrap();
+        let expected = vec![
+            "+-------------------+",
+            "| VARIANCE(test.id) |",
+            "+-------------------+",
+            "| 6.000000000000001 |",
+            "+-------------------+",
+        ];
+        assert_result_eq(expected, &*res);
+
+        let df = context
+            .sql("select STDDEV(\"id\") from test")
+            .await
+            .unwrap();
+        let res = df.collect().await.unwrap();
+        let expected = vec![
+            "+--------------------+",
+            "| STDDEV(test.id)    |",
+            "+--------------------+",
+            "| 2.4494897427831783 |",
+            "+--------------------+",
+        ];
+        assert_result_eq(expected, &*res);
+
+        let df = context
+            .sql("select STDDEV_SAMP(\"id\") from test")
+            .await
+            .unwrap();
+        let res = df.collect().await.unwrap();
+        let expected = vec![
+            "+--------------------+",
+            "| STDDEV(test.id)    |",
+            "+--------------------+",
+            "| 2.4494897427831783 |",
+            "+--------------------+",
+        ];
+        assert_result_eq(expected, &*res);
+
+        let df = context
+            .sql("select COVAR(id, tinyint_col) from test")
+            .await
+            .unwrap();
+        let res = df.collect().await.unwrap();
+        let expected = vec![
+            "+--------------------------------------+",
+            "| COVARIANCE(test.id,test.tinyint_col) |",
+            "+--------------------------------------+",
+            "| 0.28571428571428586                  |",
+            "+--------------------------------------+",
+        ];
+        assert_result_eq(expected, &*res);
+
+        let df = context
+            .sql("select CORR(id, tinyint_col) from test")
+            .await
+            .unwrap();
+        let res = df.collect().await.unwrap();
+        let expected = vec![
+            "+---------------------------------------+",
+            "| CORRELATION(test.id,test.tinyint_col) |",
+            "+---------------------------------------+",
+            "| 0.21821789023599245                   |",
+            "+---------------------------------------+",
+        ];
+        assert_result_eq(expected, &*res);
+
+        let df = context
+            .sql("select approx_percentile_cont_with_weight(\"id\", 2, 0.5) from test")
+            .await
+            .unwrap();
+        let res = df.collect().await.unwrap();
+        let expected = vec![
+            "+---------------------------------------------------------------+",
+            "| APPROXPERCENTILECONTWITHWEIGHT(test.id,Int64(2),Float64(0.5)) |",
+            "+---------------------------------------------------------------+",
+            "| 1                                                             |",
+            "+---------------------------------------------------------------+",
+        ];
+        assert_result_eq(expected, &*res);
+
         let df = context
             .sql("select approx_percentile_cont(\"double_col\", 0.5) from test")
             .await
@@ -731,14 +925,21 @@ mod tests {
             "+----------------------------------------------------+",
         ];
 
-        assert_eq!(
-            expected,
-            pretty_format_batches(&*res)
-                .unwrap()
-                .to_string()
-                .trim()
-                .lines()
-                .collect::<Vec<&str>>()
-        );
+        assert_result_eq(expected, &*res);
+
+        fn assert_result_eq(
+            expected: Vec<&str>,
+            results: &[arrow::record_batch::RecordBatch],
+        ) {
+            assert_eq!(
+                expected,
+                pretty_format_batches(results)
+                    .unwrap()
+                    .to_string()
+                    .trim()
+                    .lines()
+                    .collect::<Vec<&str>>()
+            );
+        }
     }
 }
diff --git a/ballista/rust/core/proto/datafusion.proto b/ballista/rust/core/proto/datafusion.proto
index 1dc9b34f7..9999abbf2 100644
--- a/ballista/rust/core/proto/datafusion.proto
+++ b/ballista/rust/core/proto/datafusion.proto
@@ -201,6 +201,7 @@ enum AggregateFunction {
   CORRELATION=13;
   APPROX_PERCENTILE_CONT = 14;
   APPROX_MEDIAN=15;
+  APPROX_PERCENTILE_CONT_WITH_WEIGHT = 16;
 }
 
 message AggregateExprNode {
diff --git a/ballista/rust/core/src/serde/physical_plan/to_proto.rs b/ballista/rust/core/src/serde/physical_plan/to_proto.rs
index 3a1f24d0f..d022766d8 100644
--- a/ballista/rust/core/src/serde/physical_plan/to_proto.rs
+++ b/ballista/rust/core/src/serde/physical_plan/to_proto.rs
@@ -123,6 +123,12 @@ impl TryInto<protobuf::PhysicalExprNode> for Arc<dyn AggregateExpr> {
             .is_some()
         {
             Ok(AggregateFunction::ApproxPercentileCont.into())
+        } else if self
+            .as_any()
+            .downcast_ref::<expressions::ApproxPercentileContWithWeight>()
+            .is_some()
+        {
+            Ok(AggregateFunction::ApproxPercentileContWithWeight.into())
         } else if self
             .as_any()
             .downcast_ref::<expressions::ApproxMedian>()
diff --git a/datafusion/physical-expr/src/expressions/approx_percentile_cont_with_weight.rs b/datafusion/physical-expr/src/expressions/approx_percentile_cont_with_weight.rs
index 33b2ee7a6..1beb7a86c 100644
--- a/datafusion/physical-expr/src/expressions/approx_percentile_cont_with_weight.rs
+++ b/datafusion/physical-expr/src/expressions/approx_percentile_cont_with_weight.rs
@@ -38,6 +38,7 @@ pub struct ApproxPercentileContWithWeight {
     approx_percentile_cont: ApproxPercentileCont,
     column_expr: Arc<dyn PhysicalExpr>,
     weight_expr: Arc<dyn PhysicalExpr>,
+    percentile_expr: Arc<dyn PhysicalExpr>,
 }
 
 impl ApproxPercentileContWithWeight {
@@ -58,6 +59,7 @@ impl ApproxPercentileContWithWeight {
             approx_percentile_cont,
             column_expr: expr[0].clone(),
             weight_expr: expr[1].clone(),
+            percentile_expr: expr[2].clone(),
         })
     }
 }
@@ -79,7 +81,11 @@ impl AggregateExpr for ApproxPercentileContWithWeight {
     }
 
     fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
-        vec![self.column_expr.clone(), self.weight_expr.clone()]
+        vec![
+            self.column_expr.clone(),
+            self.weight_expr.clone(),
+            self.percentile_expr.clone(),
+        ]
     }
 
     fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
@@ -115,11 +121,6 @@ impl Accumulator for ApproxPercentileWithWeightAccumulator {
     }
 
     fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
-        debug_assert_eq!(
-            values.len(),
-            2,
-            "invalid number of values in batch percentile update"
-        );
         let means = &values[0];
         let weights = &values[1];
         debug_assert_eq!(