You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/11/15 19:00:35 UTC

[arrow-datafusion] branch master updated: Support unsigned integers in `unwrap_cast_in_comparison` Optimizer rule (#4149)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 4653df465 Support unsigned integers in  `unwrap_cast_in_comparison` Optimizer rule (#4149)
4653df465 is described below

commit 4653df4652c8af5d4e8841c489c1ab2e85a54e69
Author: Andrew Lamb <an...@nerdnetworks.org>
AuthorDate: Tue Nov 15 14:00:29 2022 -0500

    Support unsigned integers in  `unwrap_cast_in_comparison` Optimizer rule (#4149)
    
    * Support unsigned integers in  `unwrap_cast_in_comparison` Optimizer rule
    
    * Update comment
---
 datafusion/core/tests/sql/joins.rs                 |  12 +-
 .../optimizer/src/unwrap_cast_in_comparison.rs     | 126 ++++++++++++++++++---
 datafusion/optimizer/tests/integration-test.rs     |   6 +-
 3 files changed, 120 insertions(+), 24 deletions(-)

diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs
index 10d024025..324ccb4c7 100644
--- a/datafusion/core/tests/sql/joins.rs
+++ b/datafusion/core/tests/sql/joins.rs
@@ -1428,9 +1428,9 @@ async fn reduce_left_join_1() -> Result<()> {
         "Explain [plan_type:Utf8, plan:Utf8]",
         "  Projection: t1.t1_id, t1.t1_name, t1.t1_int, t2.t2_id, t2.t2_name, t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
         "    Inner Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
-        "      Filter: CAST(t1.t1_id AS Int64) < Int64(100) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+        "      Filter: t1.t1_id < UInt32(100) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
         "        TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
-        "      Filter: CAST(t2.t2_id AS Int64) < Int64(100) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+        "      Filter: t2.t2_id < UInt32(100) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
         "        TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
     ];
     let formatted = plan.display_indent_schema().to_string();
@@ -1476,10 +1476,10 @@ async fn reduce_left_join_2() -> Result<()> {
     let expected = vec![
         "Explain [plan_type:Utf8, plan:Utf8]",
         "  Projection: t1.t1_id, t1.t1_name, t1.t1_int, t2.t2_id, t2.t2_name, t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
-        "    Filter: CAST(t2.t2_int AS Int64) < Int64(10) OR CAST(t1.t1_int AS Int64) > Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+        "    Filter: t2.t2_int < UInt32(10) OR t1.t1_int > UInt32(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
         "      Inner Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
         "        TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
-        "        Filter: CAST(t2.t2_int AS Int64) < Int64(10) OR t2.t2_name != Utf8(\"w\") [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+        "        Filter: t2.t2_int < UInt32(10) OR t2.t2_name != Utf8(\"w\") [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
         "          TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
     ];
     let formatted = plan.display_indent_schema().to_string();
@@ -1524,9 +1524,9 @@ async fn reduce_left_join_3() -> Result<()> {
         "      Projection: t1.t1_id, t1.t1_name, t1.t1_int, alias=t3 [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
         "        Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
         "          Inner Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
-        "            Filter: CAST(t1.t1_id AS Int64) < Int64(100) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+        "            Filter: t1.t1_id < UInt32(100) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
         "              TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
-        "            Filter: CAST(t2.t2_int AS Int64) < Int64(3) AND CAST(t2.t2_id AS Int64) < Int64(100) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+        "            Filter: t2.t2_int < UInt32(3) AND t2.t2_id < UInt32(100) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
         "              TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
         "      TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
     ];
diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs
index 28b085684..5f542d749 100644
--- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs
+++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs
@@ -283,7 +283,11 @@ fn is_comparison_op(op: &Operator) -> bool {
 fn is_support_data_type(data_type: &DataType) -> bool {
     matches!(
         data_type,
-        DataType::Int8
+        DataType::UInt8
+            | DataType::UInt16
+            | DataType::UInt32
+            | DataType::UInt64
+            | DataType::Int8
             | DataType::Int16
             | DataType::Int32
             | DataType::Int64
@@ -292,6 +296,25 @@ fn is_support_data_type(data_type: &DataType) -> bool {
     )
 }
 
+fn is_decimal_type(dt: &DataType) -> bool {
+    matches!(dt, DataType::Decimal128(_, _))
+}
+
+fn is_unsigned_type(dt: &DataType) -> bool {
+    matches!(
+        dt,
+        DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64
+    )
+}
+
+/// Until https://github.com/apache/arrow-rs/issues/1043 is done
+/// (support for unsigned <--> decimal casts) we also don't do that
+/// kind of cast in this optimizer
+fn is_unsupported_cast(dt1: &DataType, dt2: &DataType) -> bool {
+    (is_decimal_type(dt1) && is_unsigned_type(dt2))
+        || (is_decimal_type(dt2) && is_unsigned_type(dt1))
+}
+
 fn try_cast_literal_to_type(
     lit_value: &ScalarValue,
     target_type: &DataType,
@@ -301,12 +324,22 @@ fn try_cast_literal_to_type(
     if !is_support_data_type(&lit_data_type) || !is_support_data_type(target_type) {
         return Ok(None);
     }
+    if is_unsupported_cast(&lit_data_type, target_type) {
+        return Ok(None);
+    }
     if lit_value.is_null() {
         // null value can be cast to any type of null value
         return Ok(Some(ScalarValue::try_from(target_type)?));
     }
     let mul = match target_type {
-        DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => 1_i128,
+        DataType::UInt8
+        | DataType::UInt16
+        | DataType::UInt32
+        | DataType::UInt64
+        | DataType::Int8
+        | DataType::Int16
+        | DataType::Int32
+        | DataType::Int64 => 1_i128,
         DataType::Timestamp(_, _) => 1_i128,
         DataType::Decimal128(_, scale) => 10_i128.pow(*scale as u32),
         other_type => {
@@ -317,6 +350,10 @@ fn try_cast_literal_to_type(
         }
     };
     let (target_min, target_max) = match target_type {
+        DataType::UInt8 => (u8::MIN as i128, u8::MAX as i128),
+        DataType::UInt16 => (u16::MIN as i128, u16::MAX as i128),
+        DataType::UInt32 => (u32::MIN as i128, u32::MAX as i128),
+        DataType::UInt64 => (u64::MIN as i128, u64::MAX as i128),
         DataType::Int8 => (i8::MIN as i128, i8::MAX as i128),
         DataType::Int16 => (i16::MIN as i128, i16::MAX as i128),
         DataType::Int32 => (i32::MIN as i128, i32::MAX as i128),
@@ -341,6 +378,10 @@ fn try_cast_literal_to_type(
         ScalarValue::Int16(Some(v)) => (*v as i128).checked_mul(mul),
         ScalarValue::Int32(Some(v)) => (*v as i128).checked_mul(mul),
         ScalarValue::Int64(Some(v)) => (*v as i128).checked_mul(mul),
+        ScalarValue::UInt8(Some(v)) => (*v as i128).checked_mul(mul),
+        ScalarValue::UInt16(Some(v)) => (*v as i128).checked_mul(mul),
+        ScalarValue::UInt32(Some(v)) => (*v as i128).checked_mul(mul),
+        ScalarValue::UInt64(Some(v)) => (*v as i128).checked_mul(mul),
         ScalarValue::TimestampSecond(Some(v), _) => (*v as i128).checked_mul(mul),
         ScalarValue::TimestampMillisecond(Some(v), _) => (*v as i128).checked_mul(mul),
         ScalarValue::TimestampMicrosecond(Some(v), _) => (*v as i128).checked_mul(mul),
@@ -383,6 +424,10 @@ fn try_cast_literal_to_type(
                     DataType::Int16 => ScalarValue::Int16(Some(value as i16)),
                     DataType::Int32 => ScalarValue::Int32(Some(value as i32)),
                     DataType::Int64 => ScalarValue::Int64(Some(value as i64)),
+                    DataType::UInt8 => ScalarValue::UInt8(Some(value as u8)),
+                    DataType::UInt16 => ScalarValue::UInt16(Some(value as u16)),
+                    DataType::UInt32 => ScalarValue::UInt32(Some(value as u32)),
+                    DataType::UInt64 => ScalarValue::UInt64(Some(value as u64)),
                     DataType::Timestamp(TimeUnit::Second, tz) => {
                         ScalarValue::TimestampSecond(Some(value as i64), tz.clone())
                     }
@@ -469,6 +514,15 @@ mod tests {
         assert_eq!(optimize_test(lit_lt_lit, &schema), expected);
     }
 
+    #[test]
+    fn test_unwrap_cast_comparison_unsigned() {
+        // "cast(c6, UINT64) = 0u64 => c6 = 0u32
+        let schema = expr_test_schema();
+        let expr_input = cast(col("c6"), DataType::UInt64).eq(lit(0u64));
+        let expected = col("c6").eq(lit(0u32));
+        assert_eq!(optimize_test(expr_input, &schema), expected);
+    }
+
     #[test]
     fn test_not_unwrap_cast_with_decimal_comparison() {
         let schema = expr_test_schema();
@@ -635,16 +689,16 @@ mod tests {
 
     #[test]
     fn test_not_support_data_type() {
-        // "c6 > 0" will be cast to `cast(c6 as int64) > 0
+        // "c6 > 0" will be cast to `cast(c6 as float) > 0
         // but the type of c6 is uint32
         // the rewriter will not throw error and just return the original expr
         let schema = expr_test_schema();
-        let expr_input = cast(col("c6"), DataType::Int64).eq(lit(0i64));
+        let expr_input = cast(col("c6"), DataType::Float64).eq(lit(0f64));
         assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input);
 
         // inlist for unsupported data type
         let expr_input =
-            in_list(cast(col("c6"), DataType::Int64), vec![lit(0i64)], false);
+            in_list(cast(col("c6"), DataType::Float64), vec![lit(0f64)], false);
         assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input);
     }
 
@@ -733,17 +787,24 @@ mod tests {
             ScalarValue::Int16(None),
             ScalarValue::Int32(None),
             ScalarValue::Int64(None),
+            ScalarValue::UInt8(None),
+            ScalarValue::UInt16(None),
+            ScalarValue::UInt32(None),
+            ScalarValue::UInt64(None),
             ScalarValue::Decimal128(None, 3, 0),
             ScalarValue::Decimal128(None, 8, 2),
         ];
 
         for s1 in &scalars {
             for s2 in &scalars {
-                expect_cast(
-                    s1.clone(),
-                    s2.get_datatype(),
-                    ExpectedCast::Value(s2.clone()),
-                );
+                let expected_value =
+                    if is_unsupported_cast(&s1.get_datatype(), &s2.get_datatype()) {
+                        ExpectedCast::NoValue
+                    } else {
+                        ExpectedCast::Value(s2.clone())
+                    };
+
+                expect_cast(s1.clone(), s2.get_datatype(), expected_value);
             }
         }
     }
@@ -756,25 +817,56 @@ mod tests {
             ScalarValue::Int16(Some(123)),
             ScalarValue::Int32(Some(123)),
             ScalarValue::Int64(Some(123)),
+            ScalarValue::UInt8(Some(123)),
+            ScalarValue::UInt16(Some(123)),
+            ScalarValue::UInt32(Some(123)),
+            ScalarValue::UInt64(Some(123)),
             ScalarValue::Decimal128(Some(123), 3, 0),
             ScalarValue::Decimal128(Some(12300), 8, 2),
         ];
 
         for s1 in &scalars {
             for s2 in &scalars {
-                expect_cast(
-                    s1.clone(),
-                    s2.get_datatype(),
-                    ExpectedCast::Value(s2.clone()),
-                );
+                let expected_value =
+                    if is_unsupported_cast(&s1.get_datatype(), &s2.get_datatype()) {
+                        ExpectedCast::NoValue
+                    } else {
+                        ExpectedCast::Value(s2.clone())
+                    };
+
+                expect_cast(s1.clone(), s2.get_datatype(), expected_value);
             }
         }
+
+        let max_i32 = ScalarValue::Int32(Some(i32::MAX));
+        expect_cast(
+            max_i32,
+            DataType::UInt64,
+            ExpectedCast::Value(ScalarValue::UInt64(Some(i32::MAX as u64))),
+        );
+
+        let min_i32 = ScalarValue::Int32(Some(i32::MIN));
+        expect_cast(
+            min_i32,
+            DataType::Int64,
+            ExpectedCast::Value(ScalarValue::Int64(Some(i32::MIN as i64))),
+        );
+
+        let max_i64 = ScalarValue::Int64(Some(i64::MAX));
+        expect_cast(
+            max_i64,
+            DataType::UInt64,
+            ExpectedCast::Value(ScalarValue::UInt64(Some(i64::MAX as u64))),
+        );
     }
 
     #[test]
     fn test_try_cast_to_type_int_out_of_range() {
+        let min_i32 = ScalarValue::Int32(Some(i32::MIN));
+        let min_i64 = ScalarValue::Int64(Some(i64::MIN));
         let max_i64 = ScalarValue::Int64(Some(i64::MAX));
         let max_u64 = ScalarValue::UInt64(Some(u64::MAX));
+
         expect_cast(max_i64.clone(), DataType::Int8, ExpectedCast::NoValue);
 
         expect_cast(max_i64.clone(), DataType::Int16, ExpectedCast::NoValue);
@@ -783,6 +875,10 @@ mod tests {
 
         expect_cast(max_u64, DataType::Int64, ExpectedCast::NoValue);
 
+        expect_cast(min_i64, DataType::UInt64, ExpectedCast::NoValue);
+
+        expect_cast(min_i32, DataType::UInt64, ExpectedCast::NoValue);
+
         // decimal out of range
         expect_cast(
             ScalarValue::Decimal128(Some(99999999999999999999999999999999999900), 38, 0),
diff --git a/datafusion/optimizer/tests/integration-test.rs b/datafusion/optimizer/tests/integration-test.rs
index 779f156c0..2fdec1f2a 100644
--- a/datafusion/optimizer/tests/integration-test.rs
+++ b/datafusion/optimizer/tests/integration-test.rs
@@ -47,8 +47,8 @@ fn case_when() -> Result<()> {
 
     let sql = "SELECT CASE WHEN col_uint32 > 0 THEN 1 ELSE 0 END FROM test";
     let plan = test_sql(sql)?;
-    let expected = "Projection: CASE WHEN CAST(test.col_uint32 AS Int64) > Int64(0) THEN Int64(1) ELSE Int64(0) END\
-    \n  TableScan: test projection=[col_uint32]";
+    let expected = "Projection: CASE WHEN test.col_uint32 > UInt32(0) THEN Int64(1) ELSE Int64(0) END AS CASE WHEN test.col_uint32 > Int64(0) THEN Int64(1) ELSE Int64(0) END\
+                    \n  TableScan: test projection=[col_uint32]";
     assert_eq!(expected, format!("{:?}", plan));
     Ok(())
 }
@@ -91,7 +91,7 @@ fn unsigned_target_type() -> Result<()> {
     let sql = "SELECT col_utf8 FROM test WHERE col_uint32 > 0";
     let plan = test_sql(sql)?;
     let expected = "Projection: test.col_utf8\
-                    \n  Filter: CAST(test.col_uint32 AS Int64) > Int64(0)\
+                    \n  Filter: test.col_uint32 > UInt32(0)\
                     \n    TableScan: test projection=[col_uint32, col_utf8]";
     assert_eq!(expected, format!("{:?}", plan));
     Ok(())