You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2021/01/21 12:11:23 UTC

[arrow] branch master updated: ARROW-11220: [Rust] Implement GROUP BY support for Boolean

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 8b56f85  ARROW-11220: [Rust] Implement GROUP BY support for Boolean
8b56f85 is described below

commit 8b56f85084d7d544002d99368fe2e65162743d0f
Author: Dmitry Patsura <za...@gmail.com>
AuthorDate: Thu Jan 21 07:09:49 2021 -0500

    ARROW-11220: [Rust] Implement GROUP BY support for Boolean
    
    Introduce support in DataFushion for GROUP BY on boolean values. Boolean type in Rust implements Eq and Hash traits which allow us to use GroupByScalar.
    
    Closes #9174 from ovr/issue-11220
    
    Authored-by: Dmitry Patsura <za...@gmail.com>
    Signed-off-by: Andrew Lamb <an...@nerdnetworks.org>
---
 rust/datafusion/src/physical_plan/group_scalar.rs  |  4 +++
 .../datafusion/src/physical_plan/hash_aggregate.rs | 14 +++++++++-
 rust/datafusion/src/physical_plan/hash_join.rs     | 24 +++++++++++++++++
 rust/datafusion/tests/aggregate_floats.csv         | 16 ------------
 rust/datafusion/tests/aggregate_simple.csv         | 16 ++++++++++++
 rust/datafusion/tests/sql.rs                       | 30 +++++++++++++++++-----
 6 files changed, 80 insertions(+), 24 deletions(-)

diff --git a/rust/datafusion/src/physical_plan/group_scalar.rs b/rust/datafusion/src/physical_plan/group_scalar.rs
index 3d02a8d..6aa699b 100644
--- a/rust/datafusion/src/physical_plan/group_scalar.rs
+++ b/rust/datafusion/src/physical_plan/group_scalar.rs
@@ -37,6 +37,7 @@ pub(crate) enum GroupByScalar {
     Int32(i32),
     Int64(i64),
     Utf8(Box<String>),
+    Boolean(bool),
     TimeMicrosecond(i64),
     TimeNanosecond(i64),
 }
