You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by GitBox <gi...@apache.org> on 2020/10/24 16:41:48 UTC

[GitHub] [arrow] jorgecarleitao opened a new pull request #8517: ARROW-10381: [Rust] Generalized Ordering for inter-array comparisons

jorgecarleitao opened a new pull request #8517:
URL: https://github.com/apache/arrow/pull/8517


   Currently, the code on `array/ord.rs` is centered around intra-array comparison. However, this does not allow to compare values from two different arrays, which is required on e.g. merge-sort operations.
   
   This PR:
   * simplifies the code around `sort` by splitting it in smaller functions for ease of understanding (first 2 commits)
   * adds a benchmark to the `sort` kernel which I used to verify that there was no performance regression  (3rd commit)
   * simplifies and generalizes the code around `ord` to support comparisons between two arrays (that may be the same), 4th commit
   * adds some more tests to edge cases around float comparison (nan and zeros)
   
   There was no performance change on my computer.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [arrow] jorgecarleitao commented on a change in pull request #8517: ARROW-10381: [Rust] Generalized Ordering for inter-array comparisons

Posted by GitBox <gi...@apache.org>.
jorgecarleitao commented on a change in pull request #8517:
URL: https://github.com/apache/arrow/pull/8517#discussion_r511491453



##########
File path: rust/arrow/src/compute/kernels/sort.rs
##########
@@ -453,49 +466,46 @@ pub fn lexsort(columns: &[SortColumn]) -> Result<Vec<ArrayRef>> {
 /// Sort elements lexicographically from a list of `ArrayRef` into an unsigned integer
 /// (`UInt32Array`) of indices.
 pub fn lexsort_to_indices(columns: &[SortColumn]) -> Result<UInt32Array> {
+    if columns.len() == 0 {
+        return Err(ArrowError::InvalidArgumentError(
+            "Sort requires at least one column".to_string(),
+        ));
+    }
     if columns.len() == 1 {
         // fallback to non-lexical sort
         let column = &columns[0];
         return sort_to_indices(&column.values, column.options);
     }
 
-    let mut row_count = None;
+    let row_count = columns[0].values.len();
+    if columns.iter().any(|item| item.values.len() != row_count) {
+        return Err(ArrowError::ComputeError(
+            "lexical sort columns have different row counts".to_string(),
+        ));
+    };
+
     // convert ArrayRefs to OrdArray trait objects and perform row count check
     let flat_columns = columns
         .iter()
-        .map(|column| -> Result<(&Array, Box<OrdArray>, SortOptions)> {
-            // row count check
-            let curr_row_count = column.values.len() - column.values.offset();
-            match row_count {
-                None => {
-                    row_count = Some(curr_row_count);
-                }
-                Some(cnt) => {
-                    if curr_row_count != cnt {
-                        return Err(ArrowError::ComputeError(
-                            "lexical sort columns have different row counts".to_string(),
-                        ));
-                    }
-                }
-            }
-            // flatten and convert to OrdArray
+        .map(|column| -> Result<(&Array, DynComparator, SortOptions)> {
+            // flatten and convert build comparators
             Ok((
                 column.values.as_ref(),
-                as_ordarray(&column.values)?,
+                build_compare(column.values.as_ref(), column.values.as_ref())?,

Review comment:
       This is the main change: we no longer create an `OrdArray`. Instead, return a heap-allocated function (`DynComparator`) that can compare stuff.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [arrow] jorgecarleitao commented on a change in pull request #8517: ARROW-10381: [Rust] Generalized Ordering for inter-array comparisons

Posted by GitBox <gi...@apache.org>.
jorgecarleitao commented on a change in pull request #8517:
URL: https://github.com/apache/arrow/pull/8517#discussion_r511491453



##########
File path: rust/arrow/src/compute/kernels/sort.rs
##########
@@ -453,49 +466,46 @@ pub fn lexsort(columns: &[SortColumn]) -> Result<Vec<ArrayRef>> {
 /// Sort elements lexicographically from a list of `ArrayRef` into an unsigned integer
 /// (`UInt32Array`) of indices.
 pub fn lexsort_to_indices(columns: &[SortColumn]) -> Result<UInt32Array> {
+    if columns.len() == 0 {
+        return Err(ArrowError::InvalidArgumentError(
+            "Sort requires at least one column".to_string(),
+        ));
+    }
     if columns.len() == 1 {
         // fallback to non-lexical sort
         let column = &columns[0];
         return sort_to_indices(&column.values, column.options);
     }
 
-    let mut row_count = None;
+    let row_count = columns[0].values.len();
+    if columns.iter().any(|item| item.values.len() != row_count) {
+        return Err(ArrowError::ComputeError(
+            "lexical sort columns have different row counts".to_string(),
+        ));
+    };
+
     // convert ArrayRefs to OrdArray trait objects and perform row count check
     let flat_columns = columns
         .iter()
-        .map(|column| -> Result<(&Array, Box<OrdArray>, SortOptions)> {
-            // row count check
-            let curr_row_count = column.values.len() - column.values.offset();
-            match row_count {
-                None => {
-                    row_count = Some(curr_row_count);
-                }
-                Some(cnt) => {
-                    if curr_row_count != cnt {
-                        return Err(ArrowError::ComputeError(
-                            "lexical sort columns have different row counts".to_string(),
-                        ));
-                    }
-                }
-            }
-            // flatten and convert to OrdArray
+        .map(|column| -> Result<(&Array, DynComparator, SortOptions)> {
+            // flatten and convert build comparators
             Ok((
                 column.values.as_ref(),
-                as_ordarray(&column.values)?,
+                build_compare(column.values.as_ref(), column.values.as_ref())?,

Review comment:
       This is the main change: we no longer create an `OrdArray`. Instead, return a heap-allocated function (`DynComparator`) that can compare values.
   
   In this case we pass the same array on the left and right side, but it works more generally (with different arrays), which is the necessary interface we need to merge-sort arrays.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [arrow] nevi-me commented on a change in pull request #8517: ARROW-10381: [Rust] Generalized Ordering for inter-array comparisons

Posted by GitBox <gi...@apache.org>.
nevi-me commented on a change in pull request #8517:
URL: https://github.com/apache/arrow/pull/8517#discussion_r511642958



##########
File path: rust/arrow/src/array/ord.rs
##########
@@ -15,297 +15,280 @@
 // specific language governing permissions and limitations
 // under the License.
 
-//! Defines trait for array element comparison
+//! Contains functions and function factories to compare arrays.
 
 use std::cmp::Ordering;
 
 use crate::array::*;
+use crate::datatypes::TimeUnit;
 use crate::datatypes::*;
 use crate::error::{ArrowError, Result};
 
-use TimeUnit::*;
+use num::Float;
 
-/// Trait for Arrays that can be sorted
-///
-/// Example:
-/// ```
-/// use std::cmp::Ordering;
-/// use arrow::array::*;
-/// use arrow::datatypes::*;
-///
-/// let arr: Box<dyn OrdArray> = Box::new(PrimitiveArray::<Int64Type>::from(vec![
-///     Some(-2),
-///     Some(89),
-///     Some(-64),
-///     Some(101),
-/// ]));
-///
-/// assert_eq!(arr.cmp_value(1, 2), Ordering::Greater);
-/// ```
-pub trait OrdArray {
-    /// Return ordering between array element at index i and j
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering;
-}
+/// The public interface to compare values from arrays in a dynamically-typed fashion.
+pub type DynComparator<'a> = Box<dyn Fn(usize, usize) -> Ordering + 'a>;
 
-impl<T: OrdArray> OrdArray for Box<T> {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        T::cmp_value(self, i, j)
+/// compares two floats, placing NaNs at last
+fn cmp_nans_last<T: Float>(a: &T, b: &T) -> Ordering {
+    match (a, b) {
+        (x, y) if x.is_nan() && y.is_nan() => Ordering::Equal,
+        (x, _) if x.is_nan() => Ordering::Greater,
+        (_, y) if y.is_nan() => Ordering::Less,
+        (_, _) => a.partial_cmp(b).unwrap(),
     }
 }
 
-impl<T: OrdArray> OrdArray for &T {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        T::cmp_value(self, i, j)
-    }
+fn compare_primitives<'a, T: ArrowPrimitiveType>(
+    left: &'a Array,
+    right: &'a Array,
+) -> DynComparator<'a>
+where
+    T::Native: Ord,
+{
+    let left = left.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    let right = right.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    Box::new(move |i, j| left.value(i).cmp(&right.value(j)))
 }
 
-impl<T: ArrowPrimitiveType> OrdArray for PrimitiveArray<T>
+fn compare_float<'a, T: ArrowPrimitiveType>(
+    left: &'a Array,
+    right: &'a Array,
+) -> DynComparator<'a>
 where
-    T::Native: std::cmp::Ord,
+    T::Native: Float,
 {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        self.value(i).cmp(&self.value(j))
-    }
+    let left = left.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    let right = right.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    Box::new(move |i, j| cmp_nans_last(&left.value(i), &right.value(j)))
 }
 
-impl OrdArray for StringArray {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        self.value(i).cmp(self.value(j))
-    }
+fn compare_string<'a, T>(left: &'a Array, right: &'a Array) -> DynComparator<'a>
+where
+    T: StringOffsetSizeTrait,
+{
+    let left = left
+        .as_any()
+        .downcast_ref::<GenericStringArray<T>>()
+        .unwrap();
+    let right = right
+        .as_any()
+        .downcast_ref::<GenericStringArray<T>>()
+        .unwrap();
+    Box::new(move |i, j| left.value(i).cmp(&right.value(j)))
 }
 
-impl OrdArray for NullArray {
-    fn cmp_value(&self, _i: usize, _j: usize) -> Ordering {
-        Ordering::Equal
-    }
+fn compare_dict_string<'a, T>(left: &'a Array, right: &'a Array) -> DynComparator<'a>
+where
+    T: ArrowDictionaryKeyType,
+{
+    let left = left.as_any().downcast_ref::<DictionaryArray<T>>().unwrap();
+    let right = right.as_any().downcast_ref::<DictionaryArray<T>>().unwrap();
+    let left_keys = left.keys_array();
+    let right_keys = right.keys_array();
+
+    let left_values = StringArray::from(left.values().data());
+    let right_values = StringArray::from(left.values().data());
+
+    Box::new(move |i: usize, j: usize| {
+        let key_left = left_keys.value(i).to_usize().unwrap();
+        let key_right = right_keys.value(j).to_usize().unwrap();
+        let left = left_values.value(key_left);
+        let right = right_values.value(key_right);
+        left.cmp(&right)
+    })
 }
 
-macro_rules! float_ord_cmp {
-    ($NAME: ident, $T: ty) => {
-        #[inline]
-        fn $NAME(a: $T, b: $T) -> Ordering {
-            if a < b {
-                return Ordering::Less;
-            }
-            if a > b {
-                return Ordering::Greater;
+/// returns a comparison function that compares two values at two different positions
+/// between the two arrays.
+/// The arrays' types must be equal.
+/// # Example
+/// ```
+/// use arrow::array::{build_compare, Int32Array};
+///
+/// # fn main() -> arrow::error::Result<()> {
+/// let array1 = Int32Array::from(vec![1, 2]);
+/// let array2 = Int32Array::from(vec![3, 4]);
+///
+/// let cmp = build_compare(&array1, &array2)?;
+///
+/// // 1 (index 0 of array1) is smaller than 4 (index 1 of array2)
+/// assert_eq!(std::cmp::Ordering::Less, (cmp)(0, 1));
+/// # Ok(())
+/// # }
+/// ```
+// This is a factory of comparisons.
+// The lifetime 'a enforces that we cannot use the closure beyond any of the array's lifetime.
+pub fn build_compare<'a>(left: &'a Array, right: &'a Array) -> Result<DynComparator<'a>> {
+    use DataType::*;
+    use IntervalUnit::*;
+    use TimeUnit::*;
+    Ok(match (left.data_type(), right.data_type()) {
+        (a, b) if a != b => {
+            return Err(ArrowError::InvalidArgumentError(
+                "Can't compare arrays of different types".to_string(),
+            ));
+        }
+        (Boolean, Boolean) => compare_primitives::<BooleanType>(left, right),
+        (UInt8, UInt8) => compare_primitives::<UInt8Type>(left, right),
+        (UInt16, UInt16) => compare_primitives::<UInt16Type>(left, right),
+        (UInt32, UInt32) => compare_primitives::<UInt32Type>(left, right),
+        (UInt64, UInt64) => compare_primitives::<UInt64Type>(left, right),
+        (Int8, Int8) => compare_primitives::<Int8Type>(left, right),
+        (Int16, Int16) => compare_primitives::<Int16Type>(left, right),
+        (Int32, Int32) => compare_primitives::<Int32Type>(left, right),
+        (Int64, Int64) => compare_primitives::<Int64Type>(left, right),
+        (Float32, Float32) => compare_float::<Float32Type>(left, right),
+        (Float64, Float64) => compare_float::<Float64Type>(left, right),
+        (Date32(_), Date32(_)) => compare_primitives::<Date32Type>(left, right),
+        (Date64(_), Date64(_)) => compare_primitives::<Date64Type>(left, right),
+        (Time32(Second), Time32(Second)) => {
+            compare_primitives::<Time32SecondType>(left, right)
+        }
+        (Time32(Millisecond), Time32(Millisecond)) => {
+            compare_primitives::<Time32MillisecondType>(left, right)
+        }
+        (Time64(Microsecond), Time64(Microsecond)) => {
+            compare_primitives::<Time64MicrosecondType>(left, right)
+        }
+        (Time64(Nanosecond), Time64(Nanosecond)) => {
+            compare_primitives::<Time64NanosecondType>(left, right)
+        }
+        (Timestamp(Second, _), Timestamp(Second, _)) => {
+            compare_primitives::<TimestampSecondType>(left, right)
+        }
+        (Timestamp(Millisecond, _), Timestamp(Millisecond, _)) => {
+            compare_primitives::<TimestampMillisecondType>(left, right)
+        }
+        (Timestamp(Microsecond, _), Timestamp(Microsecond, _)) => {
+            compare_primitives::<TimestampMicrosecondType>(left, right)
+        }
+        (Timestamp(Nanosecond, _), Timestamp(Nanosecond, _)) => {
+            compare_primitives::<TimestampNanosecondType>(left, right)
+        }
+        (Interval(YearMonth), Interval(YearMonth)) => {
+            compare_primitives::<IntervalYearMonthType>(left, right)
+        }
+        (Interval(DayTime), Interval(DayTime)) => {
+            compare_primitives::<IntervalDayTimeType>(left, right)
+        }
+        (Duration(Second), Duration(Second)) => {
+            compare_primitives::<DurationSecondType>(left, right)
+        }
+        (Duration(Millisecond), Duration(Millisecond)) => {
+            compare_primitives::<DurationMillisecondType>(left, right)
+        }
+        (Duration(Microsecond), Duration(Microsecond)) => {
+            compare_primitives::<DurationMicrosecondType>(left, right)
+        }
+        (Duration(Nanosecond), Duration(Nanosecond)) => {
+            compare_primitives::<DurationNanosecondType>(left, right)
+        }
+        (Utf8, Utf8) => compare_string::<i32>(left, right),
+        (LargeUtf8, LargeUtf8) => compare_string::<i64>(left, right),
+        (
+            Dictionary(key_type_lhs, value_type_lhs),
+            Dictionary(key_type_rhs, value_type_rhs),
+        ) => {
+            if value_type_lhs.as_ref() != &DataType::Utf8
+                || value_type_rhs.as_ref() != &DataType::Utf8
+            {
+                return Err(ArrowError::InvalidArgumentError(
+                    "Arrow still does not support comparisons of non-string dictionary arrays"

Review comment:
       Would we incur a high cost if we cast dictionaries to primitives, then compared the primitives?

##########
File path: rust/arrow/src/array/ord.rs
##########
@@ -15,297 +15,280 @@
 // specific language governing permissions and limitations
 // under the License.
 
-//! Defines trait for array element comparison
+//! Contains functions and function factories to compare arrays.
 
 use std::cmp::Ordering;
 
 use crate::array::*;
+use crate::datatypes::TimeUnit;
 use crate::datatypes::*;
 use crate::error::{ArrowError, Result};
 
-use TimeUnit::*;
+use num::Float;
 
-/// Trait for Arrays that can be sorted
-///
-/// Example:
-/// ```
-/// use std::cmp::Ordering;
-/// use arrow::array::*;
-/// use arrow::datatypes::*;
-///
-/// let arr: Box<dyn OrdArray> = Box::new(PrimitiveArray::<Int64Type>::from(vec![
-///     Some(-2),
-///     Some(89),
-///     Some(-64),
-///     Some(101),
-/// ]));
-///
-/// assert_eq!(arr.cmp_value(1, 2), Ordering::Greater);
-/// ```
-pub trait OrdArray {
-    /// Return ordering between array element at index i and j
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering;
-}
+/// The public interface to compare values from arrays in a dynamically-typed fashion.
+pub type DynComparator<'a> = Box<dyn Fn(usize, usize) -> Ordering + 'a>;
 
-impl<T: OrdArray> OrdArray for Box<T> {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        T::cmp_value(self, i, j)
+/// compares two floats, placing NaNs at last
+fn cmp_nans_last<T: Float>(a: &T, b: &T) -> Ordering {
+    match (a, b) {
+        (x, y) if x.is_nan() && y.is_nan() => Ordering::Equal,
+        (x, _) if x.is_nan() => Ordering::Greater,
+        (_, y) if y.is_nan() => Ordering::Less,
+        (_, _) => a.partial_cmp(b).unwrap(),
     }
 }
 
-impl<T: OrdArray> OrdArray for &T {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        T::cmp_value(self, i, j)
-    }
+fn compare_primitives<'a, T: ArrowPrimitiveType>(
+    left: &'a Array,
+    right: &'a Array,
+) -> DynComparator<'a>
+where
+    T::Native: Ord,
+{
+    let left = left.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    let right = right.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    Box::new(move |i, j| left.value(i).cmp(&right.value(j)))
 }
 
-impl<T: ArrowPrimitiveType> OrdArray for PrimitiveArray<T>
+fn compare_float<'a, T: ArrowPrimitiveType>(
+    left: &'a Array,
+    right: &'a Array,
+) -> DynComparator<'a>
 where
-    T::Native: std::cmp::Ord,
+    T::Native: Float,
 {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        self.value(i).cmp(&self.value(j))
-    }
+    let left = left.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    let right = right.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    Box::new(move |i, j| cmp_nans_last(&left.value(i), &right.value(j)))
 }
 
-impl OrdArray for StringArray {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        self.value(i).cmp(self.value(j))
-    }
+fn compare_string<'a, T>(left: &'a Array, right: &'a Array) -> DynComparator<'a>
+where
+    T: StringOffsetSizeTrait,
+{
+    let left = left
+        .as_any()
+        .downcast_ref::<GenericStringArray<T>>()
+        .unwrap();
+    let right = right
+        .as_any()
+        .downcast_ref::<GenericStringArray<T>>()
+        .unwrap();
+    Box::new(move |i, j| left.value(i).cmp(&right.value(j)))
 }
 
-impl OrdArray for NullArray {
-    fn cmp_value(&self, _i: usize, _j: usize) -> Ordering {
-        Ordering::Equal
-    }
+fn compare_dict_string<'a, T>(left: &'a Array, right: &'a Array) -> DynComparator<'a>
+where
+    T: ArrowDictionaryKeyType,
+{
+    let left = left.as_any().downcast_ref::<DictionaryArray<T>>().unwrap();
+    let right = right.as_any().downcast_ref::<DictionaryArray<T>>().unwrap();
+    let left_keys = left.keys_array();
+    let right_keys = right.keys_array();
+
+    let left_values = StringArray::from(left.values().data());
+    let right_values = StringArray::from(left.values().data());
+
+    Box::new(move |i: usize, j: usize| {
+        let key_left = left_keys.value(i).to_usize().unwrap();
+        let key_right = right_keys.value(j).to_usize().unwrap();
+        let left = left_values.value(key_left);
+        let right = right_values.value(key_right);
+        left.cmp(&right)
+    })
 }
 
-macro_rules! float_ord_cmp {
-    ($NAME: ident, $T: ty) => {
-        #[inline]
-        fn $NAME(a: $T, b: $T) -> Ordering {
-            if a < b {
-                return Ordering::Less;
-            }
-            if a > b {
-                return Ordering::Greater;
+/// returns a comparison function that compares two values at two different positions
+/// between the two arrays.
+/// The arrays' types must be equal.
+/// # Example
+/// ```
+/// use arrow::array::{build_compare, Int32Array};
+///
+/// # fn main() -> arrow::error::Result<()> {
+/// let array1 = Int32Array::from(vec![1, 2]);
+/// let array2 = Int32Array::from(vec![3, 4]);
+///
+/// let cmp = build_compare(&array1, &array2)?;
+///
+/// // 1 (index 0 of array1) is smaller than 4 (index 1 of array2)
+/// assert_eq!(std::cmp::Ordering::Less, (cmp)(0, 1));
+/// # Ok(())
+/// # }
+/// ```
+// This is a factory of comparisons.
+// The lifetime 'a enforces that we cannot use the closure beyond any of the array's lifetime.
+pub fn build_compare<'a>(left: &'a Array, right: &'a Array) -> Result<DynComparator<'a>> {
+    use DataType::*;
+    use IntervalUnit::*;
+    use TimeUnit::*;
+    Ok(match (left.data_type(), right.data_type()) {
+        (a, b) if a != b => {
+            return Err(ArrowError::InvalidArgumentError(
+                "Can't compare arrays of different types".to_string(),
+            ));
+        }
+        (Boolean, Boolean) => compare_primitives::<BooleanType>(left, right),
+        (UInt8, UInt8) => compare_primitives::<UInt8Type>(left, right),
+        (UInt16, UInt16) => compare_primitives::<UInt16Type>(left, right),
+        (UInt32, UInt32) => compare_primitives::<UInt32Type>(left, right),
+        (UInt64, UInt64) => compare_primitives::<UInt64Type>(left, right),
+        (Int8, Int8) => compare_primitives::<Int8Type>(left, right),
+        (Int16, Int16) => compare_primitives::<Int16Type>(left, right),
+        (Int32, Int32) => compare_primitives::<Int32Type>(left, right),
+        (Int64, Int64) => compare_primitives::<Int64Type>(left, right),
+        (Float32, Float32) => compare_float::<Float32Type>(left, right),
+        (Float64, Float64) => compare_float::<Float64Type>(left, right),
+        (Date32(_), Date32(_)) => compare_primitives::<Date32Type>(left, right),
+        (Date64(_), Date64(_)) => compare_primitives::<Date64Type>(left, right),
+        (Time32(Second), Time32(Second)) => {
+            compare_primitives::<Time32SecondType>(left, right)
+        }
+        (Time32(Millisecond), Time32(Millisecond)) => {
+            compare_primitives::<Time32MillisecondType>(left, right)
+        }
+        (Time64(Microsecond), Time64(Microsecond)) => {
+            compare_primitives::<Time64MicrosecondType>(left, right)
+        }
+        (Time64(Nanosecond), Time64(Nanosecond)) => {
+            compare_primitives::<Time64NanosecondType>(left, right)
+        }
+        (Timestamp(Second, _), Timestamp(Second, _)) => {
+            compare_primitives::<TimestampSecondType>(left, right)
+        }
+        (Timestamp(Millisecond, _), Timestamp(Millisecond, _)) => {
+            compare_primitives::<TimestampMillisecondType>(left, right)
+        }
+        (Timestamp(Microsecond, _), Timestamp(Microsecond, _)) => {
+            compare_primitives::<TimestampMicrosecondType>(left, right)
+        }
+        (Timestamp(Nanosecond, _), Timestamp(Nanosecond, _)) => {
+            compare_primitives::<TimestampNanosecondType>(left, right)
+        }
+        (Interval(YearMonth), Interval(YearMonth)) => {
+            compare_primitives::<IntervalYearMonthType>(left, right)
+        }
+        (Interval(DayTime), Interval(DayTime)) => {
+            compare_primitives::<IntervalDayTimeType>(left, right)
+        }
+        (Duration(Second), Duration(Second)) => {
+            compare_primitives::<DurationSecondType>(left, right)
+        }
+        (Duration(Millisecond), Duration(Millisecond)) => {
+            compare_primitives::<DurationMillisecondType>(left, right)
+        }
+        (Duration(Microsecond), Duration(Microsecond)) => {
+            compare_primitives::<DurationMicrosecondType>(left, right)
+        }
+        (Duration(Nanosecond), Duration(Nanosecond)) => {
+            compare_primitives::<DurationNanosecondType>(left, right)
+        }
+        (Utf8, Utf8) => compare_string::<i32>(left, right),
+        (LargeUtf8, LargeUtf8) => compare_string::<i64>(left, right),
+        (
+            Dictionary(key_type_lhs, value_type_lhs),
+            Dictionary(key_type_rhs, value_type_rhs),
+        ) => {
+            if value_type_lhs.as_ref() != &DataType::Utf8
+                || value_type_rhs.as_ref() != &DataType::Utf8
+            {
+                return Err(ArrowError::InvalidArgumentError(
+                    "Arrow still does not support comparisons of non-string dictionary arrays"
+                        .to_string(),
+                ));
             }
-
-            // convert to bits with canonical pattern for NaN
-            let a = if a.is_nan() {
-                <$T>::NAN.to_bits()
-            } else {
-                a.to_bits()
-            };
-            let b = if b.is_nan() {
-                <$T>::NAN.to_bits()
-            } else {
-                b.to_bits()
-            };
-
-            if a == b {
-                // Equal or both NaN
-                Ordering::Equal
-            } else if a < b {
-                // (-0.0, 0.0) or (!NaN, NaN)
-                Ordering::Less
-            } else {
-                // (0.0, -0.0) or (NaN, !NaN)
-                Ordering::Greater
+            match (key_type_lhs.as_ref(), key_type_rhs.as_ref()) {
+                (a, b) if a != b => {
+                    return Err(ArrowError::InvalidArgumentError(
+                        "Can't compare arrays of different types".to_string(),
+                    ));
+                }
+                (UInt8, UInt8) => compare_dict_string::<UInt8Type>(left, right),
+                (UInt16, UInt16) => compare_dict_string::<UInt16Type>(left, right),
+                (UInt32, UInt32) => compare_dict_string::<UInt32Type>(left, right),
+                (UInt64, UInt64) => compare_dict_string::<UInt64Type>(left, right),
+                (Int8, Int8) => compare_dict_string::<Int8Type>(left, right),
+                (Int16, Int16) => compare_dict_string::<Int16Type>(left, right),
+                (Int32, Int32) => compare_dict_string::<Int32Type>(left, right),
+                (Int64, Int64) => compare_dict_string::<Int64Type>(left, right),
+                _ => todo!(),
             }
         }
-    };
+        _ => todo!(),

Review comment:
       We can add a helpful err instead of panicking

##########
File path: rust/arrow/src/compute/kernels/sort.rs
##########
@@ -453,49 +466,46 @@ pub fn lexsort(columns: &[SortColumn]) -> Result<Vec<ArrayRef>> {
 /// Sort elements lexicographically from a list of `ArrayRef` into an unsigned integer
 /// (`UInt32Array`) of indices.
 pub fn lexsort_to_indices(columns: &[SortColumn]) -> Result<UInt32Array> {
+    if columns.len() == 0 {
+        return Err(ArrowError::InvalidArgumentError(
+            "Sort requires at least one column".to_string(),
+        ));
+    }
     if columns.len() == 1 {
         // fallback to non-lexical sort
         let column = &columns[0];
         return sort_to_indices(&column.values, column.options);
     }
 
-    let mut row_count = None;
+    let row_count = columns[0].values.len();
+    if columns.iter().any(|item| item.values.len() != row_count) {
+        return Err(ArrowError::ComputeError(
+            "lexical sort columns have different row counts".to_string(),
+        ));
+    };
+
     // convert ArrayRefs to OrdArray trait objects and perform row count check
     let flat_columns = columns
         .iter()
-        .map(|column| -> Result<(&Array, Box<OrdArray>, SortOptions)> {
-            // row count check
-            let curr_row_count = column.values.len() - column.values.offset();
-            match row_count {
-                None => {
-                    row_count = Some(curr_row_count);
-                }
-                Some(cnt) => {
-                    if curr_row_count != cnt {
-                        return Err(ArrowError::ComputeError(
-                            "lexical sort columns have different row counts".to_string(),
-                        ));
-                    }
-                }
-            }
-            // flatten and convert to OrdArray
+        .map(|column| -> Result<(&Array, DynComparator, SortOptions)> {
+            // flatten and convert build comparators
             Ok((
                 column.values.as_ref(),
-                as_ordarray(&column.values)?,
+                build_compare(column.values.as_ref(), column.values.as_ref())?,

Review comment:
       I'm happy with this approach, very creative




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [arrow] alamb commented on a change in pull request #8517: ARROW-10381: [Rust] Generalized Ordering for inter-array comparisons

Posted by GitBox <gi...@apache.org>.
alamb commented on a change in pull request #8517:
URL: https://github.com/apache/arrow/pull/8517#discussion_r511896397



##########
File path: rust/arrow/src/compute/kernels/sort.rs
##########
@@ -41,110 +41,123 @@ pub fn sort(values: &ArrayRef, options: Option<SortOptions>) -> Result<ArrayRef>
     take(values, &indices, None)
 }
 
+fn partition_nan<T: ArrowPrimitiveType>(
+    array: &ArrayRef,
+    v: Vec<u32>,
+) -> (Vec<u32>, Vec<u32>) {
+    // partition by nan for float types
+    if T::DATA_TYPE == DataType::Float32 {
+        // T::Native has no `is_nan` and thus we need to downcast
+        let array = array
+            .as_any()
+            .downcast_ref::<Float32Array>()
+            .expect("Unable to downcast array");
+        let has_nan = v.iter().any(|index| array.value(*index as usize).is_nan());

Review comment:
       A minor comment is that you might be able  to make the code smaller (avoid scanning the array twice) if we don't bother to check for `has_nan`-- and just always do `v.into_iter().partition(....)`.  The rationale is if `has_nan` is false, you had to scan the entire array anyways.
   
   I have not tested my statement, and it is based on assumptions of the compiler being clever about optimizing `into_iter()` 
   

##########
File path: rust/arrow/src/compute/kernels/sort.rs
##########
@@ -41,110 +41,123 @@ pub fn sort(values: &ArrayRef, options: Option<SortOptions>) -> Result<ArrayRef>
     take(values, &indices, None)
 }
 
+fn partition_nan<T: ArrowPrimitiveType>(

Review comment:
       Suggest add comments:
   
   ```
       // partition indices into non-NaN and NaN
   ```

##########
File path: rust/arrow/src/compute/kernels/sort.rs
##########
@@ -41,110 +41,123 @@ pub fn sort(values: &ArrayRef, options: Option<SortOptions>) -> Result<ArrayRef>
     take(values, &indices, None)
 }
 
+fn partition_nan<T: ArrowPrimitiveType>(
+    array: &ArrayRef,
+    v: Vec<u32>,
+) -> (Vec<u32>, Vec<u32>) {
+    // partition by nan for float types
+    if T::DATA_TYPE == DataType::Float32 {
+        // T::Native has no `is_nan` and thus we need to downcast
+        let array = array
+            .as_any()
+            .downcast_ref::<Float32Array>()
+            .expect("Unable to downcast array");
+        let has_nan = v.iter().any(|index| array.value(*index as usize).is_nan());
+        if has_nan {
+            v.into_iter()
+                .partition(|index| !array.value(*index as usize).is_nan())
+        } else {
+            (v, vec![])
+        }
+    } else if T::DATA_TYPE == DataType::Float64 {
+        let array = array
+            .as_any()
+            .downcast_ref::<Float64Array>()
+            .expect("Unable to downcast array");
+        let has_nan = v.iter().any(|index| array.value(*index as usize).is_nan());
+        if has_nan {
+            v.into_iter()
+                .partition(|index| !array.value(*index as usize).is_nan())
+        } else {
+            (v, vec![])
+        }
+    } else {
+        unreachable!("Partition by nan is only applicable to float types")
+    }
+}
+
+fn partition_validity(array: &ArrayRef) -> (Vec<u32>, Vec<u32>) {

Review comment:
       ```
       // partition indices into valid and null indices
   ```

##########
File path: rust/arrow/src/array/ord.rs
##########
@@ -15,297 +15,259 @@
 // specific language governing permissions and limitations
 // under the License.
 
-//! Defines trait for array element comparison
+//! Contains functions and function factories to compare arrays.
 
 use std::cmp::Ordering;
 
 use crate::array::*;
+use crate::datatypes::TimeUnit;
 use crate::datatypes::*;
 use crate::error::{ArrowError, Result};
 
-use TimeUnit::*;
+use num::Float;
 
-/// Trait for Arrays that can be sorted
-///
-/// Example:
-/// ```
-/// use std::cmp::Ordering;
-/// use arrow::array::*;
-/// use arrow::datatypes::*;
-///
-/// let arr: Box<dyn OrdArray> = Box::new(PrimitiveArray::<Int64Type>::from(vec![
-///     Some(-2),
-///     Some(89),
-///     Some(-64),
-///     Some(101),
-/// ]));
-///
-/// assert_eq!(arr.cmp_value(1, 2), Ordering::Greater);
-/// ```
-pub trait OrdArray {
-    /// Return ordering between array element at index i and j
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering;
-}
+/// The public interface to compare values from arrays in a dynamically-typed fashion.
+pub type DynComparator<'a> = Box<dyn Fn(usize, usize) -> Ordering + 'a>;

Review comment:
       ```
   /// Compare the values at two arbitrary indices in two arrays.
   ```

##########
File path: rust/arrow/src/array/ord.rs
##########
@@ -15,297 +15,280 @@
 // specific language governing permissions and limitations
 // under the License.
 
-//! Defines trait for array element comparison
+//! Contains functions and function factories to compare arrays.
 
 use std::cmp::Ordering;
 
 use crate::array::*;
+use crate::datatypes::TimeUnit;
 use crate::datatypes::*;
 use crate::error::{ArrowError, Result};
 
-use TimeUnit::*;
+use num::Float;
 
-/// Trait for Arrays that can be sorted
-///
-/// Example:
-/// ```
-/// use std::cmp::Ordering;
-/// use arrow::array::*;
-/// use arrow::datatypes::*;
-///
-/// let arr: Box<dyn OrdArray> = Box::new(PrimitiveArray::<Int64Type>::from(vec![
-///     Some(-2),
-///     Some(89),
-///     Some(-64),
-///     Some(101),
-/// ]));
-///
-/// assert_eq!(arr.cmp_value(1, 2), Ordering::Greater);
-/// ```
-pub trait OrdArray {
-    /// Return ordering between array element at index i and j
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering;
-}
+/// The public interface to compare values from arrays in a dynamically-typed fashion.
+pub type DynComparator<'a> = Box<dyn Fn(usize, usize) -> Ordering + 'a>;
 
-impl<T: OrdArray> OrdArray for Box<T> {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        T::cmp_value(self, i, j)
+/// compares two floats, placing NaNs at last
+fn cmp_nans_last<T: Float>(a: &T, b: &T) -> Ordering {
+    match (a, b) {
+        (x, y) if x.is_nan() && y.is_nan() => Ordering::Equal,
+        (x, _) if x.is_nan() => Ordering::Greater,
+        (_, y) if y.is_nan() => Ordering::Less,
+        (_, _) => a.partial_cmp(b).unwrap(),
     }
 }
 
-impl<T: OrdArray> OrdArray for &T {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        T::cmp_value(self, i, j)
-    }
+fn compare_primitives<'a, T: ArrowPrimitiveType>(
+    left: &'a Array,
+    right: &'a Array,
+) -> DynComparator<'a>
+where
+    T::Native: Ord,
+{
+    let left = left.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    let right = right.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    Box::new(move |i, j| left.value(i).cmp(&right.value(j)))
 }
 
-impl<T: ArrowPrimitiveType> OrdArray for PrimitiveArray<T>
+fn compare_float<'a, T: ArrowPrimitiveType>(
+    left: &'a Array,
+    right: &'a Array,
+) -> DynComparator<'a>
 where
-    T::Native: std::cmp::Ord,
+    T::Native: Float,
 {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        self.value(i).cmp(&self.value(j))
-    }
+    let left = left.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    let right = right.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    Box::new(move |i, j| cmp_nans_last(&left.value(i), &right.value(j)))
 }
 
-impl OrdArray for StringArray {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        self.value(i).cmp(self.value(j))
-    }
+fn compare_string<'a, T>(left: &'a Array, right: &'a Array) -> DynComparator<'a>
+where
+    T: StringOffsetSizeTrait,
+{
+    let left = left
+        .as_any()
+        .downcast_ref::<GenericStringArray<T>>()
+        .unwrap();
+    let right = right
+        .as_any()
+        .downcast_ref::<GenericStringArray<T>>()
+        .unwrap();
+    Box::new(move |i, j| left.value(i).cmp(&right.value(j)))
 }
 
-impl OrdArray for NullArray {
-    fn cmp_value(&self, _i: usize, _j: usize) -> Ordering {
-        Ordering::Equal
-    }
+fn compare_dict_string<'a, T>(left: &'a Array, right: &'a Array) -> DynComparator<'a>
+where
+    T: ArrowDictionaryKeyType,
+{
+    let left = left.as_any().downcast_ref::<DictionaryArray<T>>().unwrap();
+    let right = right.as_any().downcast_ref::<DictionaryArray<T>>().unwrap();
+    let left_keys = left.keys_array();
+    let right_keys = right.keys_array();
+
+    let left_values = StringArray::from(left.values().data());
+    let right_values = StringArray::from(left.values().data());
+
+    Box::new(move |i: usize, j: usize| {
+        let key_left = left_keys.value(i).to_usize().unwrap();
+        let key_right = right_keys.value(j).to_usize().unwrap();
+        let left = left_values.value(key_left);
+        let right = right_values.value(key_right);
+        left.cmp(&right)
+    })
 }
 
-macro_rules! float_ord_cmp {
-    ($NAME: ident, $T: ty) => {
-        #[inline]
-        fn $NAME(a: $T, b: $T) -> Ordering {
-            if a < b {
-                return Ordering::Less;
-            }
-            if a > b {
-                return Ordering::Greater;
+/// returns a comparison function that compares two values at two different positions
+/// between the two arrays.
+/// The arrays' types must be equal.
+/// # Example
+/// ```
+/// use arrow::array::{build_compare, Int32Array};
+///
+/// # fn main() -> arrow::error::Result<()> {
+/// let array1 = Int32Array::from(vec![1, 2]);
+/// let array2 = Int32Array::from(vec![3, 4]);
+///
+/// let cmp = build_compare(&array1, &array2)?;
+///
+/// // 1 (index 0 of array1) is smaller than 4 (index 1 of array2)
+/// assert_eq!(std::cmp::Ordering::Less, (cmp)(0, 1));
+/// # Ok(())
+/// # }
+/// ```
+// This is a factory of comparisons.
+// The lifetime 'a enforces that we cannot use the closure beyond any of the array's lifetime.
+pub fn build_compare<'a>(left: &'a Array, right: &'a Array) -> Result<DynComparator<'a>> {
+    use DataType::*;
+    use IntervalUnit::*;
+    use TimeUnit::*;
+    Ok(match (left.data_type(), right.data_type()) {
+        (a, b) if a != b => {
+            return Err(ArrowError::InvalidArgumentError(
+                "Can't compare arrays of different types".to_string(),
+            ));
+        }
+        (Boolean, Boolean) => compare_primitives::<BooleanType>(left, right),
+        (UInt8, UInt8) => compare_primitives::<UInt8Type>(left, right),
+        (UInt16, UInt16) => compare_primitives::<UInt16Type>(left, right),
+        (UInt32, UInt32) => compare_primitives::<UInt32Type>(left, right),
+        (UInt64, UInt64) => compare_primitives::<UInt64Type>(left, right),
+        (Int8, Int8) => compare_primitives::<Int8Type>(left, right),
+        (Int16, Int16) => compare_primitives::<Int16Type>(left, right),
+        (Int32, Int32) => compare_primitives::<Int32Type>(left, right),
+        (Int64, Int64) => compare_primitives::<Int64Type>(left, right),
+        (Float32, Float32) => compare_float::<Float32Type>(left, right),
+        (Float64, Float64) => compare_float::<Float64Type>(left, right),
+        (Date32(_), Date32(_)) => compare_primitives::<Date32Type>(left, right),
+        (Date64(_), Date64(_)) => compare_primitives::<Date64Type>(left, right),
+        (Time32(Second), Time32(Second)) => {
+            compare_primitives::<Time32SecondType>(left, right)
+        }
+        (Time32(Millisecond), Time32(Millisecond)) => {
+            compare_primitives::<Time32MillisecondType>(left, right)
+        }
+        (Time64(Microsecond), Time64(Microsecond)) => {
+            compare_primitives::<Time64MicrosecondType>(left, right)
+        }
+        (Time64(Nanosecond), Time64(Nanosecond)) => {
+            compare_primitives::<Time64NanosecondType>(left, right)
+        }
+        (Timestamp(Second, _), Timestamp(Second, _)) => {
+            compare_primitives::<TimestampSecondType>(left, right)
+        }
+        (Timestamp(Millisecond, _), Timestamp(Millisecond, _)) => {
+            compare_primitives::<TimestampMillisecondType>(left, right)
+        }
+        (Timestamp(Microsecond, _), Timestamp(Microsecond, _)) => {
+            compare_primitives::<TimestampMicrosecondType>(left, right)
+        }
+        (Timestamp(Nanosecond, _), Timestamp(Nanosecond, _)) => {
+            compare_primitives::<TimestampNanosecondType>(left, right)
+        }
+        (Interval(YearMonth), Interval(YearMonth)) => {
+            compare_primitives::<IntervalYearMonthType>(left, right)
+        }
+        (Interval(DayTime), Interval(DayTime)) => {
+            compare_primitives::<IntervalDayTimeType>(left, right)
+        }
+        (Duration(Second), Duration(Second)) => {
+            compare_primitives::<DurationSecondType>(left, right)
+        }
+        (Duration(Millisecond), Duration(Millisecond)) => {
+            compare_primitives::<DurationMillisecondType>(left, right)
+        }
+        (Duration(Microsecond), Duration(Microsecond)) => {
+            compare_primitives::<DurationMicrosecondType>(left, right)
+        }
+        (Duration(Nanosecond), Duration(Nanosecond)) => {
+            compare_primitives::<DurationNanosecondType>(left, right)
+        }
+        (Utf8, Utf8) => compare_string::<i32>(left, right),
+        (LargeUtf8, LargeUtf8) => compare_string::<i64>(left, right),
+        (
+            Dictionary(key_type_lhs, value_type_lhs),
+            Dictionary(key_type_rhs, value_type_rhs),
+        ) => {
+            if value_type_lhs.as_ref() != &DataType::Utf8
+                || value_type_rhs.as_ref() != &DataType::Utf8
+            {
+                return Err(ArrowError::InvalidArgumentError(
+                    "Arrow still does not support comparisons of non-string dictionary arrays"

Review comment:
       Or maybe we can leverage the comparison kernel somehow

##########
File path: rust/arrow/src/compute/kernels/sort.rs
##########
@@ -41,110 +41,123 @@ pub fn sort(values: &ArrayRef, options: Option<SortOptions>) -> Result<ArrayRef>
     take(values, &indices, None)
 }
 
+fn partition_nan<T: ArrowPrimitiveType>(
+    array: &ArrayRef,
+    v: Vec<u32>,
+) -> (Vec<u32>, Vec<u32>) {
+    // partition by nan for float types
+    if T::DATA_TYPE == DataType::Float32 {
+        // T::Native has no `is_nan` and thus we need to downcast
+        let array = array
+            .as_any()
+            .downcast_ref::<Float32Array>()
+            .expect("Unable to downcast array");
+        let has_nan = v.iter().any(|index| array.value(*index as usize).is_nan());

Review comment:
       But I see this is just refactored from below, so no need to change

##########
File path: rust/arrow/src/array/ord.rs
##########
@@ -15,297 +15,280 @@
 // specific language governing permissions and limitations
 // under the License.
 
-//! Defines trait for array element comparison
+//! Contains functions and function factories to compare arrays.
 
 use std::cmp::Ordering;
 
 use crate::array::*;
+use crate::datatypes::TimeUnit;
 use crate::datatypes::*;
 use crate::error::{ArrowError, Result};
 
-use TimeUnit::*;
+use num::Float;
 
-/// Trait for Arrays that can be sorted
-///
-/// Example:
-/// ```
-/// use std::cmp::Ordering;
-/// use arrow::array::*;
-/// use arrow::datatypes::*;
-///
-/// let arr: Box<dyn OrdArray> = Box::new(PrimitiveArray::<Int64Type>::from(vec![
-///     Some(-2),
-///     Some(89),
-///     Some(-64),
-///     Some(101),
-/// ]));
-///
-/// assert_eq!(arr.cmp_value(1, 2), Ordering::Greater);
-/// ```
-pub trait OrdArray {
-    /// Return ordering between array element at index i and j
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering;
-}
+/// The public interface to compare values from arrays in a dynamically-typed fashion.
+pub type DynComparator<'a> = Box<dyn Fn(usize, usize) -> Ordering + 'a>;
 
-impl<T: OrdArray> OrdArray for Box<T> {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        T::cmp_value(self, i, j)
+/// compares two floats, placing NaNs at last
+fn cmp_nans_last<T: Float>(a: &T, b: &T) -> Ordering {
+    match (a, b) {
+        (x, y) if x.is_nan() && y.is_nan() => Ordering::Equal,
+        (x, _) if x.is_nan() => Ordering::Greater,
+        (_, y) if y.is_nan() => Ordering::Less,
+        (_, _) => a.partial_cmp(b).unwrap(),
     }
 }
 
-impl<T: OrdArray> OrdArray for &T {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        T::cmp_value(self, i, j)
-    }
+fn compare_primitives<'a, T: ArrowPrimitiveType>(
+    left: &'a Array,
+    right: &'a Array,
+) -> DynComparator<'a>
+where
+    T::Native: Ord,
+{
+    let left = left.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    let right = right.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    Box::new(move |i, j| left.value(i).cmp(&right.value(j)))
 }
 
-impl<T: ArrowPrimitiveType> OrdArray for PrimitiveArray<T>
+fn compare_float<'a, T: ArrowPrimitiveType>(
+    left: &'a Array,
+    right: &'a Array,
+) -> DynComparator<'a>
 where
-    T::Native: std::cmp::Ord,
+    T::Native: Float,
 {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        self.value(i).cmp(&self.value(j))
-    }
+    let left = left.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    let right = right.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    Box::new(move |i, j| cmp_nans_last(&left.value(i), &right.value(j)))
 }
 
-impl OrdArray for StringArray {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        self.value(i).cmp(self.value(j))
-    }
+fn compare_string<'a, T>(left: &'a Array, right: &'a Array) -> DynComparator<'a>
+where
+    T: StringOffsetSizeTrait,
+{
+    let left = left
+        .as_any()
+        .downcast_ref::<GenericStringArray<T>>()
+        .unwrap();
+    let right = right
+        .as_any()
+        .downcast_ref::<GenericStringArray<T>>()
+        .unwrap();
+    Box::new(move |i, j| left.value(i).cmp(&right.value(j)))
 }
 
-impl OrdArray for NullArray {
-    fn cmp_value(&self, _i: usize, _j: usize) -> Ordering {
-        Ordering::Equal
-    }
+fn compare_dict_string<'a, T>(left: &'a Array, right: &'a Array) -> DynComparator<'a>
+where
+    T: ArrowDictionaryKeyType,
+{
+    let left = left.as_any().downcast_ref::<DictionaryArray<T>>().unwrap();
+    let right = right.as_any().downcast_ref::<DictionaryArray<T>>().unwrap();
+    let left_keys = left.keys_array();
+    let right_keys = right.keys_array();
+
+    let left_values = StringArray::from(left.values().data());
+    let right_values = StringArray::from(left.values().data());
+
+    Box::new(move |i: usize, j: usize| {
+        let key_left = left_keys.value(i).to_usize().unwrap();
+        let key_right = right_keys.value(j).to_usize().unwrap();
+        let left = left_values.value(key_left);
+        let right = right_values.value(key_right);
+        left.cmp(&right)
+    })
 }
 
-macro_rules! float_ord_cmp {
-    ($NAME: ident, $T: ty) => {
-        #[inline]
-        fn $NAME(a: $T, b: $T) -> Ordering {
-            if a < b {
-                return Ordering::Less;
-            }
-            if a > b {
-                return Ordering::Greater;
+/// returns a comparison function that compares two values at two different positions
+/// between the two arrays.
+/// The arrays' types must be equal.
+/// # Example
+/// ```
+/// use arrow::array::{build_compare, Int32Array};
+///
+/// # fn main() -> arrow::error::Result<()> {
+/// let array1 = Int32Array::from(vec![1, 2]);
+/// let array2 = Int32Array::from(vec![3, 4]);
+///
+/// let cmp = build_compare(&array1, &array2)?;
+///
+/// // 1 (index 0 of array1) is smaller than 4 (index 1 of array2)
+/// assert_eq!(std::cmp::Ordering::Less, (cmp)(0, 1));
+/// # Ok(())
+/// # }
+/// ```
+// This is a factory of comparisons.
+// The lifetime 'a enforces that we cannot use the closure beyond any of the array's lifetime.
+pub fn build_compare<'a>(left: &'a Array, right: &'a Array) -> Result<DynComparator<'a>> {
+    use DataType::*;
+    use IntervalUnit::*;
+    use TimeUnit::*;
+    Ok(match (left.data_type(), right.data_type()) {
+        (a, b) if a != b => {
+            return Err(ArrowError::InvalidArgumentError(
+                "Can't compare arrays of different types".to_string(),
+            ));
+        }
+        (Boolean, Boolean) => compare_primitives::<BooleanType>(left, right),
+        (UInt8, UInt8) => compare_primitives::<UInt8Type>(left, right),
+        (UInt16, UInt16) => compare_primitives::<UInt16Type>(left, right),
+        (UInt32, UInt32) => compare_primitives::<UInt32Type>(left, right),
+        (UInt64, UInt64) => compare_primitives::<UInt64Type>(left, right),
+        (Int8, Int8) => compare_primitives::<Int8Type>(left, right),
+        (Int16, Int16) => compare_primitives::<Int16Type>(left, right),
+        (Int32, Int32) => compare_primitives::<Int32Type>(left, right),
+        (Int64, Int64) => compare_primitives::<Int64Type>(left, right),
+        (Float32, Float32) => compare_float::<Float32Type>(left, right),
+        (Float64, Float64) => compare_float::<Float64Type>(left, right),
+        (Date32(_), Date32(_)) => compare_primitives::<Date32Type>(left, right),
+        (Date64(_), Date64(_)) => compare_primitives::<Date64Type>(left, right),
+        (Time32(Second), Time32(Second)) => {
+            compare_primitives::<Time32SecondType>(left, right)
+        }
+        (Time32(Millisecond), Time32(Millisecond)) => {
+            compare_primitives::<Time32MillisecondType>(left, right)
+        }
+        (Time64(Microsecond), Time64(Microsecond)) => {
+            compare_primitives::<Time64MicrosecondType>(left, right)
+        }
+        (Time64(Nanosecond), Time64(Nanosecond)) => {
+            compare_primitives::<Time64NanosecondType>(left, right)
+        }
+        (Timestamp(Second, _), Timestamp(Second, _)) => {
+            compare_primitives::<TimestampSecondType>(left, right)
+        }
+        (Timestamp(Millisecond, _), Timestamp(Millisecond, _)) => {
+            compare_primitives::<TimestampMillisecondType>(left, right)
+        }
+        (Timestamp(Microsecond, _), Timestamp(Microsecond, _)) => {
+            compare_primitives::<TimestampMicrosecondType>(left, right)
+        }
+        (Timestamp(Nanosecond, _), Timestamp(Nanosecond, _)) => {
+            compare_primitives::<TimestampNanosecondType>(left, right)
+        }
+        (Interval(YearMonth), Interval(YearMonth)) => {
+            compare_primitives::<IntervalYearMonthType>(left, right)
+        }
+        (Interval(DayTime), Interval(DayTime)) => {
+            compare_primitives::<IntervalDayTimeType>(left, right)
+        }
+        (Duration(Second), Duration(Second)) => {
+            compare_primitives::<DurationSecondType>(left, right)
+        }
+        (Duration(Millisecond), Duration(Millisecond)) => {
+            compare_primitives::<DurationMillisecondType>(left, right)
+        }
+        (Duration(Microsecond), Duration(Microsecond)) => {
+            compare_primitives::<DurationMicrosecondType>(left, right)
+        }
+        (Duration(Nanosecond), Duration(Nanosecond)) => {
+            compare_primitives::<DurationNanosecondType>(left, right)
+        }
+        (Utf8, Utf8) => compare_string::<i32>(left, right),
+        (LargeUtf8, LargeUtf8) => compare_string::<i64>(left, right),
+        (
+            Dictionary(key_type_lhs, value_type_lhs),
+            Dictionary(key_type_rhs, value_type_rhs),
+        ) => {
+            if value_type_lhs.as_ref() != &DataType::Utf8
+                || value_type_rhs.as_ref() != &DataType::Utf8
+            {
+                return Err(ArrowError::InvalidArgumentError(
+                    "Arrow still does not support comparisons of non-string dictionary arrays"
+                        .to_string(),
+                ));
             }
-
-            // convert to bits with canonical pattern for NaN
-            let a = if a.is_nan() {
-                <$T>::NAN.to_bits()
-            } else {
-                a.to_bits()
-            };
-            let b = if b.is_nan() {
-                <$T>::NAN.to_bits()
-            } else {
-                b.to_bits()
-            };
-
-            if a == b {
-                // Equal or both NaN
-                Ordering::Equal
-            } else if a < b {
-                // (-0.0, 0.0) or (!NaN, NaN)
-                Ordering::Less
-            } else {
-                // (0.0, -0.0) or (NaN, !NaN)
-                Ordering::Greater
+            match (key_type_lhs.as_ref(), key_type_rhs.as_ref()) {
+                (a, b) if a != b => {
+                    return Err(ArrowError::InvalidArgumentError(
+                        "Can't compare arrays of different types".to_string(),
+                    ));
+                }
+                (UInt8, UInt8) => compare_dict_string::<UInt8Type>(left, right),
+                (UInt16, UInt16) => compare_dict_string::<UInt16Type>(left, right),
+                (UInt32, UInt32) => compare_dict_string::<UInt32Type>(left, right),
+                (UInt64, UInt64) => compare_dict_string::<UInt64Type>(left, right),
+                (Int8, Int8) => compare_dict_string::<Int8Type>(left, right),
+                (Int16, Int16) => compare_dict_string::<Int16Type>(left, right),
+                (Int32, Int32) => compare_dict_string::<Int32Type>(left, right),
+                (Int64, Int64) => compare_dict_string::<Int64Type>(left, right),
+                _ => todo!(),
             }
         }
-    };
+        _ => todo!(),

Review comment:
       ➕ 

##########
File path: rust/arrow/benches/sort_kernel.rs
##########
@@ -0,0 +1,83 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#[macro_use]
+extern crate criterion;
+use criterion::Criterion;
+
+use rand::Rng;
+use std::sync::Arc;
+
+extern crate arrow;
+
+use arrow::array::*;
+use arrow::compute::kernels::sort::{lexsort, SortColumn};
+
+fn create_array(size: usize, with_nulls: bool) -> ArrayRef {
+    // use random numbers to avoid spurious compiler optimizations wrt to branching
+    let mut rng = rand::thread_rng();
+    let mut builder = Float32Builder::new(size);
+
+    for _ in 0..size {
+        if with_nulls && rng.gen::<f32>() > 0.5 {
+            builder.append_null().unwrap();
+        } else {
+            builder.append_value(rng.gen()).unwrap();
+        }
+    }
+    Arc::new(builder.finish())
+}
+
+fn bench_sort(arr_a: &ArrayRef, array_b: &ArrayRef) {
+    let columns = vec![
+        SortColumn {
+            values: arr_a.clone(),
+            options: None,
+        },
+        SortColumn {
+            values: array_b.clone(),
+            options: None,
+        },
+    ];
+
+    criterion::black_box(lexsort(&columns).unwrap());
+}
+
+fn add_benchmark(c: &mut Criterion) {

Review comment:
       What would you think about adding a sort kernel for `u64` data as well?

##########
File path: rust/arrow/src/array/ord.rs
##########
@@ -15,297 +15,259 @@
 // specific language governing permissions and limitations
 // under the License.
 
-//! Defines trait for array element comparison
+//! Contains functions and function factories to compare arrays.
 
 use std::cmp::Ordering;
 
 use crate::array::*;
+use crate::datatypes::TimeUnit;
 use crate::datatypes::*;
 use crate::error::{ArrowError, Result};
 
-use TimeUnit::*;
+use num::Float;
 
-/// Trait for Arrays that can be sorted
-///
-/// Example:
-/// ```
-/// use std::cmp::Ordering;
-/// use arrow::array::*;
-/// use arrow::datatypes::*;
-///
-/// let arr: Box<dyn OrdArray> = Box::new(PrimitiveArray::<Int64Type>::from(vec![
-///     Some(-2),
-///     Some(89),
-///     Some(-64),
-///     Some(101),
-/// ]));
-///
-/// assert_eq!(arr.cmp_value(1, 2), Ordering::Greater);
-/// ```
-pub trait OrdArray {
-    /// Return ordering between array element at index i and j
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering;
-}
+/// The public interface to compare values from arrays in a dynamically-typed fashion.
+pub type DynComparator<'a> = Box<dyn Fn(usize, usize) -> Ordering + 'a>;
 
-impl<T: OrdArray> OrdArray for Box<T> {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        T::cmp_value(self, i, j)
+/// compares two floats, placing NaNs at last
+fn cmp_nans_last<T: Float>(a: &T, b: &T) -> Ordering {
+    match (a, b) {
+        (x, y) if x.is_nan() && y.is_nan() => Ordering::Equal,
+        (x, _) if x.is_nan() => Ordering::Greater,
+        (_, y) if y.is_nan() => Ordering::Less,
+        (_, _) => a.partial_cmp(b).unwrap(),
     }
 }
 
-impl<T: OrdArray> OrdArray for &T {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        T::cmp_value(self, i, j)
-    }
-}
-
-impl<T: ArrowPrimitiveType> OrdArray for PrimitiveArray<T>
+fn compare_primitives<'a, T: ArrowPrimitiveType>(
+    left: &'a Array,
+    right: &'a Array,
+) -> DynComparator<'a>
 where
-    T::Native: std::cmp::Ord,
+    T::Native: Ord,
 {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        self.value(i).cmp(&self.value(j))
-    }
+    let left = left.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    let right = right.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    Box::new(move |i, j| left.value(i).cmp(&right.value(j)))
 }
 
-impl OrdArray for StringArray {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        self.value(i).cmp(self.value(j))
-    }
-}
-
-impl OrdArray for NullArray {
-    fn cmp_value(&self, _i: usize, _j: usize) -> Ordering {
-        Ordering::Equal
-    }
-}
-
-macro_rules! float_ord_cmp {
-    ($NAME: ident, $T: ty) => {
-        #[inline]
-        fn $NAME(a: $T, b: $T) -> Ordering {
-            if a < b {
-                return Ordering::Less;
-            }
-            if a > b {
-                return Ordering::Greater;
-            }
-
-            // convert to bits with canonical pattern for NaN
-            let a = if a.is_nan() {
-                <$T>::NAN.to_bits()
-            } else {
-                a.to_bits()
-            };
-            let b = if b.is_nan() {
-                <$T>::NAN.to_bits()
-            } else {
-                b.to_bits()
-            };
-
-            if a == b {
-                // Equal or both NaN
-                Ordering::Equal
-            } else if a < b {
-                // (-0.0, 0.0) or (!NaN, NaN)
-                Ordering::Less
-            } else {
-                // (0.0, -0.0) or (NaN, !NaN)
-                Ordering::Greater
-            }
-        }
-    };
-}
-
-float_ord_cmp!(cmp_f64, f64);
-float_ord_cmp!(cmp_f32, f32);
-
-#[repr(transparent)]
-struct Float64ArrayAsOrdArray<'a>(&'a Float64Array);
-#[repr(transparent)]
-struct Float32ArrayAsOrdArray<'a>(&'a Float32Array);
-
-impl OrdArray for Float64ArrayAsOrdArray<'_> {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        let a: f64 = self.0.value(i);
-        let b: f64 = self.0.value(j);
-
-        cmp_f64(a, b)
-    }
-}
-
-impl OrdArray for Float32ArrayAsOrdArray<'_> {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        let a: f32 = self.0.value(i);
-        let b: f32 = self.0.value(j);
-
-        cmp_f32(a, b)
-    }
-}
-
-fn float32_as_ord_array<'a>(array: &'a ArrayRef) -> Box<dyn OrdArray + 'a> {
-    let float_array: &Float32Array = as_primitive_array::<Float32Type>(array);
-    Box::new(Float32ArrayAsOrdArray(float_array))
-}
-
-fn float64_as_ord_array<'a>(array: &'a ArrayRef) -> Box<dyn OrdArray + 'a> {
-    let float_array: &Float64Array = as_primitive_array::<Float64Type>(array);
-    Box::new(Float64ArrayAsOrdArray(float_array))
-}
-
-struct StringDictionaryArrayAsOrdArray<'a, T: ArrowDictionaryKeyType> {
-    dict_array: &'a DictionaryArray<T>,
-    values: StringArray,
-    keys: PrimitiveArray<T>,
+fn compare_float<'a, T: ArrowPrimitiveType>(
+    left: &'a Array,
+    right: &'a Array,
+) -> DynComparator<'a>
+where
+    T::Native: Float,
+{
+    let left = left.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    let right = right.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    Box::new(move |i, j| cmp_nans_last(&left.value(i), &right.value(j)))
 }
 
-impl<T: ArrowDictionaryKeyType> OrdArray for StringDictionaryArrayAsOrdArray<'_, T> {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        let keys = &self.keys;
-        let dict = &self.values;
-
-        let key_a: T::Native = keys.value(i);
-        let key_b: T::Native = keys.value(j);
-
-        let str_a = dict.value(key_a.to_usize().unwrap());
-        let str_b = dict.value(key_b.to_usize().unwrap());
-
-        str_a.cmp(str_b)
-    }
+fn compare_string<'a, T>(left: &'a Array, right: &'a Array) -> DynComparator<'a>
+where
+    T: StringOffsetSizeTrait,
+{
+    let left = left
+        .as_any()
+        .downcast_ref::<GenericStringArray<T>>()
+        .unwrap();
+    let right = right
+        .as_any()
+        .downcast_ref::<GenericStringArray<T>>()
+        .unwrap();
+    Box::new(move |i, j| left.value(i).cmp(&right.value(j)))
 }
 
-fn string_dict_as_ord_array<'a, T: ArrowDictionaryKeyType>(
-    array: &'a ArrayRef,
-) -> Box<dyn OrdArray + 'a>
+fn compare_dict_string<'a, T>(left: &'a Array, right: &'a Array) -> DynComparator<'a>
 where
-    T::Native: std::cmp::Ord,
+    T: ArrowDictionaryKeyType,
 {
-    let dict_array = as_dictionary_array::<T>(array);
-    let keys = dict_array.keys_array();
-
-    let values = &dict_array.values();
-    let values = StringArray::from(values.data());
-
-    Box::new(StringDictionaryArrayAsOrdArray {
-        dict_array,
-        values,
-        keys,
+    let left = left.as_any().downcast_ref::<DictionaryArray<T>>().unwrap();
+    let right = right.as_any().downcast_ref::<DictionaryArray<T>>().unwrap();
+    let left_keys = left.keys_array();
+    let right_keys = right.keys_array();
+
+    let left_values = StringArray::from(left.values().data());
+    let right_values = StringArray::from(left.values().data());
+
+    Box::new(move |i: usize, j: usize| {
+        let key_left = left_keys.value(i).to_usize().unwrap();
+        let key_right = right_keys.value(j).to_usize().unwrap();
+        let left = left_values.value(key_left);
+        let right = right_values.value(key_right);
+        left.cmp(&right)
     })
 }
 
-/// Convert ArrayRef to OrdArray trait object
-pub fn as_ordarray<'a>(values: &'a ArrayRef) -> Result<Box<OrdArray + 'a>> {
-    match values.data_type() {
-        DataType::Boolean => Ok(Box::new(as_boolean_array(&values))),
-        DataType::Utf8 => Ok(Box::new(as_string_array(&values))),
-        DataType::Null => Ok(Box::new(as_null_array(&values))),
-        DataType::Int8 => Ok(Box::new(as_primitive_array::<Int8Type>(&values))),
-        DataType::Int16 => Ok(Box::new(as_primitive_array::<Int16Type>(&values))),
-        DataType::Int32 => Ok(Box::new(as_primitive_array::<Int32Type>(&values))),
-        DataType::Int64 => Ok(Box::new(as_primitive_array::<Int64Type>(&values))),
-        DataType::UInt8 => Ok(Box::new(as_primitive_array::<UInt8Type>(&values))),
-        DataType::UInt16 => Ok(Box::new(as_primitive_array::<UInt16Type>(&values))),
-        DataType::UInt32 => Ok(Box::new(as_primitive_array::<UInt32Type>(&values))),
-        DataType::UInt64 => Ok(Box::new(as_primitive_array::<UInt64Type>(&values))),
-        DataType::Date32(_) => Ok(Box::new(as_primitive_array::<Date32Type>(&values))),
-        DataType::Date64(_) => Ok(Box::new(as_primitive_array::<Date64Type>(&values))),
-        DataType::Time32(Second) => {
-            Ok(Box::new(as_primitive_array::<Time32SecondType>(&values)))
+/// returns a comparison function that compares two values at two different positions
+/// between the two arrays.
+/// The arrays' types must be equal.
+/// # Example
+/// ```
+/// use arrow::array::{build_compare, Int32Array};
+///
+/// # fn main() -> arrow::error::Result<()> {
+/// let array1 = Int32Array::from(vec![1, 2]);
+/// let array2 = Int32Array::from(vec![3, 4]);
+///
+/// let cmp = build_compare(&array1, &array2)?;
+///
+/// // 1 (index 0 of array1) is smaller than 4 (index 1 of array2)
+/// assert_eq!(std::cmp::Ordering::Less, (cmp)(0, 1));
+/// # Ok(())
+/// # }
+/// ```
+// This is a factory of comparisons.
+// The lifetime 'a enforces that we cannot use the closure beyond any of the array's lifetime.
+pub fn build_compare<'a>(left: &'a Array, right: &'a Array) -> Result<DynComparator<'a>> {

Review comment:
       This all looks cool -- though I do wonder at the runtime overhead of doing dynamic dispatch for each comparison. However, if you have measured no performance regression, this code seems like an improvement to what is on master.
   
   I do wonder if there is some way to reuse the work in the comparison kernel:
   https://github.com/apache/arrow/blob/master/rust/arrow/src/compute/kernels/comparison.rs to try and improve performance over row by row comparison + a dynamic dispatch.
   
   Perhaps that is a good optimization for some future PR




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [arrow] jorgecarleitao closed pull request #8517: ARROW-10381: [Rust] Generalized Ordering for inter-array comparisons

Posted by GitBox <gi...@apache.org>.
jorgecarleitao closed pull request #8517:
URL: https://github.com/apache/arrow/pull/8517


   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [arrow] jorgecarleitao commented on a change in pull request #8517: ARROW-10381: [Rust] Generalized Ordering for inter-array comparisons

Posted by GitBox <gi...@apache.org>.
jorgecarleitao commented on a change in pull request #8517:
URL: https://github.com/apache/arrow/pull/8517#discussion_r513161172



##########
File path: rust/arrow/src/array/ord.rs
##########
@@ -15,297 +15,259 @@
 // specific language governing permissions and limitations
 // under the License.
 
-//! Defines trait for array element comparison
+//! Contains functions and function factories to compare arrays.
 
 use std::cmp::Ordering;
 
 use crate::array::*;
+use crate::datatypes::TimeUnit;
 use crate::datatypes::*;
 use crate::error::{ArrowError, Result};
 
-use TimeUnit::*;
+use num::Float;
 
-/// Trait for Arrays that can be sorted
-///
-/// Example:
-/// ```
-/// use std::cmp::Ordering;
-/// use arrow::array::*;
-/// use arrow::datatypes::*;
-///
-/// let arr: Box<dyn OrdArray> = Box::new(PrimitiveArray::<Int64Type>::from(vec![
-///     Some(-2),
-///     Some(89),
-///     Some(-64),
-///     Some(101),
-/// ]));
-///
-/// assert_eq!(arr.cmp_value(1, 2), Ordering::Greater);
-/// ```
-pub trait OrdArray {
-    /// Return ordering between array element at index i and j
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering;
-}
+/// The public interface to compare values from arrays in a dynamically-typed fashion.
+pub type DynComparator<'a> = Box<dyn Fn(usize, usize) -> Ordering + 'a>;
 
-impl<T: OrdArray> OrdArray for Box<T> {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        T::cmp_value(self, i, j)
+/// compares two floats, placing NaNs at last
+fn cmp_nans_last<T: Float>(a: &T, b: &T) -> Ordering {
+    match (a, b) {
+        (x, y) if x.is_nan() && y.is_nan() => Ordering::Equal,
+        (x, _) if x.is_nan() => Ordering::Greater,
+        (_, y) if y.is_nan() => Ordering::Less,
+        (_, _) => a.partial_cmp(b).unwrap(),
     }
 }
 
-impl<T: OrdArray> OrdArray for &T {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        T::cmp_value(self, i, j)
-    }
-}
-
-impl<T: ArrowPrimitiveType> OrdArray for PrimitiveArray<T>
+fn compare_primitives<'a, T: ArrowPrimitiveType>(
+    left: &'a Array,
+    right: &'a Array,
+) -> DynComparator<'a>
 where
-    T::Native: std::cmp::Ord,
+    T::Native: Ord,
 {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        self.value(i).cmp(&self.value(j))
-    }
+    let left = left.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    let right = right.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    Box::new(move |i, j| left.value(i).cmp(&right.value(j)))
 }
 
-impl OrdArray for StringArray {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        self.value(i).cmp(self.value(j))
-    }
-}
-
-impl OrdArray for NullArray {
-    fn cmp_value(&self, _i: usize, _j: usize) -> Ordering {
-        Ordering::Equal
-    }
-}
-
-macro_rules! float_ord_cmp {
-    ($NAME: ident, $T: ty) => {
-        #[inline]
-        fn $NAME(a: $T, b: $T) -> Ordering {
-            if a < b {
-                return Ordering::Less;
-            }
-            if a > b {
-                return Ordering::Greater;
-            }
-
-            // convert to bits with canonical pattern for NaN
-            let a = if a.is_nan() {
-                <$T>::NAN.to_bits()
-            } else {
-                a.to_bits()
-            };
-            let b = if b.is_nan() {
-                <$T>::NAN.to_bits()
-            } else {
-                b.to_bits()
-            };
-
-            if a == b {
-                // Equal or both NaN
-                Ordering::Equal
-            } else if a < b {
-                // (-0.0, 0.0) or (!NaN, NaN)
-                Ordering::Less
-            } else {
-                // (0.0, -0.0) or (NaN, !NaN)
-                Ordering::Greater
-            }
-        }
-    };
-}
-
-float_ord_cmp!(cmp_f64, f64);
-float_ord_cmp!(cmp_f32, f32);
-
-#[repr(transparent)]
-struct Float64ArrayAsOrdArray<'a>(&'a Float64Array);
-#[repr(transparent)]
-struct Float32ArrayAsOrdArray<'a>(&'a Float32Array);
-
-impl OrdArray for Float64ArrayAsOrdArray<'_> {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        let a: f64 = self.0.value(i);
-        let b: f64 = self.0.value(j);
-
-        cmp_f64(a, b)
-    }
-}
-
-impl OrdArray for Float32ArrayAsOrdArray<'_> {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        let a: f32 = self.0.value(i);
-        let b: f32 = self.0.value(j);
-
-        cmp_f32(a, b)
-    }
-}
-
-fn float32_as_ord_array<'a>(array: &'a ArrayRef) -> Box<dyn OrdArray + 'a> {
-    let float_array: &Float32Array = as_primitive_array::<Float32Type>(array);
-    Box::new(Float32ArrayAsOrdArray(float_array))
-}
-
-fn float64_as_ord_array<'a>(array: &'a ArrayRef) -> Box<dyn OrdArray + 'a> {
-    let float_array: &Float64Array = as_primitive_array::<Float64Type>(array);
-    Box::new(Float64ArrayAsOrdArray(float_array))
-}
-
-struct StringDictionaryArrayAsOrdArray<'a, T: ArrowDictionaryKeyType> {
-    dict_array: &'a DictionaryArray<T>,
-    values: StringArray,
-    keys: PrimitiveArray<T>,
+fn compare_float<'a, T: ArrowPrimitiveType>(
+    left: &'a Array,
+    right: &'a Array,
+) -> DynComparator<'a>
+where
+    T::Native: Float,
+{
+    let left = left.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    let right = right.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    Box::new(move |i, j| cmp_nans_last(&left.value(i), &right.value(j)))
 }
 
-impl<T: ArrowDictionaryKeyType> OrdArray for StringDictionaryArrayAsOrdArray<'_, T> {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        let keys = &self.keys;
-        let dict = &self.values;
-
-        let key_a: T::Native = keys.value(i);
-        let key_b: T::Native = keys.value(j);
-
-        let str_a = dict.value(key_a.to_usize().unwrap());
-        let str_b = dict.value(key_b.to_usize().unwrap());
-
-        str_a.cmp(str_b)
-    }
+fn compare_string<'a, T>(left: &'a Array, right: &'a Array) -> DynComparator<'a>
+where
+    T: StringOffsetSizeTrait,
+{
+    let left = left
+        .as_any()
+        .downcast_ref::<GenericStringArray<T>>()
+        .unwrap();
+    let right = right
+        .as_any()
+        .downcast_ref::<GenericStringArray<T>>()
+        .unwrap();
+    Box::new(move |i, j| left.value(i).cmp(&right.value(j)))
 }
 
-fn string_dict_as_ord_array<'a, T: ArrowDictionaryKeyType>(
-    array: &'a ArrayRef,
-) -> Box<dyn OrdArray + 'a>
+fn compare_dict_string<'a, T>(left: &'a Array, right: &'a Array) -> DynComparator<'a>
 where
-    T::Native: std::cmp::Ord,
+    T: ArrowDictionaryKeyType,
 {
-    let dict_array = as_dictionary_array::<T>(array);
-    let keys = dict_array.keys_array();
-
-    let values = &dict_array.values();
-    let values = StringArray::from(values.data());
-
-    Box::new(StringDictionaryArrayAsOrdArray {
-        dict_array,
-        values,
-        keys,
+    let left = left.as_any().downcast_ref::<DictionaryArray<T>>().unwrap();
+    let right = right.as_any().downcast_ref::<DictionaryArray<T>>().unwrap();
+    let left_keys = left.keys_array();
+    let right_keys = right.keys_array();
+
+    let left_values = StringArray::from(left.values().data());
+    let right_values = StringArray::from(left.values().data());
+
+    Box::new(move |i: usize, j: usize| {
+        let key_left = left_keys.value(i).to_usize().unwrap();
+        let key_right = right_keys.value(j).to_usize().unwrap();
+        let left = left_values.value(key_left);
+        let right = right_values.value(key_right);
+        left.cmp(&right)
     })
 }
 
-/// Convert ArrayRef to OrdArray trait object
-pub fn as_ordarray<'a>(values: &'a ArrayRef) -> Result<Box<OrdArray + 'a>> {
-    match values.data_type() {
-        DataType::Boolean => Ok(Box::new(as_boolean_array(&values))),
-        DataType::Utf8 => Ok(Box::new(as_string_array(&values))),
-        DataType::Null => Ok(Box::new(as_null_array(&values))),
-        DataType::Int8 => Ok(Box::new(as_primitive_array::<Int8Type>(&values))),
-        DataType::Int16 => Ok(Box::new(as_primitive_array::<Int16Type>(&values))),
-        DataType::Int32 => Ok(Box::new(as_primitive_array::<Int32Type>(&values))),
-        DataType::Int64 => Ok(Box::new(as_primitive_array::<Int64Type>(&values))),
-        DataType::UInt8 => Ok(Box::new(as_primitive_array::<UInt8Type>(&values))),
-        DataType::UInt16 => Ok(Box::new(as_primitive_array::<UInt16Type>(&values))),
-        DataType::UInt32 => Ok(Box::new(as_primitive_array::<UInt32Type>(&values))),
-        DataType::UInt64 => Ok(Box::new(as_primitive_array::<UInt64Type>(&values))),
-        DataType::Date32(_) => Ok(Box::new(as_primitive_array::<Date32Type>(&values))),
-        DataType::Date64(_) => Ok(Box::new(as_primitive_array::<Date64Type>(&values))),
-        DataType::Time32(Second) => {
-            Ok(Box::new(as_primitive_array::<Time32SecondType>(&values)))
+/// returns a comparison function that compares two values at two different positions
+/// between the two arrays.
+/// The arrays' types must be equal.
+/// # Example
+/// ```
+/// use arrow::array::{build_compare, Int32Array};
+///
+/// # fn main() -> arrow::error::Result<()> {
+/// let array1 = Int32Array::from(vec![1, 2]);
+/// let array2 = Int32Array::from(vec![3, 4]);
+///
+/// let cmp = build_compare(&array1, &array2)?;
+///
+/// // 1 (index 0 of array1) is smaller than 4 (index 1 of array2)
+/// assert_eq!(std::cmp::Ordering::Less, (cmp)(0, 1));
+/// # Ok(())
+/// # }
+/// ```
+// This is a factory of comparisons.
+// The lifetime 'a enforces that we cannot use the closure beyond any of the array's lifetime.
+pub fn build_compare<'a>(left: &'a Array, right: &'a Array) -> Result<DynComparator<'a>> {

Review comment:
       Note that we were already using dynamic dispatch with the `OrdArray`: in the lexical sort, we built a vector of arrays of unknown types, and then call their `cmp_values`. Because the vector contains heterogeneous array types, the calls are dynamically dispatched.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [arrow] jorgecarleitao commented on a change in pull request #8517: ARROW-10381: [Rust] Generalized Ordering for inter-array comparisons

Posted by GitBox <gi...@apache.org>.
jorgecarleitao commented on a change in pull request #8517:
URL: https://github.com/apache/arrow/pull/8517#discussion_r513164636



##########
File path: rust/arrow/src/array/ord.rs
##########
@@ -15,297 +15,280 @@
 // specific language governing permissions and limitations
 // under the License.
 
-//! Defines trait for array element comparison
+//! Contains functions and function factories to compare arrays.
 
 use std::cmp::Ordering;
 
 use crate::array::*;
+use crate::datatypes::TimeUnit;
 use crate::datatypes::*;
 use crate::error::{ArrowError, Result};
 
-use TimeUnit::*;
+use num::Float;
 
-/// Trait for Arrays that can be sorted
-///
-/// Example:
-/// ```
-/// use std::cmp::Ordering;
-/// use arrow::array::*;
-/// use arrow::datatypes::*;
-///
-/// let arr: Box<dyn OrdArray> = Box::new(PrimitiveArray::<Int64Type>::from(vec![
-///     Some(-2),
-///     Some(89),
-///     Some(-64),
-///     Some(101),
-/// ]));
-///
-/// assert_eq!(arr.cmp_value(1, 2), Ordering::Greater);
-/// ```
-pub trait OrdArray {
-    /// Return ordering between array element at index i and j
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering;
-}
+/// The public interface to compare values from arrays in a dynamically-typed fashion.
+pub type DynComparator<'a> = Box<dyn Fn(usize, usize) -> Ordering + 'a>;
 
-impl<T: OrdArray> OrdArray for Box<T> {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        T::cmp_value(self, i, j)
+/// compares two floats, placing NaNs at last
+fn cmp_nans_last<T: Float>(a: &T, b: &T) -> Ordering {
+    match (a, b) {
+        (x, y) if x.is_nan() && y.is_nan() => Ordering::Equal,
+        (x, _) if x.is_nan() => Ordering::Greater,
+        (_, y) if y.is_nan() => Ordering::Less,
+        (_, _) => a.partial_cmp(b).unwrap(),
     }
 }
 
-impl<T: OrdArray> OrdArray for &T {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        T::cmp_value(self, i, j)
-    }
+fn compare_primitives<'a, T: ArrowPrimitiveType>(
+    left: &'a Array,
+    right: &'a Array,
+) -> DynComparator<'a>
+where
+    T::Native: Ord,
+{
+    let left = left.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    let right = right.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    Box::new(move |i, j| left.value(i).cmp(&right.value(j)))
 }
 
-impl<T: ArrowPrimitiveType> OrdArray for PrimitiveArray<T>
+fn compare_float<'a, T: ArrowPrimitiveType>(
+    left: &'a Array,
+    right: &'a Array,
+) -> DynComparator<'a>
 where
-    T::Native: std::cmp::Ord,
+    T::Native: Float,
 {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        self.value(i).cmp(&self.value(j))
-    }
+    let left = left.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    let right = right.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    Box::new(move |i, j| cmp_nans_last(&left.value(i), &right.value(j)))
 }
 
-impl OrdArray for StringArray {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        self.value(i).cmp(self.value(j))
-    }
+fn compare_string<'a, T>(left: &'a Array, right: &'a Array) -> DynComparator<'a>
+where
+    T: StringOffsetSizeTrait,
+{
+    let left = left
+        .as_any()
+        .downcast_ref::<GenericStringArray<T>>()
+        .unwrap();
+    let right = right
+        .as_any()
+        .downcast_ref::<GenericStringArray<T>>()
+        .unwrap();
+    Box::new(move |i, j| left.value(i).cmp(&right.value(j)))
 }
 
-impl OrdArray for NullArray {
-    fn cmp_value(&self, _i: usize, _j: usize) -> Ordering {
-        Ordering::Equal
-    }
+fn compare_dict_string<'a, T>(left: &'a Array, right: &'a Array) -> DynComparator<'a>
+where
+    T: ArrowDictionaryKeyType,
+{
+    let left = left.as_any().downcast_ref::<DictionaryArray<T>>().unwrap();
+    let right = right.as_any().downcast_ref::<DictionaryArray<T>>().unwrap();
+    let left_keys = left.keys_array();
+    let right_keys = right.keys_array();
+
+    let left_values = StringArray::from(left.values().data());
+    let right_values = StringArray::from(left.values().data());
+
+    Box::new(move |i: usize, j: usize| {
+        let key_left = left_keys.value(i).to_usize().unwrap();
+        let key_right = right_keys.value(j).to_usize().unwrap();
+        let left = left_values.value(key_left);
+        let right = right_values.value(key_right);
+        left.cmp(&right)
+    })
 }
 
-macro_rules! float_ord_cmp {
-    ($NAME: ident, $T: ty) => {
-        #[inline]
-        fn $NAME(a: $T, b: $T) -> Ordering {
-            if a < b {
-                return Ordering::Less;
-            }
-            if a > b {
-                return Ordering::Greater;
+/// returns a comparison function that compares two values at two different positions
+/// between the two arrays.
+/// The arrays' types must be equal.
+/// # Example
+/// ```
+/// use arrow::array::{build_compare, Int32Array};
+///
+/// # fn main() -> arrow::error::Result<()> {
+/// let array1 = Int32Array::from(vec![1, 2]);
+/// let array2 = Int32Array::from(vec![3, 4]);
+///
+/// let cmp = build_compare(&array1, &array2)?;
+///
+/// // 1 (index 0 of array1) is smaller than 4 (index 1 of array2)
+/// assert_eq!(std::cmp::Ordering::Less, (cmp)(0, 1));
+/// # Ok(())
+/// # }
+/// ```
+// This is a factory of comparisons.
+// The lifetime 'a enforces that we cannot use the closure beyond any of the array's lifetime.
+pub fn build_compare<'a>(left: &'a Array, right: &'a Array) -> Result<DynComparator<'a>> {
+    use DataType::*;
+    use IntervalUnit::*;
+    use TimeUnit::*;
+    Ok(match (left.data_type(), right.data_type()) {
+        (a, b) if a != b => {
+            return Err(ArrowError::InvalidArgumentError(
+                "Can't compare arrays of different types".to_string(),
+            ));
+        }
+        (Boolean, Boolean) => compare_primitives::<BooleanType>(left, right),
+        (UInt8, UInt8) => compare_primitives::<UInt8Type>(left, right),
+        (UInt16, UInt16) => compare_primitives::<UInt16Type>(left, right),
+        (UInt32, UInt32) => compare_primitives::<UInt32Type>(left, right),
+        (UInt64, UInt64) => compare_primitives::<UInt64Type>(left, right),
+        (Int8, Int8) => compare_primitives::<Int8Type>(left, right),
+        (Int16, Int16) => compare_primitives::<Int16Type>(left, right),
+        (Int32, Int32) => compare_primitives::<Int32Type>(left, right),
+        (Int64, Int64) => compare_primitives::<Int64Type>(left, right),
+        (Float32, Float32) => compare_float::<Float32Type>(left, right),
+        (Float64, Float64) => compare_float::<Float64Type>(left, right),
+        (Date32(_), Date32(_)) => compare_primitives::<Date32Type>(left, right),
+        (Date64(_), Date64(_)) => compare_primitives::<Date64Type>(left, right),
+        (Time32(Second), Time32(Second)) => {
+            compare_primitives::<Time32SecondType>(left, right)
+        }
+        (Time32(Millisecond), Time32(Millisecond)) => {
+            compare_primitives::<Time32MillisecondType>(left, right)
+        }
+        (Time64(Microsecond), Time64(Microsecond)) => {
+            compare_primitives::<Time64MicrosecondType>(left, right)
+        }
+        (Time64(Nanosecond), Time64(Nanosecond)) => {
+            compare_primitives::<Time64NanosecondType>(left, right)
+        }
+        (Timestamp(Second, _), Timestamp(Second, _)) => {
+            compare_primitives::<TimestampSecondType>(left, right)
+        }
+        (Timestamp(Millisecond, _), Timestamp(Millisecond, _)) => {
+            compare_primitives::<TimestampMillisecondType>(left, right)
+        }
+        (Timestamp(Microsecond, _), Timestamp(Microsecond, _)) => {
+            compare_primitives::<TimestampMicrosecondType>(left, right)
+        }
+        (Timestamp(Nanosecond, _), Timestamp(Nanosecond, _)) => {
+            compare_primitives::<TimestampNanosecondType>(left, right)
+        }
+        (Interval(YearMonth), Interval(YearMonth)) => {
+            compare_primitives::<IntervalYearMonthType>(left, right)
+        }
+        (Interval(DayTime), Interval(DayTime)) => {
+            compare_primitives::<IntervalDayTimeType>(left, right)
+        }
+        (Duration(Second), Duration(Second)) => {
+            compare_primitives::<DurationSecondType>(left, right)
+        }
+        (Duration(Millisecond), Duration(Millisecond)) => {
+            compare_primitives::<DurationMillisecondType>(left, right)
+        }
+        (Duration(Microsecond), Duration(Microsecond)) => {
+            compare_primitives::<DurationMicrosecondType>(left, right)
+        }
+        (Duration(Nanosecond), Duration(Nanosecond)) => {
+            compare_primitives::<DurationNanosecondType>(left, right)
+        }
+        (Utf8, Utf8) => compare_string::<i32>(left, right),
+        (LargeUtf8, LargeUtf8) => compare_string::<i64>(left, right),
+        (
+            Dictionary(key_type_lhs, value_type_lhs),
+            Dictionary(key_type_rhs, value_type_rhs),
+        ) => {
+            if value_type_lhs.as_ref() != &DataType::Utf8
+                || value_type_rhs.as_ref() != &DataType::Utf8
+            {
+                return Err(ArrowError::InvalidArgumentError(
+                    "Arrow still does not support comparisons of non-string dictionary arrays"

Review comment:
       This code is as it was here.
   
   FYI I did tried to make this for arbitrary types, and I was very (very) close from having it one done, but it requires some unsafe usage, and so I left it for a separate PR. The relevant thread is here: https://users.rust-lang.org/t/how-to-move-values-to-closure-that-indirectly-depends-on-them/50586

##########
File path: rust/arrow/src/array/ord.rs
##########
@@ -15,297 +15,280 @@
 // specific language governing permissions and limitations
 // under the License.
 
-//! Defines trait for array element comparison
+//! Contains functions and function factories to compare arrays.
 
 use std::cmp::Ordering;
 
 use crate::array::*;
+use crate::datatypes::TimeUnit;
 use crate::datatypes::*;
 use crate::error::{ArrowError, Result};
 
-use TimeUnit::*;
+use num::Float;
 
-/// Trait for Arrays that can be sorted
-///
-/// Example:
-/// ```
-/// use std::cmp::Ordering;
-/// use arrow::array::*;
-/// use arrow::datatypes::*;
-///
-/// let arr: Box<dyn OrdArray> = Box::new(PrimitiveArray::<Int64Type>::from(vec![
-///     Some(-2),
-///     Some(89),
-///     Some(-64),
-///     Some(101),
-/// ]));
-///
-/// assert_eq!(arr.cmp_value(1, 2), Ordering::Greater);
-/// ```
-pub trait OrdArray {
-    /// Return ordering between array element at index i and j
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering;
-}
+/// The public interface to compare values from arrays in a dynamically-typed fashion.
+pub type DynComparator<'a> = Box<dyn Fn(usize, usize) -> Ordering + 'a>;
 
-impl<T: OrdArray> OrdArray for Box<T> {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        T::cmp_value(self, i, j)
+/// compares two floats, placing NaNs at last
+fn cmp_nans_last<T: Float>(a: &T, b: &T) -> Ordering {
+    match (a, b) {
+        (x, y) if x.is_nan() && y.is_nan() => Ordering::Equal,
+        (x, _) if x.is_nan() => Ordering::Greater,
+        (_, y) if y.is_nan() => Ordering::Less,
+        (_, _) => a.partial_cmp(b).unwrap(),
     }
 }
 
-impl<T: OrdArray> OrdArray for &T {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        T::cmp_value(self, i, j)
-    }
+fn compare_primitives<'a, T: ArrowPrimitiveType>(
+    left: &'a Array,
+    right: &'a Array,
+) -> DynComparator<'a>
+where
+    T::Native: Ord,
+{
+    let left = left.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    let right = right.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    Box::new(move |i, j| left.value(i).cmp(&right.value(j)))
 }
 
-impl<T: ArrowPrimitiveType> OrdArray for PrimitiveArray<T>
+fn compare_float<'a, T: ArrowPrimitiveType>(
+    left: &'a Array,
+    right: &'a Array,
+) -> DynComparator<'a>
 where
-    T::Native: std::cmp::Ord,
+    T::Native: Float,
 {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        self.value(i).cmp(&self.value(j))
-    }
+    let left = left.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    let right = right.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
+    Box::new(move |i, j| cmp_nans_last(&left.value(i), &right.value(j)))
 }
 
-impl OrdArray for StringArray {
-    fn cmp_value(&self, i: usize, j: usize) -> Ordering {
-        self.value(i).cmp(self.value(j))
-    }
+fn compare_string<'a, T>(left: &'a Array, right: &'a Array) -> DynComparator<'a>
+where
+    T: StringOffsetSizeTrait,
+{
+    let left = left
+        .as_any()
+        .downcast_ref::<GenericStringArray<T>>()
+        .unwrap();
+    let right = right
+        .as_any()
+        .downcast_ref::<GenericStringArray<T>>()
+        .unwrap();
+    Box::new(move |i, j| left.value(i).cmp(&right.value(j)))
 }
 
-impl OrdArray for NullArray {
-    fn cmp_value(&self, _i: usize, _j: usize) -> Ordering {
-        Ordering::Equal
-    }
+fn compare_dict_string<'a, T>(left: &'a Array, right: &'a Array) -> DynComparator<'a>
+where
+    T: ArrowDictionaryKeyType,
+{
+    let left = left.as_any().downcast_ref::<DictionaryArray<T>>().unwrap();
+    let right = right.as_any().downcast_ref::<DictionaryArray<T>>().unwrap();
+    let left_keys = left.keys_array();
+    let right_keys = right.keys_array();
+
+    let left_values = StringArray::from(left.values().data());
+    let right_values = StringArray::from(left.values().data());
+
+    Box::new(move |i: usize, j: usize| {
+        let key_left = left_keys.value(i).to_usize().unwrap();
+        let key_right = right_keys.value(j).to_usize().unwrap();
+        let left = left_values.value(key_left);
+        let right = right_values.value(key_right);
+        left.cmp(&right)
+    })
 }
 
-macro_rules! float_ord_cmp {
-    ($NAME: ident, $T: ty) => {
-        #[inline]
-        fn $NAME(a: $T, b: $T) -> Ordering {
-            if a < b {
-                return Ordering::Less;
-            }
-            if a > b {
-                return Ordering::Greater;
+/// returns a comparison function that compares two values at two different positions
+/// between the two arrays.
+/// The arrays' types must be equal.
+/// # Example
+/// ```
+/// use arrow::array::{build_compare, Int32Array};
+///
+/// # fn main() -> arrow::error::Result<()> {
+/// let array1 = Int32Array::from(vec![1, 2]);
+/// let array2 = Int32Array::from(vec![3, 4]);
+///
+/// let cmp = build_compare(&array1, &array2)?;
+///
+/// // 1 (index 0 of array1) is smaller than 4 (index 1 of array2)
+/// assert_eq!(std::cmp::Ordering::Less, (cmp)(0, 1));
+/// # Ok(())
+/// # }
+/// ```
+// This is a factory of comparisons.
+// The lifetime 'a enforces that we cannot use the closure beyond any of the array's lifetime.
+pub fn build_compare<'a>(left: &'a Array, right: &'a Array) -> Result<DynComparator<'a>> {
+    use DataType::*;
+    use IntervalUnit::*;
+    use TimeUnit::*;
+    Ok(match (left.data_type(), right.data_type()) {
+        (a, b) if a != b => {
+            return Err(ArrowError::InvalidArgumentError(
+                "Can't compare arrays of different types".to_string(),
+            ));
+        }
+        (Boolean, Boolean) => compare_primitives::<BooleanType>(left, right),
+        (UInt8, UInt8) => compare_primitives::<UInt8Type>(left, right),
+        (UInt16, UInt16) => compare_primitives::<UInt16Type>(left, right),
+        (UInt32, UInt32) => compare_primitives::<UInt32Type>(left, right),
+        (UInt64, UInt64) => compare_primitives::<UInt64Type>(left, right),
+        (Int8, Int8) => compare_primitives::<Int8Type>(left, right),
+        (Int16, Int16) => compare_primitives::<Int16Type>(left, right),
+        (Int32, Int32) => compare_primitives::<Int32Type>(left, right),
+        (Int64, Int64) => compare_primitives::<Int64Type>(left, right),
+        (Float32, Float32) => compare_float::<Float32Type>(left, right),
+        (Float64, Float64) => compare_float::<Float64Type>(left, right),
+        (Date32(_), Date32(_)) => compare_primitives::<Date32Type>(left, right),
+        (Date64(_), Date64(_)) => compare_primitives::<Date64Type>(left, right),
+        (Time32(Second), Time32(Second)) => {
+            compare_primitives::<Time32SecondType>(left, right)
+        }
+        (Time32(Millisecond), Time32(Millisecond)) => {
+            compare_primitives::<Time32MillisecondType>(left, right)
+        }
+        (Time64(Microsecond), Time64(Microsecond)) => {
+            compare_primitives::<Time64MicrosecondType>(left, right)
+        }
+        (Time64(Nanosecond), Time64(Nanosecond)) => {
+            compare_primitives::<Time64NanosecondType>(left, right)
+        }
+        (Timestamp(Second, _), Timestamp(Second, _)) => {
+            compare_primitives::<TimestampSecondType>(left, right)
+        }
+        (Timestamp(Millisecond, _), Timestamp(Millisecond, _)) => {
+            compare_primitives::<TimestampMillisecondType>(left, right)
+        }
+        (Timestamp(Microsecond, _), Timestamp(Microsecond, _)) => {
+            compare_primitives::<TimestampMicrosecondType>(left, right)
+        }
+        (Timestamp(Nanosecond, _), Timestamp(Nanosecond, _)) => {
+            compare_primitives::<TimestampNanosecondType>(left, right)
+        }
+        (Interval(YearMonth), Interval(YearMonth)) => {
+            compare_primitives::<IntervalYearMonthType>(left, right)
+        }
+        (Interval(DayTime), Interval(DayTime)) => {
+            compare_primitives::<IntervalDayTimeType>(left, right)
+        }
+        (Duration(Second), Duration(Second)) => {
+            compare_primitives::<DurationSecondType>(left, right)
+        }
+        (Duration(Millisecond), Duration(Millisecond)) => {
+            compare_primitives::<DurationMillisecondType>(left, right)
+        }
+        (Duration(Microsecond), Duration(Microsecond)) => {
+            compare_primitives::<DurationMicrosecondType>(left, right)
+        }
+        (Duration(Nanosecond), Duration(Nanosecond)) => {
+            compare_primitives::<DurationNanosecondType>(left, right)
+        }
+        (Utf8, Utf8) => compare_string::<i32>(left, right),
+        (LargeUtf8, LargeUtf8) => compare_string::<i64>(left, right),
+        (
+            Dictionary(key_type_lhs, value_type_lhs),
+            Dictionary(key_type_rhs, value_type_rhs),
+        ) => {
+            if value_type_lhs.as_ref() != &DataType::Utf8
+                || value_type_rhs.as_ref() != &DataType::Utf8
+            {
+                return Err(ArrowError::InvalidArgumentError(
+                    "Arrow still does not support comparisons of non-string dictionary arrays"

Review comment:
       This code is as it was here.
   
   FYI I did try to make this for arbitrary types, and I was very (very) close from having it one done, but it requires some unsafe usage, and so I left it for a separate PR. The relevant thread is here: https://users.rust-lang.org/t/how-to-move-values-to-closure-that-indirectly-depends-on-them/50586




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [arrow] jorgecarleitao commented on pull request #8517: ARROW-10381: [Rust] Generalized Ordering for inter-array comparisons

Posted by GitBox <gi...@apache.org>.
jorgecarleitao commented on pull request #8517:
URL: https://github.com/apache/arrow/pull/8517#issuecomment-715986049


   FYI @andygrove @alamb : I planning to approach the MergeSort in DataFusion: my current idea is to merge-sort record-batches within a part and then merge-sort then again across partitions. This is preparatory work for that, as we can't merge sort two arrays together before this PR.
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [arrow] github-actions[bot] commented on pull request #8517: ARROW-10381: [Rust] Generalized Ordering for inter-array comparisons

Posted by GitBox <gi...@apache.org>.
github-actions[bot] commented on pull request #8517:
URL: https://github.com/apache/arrow/pull/8517#issuecomment-715988969


   https://issues.apache.org/jira/browse/ARROW-10381


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [arrow] andygrove commented on pull request #8517: ARROW-10381: [Rust] Generalized Ordering for inter-array comparisons

Posted by GitBox <gi...@apache.org>.
andygrove commented on pull request #8517:
URL: https://github.com/apache/arrow/pull/8517#issuecomment-716031171


   @jorgecarleitao this sounds good. I am going to make time tomorrow to catch up on the current PRs.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org