You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by GitBox <gi...@apache.org> on 2020/10/26 11:53:49 UTC

[GitHub] [arrow] alamb commented on a change in pull request #8517: ARROW-10381: [Rust] Generalized Ordering for inter-array comparisons

alamb commented on a change in pull request #8517:
URL: https://github.com/apache/arrow/pull/8517#discussion_r511896397



##########
File path: rust/arrow/src/compute/kernels/sort.rs
##########
@@ -41,110 +41,123 @@ pub fn sort(values: &ArrayRef, options: Option<SortOptions>) -> Result<ArrayRef>
     take(values, &indices, None)
 }
 
+fn partition_nan<T: ArrowPrimitiveType>(
+    array: &ArrayRef,
+    v: Vec<u32>,
+) -> (Vec<u32>, Vec<u32>) {
+    // partition by nan for float types
+    if T::DATA_TYPE == DataType::Float32 {
+        // T::Native has no `is_nan` and thus we need to downcast
+        let array = array
+            .as_any()
+            .downcast_ref::<Float32Array>()
+            .expect("Unable to downcast array");
+        let has_nan = v.iter().any(|index| array.value(*index as usize).is_nan());

Review comment:
       A minor comment is that you might be able  to make the code smaller (avoid scanning the array twice) if we don't bother to check for `has_nan`-- and just always do `v.into_iter().partition(....)`.  The rationale is if `has_nan` is false, you had to scan the entire array anyways.
   
   I have not tested my statement, and it is based on assumptions of the compiler being clever about optimizing `into_iter()` 
   

##########
File path: rust/arrow/src/compute/kernels/sort.rs
##########
@@ -41,110 +41,123 @@ pub fn sort(values: &ArrayRef, options: Option<SortOptions>) -> Result<ArrayRef>
     take(values, &indices, None)
 }
 
+fn partition_nan<T: ArrowPrimitiveType>(

Review comment:
       Suggest add comments:
   
   ```
       // partition indices into non-NaN and NaN
   ```

##########
File path: rust/arrow/src/compute/kernels/sort.rs
##########
@@ -41,110 +41,123 @@ pub fn sort(values: &ArrayRef, options: Option<SortOptions>) -> Result<ArrayRef>
     take(values, &indices, None)
 }
 
+fn partition_nan<T: ArrowPrimitiveType>(
+    array: &ArrayRef,
+    v: Vec<u32>,
+) -> (Vec<u32>, Vec<u32>) {
+    // partition by nan for float types
+    if T::DATA_TYPE == DataType::Float32 {
+        // T::Native has no `is_nan` and thus we need to downcast
+        let array = array
+            .as_any()
+            .downcast_ref::<Float32Array>()
+            .expect("Unable to downcast array");
+        let has_nan = v.iter().any(|index| array.value(*index as usize).is_nan());
+        if has_nan {
+            v.into_iter()
+                .partition(|index| !array.value(*index as usize).is_nan())
+        } else {
+            (v, vec![])
+        }
+    } else if T::DATA_TYPE == DataType::Float64 {
+        let array = array
+            .as_any()
+            .downcast_ref::<Float64Array>()
+            .expect("Unable to downcast array");
+        let has_nan = v.iter().any(|index| array.value(*index as usize).is_nan());
+        if has_nan {
+            v.into_iter()
+                .partition(|index| !array.value(*index as usize).is_nan())
+        } else {
+            (v, vec![])
+        }
+    } else {
+        unreachable!("Partition by nan is only applicable to float types")
+    }
+}
+
+fn partition_validity(array: &ArrayRef) -> (Vec<u32>, Vec<u32>) {

Review comment:
       ```
       // partition indices into valid and null indices
   ```

##########
File path: rust/arrow/src/array/ord.rs
##########
@@ -15,297 +15,259 @@
 // specific language governing permissions and limitations
 // under the License.
 
-//! Defines trait for array element comparison
+//! Contains functions and function factories to compare arrays.
 
 use std::cmp::Ordering;
 
 use crate::array::*;
+use crate::datatypes::TimeUnit;
 use crate::datatypes::*;
 use crate::error::{ArrowError, Result};
 
-use TimeUnit::*;
+use num::Float;
 
-/// Trait for Arrays that can be sorted
-///
-/// Example:
-/// ```
-/// use std::cmp::Ordering;
-/// use arrow::array::*;
-/// use arrow::datatypes::*;
-///
-/// let arr: Box<dyn OrdArray> = Box::new(PrimitiveArray::<Int64Type>::from(vec![
-///     Some(-2),
-///     Some(89),
-///     Some(-64),
-///     Some(101),
-/// ]));
-///
-/// assert_eq!(arr.cmp_value(1, 2), Ordering::Greater);
-/// ```
-pub trait OrdArray {
-    /// Return ordering between array element at index i and j
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering;
-}
+/// The public interface to compare values from arrays in a dynamically-typed fashion.
+pub type DynComparator<'a> = Box<dyn Fn(usize, usize) -> Ordering + 'a>;

Review comment:
       ```
   /// Compare the values at two arbitrary indices in two arrays.
   ```

##########
File path: rust/arrow/src/array/ord.rs
##########
@@ -15,297 +15,280 @@
 // specific language governing permissions and limitations
 // under the License.
 
-//! Defines trait for array element comparison
+//! Contains functions and function factories to compare arrays.
 
 use std::cmp::Ordering;
 
 use crate::array::*;
+use crate::datatypes::TimeUnit;
 use crate::datatypes::*;
 use crate::error::{ArrowError, Result};
 
-use TimeUnit::*;
+use num::Float;
 
-/// Trait for Arrays that can be sorted
-///
-/// Example:
-/// ```
-/// use std::cmp::Ordering;
-/// use arrow::array::*;
-/// use arrow::datatypes::*;
-///
-/// let arr: Box<dyn OrdArray> = Box::new(PrimitiveArray::<Int64Type>::from(vec![
-///     Some(-2),
-///     Some(89),
-///     Some(-64),
-///     Some(101),
-/// ]));
-///
-/// assert_eq!(arr.cmp_value(1, 2), Ordering::Greater);
-/// ```
-pub trait OrdArray {
-    /// Return ordering between array element at index i and j
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering;
-}
+/// The public interface to compare values from arrays in a dynamically-typed fashion.
+pub type DynComparator<'a> = Box<dyn Fn(usize, usize) -> Ordering + 'a>;
 
-impl<T: OrdArray> OrdArray for Box<T> {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        T::cmp_value(self, i, j)
+/// compares two floats, placing NaNs at last
+fn cmp_nans_last<T: Float>(a: &T, b: &T) -> Ordering {
+    match (a, b) {
+        (x, y) if x.is_nan() && y.is_nan() => Ordering::Equal,
+        (x, _) if x.is_nan() => Ordering::Greater,
+        (_, y) if y.is_nan() => Ordering::Less,
+        (_, _) => a.partial_cmp(b).unwrap(),
     }
 }
 
-impl<T: OrdArray> OrdArray for &T {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        T::cmp_value(self, i, j)
-    }
+fn compare_primitives<'a, T: ArrowPrimitiveType>(
+    left: &'a Array,
+    right: &'a Array,
+) -> DynComparator<'a>
+where
+    T::Native: Ord,
+{
+    let left = left.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    let right = right.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    Box::new(move |i, j| left.value(i).cmp(&right.value(j)))
 }
 
-impl<T: ArrowPrimitiveType> OrdArray for PrimitiveArray<T>
+fn compare_float<'a, T: ArrowPrimitiveType>(
+    left: &'a Array,
+    right: &'a Array,
+) -> DynComparator<'a>
 where
-    T::Native: std::cmp::Ord,
+    T::Native: Float,
 {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        self.value(i).cmp(&self.value(j))
-    }
+    let left = left.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    let right = right.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    Box::new(move |i, j| cmp_nans_last(&left.value(i), &right.value(j)))
 }
 