@@ -52,6 +53,7 @@ impl TryFrom<&ScalarValue> for GroupByScalar {
             ScalarValue::Float64(Some(v)) => {
                 GroupByScalar::Float64(OrderedFloat::from(*v))
             }
+            ScalarValue::Boolean(Some(v)) => GroupByScalar::Boolean(*v),
             ScalarValue::Int8(Some(v)) => GroupByScalar::Int8(*v),
             ScalarValue::Int16(Some(v)) => GroupByScalar::Int16(*v),
             ScalarValue::Int32(Some(v)) => GroupByScalar::Int32(*v),
@@ -63,6 +65,7 @@ impl TryFrom<&ScalarValue> for GroupByScalar {
             ScalarValue::Utf8(Some(v)) => GroupByScalar::Utf8(Box::new(v.clone())),
             ScalarValue::Float32(None)
             | ScalarValue::Float64(None)
+            | ScalarValue::Boolean(None)
             | ScalarValue::Int8(None)
             | ScalarValue::Int16(None)
             | ScalarValue::Int32(None)
@@ -92,6 +95,7 @@ impl From<&GroupByScalar> for ScalarValue {
         match group_by_scalar {
             GroupByScalar::Float32(v) => ScalarValue::Float32(Some((*v).into())),
             GroupByScalar::Float64(v) => ScalarValue::Float64(Some((*v).into())),
+            GroupByScalar::Boolean(v) => ScalarValue::Boolean(Some(*v)),
             GroupByScalar::Int8(v) => ScalarValue::Int8(Some(*v)),
             GroupByScalar::Int16(v) => ScalarValue::Int16(Some(*v)),
             GroupByScalar::Int32(v) => ScalarValue::Int32(Some(*v)),
diff --git a/rust/datafusion/src/physical_plan/hash_aggregate.rs b/rust/datafusion/src/physical_plan/hash_aggregate.rs
index de5c425..880f87b 100644
--- a/rust/datafusion/src/physical_plan/hash_aggregate.rs
+++ b/rust/datafusion/src/physical_plan/hash_aggregate.rs
@@ -30,10 +30,13 @@ use crate::error::{DataFusionError, Result};
 use crate::physical_plan::{Accumulator, AggregateExpr};
 use crate::physical_plan::{Distribution, ExecutionPlan, Partitioning, PhysicalExpr};
 
-use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit};
 use arrow::error::{ArrowError, Result as ArrowResult};
 use arrow::record_batch::RecordBatch;
 use arrow::{
+    array::BooleanArray,
+    datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit},
+};
+use arrow::{
     array::{
         ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array,
         Int8Array, StringArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
@@ -373,6 +376,10 @@ pub(crate) fn create_key(
     vec.clear();
     for col in group_by_keys {
         match col.data_type() {
+            DataType::Boolean => {
+                let array = col.as_any().downcast_ref::<BooleanArray>().unwrap();
+                vec.extend_from_slice(&[array.value(row) as u8]);
+            }
             DataType::Float32 => {
                 let array = col.as_any().downcast_ref::<Float32Array>().unwrap();
                 vec.extend_from_slice(&array.value(row).to_le_bytes());
@@ -799,6 +806,7 @@ fn create_batch_from_map(
                     GroupByScalar::Utf8(str) => {
                         Arc::new(StringArray::from(vec![&***str]))
                     }
+                    GroupByScalar::Boolean(b) => Arc::new(BooleanArray::from(vec![*b])),
                     GroupByScalar::TimeMicrosecond(n) => {
                         Arc::new(TimestampMicrosecondArray::from(vec![*n]))
                     }
@@ -921,6 +929,10 @@ pub(crate) fn create_group_by_values(
                 let array = col.as_any().downcast_ref::<StringArray>().unwrap();
                 vec[i] = GroupByScalar::Utf8(Box::new(array.value(row).into()))
             }
+            DataType::Boolean => {
+                let array = col.as_any().downcast_ref::<BooleanArray>().unwrap();
+                vec[i] = GroupByScalar::Boolean(array.value(row))
+            }
             DataType::Timestamp(TimeUnit::Microsecond, None) => {
                 let array = col
                     .as_any()
diff --git a/rust/datafusion/src/physical_plan/hash_join.rs b/rust/datafusion/src/physical_plan/hash_join.rs
index a096171..874b9b2 100644
--- a/rust/datafusion/src/physical_plan/hash_join.rs
+++ b/rust/datafusion/src/physical_plan/hash_join.rs
@@ -696,6 +696,27 @@ macro_rules! hash_array {
     };
 }
 
+macro_rules! hash_array_cast {
+    ($array_type:ident, $column: ident, $f: ident, $hashes: ident, $random_state: ident, $as_type:tt) => {
+        let array = $column.as_any().downcast_ref::<$array_type>().unwrap();
+        if array.null_count() == 0 {
+            for (i, hash) in $hashes.iter_mut().enumerate() {
+                let mut hasher = $random_state.build_hasher();
+                hasher.$f(array.value(i) as $as_type);
+                *hash = combine_hashes(hasher.finish(), *hash);
+            }
+        } else {
+            for (i, hash) in $hashes.iter_mut().enumerate() {
+                let mut hasher = $random_state.build_hasher();
+                if !array.is_null(i) {
+                    hasher.$f(array.value(i) as $as_type);
+                    *hash = combine_hashes(hasher.finish(), *hash);
+                }
+            }
+        }
+    };
+}
+
 /// Creates hash values for every element in the row based on the values in the columns
 fn create_hashes(arrays: &[ArrayRef], random_state: &RandomState) -> Result<Vec<u64>> {
     let rows = arrays[0].len();
@@ -745,6 +766,9 @@ fn create_hashes(arrays: &[ArrayRef], random_state: &RandomState) -> Result<Vec<
                     random_state
                 );
             }
+            DataType::Boolean => {
+                hash_array_cast!(BooleanArray, col, write_u8, hashes, random_state, u8);
+            }
             DataType::Utf8 => {
                 let array = col.as_any().downcast_ref::<StringArray>().unwrap();
                 for (i, hash) in hashes.iter_mut().enumerate() {
diff --git a/rust/datafusion/tests/aggregate_floats.csv b/rust/datafusion/tests/aggregate_floats.csv
deleted file mode 100644
index 86f5750..0000000
--- a/rust/datafusion/tests/aggregate_floats.csv
+++ /dev/null
@@ -1,16 +0,0 @@
-c1,c2
-0.00001,0.000000000001
-0.00002,0.000000000002
-0.00002,0.000000000002
-0.00003,0.000000000003
-0.00003,0.000000000003
-0.00003,0.000000000003
-0.00004,0.000000000004
-0.00004,0.000000000004
-0.00004,0.000000000004
-0.00004,0.000000000004
-0.00005,0.000000000005
-0.00005,0.000000000005
-0.00005,0.000000000005
-0.00005,0.000000000005
-0.00005,0.000000000005
\ No newline at end of file
diff --git a/rust/datafusion/tests/aggregate_simple.csv b/rust/datafusion/tests/aggregate_simple.csv
new file mode 100644
index 0000000..7a0256c
--- /dev/null
+++ b/rust/datafusion/tests/aggregate_simple.csv
@@ -0,0 +1,16 @@
+c1,c2,c3
+0.00001,0.000000000001,true
+0.00002,0.000000000002,false
+0.00002,0.000000000002,false
+0.00003,0.000000000003,true
+0.00003,0.000000000003,true
+0.00003,0.000000000003,true
+0.00004,0.000000000004,false
+0.00004,0.000000000004,false
+0.00004,0.000000000004,false
+0.00004,0.000000000004,false
+0.00005,0.000000000005,true
+0.00005,0.000000000005,true
+0.00005,0.000000000005,true
+0.00005,0.000000000005,true
+0.00005,0.000000000005,true
\ No newline at end of file
diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs
index 80ab70f..14b76a5 100644
--- a/rust/datafusion/tests/sql.rs
+++ b/rust/datafusion/tests/sql.rs
@@ -348,10 +348,10 @@ async fn csv_query_group_by_int_min_max() -> Result<()> {
 #[tokio::test]
 async fn csv_query_group_by_float32() -> Result<()> {
     let mut ctx = ExecutionContext::new();
-    register_aggregate_floats_csv(&mut ctx)?;
+    register_aggregate_simple_csv(&mut ctx)?;
 
     let sql =
-        "SELECT COUNT(*) as cnt, c1 FROM aggregate_floats GROUP BY c1 ORDER BY cnt DESC";
+        "SELECT COUNT(*) as cnt, c1 FROM aggregate_simple GROUP BY c1 ORDER BY cnt DESC";
     let actual = execute(&mut ctx, sql).await;
 
     let expected = vec![
@@ -369,10 +369,10 @@ async fn csv_query_group_by_float32() -> Result<()> {
 #[tokio::test]
 async fn csv_query_group_by_float64() -> Result<()> {
     let mut ctx = ExecutionContext::new();
-    register_aggregate_floats_csv(&mut ctx)?;
+    register_aggregate_simple_csv(&mut ctx)?;
 
     let sql =
-        "SELECT COUNT(*) as cnt, c2 FROM aggregate_floats GROUP BY c2 ORDER BY cnt DESC";
+        "SELECT COUNT(*) as cnt, c2 FROM aggregate_simple GROUP BY c2 ORDER BY cnt DESC";
     let actual = execute(&mut ctx, sql).await;
 
     let expected = vec![
@@ -388,6 +388,21 @@ async fn csv_query_group_by_float64() -> Result<()> {
 }
 
 #[tokio::test]
+async fn csv_query_group_by_boolean() -> Result<()> {
+    let mut ctx = ExecutionContext::new();
+    register_aggregate_simple_csv(&mut ctx)?;
+
+    let sql =
+        "SELECT COUNT(*) as cnt, c3 FROM aggregate_simple GROUP BY c3 ORDER BY cnt DESC";
+    let actual = execute(&mut ctx, sql).await;
+
+    let expected = vec![vec!["9", "true"], vec!["6", "false"]];
+    assert_eq!(expected, actual);
+
+    Ok(())
+}
+
+#[tokio::test]
 async fn csv_query_group_by_two_columns() -> Result<()> {
     let mut ctx = ExecutionContext::new();
     register_aggregate_csv(&mut ctx)?;
@@ -1367,16 +1382,17 @@ fn register_aggregate_csv(ctx: &mut ExecutionContext) -> Result<()> {
     Ok(())
 }
 
-fn register_aggregate_floats_csv(ctx: &mut ExecutionContext) -> Result<()> {
+fn register_aggregate_simple_csv(ctx: &mut ExecutionContext) -> Result<()> {
     // It's not possible to use aggregate_test_100, not enought similar values to test grouping on floats
     let schema = Arc::new(Schema::new(vec![
         Field::new("c1", DataType::Float32, false),
         Field::new("c2", DataType::Float64, false),
+        Field::new("c3", DataType::Boolean, false),
     ]));
 
     ctx.register_csv(
-        "aggregate_floats",
-        "tests/aggregate_floats.csv",
+        "aggregate_simple",
+        "tests/aggregate_simple.csv",
         CsvReadOptions::new().schema(&schema),
     )?;
     Ok(())