You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by tu...@apache.org on 2023/09/11 06:34:26 UTC
[arrow-datafusion] branch main updated: Simplify ScalarValue::distance (#7517) (#7519)
This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 4abae3b4fa Simplify ScalarValue::distance (#7517) (#7519)
4abae3b4fa is described below
commit 4abae3b4fadeadc8a368155e14186016117529c8
Author: Raphael Taylor-Davies <17...@users.noreply.github.com>
AuthorDate: Mon Sep 11 07:34:20 2023 +0100
Simplify ScalarValue::distance (#7517) (#7519)
---
datafusion/common/src/scalar.rs | 44 ++++++++++++++---------------------------
1 file changed, 15 insertions(+), 29 deletions(-)
diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs
index 6189a293b4..fa2175c223 100644
--- a/datafusion/common/src/scalar.rs
+++ b/datafusion/common/src/scalar.rs
@@ -1135,31 +1135,22 @@ impl ScalarValue {
///
/// Note: the datatype itself must support subtraction.
pub fn distance(&self, other: &ScalarValue) -> Option<usize> {
- // Having an explicit null check here is important because the
- // subtraction for scalar values will return a real value even
- // if one side is null.
- if self.is_null() || other.is_null() {
- return None;
- }
-
- let distance = if self > other {
- self.sub_checked(other).ok()?
- } else {
- other.sub_checked(self).ok()?
- };
-
- match distance {
- ScalarValue::Int8(Some(v)) => usize::try_from(v).ok(),
- ScalarValue::Int16(Some(v)) => usize::try_from(v).ok(),
- ScalarValue::Int32(Some(v)) => usize::try_from(v).ok(),
- ScalarValue::Int64(Some(v)) => usize::try_from(v).ok(),
- ScalarValue::UInt8(Some(v)) => Some(v as usize),
- ScalarValue::UInt16(Some(v)) => Some(v as usize),
- ScalarValue::UInt32(Some(v)) => usize::try_from(v).ok(),
- ScalarValue::UInt64(Some(v)) => usize::try_from(v).ok(),
+ match (self, other) {
+ (Self::Int8(Some(l)), Self::Int8(Some(r))) => Some(l.abs_diff(*r) as _),
+ (Self::Int16(Some(l)), Self::Int16(Some(r))) => Some(l.abs_diff(*r) as _),
+ (Self::Int32(Some(l)), Self::Int32(Some(r))) => Some(l.abs_diff(*r) as _),
+ (Self::Int64(Some(l)), Self::Int64(Some(r))) => Some(l.abs_diff(*r) as _),
+ (Self::UInt8(Some(l)), Self::UInt8(Some(r))) => Some(l.abs_diff(*r) as _),
+ (Self::UInt16(Some(l)), Self::UInt16(Some(r))) => Some(l.abs_diff(*r) as _),
+ (Self::UInt32(Some(l)), Self::UInt32(Some(r))) => Some(l.abs_diff(*r) as _),
+ (Self::UInt64(Some(l)), Self::UInt64(Some(r))) => Some(l.abs_diff(*r) as _),
// TODO: we might want to look into supporting ceil/floor here for floats.
- ScalarValue::Float32(Some(v)) => Some(v.round() as usize),
- ScalarValue::Float64(Some(v)) => Some(v.round() as usize),
+ (Self::Float32(Some(l)), Self::Float32(Some(r))) => {
+ Some((l - r).abs().round() as _)
+ }
+ (Self::Float64(Some(l)), Self::Float64(Some(r))) => {
+ Some((l - r).abs().round() as _)
+ }
_ => None,
}
}
@@ -4725,11 +4716,6 @@ mod tests {
ScalarValue::Decimal128(Some(123), 5, 5),
ScalarValue::Decimal128(Some(120), 5, 5),
),
- // Overflows
- (
- ScalarValue::Int8(Some(i8::MAX)),
- ScalarValue::Int8(Some(i8::MIN)),
- ),
];
for (lhs, rhs) in cases {
let distance = lhs.distance(&rhs);