You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by "comphead (via GitHub)" <gi...@apache.org> on 2023/03/13 21:30:36 UTC

[GitHub] [arrow-datafusion] comphead commented on a diff in pull request #5554: Improve performance of COUNT (distinct x) for dictionary columns #258

comphead commented on code in PR #5554:
URL: https://github.com/apache/arrow-datafusion/pull/5554#discussion_r1134637197


##########
datafusion/physical-expr/src/aggregate/count_distinct.rs:
##########
@@ -158,38 +219,98 @@ impl Accumulator for DistinctCountAccumulator {
         })
     }
     fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
-        if states.is_empty() {
-            return Ok(());
+        merge_values(&mut self.values, states)
+    }
+
+    fn evaluate(&self) -> Result<ScalarValue> {
+        Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
+    }
+
+    fn size(&self) -> usize {
+        let values_size = match &self.state_data_type {
+            DataType::Boolean | DataType::Null => values_fixed_size(&self.values),
+            d if d.is_primitive() => values_fixed_size(&self.values),
+            _ => values_full_size(&self.values),
+        };
+        std::mem::size_of_val(self) + values_size + std::mem::size_of::<DataType>()
+    }
+}
+/// Special case accumulator for counting distinct values in a dict
+struct CountDistinctDictAccumulator<K>
+where
+    K: ArrowDictionaryKeyType + std::marker::Send + std::marker::Sync,
+{
+    /// `K` is required when casting to dict array
+    _dt: core::marker::PhantomData<K>,
+    values_datatype: DataType,
+    values: ValueSet,
+}
+
+impl<K> std::fmt::Debug for CountDistinctDictAccumulator<K>
+where
+    K: ArrowDictionaryKeyType + std::marker::Send + std::marker::Sync,
+{
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        f.debug_struct("CountDistinctDictAccumulator")
+            .field("values", &self.values)
+            .field("values_datatype", &self.values_datatype)
+            .finish()
+    }
+}
+impl<K: ArrowDictionaryKeyType + std::marker::Send + std::marker::Sync>
+    CountDistinctDictAccumulator<K>
+{
+    fn new(values_datatype: DataType) -> Self {
+        Self {
+            _dt: core::marker::PhantomData,
+            values: Default::default(),
+            values_datatype,
         }
-        let arr = &states[0];
-        (0..arr.len()).try_for_each(|index| {
-            let scalar = ScalarValue::try_from_array(arr, index)?;
+    }
+}
+impl<K> Accumulator for CountDistinctDictAccumulator<K>
+where
+    K: ArrowDictionaryKeyType + std::marker::Send + std::marker::Sync,
+{
+    fn state(&self) -> Result<Vec<ScalarValue>> {
+        values_to_state(&self.values, &self.values_datatype)
+    }
 
-            if let ScalarValue::List(Some(scalar), _) = scalar {
-                scalar.iter().for_each(|scalar| {
-                    if !ScalarValue::is_null(scalar) {
-                        self.values.insert(scalar.clone());
-                    }
-                });
-            } else {
-                return Err(DataFusionError::Internal(
-                    "Unexpected accumulator state".into(),
-                ));
+    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+        if values.is_empty() {
+            return Ok(());
+        }
+        let arr = as_dictionary_array::<K>(&values[0])?;
+        let nvalues = arr.values().len();
+        // map keys to whether their corresponding value has been seen or not
+        let mut seen_map = (0..nvalues).map(|_| false).collect::<Vec<_>>();
+        for idx in arr.keys_iter().flatten() {
+            seen_map[idx] = true;
+        }
+        for (idx, seen) in seen_map.into_iter().enumerate() {
+            if seen {

Review Comment:
   do you think if its possible to merge 2 loops into 1



##########
datafusion/physical-expr/src/aggregate/count_distinct.rs:
##########
@@ -85,64 +86,124 @@ impl AggregateExpr for DistinctCount {
     }
 
     fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
-        Ok(Box::new(DistinctCountAccumulator {
-            values: HashSet::default(),
-            state_data_type: self.state_data_type.clone(),
-        }))
+        use arrow::datatypes;
+        use datatypes::DataType::*;
+
+        Ok(match &self.state_data_type {
+            Dictionary(key, val) if key.is_dictionary_key_type() => {
+                let val_type = *val.clone();
+                match **key {
+                    Int8 => Box::new(
+                        CountDistinctDictAccumulator::<datatypes::Int8Type>::new(
+                            val_type,
+                        ),
+                    ),
+                    Int16 => Box::new(
+                        CountDistinctDictAccumulator::<datatypes::Int16Type>::new(
+                            val_type,
+                        ),
+                    ),
+                    Int32 => Box::new(
+                        CountDistinctDictAccumulator::<datatypes::Int32Type>::new(
+                            val_type,
+                        ),
+                    ),
+                    Int64 => Box::new(
+                        CountDistinctDictAccumulator::<datatypes::Int64Type>::new(
+                            val_type,
+                        ),
+                    ),
+                    UInt8 => Box::new(
+                        CountDistinctDictAccumulator::<datatypes::UInt8Type>::new(
+                            val_type,
+                        ),
+                    ),
+                    UInt16 => Box::new(CountDistinctDictAccumulator::<
+                        datatypes::UInt16Type,
+                    >::new(val_type)),
+                    UInt32 => Box::new(CountDistinctDictAccumulator::<
+                        datatypes::UInt32Type,
+                    >::new(val_type)),
+                    UInt64 => Box::new(CountDistinctDictAccumulator::<
+                        datatypes::UInt64Type,
+                    >::new(val_type)),
+                    _ => {
+                        // just checked that datatype is a valid dict key type
+                        unreachable!()

Review Comment:
   we were trying as part of other tickets to get rid of panic!, unreachable, dangerous unwrap. Please return Err



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org