You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by dh...@apache.org on 2022/10/15 19:52:56 UTC

[arrow-datafusion] branch master updated: Infer the count of maximum distinct values from min/max (#3837)

This is an automated email from the ASF dual-hosted git repository.

dheres pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new fe0000e6d Infer the count of maximum distinct values from min/max (#3837)
fe0000e6d is described below

commit fe0000e6de4f1a2e7e7fd1b0d75ef32f4b61f194
Author: Batuhan Taskaya <is...@gmail.com>
AuthorDate: Sat Oct 15 22:52:50 2022 +0300

    Infer the count of maximum distinct values from min/max (#3837)
    
    * Infer the count of maximum distinct values from min/max
    
    * Even if the delta is 0, ensure that the distinct count is 1 (when min=max)
---
 datafusion/core/src/physical_plan/join_utils.rs | 227 ++++++++++++++++++++----
 1 file changed, 196 insertions(+), 31 deletions(-)

diff --git a/datafusion/core/src/physical_plan/join_utils.rs b/datafusion/core/src/physical_plan/join_utils.rs
index 780a5e96f..d010f4219 100644
--- a/datafusion/core/src/physical_plan/join_utils.rs
+++ b/datafusion/core/src/physical_plan/join_utils.rs
@@ -22,6 +22,7 @@ use crate::logical_expr::JoinType;
 use crate::physical_plan::expressions::Column;
 use arrow::datatypes::{Field, Schema};
 use arrow::error::ArrowError;
+use datafusion_common::ScalarValue;
 use datafusion_physical_expr::PhysicalExpr;
 use futures::future::{BoxFuture, Shared};
 use futures::{ready, FutureExt};
@@ -423,7 +424,9 @@ fn estimate_inner_join_cardinality(
             return None;
         }
 
-        let max_distinct = max(left_stat.distinct_count, right_stat.distinct_count);
+        let left_max_distinct = max_distinct_count(left_num_rows, left_stat.clone());
+        let right_max_distinct = max_distinct_count(right_num_rows, right_stat.clone());
+        let max_distinct = max(left_max_distinct, right_max_distinct);
         if max_distinct > join_selectivity {
             // Seems like there are a few implementations of this algorithm that implement
             // exponential decay for the selectivity (like Hive's Optiq Optimizer). Needs
@@ -447,6 +450,50 @@ fn estimate_inner_join_cardinality(
     }
 }
 
+/// Estimate the number of maximum distinct values that can be present in the
+/// given column from its statistics.
+///
+/// If distinct_count is available, uses it directly. If the column numeric, and
+/// has min/max values, then they might be used as a fallback option. Otherwise,
+/// returns None.
+fn max_distinct_count(num_rows: usize, stats: ColumnStatistics) -> Option<usize> {
+    match (stats.distinct_count, stats.max_value, stats.min_value) {
+        (Some(_), _, _) => stats.distinct_count,
+        (_, Some(max), Some(min)) => {
+            // Note that float support is intentionally omitted here, since the computation
+            // of a range between two float values is not trivial and the result would be
+            // highly inaccurate.
+            let numeric_range = get_int_range(min, max)?;
+
+            // The number can never be greater than the number of rows we have (minus
+            // the nulls, since they don't count as distinct values).
+            let ceiling = num_rows - stats.null_count.unwrap_or(0);
+            Some(numeric_range.min(ceiling))
+        }
+        _ => None,
+    }
+}
+
+/// Return the numeric range between the given min and max values.
+fn get_int_range(min: ScalarValue, max: ScalarValue) -> Option<usize> {
+    let delta = &max.sub(&min).ok()?;
+    match delta {
+        ScalarValue::Int8(Some(delta)) if *delta >= 0 => Some(*delta as usize),
+        ScalarValue::Int16(Some(delta)) if *delta >= 0 => Some(*delta as usize),
+        ScalarValue::Int32(Some(delta)) if *delta >= 0 => Some(*delta as usize),
+        ScalarValue::Int64(Some(delta)) if *delta >= 0 => Some(*delta as usize),
+        ScalarValue::UInt8(Some(delta)) => Some(*delta as usize),
+        ScalarValue::UInt16(Some(delta)) => Some(*delta as usize),
+        ScalarValue::UInt32(Some(delta)) => Some(*delta as usize),
+        ScalarValue::UInt64(Some(delta)) => Some(*delta as usize),
+        _ => None,
+    }
+    // The delta (directly) is not the real range, since it does not include the
+    // first term.
+    // E.g. (min=2, max=4) -> (4 - 2) -> 2, but the actual result should be 3 (1, 2, 3).
+    .map(|open_ended_range| open_ended_range + 1)
+}
+
 enum OnceFutState<T> {
     Pending(OnceFutPending<T>),
     Ready(Arc<Result<T>>),
@@ -626,19 +673,19 @@ mod tests {
     }
 
     fn create_column_stats(
-        min: Option<u64>,
-        max: Option<u64>,
+        min: Option<i64>,
+        max: Option<i64>,
         distinct_count: Option<usize>,
     ) -> ColumnStatistics {
         ColumnStatistics {
             distinct_count,
-            min_value: min.map(|size| ScalarValue::UInt64(Some(size))),
-            max_value: max.map(|size| ScalarValue::UInt64(Some(size))),
+            min_value: min.map(|size| ScalarValue::Int64(Some(size))),
+            max_value: max.map(|size| ScalarValue::Int64(Some(size))),
             ..Default::default()
         }
     }
 
-    type PartialStats = (usize, u64, u64, Option<usize>);
+    type PartialStats = (usize, Option<i64>, Option<i64>, Option<usize>);
 
     // This is mainly for validating the all edge cases of the estimation, but
     // more advanced (and real world test cases) are below where we need some control
@@ -650,40 +697,135 @@ mod tests {
             // | left(rows, min, max, distinct), right(rows, min, max, distinct), expected |
             // -----------------------------------------------------------------------------
 
-            // distinct(left) is None OR distinct(right) is None
+            // Cardinality computation
+            // =======================
+            //
+            // distinct(left) == NaN, distinct(right) == NaN
+            (
+                (10, Some(1), Some(10), None),
+                (10, Some(1), Some(10), None),
+                Some(10),
+            ),
+            // range(left) > range(right)
+            (
+                (10, Some(6), Some(10), None),
+                (10, Some(8), Some(10), None),
+                Some(20),
+            ),
+            // range(right) > range(left)
+            (
+                (10, Some(8), Some(10), None),
+                (10, Some(6), Some(10), None),
+                Some(20),
+            ),
+            // range(left) > len(left), range(right) > len(right)
+            (
+                (10, Some(1), Some(15), None),
+                (20, Some(1), Some(40), None),
+                Some(10),
+            ),
+            // When we have distinct count.
+            (
+                (10, Some(1), Some(10), Some(10)),
+                (10, Some(1), Some(10), Some(10)),
+                Some(10),
+            ),
+            // distinct(left) > distinct(right)
+            (
+                (10, Some(1), Some(10), Some(5)),
+                (10, Some(1), Some(10), Some(2)),
+                Some(20),
+            ),
+            // distinct(right) > distinct(left)
+            (
+                (10, Some(1), Some(10), Some(2)),
+                (10, Some(1), Some(10), Some(5)),
+                Some(20),
+            ),
+            // min(left) < 0 (range(left) > range(right))
+            (
+                (10, Some(-5), Some(5), None),
+                (10, Some(1), Some(5), None),
+                Some(10),
+            ),
+            // min(right) < 0, max(right) < 0 (range(right) > range(left))
+            (
+                (10, Some(-25), Some(-20), None),
+                (10, Some(-25), Some(-15), None),
+                Some(10),
+            ),
+            // range(left) < 0, range(right) >= 0
+            // (there isn't a case where both left and right ranges are negative
+            //  so one of them is always going to work, this just proves negative
+            //  ranges with bigger absolute values are not are not accidentally used).
+            (
+                (10, Some(10), Some(0), None),
+                (10, Some(0), Some(10), Some(5)),
+                Some(20), // It would have been ten if we have used abs(range(left))
+            ),
+            // range(left) = 1, range(right) = 1
+            (
+                (10, Some(1), Some(1), None),
+                (10, Some(1), Some(1), None),
+                Some(100),
+            ),
             //
-            // len(left) = len(right), len(left) * len(right)
-            ((10, 0, 10, None), (10, 0, 10, None), None),
-            // len(left) > len(right) OR len(left) < len(right), len(left) * len(right)
-            ((10, 0, 10, None), (5, 0, 10, None), None),
-            ((5, 0, 10, None), (10, 0, 10, None), None),
-            ((10, 0, 10, None), (5, 0, 10, None), None),
-            ((5, 0, 10, None), (10, 0, 10, None), None),
-            // min(left) > max(right) OR min(right) > max(left), None
-            ((10, 0, 10, None), (10, 11, 20, None), None),
-            ((10, 11, 20, None), (10, 0, 10, None), None),
-            ((10, 5, 10, None), (10, 11, 3, None), None),
-            ((10, 10, 5, None), (10, 3, 7, None), None),
-            // distinct(left) is not None AND distinct(right) is not None
+            // Edge cases
+            // ==========
             //
-            // len(left) = len(right), len(left) * len(right) / max(distinct(left), distinct(right))
-            ((10, 0, 10, Some(5)), (10, 0, 10, Some(5)), Some(20)),
-            ((10, 0, 10, Some(10)), (10, 0, 10, Some(5)), Some(10)),
-            ((10, 0, 10, Some(5)), (10, 0, 10, Some(10)), Some(10)),
+            // No column level stats.
+            ((10, None, None, None), (10, None, None, None), None),
+            // No min or max (or both).
+            ((10, None, None, Some(3)), (10, None, None, Some(3)), None),
+            (
+                (10, Some(2), None, Some(3)),
+                (10, None, Some(5), Some(3)),
+                None,
+            ),
+            (
+                (10, None, Some(3), Some(3)),
+                (10, Some(1), None, Some(3)),
+                None,
+            ),
+            ((10, None, Some(3), None), (10, Some(1), None, None), None),
+            // Non overlapping min/max.
+            (
+                (10, Some(0), Some(10), None),
+                (10, Some(11), Some(20), None),
+                None,
+            ),
+            (
+                (10, Some(11), Some(20), None),
+                (10, Some(0), Some(10), None),
+                None,
+            ),
+            (
+                (10, Some(5), Some(10), Some(10)),
+                (10, Some(11), Some(3), Some(10)),
+                None,
+            ),
+            (
+                (10, Some(10), Some(5), Some(10)),
+                (10, Some(3), Some(7), Some(10)),
+                None,
+            ),
+            // distinct(left) = 0, distinct(right) = 0
+            (
+                (10, Some(1), Some(10), Some(0)),
+                (10, Some(1), Some(10), Some(0)),
+                None,
+            ),
         ];
 
         for (left_info, right_info, expected_cardinality) in cases {
             let left_num_rows = left_info.0;
-            let left_col_stats = vec![create_column_stats(
-                Some(left_info.1),
-                Some(left_info.2),
-                left_info.3,
-            )];
+            let left_col_stats =
+                vec![create_column_stats(left_info.1, left_info.2, left_info.3)];
 
             let right_num_rows = right_info.0;
             let right_col_stats = vec![create_column_stats(
-                Some(right_info.1),
-                Some(right_info.2),
+                right_info.1,
+                right_info.2,
                 right_info.3,
             )];
 
@@ -740,6 +882,29 @@ mod tests {
         Ok(())
     }
 
+    #[test]
+    fn test_inner_join_cardinality_decimal_range() -> Result<()> {
+        let left_col_stats = vec![ColumnStatistics {
+            distinct_count: None,
+            min_value: Some(ScalarValue::Decimal128(Some(32500), 14, 4)),
+            max_value: Some(ScalarValue::Decimal128(Some(35000), 14, 4)),
+            ..Default::default()
+        }];
+
+        let right_col_stats = vec![ColumnStatistics {
+            distinct_count: None,
+            min_value: Some(ScalarValue::Decimal128(Some(33500), 14, 4)),
+            max_value: Some(ScalarValue::Decimal128(Some(34000), 14, 4)),
+            ..Default::default()
+        }];
+
+        assert_eq!(
+            estimate_inner_join_cardinality(100, 100, left_col_stats, right_col_stats),
+            None
+        );
+        Ok(())
+    }
+
     #[test]
     fn test_join_cardinality() -> Result<()> {
         // Left table (rows=1000)