You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by GitBox <gi...@apache.org> on 2020/10/25 20:22:00 UTC

[GitHub] [arrow] nevi-me commented on a change in pull request #8517: ARROW-10381: [Rust] Generalized Ordering for inter-array comparisons

nevi-me commented on a change in pull request #8517:
URL: https://github.com/apache/arrow/pull/8517#discussion_r511642958



##########
File path: rust/arrow/src/array/ord.rs
##########
@@ -15,297 +15,280 @@
 // specific language governing permissions and limitations
 // under the License.
 
-//! Defines trait for array element comparison
+//! Contains functions and function factories to compare arrays.
 
 use std::cmp::Ordering;
 
 use crate::array::*;
+use crate::datatypes::TimeUnit;
 use crate::datatypes::*;
 use crate::error::{ArrowError, Result};
 
-use TimeUnit::*;
+use num::Float;
 
-/// Trait for Arrays that can be sorted
-///
-/// Example:
-/// ```
-/// use std::cmp::Ordering;
-/// use arrow::array::*;
-/// use arrow::datatypes::*;
-///
-/// let arr: Box<dyn OrdArray> = Box::new(PrimitiveArray::<Int64Type>::from(vec![
-///     Some(-2),
-///     Some(89),
-///     Some(-64),
-///     Some(101),
-/// ]));
-///
-/// assert_eq!(arr.cmp_value(1, 2), Ordering::Greater);
-/// ```
-pub trait OrdArray {
-    /// Return ordering between array element at index i and j
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering;
-}
+/// The public interface to compare values from arrays in a dynamically-typed fashion.
+pub type DynComparator<'a> = Box<dyn Fn(usize, usize) -> Ordering + 'a>;
 
-impl<T: OrdArray> OrdArray for Box<T> {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        T::cmp_value(self, i, j)
+/// compares two floats, placing NaNs at last
+fn cmp_nans_last<T: Float>(a: &T, b: &T) -> Ordering {
+    match (a, b) {
+        (x, y) if x.is_nan() && y.is_nan() => Ordering::Equal,
+        (x, _) if x.is_nan() => Ordering::Greater,
+        (_, y) if y.is_nan() => Ordering::Less,
+        (_, _) => a.partial_cmp(b).unwrap(),
     }
 }
 
-impl<T: OrdArray> OrdArray for &T {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        T::cmp_value(self, i, j)
-    }
+fn compare_primitives<'a, T: ArrowPrimitiveType>(
+    left: &'a Array,
+    right: &'a Array,
+) -> DynComparator<'a>
+where
+    T::Native: Ord,
+{
+    let left = left.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    let right = right.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    Box::new(move |i, j| left.value(i).cmp(&right.value(j)))
 }
 
-impl<T: ArrowPrimitiveType> OrdArray for PrimitiveArray<T>
+fn compare_float<'a, T: ArrowPrimitiveType>(
+    left: &'a Array,
+    right: &'a Array,
+) -> DynComparator<'a>
 where
-    T::Native: std::cmp::Ord,
+    T::Native: Float,
 {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        self.value(i).cmp(&self.value(j))
-    }
+    let left = left.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    let right = right.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    Box::new(move |i, j| cmp_nans_last(&left.value(i), &right.value(j)))
 }
 
-impl OrdArray for StringArray {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        self.value(i).cmp(self.value(j))
-    }
+fn compare_string<'a, T>(left: &'a Array, right: &'a Array) -> DynComparator<'a>
+where
+    T: StringOffsetSizeTrait,
+{
+    let left = left
+        .as_any()
+        .downcast_ref::<GenericStringArray<T>>()
+        .unwrap();
+    let right = right
+        .as_any()
+        .downcast_ref::<GenericStringArray<T>>()
+        .unwrap();
+    Box::new(move |i, j| left.value(i).cmp(&right.value(j)))
 }
 
