You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2021/08/02 11:26:06 UTC
[arrow-datafusion] branch master updated: Produce correct answers
for Group BY NULL (Option 1) (#793)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 2bcf040 Produce correct answers for Group BY NULL (Option 1) (#793)
2bcf040 is described below
commit 2bcf04017d710a6b8684617e81a64c9db1184f5c
Author: Andrew Lamb <an...@nerdnetworks.org>
AuthorDate: Mon Aug 2 07:23:49 2021 -0400
Produce correct answers for Group BY NULL (Option 1) (#793)
* Add support for group by hash of a null column, tests for same
* Update datafusion/src/physical_plan/hash_aggregate.rs
Co-authored-by: Daniël Heres <da...@gmail.com>
Co-authored-by: Daniël Heres <da...@gmail.com>
---
datafusion/src/physical_plan/hash_aggregate.rs | 60 +++++++++++++-
datafusion/src/scalar.rs | 37 ++++++++-
datafusion/tests/sql.rs | 110 +++++++++++++++++++++++++
3 files changed, 202 insertions(+), 5 deletions(-)
diff --git a/datafusion/src/physical_plan/hash_aggregate.rs b/datafusion/src/physical_plan/hash_aggregate.rs
index eb4a356..5c3c576 100644
--- a/datafusion/src/physical_plan/hash_aggregate.rs
+++ b/datafusion/src/physical_plan/hash_aggregate.rs
@@ -395,7 +395,10 @@ fn group_aggregate_batch(
// We can safely unwrap here as we checked we can create an accumulator before
let accumulator_set = create_accumulators(aggr_expr).unwrap();
batch_keys.push(key.clone());
- let _ = create_group_by_values(&group_values, row, &mut group_by_values);
+ // Note it would be nice to make this a real error (rather than panic)
+ // but it is better than silently ignoring the issue and getting wrong results
+ create_group_by_values(&group_values, row, &mut group_by_values)
+ .expect("can not create group by value");
(
key.clone(),
(group_by_values.clone(), accumulator_set, vec![row as u32]),
@@ -508,7 +511,9 @@ fn dictionary_create_key_for_col<K: ArrowDictionaryKeyType>(
}
/// Appends a sequence of [u8] bytes for the value in `col[row]` to
-/// `vec` to be used as a key into the hash map
+/// `vec` to be used as a key into the hash map.
+///
+/// NOTE: This function does not check col.is_valid(). Caller must do so
fn create_key_for_col(col: &ArrayRef, row: usize, vec: &mut Vec<u8>) -> Result<()> {
match col.data_type() {
DataType::Boolean => {
@@ -640,6 +645,50 @@ fn create_key_for_col(col: &ArrayRef, row: usize, vec: &mut Vec<u8>) -> Result<(
}
/// Create a key `Vec<u8>` that is used as key for the hashmap
+///
+/// This looks like
+/// [null_byte][col_value_bytes][null_byte][col_value_bytes]
+///
+/// Note that relatively uncommon patterns (e.g. not 0x00) are chosen
+/// for the null_byte to make debugging easier. The actual values are
+/// arbitrary.
+///
+/// For a NULL value in a column, the key looks like
+/// [0xFE]
+///
+/// For a Non-NULL value in a column, this looks like:
+/// [0xFF][byte representation of column value]
+///
+/// Example of a key with no NULL values:
+/// ```text
+/// 0xFF byte at the start of each column
+/// signifies the value is non-null
+/// │
+///
+/// ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┴ ─ ─ ─ ─ ─ ─ ─ ┐
+///
+/// │ string len │ 0x1234
+/// { ▼ (as usize le) "foo" ▼(as u16 le)
+/// k1: "foo" ╔ ═┌──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──╦ ═┌──┬──┐
+/// k2: 0x1234u16 FF║03│00│00│00│00│00│00│00│"f│"o│"o│FF║34│12│
+/// } ╚ ═└──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──╩ ═└──┴──┘
+/// 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14
+/// ```
+///
+/// Example of a key with NULL values:
+///
+///```text
+/// 0xFE byte at the start of k1 column
+/// ┌ ─ signifies the value is NULL
+///
+/// └ ┐
+/// 0x1234
+/// { ▼ (as u16 le)
+/// k1: NULL ╔ ═╔ ═┌──┬──┐
+/// k2: 0x1234u16 FE║FF║12│34│
+/// } ╚ ═╚ ═└──┴──┘
+/// 0 1 2 3
+///```
pub(crate) fn create_key(
group_by_keys: &[ArrayRef],
row: usize,
@@ -647,7 +696,12 @@ pub(crate) fn create_key(
) -> Result<()> {
vec.clear();
for col in group_by_keys {
- create_key_for_col(col, row, vec)?
+ if !col.is_valid(row) {
+ vec.push(0xFE);
+ } else {
+ vec.push(0xFF);
+ create_key_for_col(col, row, vec)?
+ }
}
Ok(())
}
diff --git a/datafusion/src/scalar.rs b/datafusion/src/scalar.rs
index 8efea63..90c9bf7 100644
--- a/datafusion/src/scalar.rs
+++ b/datafusion/src/scalar.rs
@@ -28,7 +28,7 @@ use arrow::{
},
};
use ordered_float::OrderedFloat;
-use std::convert::Infallible;
+use std::convert::{Infallible, TryInto};
use std::str::FromStr;
use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc};
@@ -796,6 +796,11 @@ impl ScalarValue {
/// Converts a value in `array` at `index` into a ScalarValue
pub fn try_from_array(array: &ArrayRef, index: usize) -> Result<Self> {
+ // handle NULL value
+ if !array.is_valid(index) {
+ return array.data_type().try_into();
+ }
+
Ok(match array.data_type() {
DataType::Boolean => typed_cast!(array, index, BooleanArray, Boolean),
DataType::Float64 => typed_cast!(array, index, Float64Array, Float64),
@@ -897,6 +902,7 @@ impl ScalarValue {
let dict_array = array.as_any().downcast_ref::<DictionaryArray<K>>().unwrap();
// look up the index in the values dictionary
+ // (note validity was previously checked in `try_from_array`)
let keys_col = dict_array.keys();
let values_index = keys_col.value(index).to_usize().ok_or_else(|| {
DataFusionError::Internal(format!(
@@ -1132,6 +1138,7 @@ impl_try_from!(Boolean, bool);
impl TryFrom<&DataType> for ScalarValue {
type Error = DataFusionError;
+ /// Create a Null instance of ScalarValue for this datatype
fn try_from(datatype: &DataType) -> Result<Self> {
Ok(match datatype {
DataType::Boolean => ScalarValue::Boolean(None),
@@ -1161,12 +1168,15 @@ impl TryFrom<&DataType> for ScalarValue {
DataType::Timestamp(TimeUnit::Nanosecond, _) => {
ScalarValue::TimestampNanosecond(None)
}
+ DataType::Dictionary(_index_type, value_type) => {
+ value_type.as_ref().try_into()?
+ }
DataType::List(ref nested_type) => {
ScalarValue::List(None, Box::new(nested_type.data_type().clone()))
}
_ => {
return Err(DataFusionError::NotImplemented(format!(
- "Can't create a scalar of type \"{:?}\"",
+ "Can't create a scalar from data_type \"{:?}\"",
datatype
)))
}
@@ -1536,6 +1546,29 @@ mod tests {
}
#[test]
+ fn scalar_try_from_array_null() {
+ let array = vec![Some(33), None].into_iter().collect::<Int64Array>();
+ let array: ArrayRef = Arc::new(array);
+
+ assert_eq!(
+ ScalarValue::Int64(Some(33)),
+ ScalarValue::try_from_array(&array, 0).unwrap()
+ );
+ assert_eq!(
+ ScalarValue::Int64(None),
+ ScalarValue::try_from_array(&array, 1).unwrap()
+ );
+ }
+
+ #[test]
+ fn scalar_try_from_dict_datatype() {
+ let data_type =
+ DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8));
+ let data_type = &data_type;
+ assert_eq!(ScalarValue::Utf8(None), data_type.try_into().unwrap())
+ }
+
+ #[test]
fn size_of_scalar() {
// Since ScalarValues are used in a non trivial number of places,
// making it larger means significant more memory consumption
diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs
index 42a7d20..3a83f20 100644
--- a/datafusion/tests/sql.rs
+++ b/datafusion/tests/sql.rs
@@ -3057,6 +3057,109 @@ async fn query_count_distinct() -> Result<()> {
}
#[tokio::test]
+async fn query_group_on_null() -> Result<()> {
+ let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Int32, true)]));
+
+ let data = RecordBatch::try_new(
+ schema.clone(),
+ vec![Arc::new(Int32Array::from(vec![
+ Some(0),
+ Some(3),
+ None,
+ Some(1),
+ Some(3),
+ ]))],
+ )?;
+
+ let table = MemTable::try_new(schema, vec![vec![data]])?;
+
+ let mut ctx = ExecutionContext::new();
+ ctx.register_table("test", Arc::new(table))?;
+ let sql = "SELECT COUNT(*), c1 FROM test GROUP BY c1";
+
+ let actual = execute_to_batches(&mut ctx, sql).await;
+
+ // Note that the results also
+ // include a row for NULL (c1=NULL, count = 1)
+ let expected = vec![
+ "+-----------------+----+",
+ "| COUNT(UInt8(1)) | c1 |",
+ "+-----------------+----+",
+ "| 1 | |",
+ "| 1 | 0 |",
+ "| 1 | 1 |",
+ "| 2 | 3 |",
+ "+-----------------+----+",
+ ];
+ assert_batches_sorted_eq!(expected, &actual);
+ Ok(())
+}
+
+#[tokio::test]
+async fn query_group_on_null_multi_col() -> Result<()> {
+ let schema = Arc::new(Schema::new(vec![
+ Field::new("c1", DataType::Int32, true),
+ Field::new("c2", DataType::Utf8, true),
+ ]));
+
+ let data = RecordBatch::try_new(
+ schema.clone(),
+ vec![
+ Arc::new(Int32Array::from(vec![
+ Some(0),
+ Some(0),
+ Some(3),
+ None,
+ None,
+ Some(3),
+ Some(0),
+ None,
+ Some(3),
+ ])),
+ Arc::new(StringArray::from(vec![
+ None,
+ None,
+ Some("foo"),
+ None,
+ Some("bar"),
+ Some("foo"),
+ None,
+ Some("bar"),
+ Some("foo"),
+ ])),
+ ],
+ )?;
+
+ let table = MemTable::try_new(schema, vec![vec![data]])?;
+
+ let mut ctx = ExecutionContext::new();
+ ctx.register_table("test", Arc::new(table))?;
+ let sql = "SELECT COUNT(*), c1, c2 FROM test GROUP BY c1, c2";
+
+ let actual = execute_to_batches(&mut ctx, sql).await;
+
+ // Note that the results also include values for null
+ // include a row for NULL (c1=NULL, count = 1)
+ let expected = vec![
+ "+-----------------+----+-----+",
+ "| COUNT(UInt8(1)) | c1 | c2 |",
+ "+-----------------+----+-----+",
+ "| 1 | | |",
+ "| 2 | | bar |",
+ "| 3 | 0 | |",
+ "| 3 | 3 | foo |",
+ "+-----------------+----+-----+",
+ ];
+ assert_batches_sorted_eq!(expected, &actual);
+
+ // Also run query with group columns reversed (results shoudl be the same)
+ let sql = "SELECT COUNT(*), c1, c2 FROM test GROUP BY c2, c1";
+ let actual = execute_to_batches(&mut ctx, sql).await;
+ assert_batches_sorted_eq!(expected, &actual);
+ Ok(())
+}
+
+#[tokio::test]
async fn query_on_string_dictionary() -> Result<()> {
// Test to ensure DataFusion can operate on dictionary types
// Use StringDictionary (32 bit indexes = keys)
@@ -3109,6 +3212,13 @@ async fn query_on_string_dictionary() -> Result<()> {
let expected = vec![vec!["2"]];
assert_eq!(expected, actual);
+ // grouping
+ let sql = "SELECT d1, COUNT(*) FROM test group by d1";
+ let mut actual = execute(&mut ctx, sql).await;
+ actual.sort();
+ let expected = vec![vec!["NULL", "1"], vec!["one", "1"], vec!["three", "1"]];
+ assert_eq!(expected, actual);
+
Ok(())
}