You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/11/15 19:00:35 UTC
[arrow-datafusion] branch master updated: Support unsigned integers in `unwrap_cast_in_comparison` Optimizer rule (#4149)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 4653df465 Support unsigned integers in `unwrap_cast_in_comparison` Optimizer rule (#4149)
4653df465 is described below
commit 4653df4652c8af5d4e8841c489c1ab2e85a54e69
Author: Andrew Lamb <an...@nerdnetworks.org>
AuthorDate: Tue Nov 15 14:00:29 2022 -0500
Support unsigned integers in `unwrap_cast_in_comparison` Optimizer rule (#4149)
* Support unsigned integers in `unwrap_cast_in_comparison` Optimizer rule
* Update comment
---
datafusion/core/tests/sql/joins.rs | 12 +-
.../optimizer/src/unwrap_cast_in_comparison.rs | 126 ++++++++++++++++++---
datafusion/optimizer/tests/integration-test.rs | 6 +-
3 files changed, 120 insertions(+), 24 deletions(-)
diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs
index 10d024025..324ccb4c7 100644
--- a/datafusion/core/tests/sql/joins.rs
+++ b/datafusion/core/tests/sql/joins.rs
@@ -1428,9 +1428,9 @@ async fn reduce_left_join_1() -> Result<()> {
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: t1.t1_id, t1.t1_name, t1.t1_int, t2.t2_id, t2.t2_name, t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
" Inner Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
- " Filter: CAST(t1.t1_id AS Int64) < Int64(100) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " Filter: t1.t1_id < UInt32(100) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
- " Filter: CAST(t2.t2_id AS Int64) < Int64(100) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+ " Filter: t2.t2_id < UInt32(100) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
" TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
];
let formatted = plan.display_indent_schema().to_string();
@@ -1476,10 +1476,10 @@ async fn reduce_left_join_2() -> Result<()> {
let expected = vec![
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: t1.t1_id, t1.t1_name, t1.t1_int, t2.t2_id, t2.t2_name, t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
- " Filter: CAST(t2.t2_int AS Int64) < Int64(10) OR CAST(t1.t1_int AS Int64) > Int64(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+ " Filter: t2.t2_int < UInt32(10) OR t1.t1_int > UInt32(2) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
" Inner Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
" TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
- " Filter: CAST(t2.t2_int AS Int64) < Int64(10) OR t2.t2_name != Utf8(\"w\") [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+ " Filter: t2.t2_int < UInt32(10) OR t2.t2_name != Utf8(\"w\") [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
" TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
];
let formatted = plan.display_indent_schema().to_string();
@@ -1524,9 +1524,9 @@ async fn reduce_left_join_3() -> Result<()> {
" Projection: t1.t1_id, t1.t1_name, t1.t1_int, alias=t3 [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" Inner Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
- " Filter: CAST(t1.t1_id AS Int64) < Int64(100) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
+ " Filter: t1.t1_id < UInt32(100) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
" TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]",
- " Filter: CAST(t2.t2_int AS Int64) < Int64(3) AND CAST(t2.t2_id AS Int64) < Int64(100) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
+ " Filter: t2.t2_int < UInt32(3) AND t2.t2_id < UInt32(100) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
" TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
" TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]",
];
diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs
index 28b085684..5f542d749 100644
--- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs
+++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs
@@ -283,7 +283,11 @@ fn is_comparison_op(op: &Operator) -> bool {
fn is_support_data_type(data_type: &DataType) -> bool {
matches!(
data_type,
- DataType::Int8
+ DataType::UInt8
+ | DataType::UInt16
+ | DataType::UInt32
+ | DataType::UInt64
+ | DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
@@ -292,6 +296,25 @@ fn is_support_data_type(data_type: &DataType) -> bool {
)
}
+fn is_decimal_type(dt: &DataType) -> bool {
+ matches!(dt, DataType::Decimal128(_, _))
+}
+
+fn is_unsigned_type(dt: &DataType) -> bool {
+ matches!(
+ dt,
+ DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64
+ )
+}
+
+/// Until https://github.com/apache/arrow-rs/issues/1043 is done
+/// (support for unsigned <--> decimal casts) we also don't do that
+/// kind of cast in this optimizer
+fn is_unsupported_cast(dt1: &DataType, dt2: &DataType) -> bool {
+ (is_decimal_type(dt1) && is_unsigned_type(dt2))
+ || (is_decimal_type(dt2) && is_unsigned_type(dt1))
+}
+
fn try_cast_literal_to_type(
lit_value: &ScalarValue,
target_type: &DataType,
@@ -301,12 +324,22 @@ fn try_cast_literal_to_type(
if !is_support_data_type(&lit_data_type) || !is_support_data_type(target_type) {
return Ok(None);
}
+ if is_unsupported_cast(&lit_data_type, target_type) {
+ return Ok(None);
+ }
if lit_value.is_null() {
// null value can be cast to any type of null value
return Ok(Some(ScalarValue::try_from(target_type)?));
}
let mul = match target_type {
- DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => 1_i128,
+ DataType::UInt8
+ | DataType::UInt16
+ | DataType::UInt32
+ | DataType::UInt64
+ | DataType::Int8
+ | DataType::Int16
+ | DataType::Int32
+ | DataType::Int64 => 1_i128,
DataType::Timestamp(_, _) => 1_i128,
DataType::Decimal128(_, scale) => 10_i128.pow(*scale as u32),
other_type => {
@@ -317,6 +350,10 @@ fn try_cast_literal_to_type(
}
};
let (target_min, target_max) = match target_type {
+ DataType::UInt8 => (u8::MIN as i128, u8::MAX as i128),
+ DataType::UInt16 => (u16::MIN as i128, u16::MAX as i128),
+ DataType::UInt32 => (u32::MIN as i128, u32::MAX as i128),
+ DataType::UInt64 => (u64::MIN as i128, u64::MAX as i128),
DataType::Int8 => (i8::MIN as i128, i8::MAX as i128),
DataType::Int16 => (i16::MIN as i128, i16::MAX as i128),
DataType::Int32 => (i32::MIN as i128, i32::MAX as i128),
@@ -341,6 +378,10 @@ fn try_cast_literal_to_type(
ScalarValue::Int16(Some(v)) => (*v as i128).checked_mul(mul),
ScalarValue::Int32(Some(v)) => (*v as i128).checked_mul(mul),
ScalarValue::Int64(Some(v)) => (*v as i128).checked_mul(mul),
+ ScalarValue::UInt8(Some(v)) => (*v as i128).checked_mul(mul),
+ ScalarValue::UInt16(Some(v)) => (*v as i128).checked_mul(mul),
+ ScalarValue::UInt32(Some(v)) => (*v as i128).checked_mul(mul),
+ ScalarValue::UInt64(Some(v)) => (*v as i128).checked_mul(mul),
ScalarValue::TimestampSecond(Some(v), _) => (*v as i128).checked_mul(mul),
ScalarValue::TimestampMillisecond(Some(v), _) => (*v as i128).checked_mul(mul),
ScalarValue::TimestampMicrosecond(Some(v), _) => (*v as i128).checked_mul(mul),
@@ -383,6 +424,10 @@ fn try_cast_literal_to_type(
DataType::Int16 => ScalarValue::Int16(Some(value as i16)),
DataType::Int32 => ScalarValue::Int32(Some(value as i32)),
DataType::Int64 => ScalarValue::Int64(Some(value as i64)),
+ DataType::UInt8 => ScalarValue::UInt8(Some(value as u8)),
+ DataType::UInt16 => ScalarValue::UInt16(Some(value as u16)),
+ DataType::UInt32 => ScalarValue::UInt32(Some(value as u32)),
+ DataType::UInt64 => ScalarValue::UInt64(Some(value as u64)),
DataType::Timestamp(TimeUnit::Second, tz) => {
ScalarValue::TimestampSecond(Some(value as i64), tz.clone())
}
@@ -469,6 +514,15 @@ mod tests {
assert_eq!(optimize_test(lit_lt_lit, &schema), expected);
}
+ #[test]
+ fn test_unwrap_cast_comparison_unsigned() {
+ // "cast(c6, UINT64) = 0u64 => c6 = 0u32
+ let schema = expr_test_schema();
+ let expr_input = cast(col("c6"), DataType::UInt64).eq(lit(0u64));
+ let expected = col("c6").eq(lit(0u32));
+ assert_eq!(optimize_test(expr_input, &schema), expected);
+ }
+
#[test]
fn test_not_unwrap_cast_with_decimal_comparison() {
let schema = expr_test_schema();
@@ -635,16 +689,16 @@ mod tests {
#[test]
fn test_not_support_data_type() {
- // "c6 > 0" will be cast to `cast(c6 as int64) > 0
+ // "c6 > 0" will be cast to `cast(c6 as float) > 0
// but the type of c6 is uint32
// the rewriter will not throw error and just return the original expr
let schema = expr_test_schema();
- let expr_input = cast(col("c6"), DataType::Int64).eq(lit(0i64));
+ let expr_input = cast(col("c6"), DataType::Float64).eq(lit(0f64));
assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input);
// inlist for unsupported data type
let expr_input =
- in_list(cast(col("c6"), DataType::Int64), vec![lit(0i64)], false);
+ in_list(cast(col("c6"), DataType::Float64), vec![lit(0f64)], false);
assert_eq!(optimize_test(expr_input.clone(), &schema), expr_input);
}
@@ -733,17 +787,24 @@ mod tests {
ScalarValue::Int16(None),
ScalarValue::Int32(None),
ScalarValue::Int64(None),
+ ScalarValue::UInt8(None),
+ ScalarValue::UInt16(None),
+ ScalarValue::UInt32(None),
+ ScalarValue::UInt64(None),
ScalarValue::Decimal128(None, 3, 0),
ScalarValue::Decimal128(None, 8, 2),
];
for s1 in &scalars {
for s2 in &scalars {
- expect_cast(
- s1.clone(),
- s2.get_datatype(),
- ExpectedCast::Value(s2.clone()),
- );
+ let expected_value =
+ if is_unsupported_cast(&s1.get_datatype(), &s2.get_datatype()) {
+ ExpectedCast::NoValue
+ } else {
+ ExpectedCast::Value(s2.clone())
+ };
+
+ expect_cast(s1.clone(), s2.get_datatype(), expected_value);
}
}
}
@@ -756,25 +817,56 @@ mod tests {
ScalarValue::Int16(Some(123)),
ScalarValue::Int32(Some(123)),
ScalarValue::Int64(Some(123)),
+ ScalarValue::UInt8(Some(123)),
+ ScalarValue::UInt16(Some(123)),
+ ScalarValue::UInt32(Some(123)),
+ ScalarValue::UInt64(Some(123)),
ScalarValue::Decimal128(Some(123), 3, 0),
ScalarValue::Decimal128(Some(12300), 8, 2),
];
for s1 in &scalars {
for s2 in &scalars {
- expect_cast(
- s1.clone(),
- s2.get_datatype(),
- ExpectedCast::Value(s2.clone()),
- );
+ let expected_value =
+ if is_unsupported_cast(&s1.get_datatype(), &s2.get_datatype()) {
+ ExpectedCast::NoValue
+ } else {
+ ExpectedCast::Value(s2.clone())
+ };
+
+ expect_cast(s1.clone(), s2.get_datatype(), expected_value);
}
}
+
+ let max_i32 = ScalarValue::Int32(Some(i32::MAX));
+ expect_cast(
+ max_i32,
+ DataType::UInt64,
+ ExpectedCast::Value(ScalarValue::UInt64(Some(i32::MAX as u64))),
+ );
+
+ let min_i32 = ScalarValue::Int32(Some(i32::MIN));
+ expect_cast(
+ min_i32,
+ DataType::Int64,
+ ExpectedCast::Value(ScalarValue::Int64(Some(i32::MIN as i64))),
+ );
+
+ let max_i64 = ScalarValue::Int64(Some(i64::MAX));
+ expect_cast(
+ max_i64,
+ DataType::UInt64,
+ ExpectedCast::Value(ScalarValue::UInt64(Some(i64::MAX as u64))),
+ );
}
#[test]
fn test_try_cast_to_type_int_out_of_range() {
+ let min_i32 = ScalarValue::Int32(Some(i32::MIN));
+ let min_i64 = ScalarValue::Int64(Some(i64::MIN));
let max_i64 = ScalarValue::Int64(Some(i64::MAX));
let max_u64 = ScalarValue::UInt64(Some(u64::MAX));
+
expect_cast(max_i64.clone(), DataType::Int8, ExpectedCast::NoValue);
expect_cast(max_i64.clone(), DataType::Int16, ExpectedCast::NoValue);
@@ -783,6 +875,10 @@ mod tests {
expect_cast(max_u64, DataType::Int64, ExpectedCast::NoValue);
+ expect_cast(min_i64, DataType::UInt64, ExpectedCast::NoValue);
+
+ expect_cast(min_i32, DataType::UInt64, ExpectedCast::NoValue);
+
// decimal out of range
expect_cast(
ScalarValue::Decimal128(Some(99999999999999999999999999999999999900), 38, 0),
diff --git a/datafusion/optimizer/tests/integration-test.rs b/datafusion/optimizer/tests/integration-test.rs
index 779f156c0..2fdec1f2a 100644
--- a/datafusion/optimizer/tests/integration-test.rs
+++ b/datafusion/optimizer/tests/integration-test.rs
@@ -47,8 +47,8 @@ fn case_when() -> Result<()> {
let sql = "SELECT CASE WHEN col_uint32 > 0 THEN 1 ELSE 0 END FROM test";
let plan = test_sql(sql)?;
- let expected = "Projection: CASE WHEN CAST(test.col_uint32 AS Int64) > Int64(0) THEN Int64(1) ELSE Int64(0) END\
- \n TableScan: test projection=[col_uint32]";
+ let expected = "Projection: CASE WHEN test.col_uint32 > UInt32(0) THEN Int64(1) ELSE Int64(0) END AS CASE WHEN test.col_uint32 > Int64(0) THEN Int64(1) ELSE Int64(0) END\
+ \n TableScan: test projection=[col_uint32]";
assert_eq!(expected, format!("{:?}", plan));
Ok(())
}
@@ -91,7 +91,7 @@ fn unsigned_target_type() -> Result<()> {
let sql = "SELECT col_utf8 FROM test WHERE col_uint32 > 0";
let plan = test_sql(sql)?;
let expected = "Projection: test.col_utf8\
- \n Filter: CAST(test.col_uint32 AS Int64) > Int64(0)\
+ \n Filter: test.col_uint32 > UInt32(0)\
\n TableScan: test projection=[col_uint32, col_utf8]";
assert_eq!(expected, format!("{:?}", plan));
Ok(())