-impl OrdArray for NullArray {
-    fn cmp_value(&self, _i: usize, _j: usize) -> Ordering {
-        Ordering::Equal
-    }
+fn compare_dict_string<'a, T>(left: &'a Array, right: &'a Array) -> DynComparator<'a>
+where
+    T: ArrowDictionaryKeyType,
+{
+    let left = left.as_any().downcast_ref::<DictionaryArray<T>>().unwrap();
+    let right = right.as_any().downcast_ref::<DictionaryArray<T>>().unwrap();
+    let left_keys = left.keys_array();
+    let right_keys = right.keys_array();
+
+    let left_values = StringArray::from(left.values().data());
+    let right_values = StringArray::from(left.values().data());
+
+    Box::new(move |i: usize, j: usize| {
+        let key_left = left_keys.value(i).to_usize().unwrap();
+        let key_right = right_keys.value(j).to_usize().unwrap();
+        let left = left_values.value(key_left);
+        let right = right_values.value(key_right);
+        left.cmp(&right)
+    })
 }
 
-macro_rules! float_ord_cmp {
-    ($NAME: ident, $T: ty) => {
-        #[inline]
-        fn $NAME(a: $T, b: $T) -> Ordering {
-            if a < b {
-                return Ordering::Less;
-            }
-            if a > b {
-                return Ordering::Greater;
+/// returns a comparison function that compares two values at two different positions
+/// between the two arrays.
+/// The arrays' types must be equal.
+/// # Example
+/// ```
+/// use arrow::array::{build_compare, Int32Array};
+///
+/// # fn main() -> arrow::error::Result<()> {
+/// let array1 = Int32Array::from(vec![1, 2]);
+/// let array2 = Int32Array::from(vec![3, 4]);
+///
+/// let cmp = build_compare(&array1, &array2)?;
+///
+/// // 1 (index 0 of array1) is smaller than 4 (index 1 of array2)
+/// assert_eq!(std::cmp::Ordering::Less, (cmp)(0, 1));
+/// # Ok(())
+/// # }
+/// ```
+// This is a factory of comparisons.
+// The lifetime 'a enforces that we cannot use the closure beyond any of the array's lifetime.
+pub fn build_compare<'a>(left: &'a Array, right: &'a Array) -> Result<DynComparator<'a>> {
+    use DataType::*;
+    use IntervalUnit::*;
+    use TimeUnit::*;
+    Ok(match (left.data_type(), right.data_type()) {
+        (a, b) if a != b => {
+            return Err(ArrowError::InvalidArgumentError(
+                "Can't compare arrays of different types".to_string(),
+            ));
+        }
+        (Boolean, Boolean) => compare_primitives::<BooleanType>(left, right),
+        (UInt8, UInt8) => compare_primitives::<UInt8Type>(left, right),
+        (UInt16, UInt16) => compare_primitives::<UInt16Type>(left, right),
+        (UInt32, UInt32) => compare_primitives::<UInt32Type>(left, right),
+        (UInt64, UInt64) => compare_primitives::<UInt64Type>(left, right),
+        (Int8, Int8) => compare_primitives::<Int8Type>(left, right),
+        (Int16, Int16) => compare_primitives::<Int16Type>(left, right),
+        (Int32, Int32) => compare_primitives::<Int32Type>(left, right),
+        (Int64, Int64) => compare_primitives::<Int64Type>(left, right),
+        (Float32, Float32) => compare_float::<Float32Type>(left, right),
+        (Float64, Float64) => compare_float::<Float64Type>(left, right),
+        (Date32(_), Date32(_)) => compare_primitives::<Date32Type>(left, right),
+        (Date64(_), Date64(_)) => compare_primitives::<Date64Type>(left, right),
+        (Time32(Second), Time32(Second)) => {
+            compare_primitives::<Time32SecondType>(left, right)
+        }
+        (Time32(Millisecond), Time32(Millisecond)) => {
+            compare_primitives::<Time32MillisecondType>(left, right)
+        }
+        (Time64(Microsecond), Time64(Microsecond)) => {
+            compare_primitives::<Time64MicrosecondType>(left, right)
+        }
+        (Time64(Nanosecond), Time64(Nanosecond)) => {
+            compare_primitives::<Time64NanosecondType>(left, right)
+        }
+        (Timestamp(Second, _), Timestamp(Second, _)) => {
+            compare_primitives::<TimestampSecondType>(left, right)
+        }
+        (Timestamp(Millisecond, _), Timestamp(Millisecond, _)) => {
+            compare_primitives::<TimestampMillisecondType>(left, right)
+        }
+        (Timestamp(Microsecond, _), Timestamp(Microsecond, _)) => {
+            compare_primitives::<TimestampMicrosecondType>(left, right)
+        }
+        (Timestamp(Nanosecond, _), Timestamp(Nanosecond, _)) => {
+            compare_primitives::<TimestampNanosecondType>(left, right)
+        }
+        (Interval(YearMonth), Interval(YearMonth)) => {
+            compare_primitives::<IntervalYearMonthType>(left, right)
+        }
+        (Interval(DayTime), Interval(DayTime)) => {
+            compare_primitives::<IntervalDayTimeType>(left, right)
+        }
+        (Duration(Second), Duration(Second)) => {
+            compare_primitives::<DurationSecondType>(left, right)
+        }
+        (Duration(Millisecond), Duration(Millisecond)) => {
+            compare_primitives::<DurationMillisecondType>(left, right)
+        }
+        (Duration(Microsecond), Duration(Microsecond)) => {
+            compare_primitives::<DurationMicrosecondType>(left, right)
+        }
+        (Duration(Nanosecond), Duration(Nanosecond)) => {
+            compare_primitives::<DurationNanosecondType>(left, right)
+        }
+        (Utf8, Utf8) => compare_string::<i32>(left, right),
+        (LargeUtf8, LargeUtf8) => compare_string::<i64>(left, right),
+        (
+            Dictionary(key_type_lhs, value_type_lhs),
+            Dictionary(key_type_rhs, value_type_rhs),
+        ) => {
+            if value_type_lhs.as_ref() != &DataType::Utf8
+                || value_type_rhs.as_ref() != &DataType::Utf8
+            {
+                return Err(ArrowError::InvalidArgumentError(
+                    "Arrow still does not support comparisons of non-string dictionary arrays"

Review comment:
       Would we incur a high cost if we cast dictionaries to primitives, then compared the primitives?

##########
File path: rust/arrow/src/array/ord.rs
##########
@@ -15,297 +15,280 @@
 // specific language governing permissions and limitations
 // under the License.
 
-//! Defines trait for array element comparison
+//! Contains functions and function factories to compare arrays.
 
 use std::cmp::Ordering;
 
 use crate::array::*;
+use crate::datatypes::TimeUnit;
 use crate::datatypes::*;
 use crate::error::{ArrowError, Result};
 
-use TimeUnit::*;
+use num::Float;
 
-/// Trait for Arrays that can be sorted
-///
-/// Example:
-/// ```
-/// use std::cmp::Ordering;
-/// use arrow::array::*;
-/// use arrow::datatypes::*;
-///
-/// let arr: Box<dyn OrdArray> = Box::new(PrimitiveArray::<Int64Type>::from(vec![
-///     Some(-2),
-///     Some(89),
-///     Some(-64),
-///     Some(101),
-/// ]));
-///
-/// assert_eq!(arr.cmp_value(1, 2), Ordering::Greater);
-/// ```
-pub trait OrdArray {
-    /// Return ordering between array element at index i and j
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering;
-}
+/// The public interface to compare values from arrays in a dynamically-typed fashion.
+pub type DynComparator<'a> = Box<dyn Fn(usize, usize) -> Ordering + 'a>;
 
-impl<T: OrdArray> OrdArray for Box<T> {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        T::cmp_value(self, i, j)
+/// compares two floats, placing NaNs at last
+fn cmp_nans_last<T: Float>(a: &T, b: &T) -> Ordering {
+    match (a, b) {
+        (x, y) if x.is_nan() && y.is_nan() => Ordering::Equal,
+        (x, _) if x.is_nan() => Ordering::Greater,
+        (_, y) if y.is_nan() => Ordering::Less,
+        (_, _) => a.partial_cmp(b).unwrap(),
     }
 }
 
-impl<T: OrdArray> OrdArray for &T {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        T::cmp_value(self, i, j)
-    }
+fn compare_primitives<'a, T: ArrowPrimitiveType>(
+    left: &'a Array,
+    right: &'a Array,
+) -> DynComparator<'a>
+where
+    T::Native: Ord,
+{
+    let left = left.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    let right = right.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    Box::new(move |i, j| left.value(i).cmp(&right.value(j)))
 }
 
-impl<T: ArrowPrimitiveType> OrdArray for PrimitiveArray<T>
+fn compare_float<'a, T: ArrowPrimitiveType>(
+    left: &'a Array,
+    right: &'a Array,
+) -> DynComparator<'a>
 where
-    T::Native: std::cmp::Ord,
+    T::Native: Float,
 {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        self.value(i).cmp(&self.value(j))
-    }
+    let left = left.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    let right = right.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    Box::new(move |i, j| cmp_nans_last(&left.value(i), &right.value(j)))
 }
 
-impl OrdArray for StringArray {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        self.value(i).cmp(self.value(j))
-    }
+fn compare_string<'a, T>(left: &'a Array, right: &'a Array) -> DynComparator<'a>
+where
+    T: StringOffsetSizeTrait,
+{
+    let left = left
+        .as_any()
+        .downcast_ref::<GenericStringArray<T>>()
+        .unwrap();
+    let right = right
+        .as_any()
+        .downcast_ref::<GenericStringArray<T>>()
+        .unwrap();
+    Box::new(move |i, j| left.value(i).cmp(&right.value(j)))
 }
 
-impl OrdArray for NullArray {
-    fn cmp_value(&self, _i: usize, _j: usize) -> Ordering {
-        Ordering::Equal
-    }
+fn compare_dict_string<'a, T>(left: &'a Array, right: &'a Array) -> DynComparator<'a>
+where
+    T: ArrowDictionaryKeyType,
+{
+    let left = left.as_any().downcast_ref::<DictionaryArray<T>>().unwrap();
+    let right = right.as_any().downcast_ref::<DictionaryArray<T>>().unwrap();
+    let left_keys = left.keys_array();
+    let right_keys = right.keys_array();
+
+    let left_values = StringArray::from(left.values().data());
+    let right_values = StringArray::from(left.values().data());
+
+    Box::new(move |i: usize, j: usize| {
+        let key_left = left_keys.value(i).to_usize().unwrap();
+        let key_right = right_keys.value(j).to_usize().unwrap();
+        let left = left_values.value(key_left);
+        let right = right_values.value(key_right);
+        left.cmp(&right)
+    })
 }
 
-macro_rules! float_ord_cmp {
-    ($NAME: ident, $T: ty) => {
-        #[inline]
-        fn $NAME(a: $T, b: $T) -> Ordering {
-            if a < b {
-                return Ordering::Less;
-            }
-            if a > b {
-                return Ordering::Greater;
+/// returns a comparison function that compares two values at two different positions
+/// between the two arrays.
+/// The arrays' types must be equal.
+/// # Example
+/// ```
+/// use arrow::array::{build_compare, Int32Array};
+///
+/// # fn main() -> arrow::error::Result<()> {
+/// let array1 = Int32Array::from(vec![1, 2]);
+/// let array2 = Int32Array::from(vec![3, 4]);
+///
+/// let cmp = build_compare(&array1, &array2)?;
+///
+/// // 1 (index 0 of array1) is smaller than 4 (index 1 of array2)
+/// assert_eq!(std::cmp::Ordering::Less, (cmp)(0, 1));
+/// # Ok(())
+/// # }
+/// ```
+// This is a factory of comparisons.
+// The lifetime 'a enforces that we cannot use the closure beyond any of the array's lifetime.
+pub fn build_compare<'a>(left: &'a Array, right: &'a Array) -> Result<DynComparator<'a>> {
+    use DataType::*;
+    use IntervalUnit::*;
+    use TimeUnit::*;
+    Ok(match (left.data_type(), right.data_type()) {
+        (a, b) if a != b => {
+            return Err(ArrowError::InvalidArgumentError(
+                "Can't compare arrays of different types".to_string(),
+            ));
+        }
+        (Boolean, Boolean) => compare_primitives::<BooleanType>(left, right),
+        (UInt8, UInt8) => compare_primitives::<UInt8Type>(left, right),
+        (UInt16, UInt16) => compare_primitives::<UInt16Type>(left, right),
+        (UInt32, UInt32) => compare_primitives::<UInt32Type>(left, right),
+        (UInt64, UInt64) => compare_primitives::<UInt64Type>(left, right),
+        (Int8, Int8) => compare_primitives::<Int8Type>(left, right),
+        (Int16, Int16) => compare_primitives::<Int16Type>(left, right),
+        (Int32, Int32) => compare_primitives::<Int32Type>(left, right),
+        (Int64, Int64) => compare_primitives::<Int64Type>(left, right),
+        (Float32, Float32) => compare_float::<Float32Type>(left, right),
+        (Float64, Float64) => compare_float::<Float64Type>(left, right),
+        (Date32(_), Date32(_)) => compare_primitives::<Date32Type>(left, right),
+        (Date64(_), Date64(_)) => compare_primitives::<Date64Type>(left, right),
+        (Time32(Second), Time32(Second)) => {
+            compare_primitives::<Time32SecondType>(left, right)
+        }
+        (Time32(Millisecond), Time32(Millisecond)) => {
+            compare_primitives::<Time32MillisecondType>(left, right)
+        }
+        (Time64(Microsecond), Time64(Microsecond)) => {
+            compare_primitives::<Time64MicrosecondType>(left, right)
+        }
+        (Time64(Nanosecond), Time64(Nanosecond)) => {
+            compare_primitives::<Time64NanosecondType>(left, right)
+        }
+        (Timestamp(Second, _), Timestamp(Second, _)) => {
+            compare_primitives::<TimestampSecondType>(left, right)
+        }
+        (Timestamp(Millisecond, _), Timestamp(Millisecond, _)) => {
+            compare_primitives::<TimestampMillisecondType>(left, right)
+        }
+        (Timestamp(Microsecond, _), Timestamp(Microsecond, _)) => {
+            compare_primitives::<TimestampMicrosecondType>(left, right)
+        }
+        (Timestamp(Nanosecond, _), Timestamp(Nanosecond, _)) => {
+            compare_primitives::<TimestampNanosecondType>(left, right)
+        }
+        (Interval(YearMonth), Interval(YearMonth)) => {
+            compare_primitives::<IntervalYearMonthType>(left, right)
+        }
+        (Interval(DayTime), Interval(DayTime)) => {
+            compare_primitives::<IntervalDayTimeType>(left, right)
+        }
+        (Duration(Second), Duration(Second)) => {
+            compare_primitives::<DurationSecondType>(left, right)
+        }
+        (Duration(Millisecond), Duration(Millisecond)) => {
+            compare_primitives::<DurationMillisecondType>(left, right)
+        }
+        (Duration(Microsecond), Duration(Microsecond)) => {
+            compare_primitives::<DurationMicrosecondType>(left, right)
+        }
+        (Duration(Nanosecond), Duration(Nanosecond)) => {
+            compare_primitives::<DurationNanosecondType>(left, right)
+        }
+        (Utf8, Utf8) => compare_string::<i32>(left, right),
+        (LargeUtf8, LargeUtf8) => compare_string::<i64>(left, right),
+        (
+            Dictionary(key_type_lhs, value_type_lhs),
+            Dictionary(key_type_rhs, value_type_rhs),
+        ) => {
+            if value_type_lhs.as_ref() != &DataType::Utf8
+                || value_type_rhs.as_ref() != &DataType::Utf8
+            {
+                return Err(ArrowError::InvalidArgumentError(
+                    "Arrow still does not support comparisons of non-string dictionary arrays"
+                        .to_string(),
+                ));
             }
-
-            // convert to bits with canonical pattern for NaN
-            let a = if a.is_nan() {
-                <$T>::NAN.to_bits()
-            } else {
-                a.to_bits()
-            };
-            let b = if b.is_nan() {
-                <$T>::NAN.to_bits()
-            } else {
-                b.to_bits()
-            };
-
-            if a == b {
-                // Equal or both NaN
-                Ordering::Equal
-            } else if a < b {
-                // (-0.0, 0.0) or (!NaN, NaN)
-                Ordering::Less
-            } else {
-                // (0.0, -0.0) or (NaN, !NaN)
-                Ordering::Greater
+            match (key_type_lhs.as_ref(), key_type_rhs.as_ref()) {
+                (a, b) if a != b => {
+                    return Err(ArrowError::InvalidArgumentError(
+                        "Can't compare arrays of different types".to_string(),
+                    ));
+                }
+                (UInt8, UInt8) => compare_dict_string::<UInt8Type>(left, right),
+                (UInt16, UInt16) => compare_dict_string::<UInt16Type>(left, right),
+                (UInt32, UInt32) => compare_dict_string::<UInt32Type>(left, right),
+                (UInt64, UInt64) => compare_dict_string::<UInt64Type>(left, right),
+                (Int8, Int8) => compare_dict_string::<Int8Type>(left, right),
+                (Int16, Int16) => compare_dict_string::<Int16Type>(left, right),
+                (Int32, Int32) => compare_dict_string::<Int32Type>(left, right),
+                (Int64, Int64) => compare_dict_string::<Int64Type>(left, right),
+                _ => todo!(),
             }
         }
-    };
+        _ => todo!(),

Review comment:
       We can add a helpful err instead of panicking

##########
File path: rust/arrow/src/compute/kernels/sort.rs
##########
@@ -453,49 +466,46 @@ pub fn lexsort(columns: &[SortColumn]) -> Result<Vec<ArrayRef>> {
 /// Sort elements lexicographically from a list of `ArrayRef` into an unsigned integer
 /// (`UInt32Array`) of indices.
 pub fn lexsort_to_indices(columns: &[SortColumn]) -> Result<UInt32Array> {
+    if columns.len() == 0 {
+        return Err(ArrowError::InvalidArgumentError(
+            "Sort requires at least one column".to_string(),
+        ));
+    }
     if columns.len() == 1 {
         // fallback to non-lexical sort
         let column = &columns[0];
         return sort_to_indices(&column.values, column.options);
     }
 
-    let mut row_count = None;
+    let row_count = columns[0].values.len();
+    if columns.iter().any(|item| item.values.len() != row_count) {
+        return Err(ArrowError::ComputeError(
+            "lexical sort columns have different row counts".to_string(),
+        ));
+    };
+
     // convert ArrayRefs to OrdArray trait objects and perform row count check
     let flat_columns = columns
         .iter()
-        .map(|column| -> Result<(&Array, Box<OrdArray>, SortOptions)> {
-            // row count check
-            let curr_row_count = column.values.len() - column.values.offset();
-            match row_count {
-                None => {
-                    row_count = Some(curr_row_count);
-                }
-                Some(cnt) => {
-                    if curr_row_count != cnt {
-                        return Err(ArrowError::ComputeError(
-                            "lexical sort columns have different row counts".to_string(),
-                        ));
-                    }
-                }
-            }
-            // flatten and convert to OrdArray
+        .map(|column| -> Result<(&Array, DynComparator, SortOptions)> {
+            // flatten and convert build comparators
             Ok((
                 column.values.as_ref(),
-                as_ordarray(&column.values)?,
+                build_compare(column.values.as_ref(), column.values.as_ref())?,

Review comment:
       I'm happy with this approach, very creative




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org