-impl OrdArray for StringArray {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        self.value(i).cmp(self.value(j))
-    }
+fn compare_string<'a, T>(left: &'a Array, right: &'a Array) -> DynComparator<'a>
+where
+    T: StringOffsetSizeTrait,
+{
+    let left = left
+        .as_any()
+        .downcast_ref::<GenericStringArray<T>>()
+        .unwrap();
+    let right = right
+        .as_any()
+        .downcast_ref::<GenericStringArray<T>>()
+        .unwrap();
+    Box::new(move |i, j| left.value(i).cmp(&right.value(j)))
 }
 
-impl OrdArray for NullArray {
-    fn cmp_value(&self, _i: usize, _j: usize) -> Ordering {
-        Ordering::Equal
-    }
+fn compare_dict_string<'a, T>(left: &'a Array, right: &'a Array) -> DynComparator<'a>
+where
+    T: ArrowDictionaryKeyType,
+{
+    let left = left.as_any().downcast_ref::<DictionaryArray<T>>().unwrap();
+    let right = right.as_any().downcast_ref::<DictionaryArray<T>>().unwrap();
+    let left_keys = left.keys_array();
+    let right_keys = right.keys_array();
+
+    let left_values = StringArray::from(left.values().data());
+    let right_values = StringArray::from(left.values().data());
+
+    Box::new(move |i: usize, j: usize| {
+        let key_left = left_keys.value(i).to_usize().unwrap();
+        let key_right = right_keys.value(j).to_usize().unwrap();
+        let left = left_values.value(key_left);
+        let right = right_values.value(key_right);
+        left.cmp(&right)
+    })
 }
 
-macro_rules! float_ord_cmp {
-    ($NAME: ident, $T: ty) => {
-        #[inline]
-        fn $NAME(a: $T, b: $T) -> Ordering {
-            if a < b {
-                return Ordering::Less;
-            }
-            if a > b {
-                return Ordering::Greater;
+/// returns a comparison function that compares two values at two different positions
+/// between the two arrays.
+/// The arrays' types must be equal.
+/// # Example
+/// ```
+/// use arrow::array::{build_compare, Int32Array};
+///
+/// # fn main() -> arrow::error::Result<()> {
+/// let array1 = Int32Array::from(vec![1, 2]);
+/// let array2 = Int32Array::from(vec![3, 4]);
+///
+/// let cmp = build_compare(&array1, &array2)?;
+///
+/// // 1 (index 0 of array1) is smaller than 4 (index 1 of array2)
+/// assert_eq!(std::cmp::Ordering::Less, (cmp)(0, 1));
+/// # Ok(())
+/// # }
+/// ```
+// This is a factory of comparisons.
+// The lifetime 'a enforces that we cannot use the closure beyond any of the array's lifetime.
+pub fn build_compare<'a>(left: &'a Array, right: &'a Array) -> Result<DynComparator<'a>> {
+    use DataType::*;
+    use IntervalUnit::*;
+    use TimeUnit::*;
+    Ok(match (left.data_type(), right.data_type()) {
+        (a, b) if a != b => {
+            return Err(ArrowError::InvalidArgumentError(
+                "Can't compare arrays of different types".to_string(),
+            ));
+        }
+        (Boolean, Boolean) => compare_primitives::<BooleanType>(left, right),
+        (UInt8, UInt8) => compare_primitives::<UInt8Type>(left, right),
+        (UInt16, UInt16) => compare_primitives::<UInt16Type>(left, right),
+        (UInt32, UInt32) => compare_primitives::<UInt32Type>(left, right),
+        (UInt64, UInt64) => compare_primitives::<UInt64Type>(left, right),
+        (Int8, Int8) => compare_primitives::<Int8Type>(left, right),
+        (Int16, Int16) => compare_primitives::<Int16Type>(left, right),
+        (Int32, Int32) => compare_primitives::<Int32Type>(left, right),
+        (Int64, Int64) => compare_primitives::<Int64Type>(left, right),
+        (Float32, Float32) => compare_float::<Float32Type>(left, right),
+        (Float64, Float64) => compare_float::<Float64Type>(left, right),
+        (Date32(_), Date32(_)) => compare_primitives::<Date32Type>(left, right),
+        (Date64(_), Date64(_)) => compare_primitives::<Date64Type>(left, right),
+        (Time32(Second), Time32(Second)) => {
+            compare_primitives::<Time32SecondType>(left, right)
+        }
+        (Time32(Millisecond), Time32(Millisecond)) => {
+            compare_primitives::<Time32MillisecondType>(left, right)
+        }
+        (Time64(Microsecond), Time64(Microsecond)) => {
+            compare_primitives::<Time64MicrosecondType>(left, right)
+        }
+        (Time64(Nanosecond), Time64(Nanosecond)) => {
+            compare_primitives::<Time64NanosecondType>(left, right)
+        }
+        (Timestamp(Second, _), Timestamp(Second, _)) => {
+            compare_primitives::<TimestampSecondType>(left, right)
+        }
+        (Timestamp(Millisecond, _), Timestamp(Millisecond, _)) => {
+            compare_primitives::<TimestampMillisecondType>(left, right)
+        }
+        (Timestamp(Microsecond, _), Timestamp(Microsecond, _)) => {
+            compare_primitives::<TimestampMicrosecondType>(left, right)
+        }
+        (Timestamp(Nanosecond, _), Timestamp(Nanosecond, _)) => {
+            compare_primitives::<TimestampNanosecondType>(left, right)
+        }
+        (Interval(YearMonth), Interval(YearMonth)) => {
+            compare_primitives::<IntervalYearMonthType>(left, right)
+        }
+        (Interval(DayTime), Interval(DayTime)) => {
+            compare_primitives::<IntervalDayTimeType>(left, right)
+        }
+        (Duration(Second), Duration(Second)) => {
+            compare_primitives::<DurationSecondType>(left, right)
+        }
+        (Duration(Millisecond), Duration(Millisecond)) => {
+            compare_primitives::<DurationMillisecondType>(left, right)
+        }
+        (Duration(Microsecond), Duration(Microsecond)) => {
+            compare_primitives::<DurationMicrosecondType>(left, right)
+        }
+        (Duration(Nanosecond), Duration(Nanosecond)) => {
+            compare_primitives::<DurationNanosecondType>(left, right)
+        }
+        (Utf8, Utf8) => compare_string::<i32>(left, right),
+        (LargeUtf8, LargeUtf8) => compare_string::<i64>(left, right),
+        (
+            Dictionary(key_type_lhs, value_type_lhs),
+            Dictionary(key_type_rhs, value_type_rhs),
+        ) => {
+            if value_type_lhs.as_ref() != &DataType::Utf8
+                || value_type_rhs.as_ref() != &DataType::Utf8
+            {
+                return Err(ArrowError::InvalidArgumentError(
+                    "Arrow still does not support comparisons of non-string dictionary arrays"

Review comment:
       Or maybe we can leverage the comparison kernel somehow

##########
File path: rust/arrow/src/compute/kernels/sort.rs
##########
@@ -41,110 +41,123 @@ pub fn sort(values: &ArrayRef, options: Option<SortOptions>) -> Result<ArrayRef>
     take(values, &indices, None)
 }
 
+fn partition_nan<T: ArrowPrimitiveType>(
+    array: &ArrayRef,
+    v: Vec<u32>,
+) -> (Vec<u32>, Vec<u32>) {
+    // partition by nan for float types
+    if T::DATA_TYPE == DataType::Float32 {
+        // T::Native has no `is_nan` and thus we need to downcast
+        let array = array
+            .as_any()
+            .downcast_ref::<Float32Array>()
+            .expect("Unable to downcast array");
+        let has_nan = v.iter().any(|index| array.value(*index as usize).is_nan());

Review comment:
       But I see this is just refactored from below, so no need to change

##########
File path: rust/arrow/src/array/ord.rs
##########
@@ -15,297 +15,280 @@
 // specific language governing permissions and limitations
 // under the License.
 
-//! Defines trait for array element comparison
+//! Contains functions and function factories to compare arrays.
 
 use std::cmp::Ordering;
 
 use crate::array::*;
+use crate::datatypes::TimeUnit;
 use crate::datatypes::*;
 use crate::error::{ArrowError, Result};
 
-use TimeUnit::*;
+use num::Float;
 
-/// Trait for Arrays that can be sorted
-///
-/// Example:
-/// ```
-/// use std::cmp::Ordering;
-/// use arrow::array::*;
-/// use arrow::datatypes::*;
-///
-/// let arr: Box<dyn OrdArray> = Box::new(PrimitiveArray::<Int64Type>::from(vec![
-///     Some(-2),
-///     Some(89),
-///     Some(-64),
-///     Some(101),
-/// ]));
-///
-/// assert_eq!(arr.cmp_value(1, 2), Ordering::Greater);
-/// ```
-pub trait OrdArray {
-    /// Return ordering between array element at index i and j
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering;
-}
+/// The public interface to compare values from arrays in a dynamically-typed fashion.
+pub type DynComparator<'a> = Box<dyn Fn(usize, usize) -> Ordering + 'a>;
 
-impl<T: OrdArray> OrdArray for Box<T> {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        T::cmp_value(self, i, j)
+/// compares two floats, placing NaNs at last
+fn cmp_nans_last<T: Float>(a: &T, b: &T) -> Ordering {
+    match (a, b) {
+        (x, y) if x.is_nan() && y.is_nan() => Ordering::Equal,
+        (x, _) if x.is_nan() => Ordering::Greater,
+        (_, y) if y.is_nan() => Ordering::Less,
+        (_, _) => a.partial_cmp(b).unwrap(),
     }
 }
 
-impl<T: OrdArray> OrdArray for &T {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        T::cmp_value(self, i, j)
-    }
+fn compare_primitives<'a, T: ArrowPrimitiveType>(
+    left: &'a Array,
+    right: &'a Array,
+) -> DynComparator<'a>
+where
+    T::Native: Ord,
+{
+    let left = left.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    let right = right.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    Box::new(move |i, j| left.value(i).cmp(&right.value(j)))
 }
 
-impl<T: ArrowPrimitiveType> OrdArray for PrimitiveArray<T>
+fn compare_float<'a, T: ArrowPrimitiveType>(
+    left: &'a Array,
+    right: &'a Array,
+) -> DynComparator<'a>
 where
-    T::Native: std::cmp::Ord,
+    T::Native: Float,
 {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        self.value(i).cmp(&self.value(j))
-    }
+    let left = left.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    let right = right.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    Box::new(move |i, j| cmp_nans_last(&left.value(i), &right.value(j)))
 }
 
-impl OrdArray for StringArray {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        self.value(i).cmp(self.value(j))
-    }
+fn compare_string<'a, T>(left: &'a Array, right: &'a Array) -> DynComparator<'a>
+where
+    T: StringOffsetSizeTrait,
+{
+    let left = left
+        .as_any()
+        .downcast_ref::<GenericStringArray<T>>()
+        .unwrap();
+    let right = right
+        .as_any()
+        .downcast_ref::<GenericStringArray<T>>()
+        .unwrap();
+    Box::new(move |i, j| left.value(i).cmp(&right.value(j)))
 }
 
-impl OrdArray for NullArray {
-    fn cmp_value(&self, _i: usize, _j: usize) -> Ordering {
-        Ordering::Equal
-    }
+fn compare_dict_string<'a, T>(left: &'a Array, right: &'a Array) -> DynComparator<'a>
+where
+    T: ArrowDictionaryKeyType,
+{
+    let left = left.as_any().downcast_ref::<DictionaryArray<T>>().unwrap();
+    let right = right.as_any().downcast_ref::<DictionaryArray<T>>().unwrap();
+    let left_keys = left.keys_array();
+    let right_keys = right.keys_array();
+
+    let left_values = StringArray::from(left.values().data());
+    let right_values = StringArray::from(left.values().data());
+
+    Box::new(move |i: usize, j: usize| {
+        let key_left = left_keys.value(i).to_usize().unwrap();
+        let key_right = right_keys.value(j).to_usize().unwrap();
+        let left = left_values.value(key_left);
+        let right = right_values.value(key_right);
+        left.cmp(&right)
+    })
 }
 
-macro_rules! float_ord_cmp {
-    ($NAME: ident, $T: ty) => {
-        #[inline]
-        fn $NAME(a: $T, b: $T) -> Ordering {
-            if a < b {
-                return Ordering::Less;
-            }
-            if a > b {
-                return Ordering::Greater;
+/// returns a comparison function that compares two values at two different positions
+/// between the two arrays.
+/// The arrays' types must be equal.
+/// # Example
+/// ```
+/// use arrow::array::{build_compare, Int32Array};
+///
+/// # fn main() -> arrow::error::Result<()> {
+/// let array1 = Int32Array::from(vec![1, 2]);
+/// let array2 = Int32Array::from(vec![3, 4]);
+///
+/// let cmp = build_compare(&array1, &array2)?;
+///
+/// // 1 (index 0 of array1) is smaller than 4 (index 1 of array2)
+/// assert_eq!(std::cmp::Ordering::Less, (cmp)(0, 1));
+/// # Ok(())
+/// # }
+/// ```
+// This is a factory of comparisons.
+// The lifetime 'a enforces that we cannot use the closure beyond any of the array's lifetime.
+pub fn build_compare<'a>(left: &'a Array, right: &'a Array) -> Result<DynComparator<'a>> {
+    use DataType::*;
+    use IntervalUnit::*;
+    use TimeUnit::*;
+    Ok(match (left.data_type(), right.data_type()) {
+        (a, b) if a != b => {
+            return Err(ArrowError::InvalidArgumentError(
+                "Can't compare arrays of different types".to_string(),
+            ));
+        }
+        (Boolean, Boolean) => compare_primitives::<BooleanType>(left, right),
+        (UInt8, UInt8) => compare_primitives::<UInt8Type>(left, right),
+        (UInt16, UInt16) => compare_primitives::<UInt16Type>(left, right),
+        (UInt32, UInt32) => compare_primitives::<UInt32Type>(left, right),
+        (UInt64, UInt64) => compare_primitives::<UInt64Type>(left, right),
+        (Int8, Int8) => compare_primitives::<Int8Type>(left, right),
+        (Int16, Int16) => compare_primitives::<Int16Type>(left, right),
+        (Int32, Int32) => compare_primitives::<Int32Type>(left, right),
+        (Int64, Int64) => compare_primitives::<Int64Type>(left, right),
+        (Float32, Float32) => compare_float::<Float32Type>(left, right),
+        (Float64, Float64) => compare_float::<Float64Type>(left, right),
+        (Date32(_), Date32(_)) => compare_primitives::<Date32Type>(left, right),
+        (Date64(_), Date64(_)) => compare_primitives::<Date64Type>(left, right),
+        (Time32(Second), Time32(Second)) => {
+            compare_primitives::<Time32SecondType>(left, right)
+        }
+        (Time32(Millisecond), Time32(Millisecond)) => {
+            compare_primitives::<Time32MillisecondType>(left, right)
+        }
+        (Time64(Microsecond), Time64(Microsecond)) => {
+            compare_primitives::<Time64MicrosecondType>(left, right)
+        }
+        (Time64(Nanosecond), Time64(Nanosecond)) => {
+            compare_primitives::<Time64NanosecondType>(left, right)
+        }
+        (Timestamp(Second, _), Timestamp(Second, _)) => {
+            compare_primitives::<TimestampSecondType>(left, right)
+        }
+        (Timestamp(Millisecond, _), Timestamp(Millisecond, _)) => {
+            compare_primitives::<TimestampMillisecondType>(left, right)
+        }
+        (Timestamp(Microsecond, _), Timestamp(Microsecond, _)) => {
+            compare_primitives::<TimestampMicrosecondType>(left, right)
+        }
+        (Timestamp(Nanosecond, _), Timestamp(Nanosecond, _)) => {
+            compare_primitives::<TimestampNanosecondType>(left, right)
+        }
+        (Interval(YearMonth), Interval(YearMonth)) => {
+            compare_primitives::<IntervalYearMonthType>(left, right)
+        }
+        (Interval(DayTime), Interval(DayTime)) => {
+            compare_primitives::<IntervalDayTimeType>(left, right)
+        }
+        (Duration(Second), Duration(Second)) => {
+            compare_primitives::<DurationSecondType>(left, right)
+        }
+        (Duration(Millisecond), Duration(Millisecond)) => {
+            compare_primitives::<DurationMillisecondType>(left, right)
+        }
+        (Duration(Microsecond), Duration(Microsecond)) => {
+            compare_primitives::<DurationMicrosecondType>(left, right)
+        }
+        (Duration(Nanosecond), Duration(Nanosecond)) => {
+            compare_primitives::<DurationNanosecondType>(left, right)
+        }
+        (Utf8, Utf8) => compare_string::<i32>(left, right),
+        (LargeUtf8, LargeUtf8) => compare_string::<i64>(left, right),
+        (
+            Dictionary(key_type_lhs, value_type_lhs),
+            Dictionary(key_type_rhs, value_type_rhs),
+        ) => {
+            if value_type_lhs.as_ref() != &DataType::Utf8
+                || value_type_rhs.as_ref() != &DataType::Utf8
+            {
+                return Err(ArrowError::InvalidArgumentError(
+                    "Arrow still does not support comparisons of non-string dictionary arrays"
+                        .to_string(),
+                ));
             }
-
-            // convert to bits with canonical pattern for NaN
-            let a = if a.is_nan() {
-                <$T>::NAN.to_bits()
-            } else {
-                a.to_bits()
-            };
-            let b = if b.is_nan() {
-                <$T>::NAN.to_bits()
-            } else {
-                b.to_bits()
-            };
-
-            if a == b {
-                // Equal or both NaN
-                Ordering::Equal
-            } else if a < b {
-                // (-0.0, 0.0) or (!NaN, NaN)
-                Ordering::Less
-            } else {
-                // (0.0, -0.0) or (NaN, !NaN)
-                Ordering::Greater
+            match (key_type_lhs.as_ref(), key_type_rhs.as_ref()) {
+                (a, b) if a != b => {
+                    return Err(ArrowError::InvalidArgumentError(
+                        "Can't compare arrays of different types".to_string(),
+                    ));
+                }
+                (UInt8, UInt8) => compare_dict_string::<UInt8Type>(left, right),
+                (UInt16, UInt16) => compare_dict_string::<UInt16Type>(left, right),
+                (UInt32, UInt32) => compare_dict_string::<UInt32Type>(left, right),
+                (UInt64, UInt64) => compare_dict_string::<UInt64Type>(left, right),
+                (Int8, Int8) => compare_dict_string::<Int8Type>(left, right),
+                (Int16, Int16) => compare_dict_string::<Int16Type>(left, right),
+                (Int32, Int32) => compare_dict_string::<Int32Type>(left, right),
+                (Int64, Int64) => compare_dict_string::<Int64Type>(left, right),
+                _ => todo!(),
             }
         }
-    };
+        _ => todo!(),

Review comment:
       ➕ 

##########
File path: rust/arrow/benches/sort_kernel.rs
##########
@@ -0,0 +1,83 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#[macro_use]
+extern crate criterion;
+use criterion::Criterion;
+
+use rand::Rng;
+use std::sync::Arc;
+
+extern crate arrow;
+
+use arrow::array::*;
+use arrow::compute::kernels::sort::{lexsort, SortColumn};
+
+fn create_array(size: usize, with_nulls: bool) -> ArrayRef {
+    // use random numbers to avoid spurious compiler optimizations wrt to branching
+    let mut rng = rand::thread_rng();
+    let mut builder = Float32Builder::new(size);
+
+    for _ in 0..size {
+        if with_nulls && rng.gen::<f32>() > 0.5 {
+            builder.append_null().unwrap();
+        } else {
+            builder.append_value(rng.gen()).unwrap();
+        }
+    }
+    Arc::new(builder.finish())
+}
+
+fn bench_sort(arr_a: &ArrayRef, array_b: &ArrayRef) {
+    let columns = vec![
+        SortColumn {
+            values: arr_a.clone(),
+            options: None,
+        },
+        SortColumn {
+            values: array_b.clone(),
+            options: None,
+        },
+    ];
+
+    criterion::black_box(lexsort(&columns).unwrap());
+}
+
+fn add_benchmark(c: &mut Criterion) {

Review comment:
       What would you think about adding a sort kernel for `u64` data as well?

##########
File path: rust/arrow/src/array/ord.rs
##########
@@ -15,297 +15,259 @@
 // specific language governing permissions and limitations
 // under the License.
 
-//! Defines trait for array element comparison
+//! Contains functions and function factories to compare arrays.
 
 use std::cmp::Ordering;
 
 use crate::array::*;
+use crate::datatypes::TimeUnit;
 use crate::datatypes::*;
 use crate::error::{ArrowError, Result};
 
-use TimeUnit::*;
+use num::Float;
 
-/// Trait for Arrays that can be sorted
-///
-/// Example:
-/// ```
-/// use std::cmp::Ordering;
-/// use arrow::array::*;
-/// use arrow::datatypes::*;
-///
-/// let arr: Box<dyn OrdArray> = Box::new(PrimitiveArray::<Int64Type>::from(vec![
-///     Some(-2),
-///     Some(89),
-///     Some(-64),
-///     Some(101),
-/// ]));
-///
-/// assert_eq!(arr.cmp_value(1, 2), Ordering::Greater);
-/// ```
-pub trait OrdArray {
-    /// Return ordering between array element at index i and j
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering;
-}
+/// The public interface to compare values from arrays in a dynamically-typed fashion.
+pub type DynComparator<'a> = Box<dyn Fn(usize, usize) -> Ordering + 'a>;
 
-impl<T: OrdArray> OrdArray for Box<T> {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        T::cmp_value(self, i, j)
+/// compares two floats, placing NaNs at last
+fn cmp_nans_last<T: Float>(a: &T, b: &T) -> Ordering {
+    match (a, b) {
+        (x, y) if x.is_nan() && y.is_nan() => Ordering::Equal,
+        (x, _) if x.is_nan() => Ordering::Greater,
+        (_, y) if y.is_nan() => Ordering::Less,
+        (_, _) => a.partial_cmp(b).unwrap(),
     }
 }
 
-impl<T: OrdArray> OrdArray for &T {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        T::cmp_value(self, i, j)
-    }
-}
-
-impl<T: ArrowPrimitiveType> OrdArray for PrimitiveArray<T>
+fn compare_primitives<'a, T: ArrowPrimitiveType>(
+    left: &'a Array,
+    right: &'a Array,
+) -> DynComparator<'a>
 where
-    T::Native: std::cmp::Ord,
+    T::Native: Ord,
 {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        self.value(i).cmp(&self.value(j))
-    }
+    let left = left.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    let right = right.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    Box::new(move |i, j| left.value(i).cmp(&right.value(j)))
 }
 
-impl OrdArray for StringArray {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        self.value(i).cmp(self.value(j))
-    }
-}
-
-impl OrdArray for NullArray {
-    fn cmp_value(&self, _i: usize, _j: usize) -> Ordering {
-        Ordering::Equal
-    }
-}
-
-macro_rules! float_ord_cmp {
-    ($NAME: ident, $T: ty) => {
-        #[inline]
-        fn $NAME(a: $T, b: $T) -> Ordering {
-            if a < b {
-                return Ordering::Less;
-            }
-            if a > b {
-                return Ordering::Greater;
-            }
-
-            // convert to bits with canonical pattern for NaN
-            let a = if a.is_nan() {
-                <$T>::NAN.to_bits()
-            } else {
-                a.to_bits()
-            };
-            let b = if b.is_nan() {
-                <$T>::NAN.to_bits()
-            } else {
-                b.to_bits()
-            };
-
-            if a == b {
-                // Equal or both NaN
-                Ordering::Equal
-            } else if a < b {
-                // (-0.0, 0.0) or (!NaN, NaN)
-                Ordering::Less
-            } else {
-                // (0.0, -0.0) or (NaN, !NaN)
-                Ordering::Greater
-            }
-        }
-    };
-}
-
-float_ord_cmp!(cmp_f64, f64);
-float_ord_cmp!(cmp_f32, f32);
-
-#[repr(transparent)]
-struct Float64ArrayAsOrdArray<'a>(&'a Float64Array);
-#[repr(transparent)]
-struct Float32ArrayAsOrdArray<'a>(&'a Float32Array);
-
-impl OrdArray for Float64ArrayAsOrdArray<'_> {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        let a: f64 = self.0.value(i);
-        let b: f64 = self.0.value(j);
-
-        cmp_f64(a, b)
-    }
-}
-
-impl OrdArray for Float32ArrayAsOrdArray<'_> {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        let a: f32 = self.0.value(i);
-        let b: f32 = self.0.value(j);
-
-        cmp_f32(a, b)
-    }
-}
-
-fn float32_as_ord_array<'a>(array: &'a ArrayRef) -> Box<dyn OrdArray + 'a> {
-    let float_array: &Float32Array = as_primitive_array::<Float32Type>(array);
-    Box::new(Float32ArrayAsOrdArray(float_array))
-}
-
-fn float64_as_ord_array<'a>(array: &'a ArrayRef) -> Box<dyn OrdArray + 'a> {
-    let float_array: &Float64Array = as_primitive_array::<Float64Type>(array);
-    Box::new(Float64ArrayAsOrdArray(float_array))
-}
-
-struct StringDictionaryArrayAsOrdArray<'a, T: ArrowDictionaryKeyType> {
-    dict_array: &'a DictionaryArray<T>,
-    values: StringArray,
-    keys: PrimitiveArray<T>,
+fn compare_float<'a, T: ArrowPrimitiveType>(
+    left: &'a Array,
+    right: &'a Array,
+) -> DynComparator<'a>
+where
+    T::Native: Float,
+{
+    let left = left.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    let right = right.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    Box::new(move |i, j| cmp_nans_last(&left.value(i), &right.value(j)))
 }
 
-impl<T: ArrowDictionaryKeyType> OrdArray for StringDictionaryArrayAsOrdArray<'_, T> {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        let keys = &self.keys;
-        let dict = &self.values;
-
-        let key_a: T::Native = keys.value(i);
-        let key_b: T::Native = keys.value(j);
-
-        let str_a = dict.value(key_a.to_usize().unwrap());
-        let str_b = dict.value(key_b.to_usize().unwrap());
-
-        str_a.cmp(str_b)
-    }
+fn compare_string<'a, T>(left: &'a Array, right: &'a Array) -> DynComparator<'a>
+where
+    T: StringOffsetSizeTrait,
+{
+    let left = left
+        .as_any()
+        .downcast_ref::<GenericStringArray<T>>()
+        .unwrap();
+    let right = right
+        .as_any()
+        .downcast_ref::<GenericStringArray<T>>()
+        .unwrap();
+    Box::new(move |i, j| left.value(i).cmp(&right.value(j)))
 }
 
-fn string_dict_as_ord_array<'a, T: ArrowDictionaryKeyType>(
-    array: &'a ArrayRef,
-) -> Box<dyn OrdArray + 'a>
+fn compare_dict_string<'a, T>(left: &'a Array, right: &'a Array) -> DynComparator<'a>
 where
-    T::Native: std::cmp::Ord,
+    T: ArrowDictionaryKeyType,
 {
-    let dict_array = as_dictionary_array::<T>(array);
-    let keys = dict_array.keys_array();
-
-    let values = &dict_array.values();
-    let values = StringArray::from(values.data());
-
-    Box::new(StringDictionaryArrayAsOrdArray {
-        dict_array,
-        values,
-        keys,
+    let left = left.as_any().downcast_ref::<DictionaryArray<T>>().unwrap();
+    let right = right.as_any().downcast_ref::<DictionaryArray<T>>().unwrap();
+    let left_keys = left.keys_array();
+    let right_keys = right.keys_array();
+
+    let left_values = StringArray::from(left.values().data());
+    let right_values = StringArray::from(left.values().data());
+
+    Box::new(move |i: usize, j: usize| {
+        let key_left = left_keys.value(i).to_usize().unwrap();
+        let key_right = right_keys.value(j).to_usize().unwrap();
+        let left = left_values.value(key_left);
+        let right = right_values.value(key_right);
+        left.cmp(&right)
     })
 }
 
-/// Convert ArrayRef to OrdArray trait object
-pub fn as_ordarray<'a>(values: &'a ArrayRef) -> Result<Box<OrdArray + 'a>> {
-    match values.data_type() {
-        DataType::Boolean => Ok(Box::new(as_boolean_array(&values))),
-        DataType::Utf8 => Ok(Box::new(as_string_array(&values))),
-        DataType::Null => Ok(Box::new(as_null_array(&values))),
-        DataType::Int8 => Ok(Box::new(as_primitive_array::<Int8Type>(&values))),
-        DataType::Int16 => Ok(Box::new(as_primitive_array::<Int16Type>(&values))),
-        DataType::Int32 => Ok(Box::new(as_primitive_array::<Int32Type>(&values))),
-        DataType::Int64 => Ok(Box::new(as_primitive_array::<Int64Type>(&values))),
-        DataType::UInt8 => Ok(Box::new(as_primitive_array::<UInt8Type>(&values))),
-        DataType::UInt16 => Ok(Box::new(as_primitive_array::<UInt16Type>(&values))),
-        DataType::UInt32 => Ok(Box::new(as_primitive_array::<UInt32Type>(&values))),
-        DataType::UInt64 => Ok(Box::new(as_primitive_array::<UInt64Type>(&values))),
-        DataType::Date32(_) => Ok(Box::new(as_primitive_array::<Date32Type>(&values))),
-        DataType::Date64(_) => Ok(Box::new(as_primitive_array::<Date64Type>(&values))),
-        DataType::Time32(Second) => {
-            Ok(Box::new(as_primitive_array::<Time32SecondType>(&values)))
+/// returns a comparison function that compares two values at two different positions
+/// between the two arrays.
+/// The arrays' types must be equal.
+/// # Example
+/// ```
+/// use arrow::array::{build_compare, Int32Array};
+///
+/// # fn main() -> arrow::error::Result<()> {
+/// let array1 = Int32Array::from(vec![1, 2]);
+/// let array2 = Int32Array::from(vec![3, 4]);
+///
+/// let cmp = build_compare(&array1, &array2)?;
+///
+/// // 1 (index 0 of array1) is smaller than 4 (index 1 of array2)
+/// assert_eq!(std::cmp::Ordering::Less, (cmp)(0, 1));
+/// # Ok(())
+/// # }
+/// ```
+// This is a factory of comparisons.
+// The lifetime 'a enforces that we cannot use the closure beyond any of the array's lifetime.
+pub fn build_compare<'a>(left: &'a Array, right: &'a Array) -> Result<DynComparator<'a>> {

Review comment:
       This all looks cool -- though I do wonder at the runtime overhead of doing dynamic dispatch for each comparison. However, if you have measured no performance regression, this code seems like an improvement to what is on master.
   
   I do wonder if there is some way to reuse the work in the comparison kernel:
   https://github.com/apache/arrow/blob/master/rust/arrow/src/compute/kernels/comparison.rs to try and improve performance over row by row comparison + a dynamic dispatch.
   
   Perhaps that is a good optimization for some future PR




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org