You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by yj...@apache.org on 2022/04/18 02:19:30 UTC
[arrow-datafusion] branch master updated: [Ballista] Enable ApproxPercentileWithWeight in Ballista and fill UT (#2192)
This is an automated email from the ASF dual-hosted git repository.
yjshen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new c91efc276 [Ballista] Enable ApproxPercentileWithWeight in Ballista and fill UT (#2192)
c91efc276 is described below
commit c91efc27658e58264c4f346a5cfdec8810179e90
Author: Yang Jiang <37...@users.noreply.github.com>
AuthorDate: Mon Apr 18 10:19:26 2022 +0800
[Ballista] Enable ApproxPercentileWithWeight in Ballista and fill UT (#2192)
* enable ApproxPercentileWithWeight in Ballista
* add ApproxPercentileWithWeight in Ballista proto
---
ballista/rust/client/src/context.rs | 223 ++++++++++++++++++++-
ballista/rust/core/proto/datafusion.proto | 1 +
.../rust/core/src/serde/physical_plan/to_proto.rs | 6 +
.../approx_percentile_cont_with_weight.rs | 13 +-
4 files changed, 226 insertions(+), 17 deletions(-)
diff --git a/ballista/rust/client/src/context.rs b/ballista/rust/client/src/context.rs
index 5899598ba..7dc7ec63b 100644
--- a/ballista/rust/client/src/context.rs
+++ b/ballista/rust/client/src/context.rs
@@ -642,7 +642,6 @@ mod tests {
BallistaConfigBuilder, BALLISTA_WITH_INFORMATION_SCHEMA,
};
use datafusion::arrow::util::pretty::pretty_format_batches;
- use datafusion::assert_batches_eq;
let config = BallistaConfigBuilder::default()
.set(BALLISTA_WITH_INFORMATION_SCHEMA, "true")
.build()
@@ -696,13 +695,15 @@ mod tests {
#[tokio::test]
#[cfg(feature = "standalone")]
- async fn test_percentile_func() {
+ async fn test_aggregate_func() {
use crate::context::BallistaContext;
use ballista_core::config::{
BallistaConfigBuilder, BALLISTA_WITH_INFORMATION_SCHEMA,
};
+ use datafusion::arrow;
use datafusion::arrow::util::pretty::pretty_format_batches;
use datafusion::prelude::ParquetReadOptions;
+
let config = BallistaConfigBuilder::default()
.set(BALLISTA_WITH_INFORMATION_SCHEMA, "true")
.build()
@@ -718,6 +719,199 @@ mod tests {
)
.await
.unwrap();
+
+ let df = context.sql("select min(\"id\") from test").await.unwrap();
+ let res = df.collect().await.unwrap();
+ let expected = vec![
+ "+--------------+",
+ "| MIN(test.id) |",
+ "+--------------+",
+ "| 0 |",
+ "+--------------+",
+ ];
+ assert_result_eq(expected, &*res);
+
+ let df = context.sql("select max(\"id\") from test").await.unwrap();
+ let res = df.collect().await.unwrap();
+ let expected = vec![
+ "+--------------+",
+ "| MAX(test.id) |",
+ "+--------------+",
+ "| 7 |",
+ "+--------------+",
+ ];
+ assert_result_eq(expected, &*res);
+
+ let df = context.sql("select SUM(\"id\") from test").await.unwrap();
+ let res = df.collect().await.unwrap();
+ let expected = vec![
+ "+--------------+",
+ "| SUM(test.id) |",
+ "+--------------+",
+ "| 28 |",
+ "+--------------+",
+ ];
+ assert_result_eq(expected, &*res);
+
+ let df = context.sql("select AVG(\"id\") from test").await.unwrap();
+ let res = df.collect().await.unwrap();
+ let expected = vec![
+ "+--------------+",
+ "| AVG(test.id) |",
+ "+--------------+",
+ "| 3.5 |",
+ "+--------------+",
+ ];
+ assert_result_eq(expected, &*res);
+
+ let df = context.sql("select COUNT(\"id\") from test").await.unwrap();
+ let res = df.collect().await.unwrap();
+ let expected = vec![
+ "+----------------+",
+ "| COUNT(test.id) |",
+ "+----------------+",
+ "| 8 |",
+ "+----------------+",
+ ];
+ assert_result_eq(expected, &*res);
+
+ let df = context
+ .sql("select approx_distinct(\"id\") from test")
+ .await
+ .unwrap();
+ let res = df.collect().await.unwrap();
+ let expected = vec![
+ "+-------------------------+",
+ "| APPROXDISTINCT(test.id) |",
+ "+-------------------------+",
+ "| 8 |",
+ "+-------------------------+",
+ ];
+ assert_result_eq(expected, &*res);
+
+ let df = context
+ .sql("select ARRAY_AGG(\"id\") from test")
+ .await
+ .unwrap();
+ let res = df.collect().await.unwrap();
+ let expected = vec![
+ "+--------------------------+",
+ "| ARRAYAGG(test.id) |",
+ "+--------------------------+",
+ "| [4, 5, 6, 7, 2, 3, 0, 1] |",
+ "+--------------------------+",
+ ];
+ assert_result_eq(expected, &*res);
+
+ let df = context.sql("select VAR(\"id\") from test").await.unwrap();
+ let res = df.collect().await.unwrap();
+ let expected = vec![
+ "+-------------------+",
+ "| VARIANCE(test.id) |",
+ "+-------------------+",
+ "| 6.000000000000001 |",
+ "+-------------------+",
+ ];
+ assert_result_eq(expected, &*res);
+
+ let df = context
+ .sql("select VAR_POP(\"id\") from test")
+ .await
+ .unwrap();
+ let res = df.collect().await.unwrap();
+ let expected = vec![
+ "+----------------------+",
+ "| VARIANCEPOP(test.id) |",
+ "+----------------------+",
+ "| 5.250000000000001 |",
+ "+----------------------+",
+ ];
+ assert_result_eq(expected, &*res);
+
+ let df = context
+ .sql("select VAR_SAMP(\"id\") from test")
+ .await
+ .unwrap();
+ let res = df.collect().await.unwrap();
+ let expected = vec![
+ "+-------------------+",
+ "| VARIANCE(test.id) |",
+ "+-------------------+",
+ "| 6.000000000000001 |",
+ "+-------------------+",
+ ];
+ assert_result_eq(expected, &*res);
+
+ let df = context
+ .sql("select STDDEV(\"id\") from test")
+ .await
+ .unwrap();
+ let res = df.collect().await.unwrap();
+ let expected = vec![
+ "+--------------------+",
+ "| STDDEV(test.id) |",
+ "+--------------------+",
+ "| 2.4494897427831783 |",
+ "+--------------------+",
+ ];
+ assert_result_eq(expected, &*res);
+
+ let df = context
+ .sql("select STDDEV_SAMP(\"id\") from test")
+ .await
+ .unwrap();
+ let res = df.collect().await.unwrap();
+ let expected = vec![
+ "+--------------------+",
+ "| STDDEV(test.id) |",
+ "+--------------------+",
+ "| 2.4494897427831783 |",
+ "+--------------------+",
+ ];
+ assert_result_eq(expected, &*res);
+
+ let df = context
+ .sql("select COVAR(id, tinyint_col) from test")
+ .await
+ .unwrap();
+ let res = df.collect().await.unwrap();
+ let expected = vec![
+ "+--------------------------------------+",
+ "| COVARIANCE(test.id,test.tinyint_col) |",
+ "+--------------------------------------+",
+ "| 0.28571428571428586 |",
+ "+--------------------------------------+",
+ ];
+ assert_result_eq(expected, &*res);
+
+ let df = context
+ .sql("select CORR(id, tinyint_col) from test")
+ .await
+ .unwrap();
+ let res = df.collect().await.unwrap();
+ let expected = vec![
+ "+---------------------------------------+",
+ "| CORRELATION(test.id,test.tinyint_col) |",
+ "+---------------------------------------+",
+ "| 0.21821789023599245 |",
+ "+---------------------------------------+",
+ ];
+ assert_result_eq(expected, &*res);
+
+ let df = context
+ .sql("select approx_percentile_cont_with_weight(\"id\", 2, 0.5) from test")
+ .await
+ .unwrap();
+ let res = df.collect().await.unwrap();
+ let expected = vec![
+ "+---------------------------------------------------------------+",
+ "| APPROXPERCENTILECONTWITHWEIGHT(test.id,Int64(2),Float64(0.5)) |",
+ "+---------------------------------------------------------------+",
+ "| 1 |",
+ "+---------------------------------------------------------------+",
+ ];
+ assert_result_eq(expected, &*res);
+
let df = context
.sql("select approx_percentile_cont(\"double_col\", 0.5) from test")
.await
@@ -731,14 +925,21 @@ mod tests {
"+----------------------------------------------------+",
];
- assert_eq!(
- expected,
- pretty_format_batches(&*res)
- .unwrap()
- .to_string()
- .trim()
- .lines()
- .collect::<Vec<&str>>()
- );
+ assert_result_eq(expected, &*res);
+
+ fn assert_result_eq(
+ expected: Vec<&str>,
+ results: &[arrow::record_batch::RecordBatch],
+ ) {
+ assert_eq!(
+ expected,
+ pretty_format_batches(results)
+ .unwrap()
+ .to_string()
+ .trim()
+ .lines()
+ .collect::<Vec<&str>>()
+ );
+ }
}
}
diff --git a/ballista/rust/core/proto/datafusion.proto b/ballista/rust/core/proto/datafusion.proto
index 1dc9b34f7..9999abbf2 100644
--- a/ballista/rust/core/proto/datafusion.proto
+++ b/ballista/rust/core/proto/datafusion.proto
@@ -201,6 +201,7 @@ enum AggregateFunction {
CORRELATION=13;
APPROX_PERCENTILE_CONT = 14;
APPROX_MEDIAN=15;
+ APPROX_PERCENTILE_CONT_WITH_WEIGHT = 16;
}
message AggregateExprNode {
diff --git a/ballista/rust/core/src/serde/physical_plan/to_proto.rs b/ballista/rust/core/src/serde/physical_plan/to_proto.rs
index 3a1f24d0f..d022766d8 100644
--- a/ballista/rust/core/src/serde/physical_plan/to_proto.rs
+++ b/ballista/rust/core/src/serde/physical_plan/to_proto.rs
@@ -123,6 +123,12 @@ impl TryInto<protobuf::PhysicalExprNode> for Arc<dyn AggregateExpr> {
.is_some()
{
Ok(AggregateFunction::ApproxPercentileCont.into())
+ } else if self
+ .as_any()
+ .downcast_ref::<expressions::ApproxPercentileContWithWeight>()
+ .is_some()
+ {
+ Ok(AggregateFunction::ApproxPercentileContWithWeight.into())
} else if self
.as_any()
.downcast_ref::<expressions::ApproxMedian>()
diff --git a/datafusion/physical-expr/src/expressions/approx_percentile_cont_with_weight.rs b/datafusion/physical-expr/src/expressions/approx_percentile_cont_with_weight.rs
index 33b2ee7a6..1beb7a86c 100644
--- a/datafusion/physical-expr/src/expressions/approx_percentile_cont_with_weight.rs
+++ b/datafusion/physical-expr/src/expressions/approx_percentile_cont_with_weight.rs
@@ -38,6 +38,7 @@ pub struct ApproxPercentileContWithWeight {
approx_percentile_cont: ApproxPercentileCont,
column_expr: Arc<dyn PhysicalExpr>,
weight_expr: Arc<dyn PhysicalExpr>,
+ percentile_expr: Arc<dyn PhysicalExpr>,
}
impl ApproxPercentileContWithWeight {
@@ -58,6 +59,7 @@ impl ApproxPercentileContWithWeight {
approx_percentile_cont,
column_expr: expr[0].clone(),
weight_expr: expr[1].clone(),
+ percentile_expr: expr[2].clone(),
})
}
}
@@ -79,7 +81,11 @@ impl AggregateExpr for ApproxPercentileContWithWeight {
}
fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
- vec![self.column_expr.clone(), self.weight_expr.clone()]
+ vec![
+ self.column_expr.clone(),
+ self.weight_expr.clone(),
+ self.percentile_expr.clone(),
+ ]
}
fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
@@ -115,11 +121,6 @@ impl Accumulator for ApproxPercentileWithWeightAccumulator {
}
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
- debug_assert_eq!(
- values.len(),
- 2,
- "invalid number of values in batch percentile update"
- );
let means = &values[0];
let weights = &values[1];
debug_assert_eq!(