You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by tu...@apache.org on 2023/09/11 06:34:26 UTC

[arrow-datafusion] branch main updated: Simplify ScalarValue::distance (#7517) (#7519)

This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 4abae3b4fa Simplify ScalarValue::distance (#7517) (#7519)
4abae3b4fa is described below

commit 4abae3b4fadeadc8a368155e14186016117529c8
Author: Raphael Taylor-Davies <17...@users.noreply.github.com>
AuthorDate: Mon Sep 11 07:34:20 2023 +0100

    Simplify ScalarValue::distance (#7517) (#7519)
---
 datafusion/common/src/scalar.rs | 44 ++++++++++++++---------------------------
 1 file changed, 15 insertions(+), 29 deletions(-)

diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs
index 6189a293b4..fa2175c223 100644
--- a/datafusion/common/src/scalar.rs
+++ b/datafusion/common/src/scalar.rs
@@ -1135,31 +1135,22 @@ impl ScalarValue {
     ///
     /// Note: the datatype itself must support subtraction.
     pub fn distance(&self, other: &ScalarValue) -> Option<usize> {
-        // Having an explicit null check here is important because the
-        // subtraction for scalar values will return a real value even
-        // if one side is null.
-        if self.is_null() || other.is_null() {
-            return None;
-        }
-
-        let distance = if self > other {
-            self.sub_checked(other).ok()?
-        } else {
-            other.sub_checked(self).ok()?
-        };
-
-        match distance {
-            ScalarValue::Int8(Some(v)) => usize::try_from(v).ok(),
-            ScalarValue::Int16(Some(v)) => usize::try_from(v).ok(),
-            ScalarValue::Int32(Some(v)) => usize::try_from(v).ok(),
-            ScalarValue::Int64(Some(v)) => usize::try_from(v).ok(),
-            ScalarValue::UInt8(Some(v)) => Some(v as usize),
-            ScalarValue::UInt16(Some(v)) => Some(v as usize),
-            ScalarValue::UInt32(Some(v)) => usize::try_from(v).ok(),
-            ScalarValue::UInt64(Some(v)) => usize::try_from(v).ok(),
+        match (self, other) {
+            (Self::Int8(Some(l)), Self::Int8(Some(r))) => Some(l.abs_diff(*r) as _),
+            (Self::Int16(Some(l)), Self::Int16(Some(r))) => Some(l.abs_diff(*r) as _),
+            (Self::Int32(Some(l)), Self::Int32(Some(r))) => Some(l.abs_diff(*r) as _),
+            (Self::Int64(Some(l)), Self::Int64(Some(r))) => Some(l.abs_diff(*r) as _),
+            (Self::UInt8(Some(l)), Self::UInt8(Some(r))) => Some(l.abs_diff(*r) as _),
+            (Self::UInt16(Some(l)), Self::UInt16(Some(r))) => Some(l.abs_diff(*r) as _),
+            (Self::UInt32(Some(l)), Self::UInt32(Some(r))) => Some(l.abs_diff(*r) as _),
+            (Self::UInt64(Some(l)), Self::UInt64(Some(r))) => Some(l.abs_diff(*r) as _),
             // TODO: we might want to look into supporting ceil/floor here for floats.
-            ScalarValue::Float32(Some(v)) => Some(v.round() as usize),
-            ScalarValue::Float64(Some(v)) => Some(v.round() as usize),
+            (Self::Float32(Some(l)), Self::Float32(Some(r))) => {
+                Some((l - r).abs().round() as _)
+            }
+            (Self::Float64(Some(l)), Self::Float64(Some(r))) => {
+                Some((l - r).abs().round() as _)
+            }
             _ => None,
         }
     }
@@ -4725,11 +4716,6 @@ mod tests {
                 ScalarValue::Decimal128(Some(123), 5, 5),
                 ScalarValue::Decimal128(Some(120), 5, 5),
             ),
-            // Overflows
-            (
-                ScalarValue::Int8(Some(i8::MAX)),
-                ScalarValue::Int8(Some(i8::MIN)),
-            ),
         ];
         for (lhs, rhs) in cases {
             let distance = lhs.distance(&rhs);