You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2021/08/02 11:26:06 UTC

[arrow-datafusion] branch master updated: Produce correct answers for Group BY NULL (Option 1) (#793)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 2bcf040  Produce correct answers for Group BY NULL (Option 1) (#793)
2bcf040 is described below

commit 2bcf04017d710a6b8684617e81a64c9db1184f5c
Author: Andrew Lamb <an...@nerdnetworks.org>
AuthorDate: Mon Aug 2 07:23:49 2021 -0400

    Produce correct answers for Group BY NULL (Option 1) (#793)
    
    * Add support for group by hash of a null column, tests for same
    
    * Update datafusion/src/physical_plan/hash_aggregate.rs
    
    Co-authored-by: Daniël Heres <da...@gmail.com>
    
    Co-authored-by: Daniël Heres <da...@gmail.com>
---
 datafusion/src/physical_plan/hash_aggregate.rs |  60 +++++++++++++-
 datafusion/src/scalar.rs                       |  37 ++++++++-
 datafusion/tests/sql.rs                        | 110 +++++++++++++++++++++++++
 3 files changed, 202 insertions(+), 5 deletions(-)

diff --git a/datafusion/src/physical_plan/hash_aggregate.rs b/datafusion/src/physical_plan/hash_aggregate.rs
index eb4a356..5c3c576 100644
--- a/datafusion/src/physical_plan/hash_aggregate.rs
+++ b/datafusion/src/physical_plan/hash_aggregate.rs
@@ -395,7 +395,10 @@ fn group_aggregate_batch(
                 // We can safely unwrap here as we checked we can create an accumulator before
                 let accumulator_set = create_accumulators(aggr_expr).unwrap();
                 batch_keys.push(key.clone());
-                let _ = create_group_by_values(&group_values, row, &mut group_by_values);
+                // Note it would be nice to make this a real error (rather than panic)
+                // but it is better than silently ignoring the issue and getting wrong results
+                create_group_by_values(&group_values, row, &mut group_by_values)
+                    .expect("can not create group by value");
                 (
                     key.clone(),
                     (group_by_values.clone(), accumulator_set, vec![row as u32]),
@@ -508,7 +511,9 @@ fn dictionary_create_key_for_col<K: ArrowDictionaryKeyType>(
 }
 
 /// Appends a sequence of [u8] bytes for the value in `col[row]` to
-/// `vec` to be used as a key into the hash map
+/// `vec` to be used as a key into the hash map.
+///
+/// NOTE: This function does not check col.is_valid(). Caller must do so
 fn create_key_for_col(col: &ArrayRef, row: usize, vec: &mut Vec<u8>) -> Result<()> {
     match col.data_type() {
         DataType::Boolean => {
@@ -640,6 +645,50 @@ fn create_key_for_col(col: &ArrayRef, row: usize, vec: &mut Vec<u8>) -> Result<(
 }
 
 /// Create a key `Vec<u8>` that is used as key for the hashmap
+///
+/// This looks like
+/// [null_byte][col_value_bytes][null_byte][col_value_bytes]
+///
+/// Note that relatively uncommon patterns (e.g. not 0x00) are chosen
+/// for the null_byte to make debugging easier. The actual values are
+/// arbitrary.
+///
+/// For a NULL value in a column, the key looks like
+/// [0xFE]
+///
+/// For a Non-NULL value in a column, this looks like:
+/// [0xFF][byte representation of column value]
+///
+/// Example of a key with no NULL values:
+/// ```text
+///                        0xFF byte at the start of each column
+///                           signifies the value is non-null
+///                                          │
+///
+///                      ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┴ ─ ─ ─ ─ ─ ─ ─ ┐
+///
+///                      │        string len                 │  0x1234
+/// {                    ▼       (as usize le)      "foo"    ▼(as u16 le)
+///   k1: "foo"        ╔ ═┌──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──╦ ═┌──┬──┐
+///   k2: 0x1234u16     FF║03│00│00│00│00│00│00│00│"f│"o│"o│FF║34│12│
+/// }                  ╚ ═└──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──╩ ═└──┴──┘
+///                     0  1  2  3  4  5  6  7  8  9  10 11 12 13 14
+/// ```
+///
+///  Example of a key with NULL values:
+///
+///```text
+///                         0xFE byte at the start of k1 column
+///                     ┌ ─     signifies the value is NULL
+///
+///                     └ ┐
+///                              0x1234
+/// {                     ▼    (as u16 le)
+///   k1: NULL          ╔ ═╔ ═┌──┬──┐
+///   k2: 0x1234u16      FE║FF║12│34│
+/// }                   ╚ ═╚ ═└──┴──┘
+///                       0  1  2  3
+///```
 pub(crate) fn create_key(
     group_by_keys: &[ArrayRef],
     row: usize,
@@ -647,7 +696,12 @@ pub(crate) fn create_key(
 ) -> Result<()> {
     vec.clear();
     for col in group_by_keys {
-        create_key_for_col(col, row, vec)?
+        if !col.is_valid(row) {
+            vec.push(0xFE);
+        } else {
+            vec.push(0xFF);
+            create_key_for_col(col, row, vec)?
+        }
     }
     Ok(())
 }
diff --git a/datafusion/src/scalar.rs b/datafusion/src/scalar.rs
index 8efea63..90c9bf7 100644
--- a/datafusion/src/scalar.rs
+++ b/datafusion/src/scalar.rs
@@ -28,7 +28,7 @@ use arrow::{
     },
 };
 use ordered_float::OrderedFloat;
-use std::convert::Infallible;
+use std::convert::{Infallible, TryInto};
 use std::str::FromStr;
 use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc};
 
@@ -796,6 +796,11 @@ impl ScalarValue {
 
     /// Converts a value in `array` at `index` into a ScalarValue
     pub fn try_from_array(array: &ArrayRef, index: usize) -> Result<Self> {
+        // handle NULL value
+        if !array.is_valid(index) {
+            return array.data_type().try_into();
+        }
+
         Ok(match array.data_type() {
             DataType::Boolean => typed_cast!(array, index, BooleanArray, Boolean),
             DataType::Float64 => typed_cast!(array, index, Float64Array, Float64),
@@ -897,6 +902,7 @@ impl ScalarValue {
         let dict_array = array.as_any().downcast_ref::<DictionaryArray<K>>().unwrap();
 
         // look up the index in the values dictionary
+        // (note validity was previously checked in `try_from_array`)
         let keys_col = dict_array.keys();
         let values_index = keys_col.value(index).to_usize().ok_or_else(|| {
             DataFusionError::Internal(format!(
@@ -1132,6 +1138,7 @@ impl_try_from!(Boolean, bool);
 impl TryFrom<&DataType> for ScalarValue {
     type Error = DataFusionError;
 
+    /// Create a Null instance of ScalarValue for this datatype
     fn try_from(datatype: &DataType) -> Result<Self> {
         Ok(match datatype {
             DataType::Boolean => ScalarValue::Boolean(None),
@@ -1161,12 +1168,15 @@ impl TryFrom<&DataType> for ScalarValue {
             DataType::Timestamp(TimeUnit::Nanosecond, _) => {
                 ScalarValue::TimestampNanosecond(None)
             }
+            DataType::Dictionary(_index_type, value_type) => {
+                value_type.as_ref().try_into()?
+            }
             DataType::List(ref nested_type) => {
                 ScalarValue::List(None, Box::new(nested_type.data_type().clone()))
             }
             _ => {
                 return Err(DataFusionError::NotImplemented(format!(
-                    "Can't create a scalar of type \"{:?}\"",
+                    "Can't create a scalar from data_type \"{:?}\"",
                     datatype
                 )))
             }
@@ -1536,6 +1546,29 @@ mod tests {
     }
 
     #[test]
+    fn scalar_try_from_array_null() {
+        let array = vec![Some(33), None].into_iter().collect::<Int64Array>();
+        let array: ArrayRef = Arc::new(array);
+
+        assert_eq!(
+            ScalarValue::Int64(Some(33)),
+            ScalarValue::try_from_array(&array, 0).unwrap()
+        );
+        assert_eq!(
+            ScalarValue::Int64(None),
+            ScalarValue::try_from_array(&array, 1).unwrap()
+        );
+    }
+
+    #[test]
+    fn scalar_try_from_dict_datatype() {
+        let data_type =
+            DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8));
+        let data_type = &data_type;
+        assert_eq!(ScalarValue::Utf8(None), data_type.try_into().unwrap())
+    }
+
+    #[test]
     fn size_of_scalar() {
         // Since ScalarValues are used in a non trivial number of places,
         // making it larger means significant more memory consumption
diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs
index 42a7d20..3a83f20 100644
--- a/datafusion/tests/sql.rs
+++ b/datafusion/tests/sql.rs
@@ -3057,6 +3057,109 @@ async fn query_count_distinct() -> Result<()> {
 }
 
 #[tokio::test]
+async fn query_group_on_null() -> Result<()> {
+    let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, true)]));
+
+    let data = RecordBatch::try_new(
+        schema.clone(),
+        vec![Arc::new(Int32Array::from(vec![
+            Some(0),
+            Some(3),
+            None,
+            Some(1),
+            Some(3),
+        ]))],
+    )?;
+
+    let table = MemTable::try_new(schema, vec![vec![data]])?;
+
+    let mut ctx = ExecutionContext::new();
+    ctx.register_table("test", Arc::new(table))?;
+    let sql = "SELECT COUNT(*), c1 FROM test GROUP BY c1";
+
+    let actual = execute_to_batches(&mut ctx, sql).await;
+
+    // Note that the results also
+    // include a row for NULL (c1=NULL, count = 1)
+    let expected = vec![
+        "+-----------------+----+",
+        "| COUNT(UInt8(1)) | c1 |",
+        "+-----------------+----+",
+        "| 1               |    |",
+        "| 1               | 0  |",
+        "| 1               | 1  |",
+        "| 2               | 3  |",
+        "+-----------------+----+",
+    ];
+    assert_batches_sorted_eq!(expected, &actual);
+    Ok(())
+}
+
+#[tokio::test]
+async fn query_group_on_null_multi_col() -> Result<()> {
+    let schema = Arc::new(Schema::new(vec![
+        Field::new("c1", DataType::Int32, true),
+        Field::new("c2", DataType::Utf8, true),
+    ]));
+
+    let data = RecordBatch::try_new(
+        schema.clone(),
+        vec![
+            Arc::new(Int32Array::from(vec![
+                Some(0),
+                Some(0),
+                Some(3),
+                None,
+                None,
+                Some(3),
+                Some(0),
+                None,
+                Some(3),
+            ])),
+            Arc::new(StringArray::from(vec![
+                None,
+                None,
+                Some("foo"),
+                None,
+                Some("bar"),
+                Some("foo"),
+                None,
+                Some("bar"),
+                Some("foo"),
+            ])),
+        ],
+    )?;
+
+    let table = MemTable::try_new(schema, vec![vec![data]])?;
+
+    let mut ctx = ExecutionContext::new();
+    ctx.register_table("test", Arc::new(table))?;
+    let sql = "SELECT COUNT(*), c1, c2 FROM test GROUP BY c1, c2";
+
+    let actual = execute_to_batches(&mut ctx, sql).await;
+
+    // Note that the results also include values for null
+    // include a row for NULL (c1=NULL, count = 1)
+    let expected = vec![
+        "+-----------------+----+-----+",
+        "| COUNT(UInt8(1)) | c1 | c2  |",
+        "+-----------------+----+-----+",
+        "| 1               |    |     |",
+        "| 2               |    | bar |",
+        "| 3               | 0  |     |",
+        "| 3               | 3  | foo |",
+        "+-----------------+----+-----+",
+    ];
+    assert_batches_sorted_eq!(expected, &actual);
+
+    // Also run query with group columns reversed (results shoudl be the same)
+    let sql = "SELECT COUNT(*), c1, c2 FROM test GROUP BY c2, c1";
+    let actual = execute_to_batches(&mut ctx, sql).await;
+    assert_batches_sorted_eq!(expected, &actual);
+    Ok(())
+}
+
+#[tokio::test]
 async fn query_on_string_dictionary() -> Result<()> {
     // Test to ensure DataFusion can operate on dictionary types
     // Use StringDictionary (32 bit indexes = keys)
@@ -3109,6 +3212,13 @@ async fn query_on_string_dictionary() -> Result<()> {
     let expected = vec![vec!["2"]];
     assert_eq!(expected, actual);
 
+    // grouping
+    let sql = "SELECT d1, COUNT(*) FROM test group by d1";
+    let mut actual = execute(&mut ctx, sql).await;
+    actual.sort();
+    let expected = vec![vec!["NULL", "1"], vec!["one", "1"], vec!["three", "1"]];
+    assert_eq!(expected, actual);
+
     Ok(())
 }