You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by GitBox <gi...@apache.org> on 2021/12/22 16:21:36 UTC

[GitHub] [arrow-rs] shepmaster commented on a change in pull request #1074: Define eq_dyn_scalar API

shepmaster commented on a change in pull request #1074:
URL: https://github.com/apache/arrow-rs/pull/1074#discussion_r774003760



##########
File path: arrow/src/compute/kernels/comparison.rs
##########
@@ -898,6 +898,126 @@ pub fn gt_eq_utf8_scalar<OffsetSize: StringOffsetSizeTrait>(
     compare_op_scalar!(left, right, |a, b| a >= b)
 }
 
+macro_rules! dyn_cmp_scalar {
+    ($LEFT: expr, $RIGHT: expr, $T: ident, $OP: ident, $TT: tt) => {{

Review comment:
       Should `$T` and `$TT` be [`ty`](https://doc.rust-lang.org/stable/reference/macros-by-example.html#metavariables)?

##########
File path: arrow/src/compute/kernels/comparison.rs
##########
@@ -898,6 +898,126 @@ pub fn gt_eq_utf8_scalar<OffsetSize: StringOffsetSizeTrait>(
     compare_op_scalar!(left, right, |a, b| a >= b)
 }
 
+macro_rules! dyn_cmp_scalar {
+    ($LEFT: expr, $RIGHT: expr, $T: ident, $OP: ident, $TT: tt) => {{
+        let left = $LEFT.as_any().downcast_ref::<$T>().ok_or_else(|| {
+            ArrowError::CastError(format!(
+                "Left array cannot be cast to {}",
+                type_name::<$T>()
+            ))
+        })?;
+        let right = $RIGHT.as_any().downcast_ref::<$T>().ok_or_else(|| {
+            ArrowError::CastError(format!(
+                "Right array cannot be cast to {}",
+                type_name::<$T>(),
+            ))
+        })?;
+        $OP::<$TT>(left, right)

Review comment:
       `$OP::<$TT>` could probably be fused as one `expr` macro argument

##########
File path: arrow/src/compute/kernels/comparison.rs
##########
@@ -898,6 +898,126 @@ pub fn gt_eq_utf8_scalar<OffsetSize: StringOffsetSizeTrait>(
     compare_op_scalar!(left, right, |a, b| a >= b)
 }
 
+macro_rules! dyn_cmp_scalar {
+    ($LEFT: expr, $RIGHT: expr, $T: ident, $OP: ident, $TT: tt) => {{
+        let left = $LEFT.as_any().downcast_ref::<$T>().ok_or_else(|| {
+            ArrowError::CastError(format!(
+                "Left array cannot be cast to {}",
+                type_name::<$T>()
+            ))
+        })?;
+        let right = $RIGHT.as_any().downcast_ref::<$T>().ok_or_else(|| {
+            ArrowError::CastError(format!(
+                "Right array cannot be cast to {}",
+                type_name::<$T>(),
+            ))
+        })?;
+        $OP::<$TT>(left, right)
+    }};
+}
+
+macro_rules! dyn_compare_scalar {
+    ($LEFT: expr, $RIGHT: expr, $OP: ident) => {{
+        let right = $RIGHT.try_into().map_err(|_| {
+            ArrowError::ComputeError(format!(
+                "Can not convert scalar {:?} to i128",
+                $RIGHT
+            ))
+        });
+        match ($LEFT.data_type(), $RIGHT::Arrow) {
+            // (DataType::Boolean, DataType::Boolean) => {
+            //     typed_cmp_scalar!($LEFT, $RIGHT, BooleanArray, $OP_BOOL)
+            // }
+            (DataType::Int8, DataType::Int8) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, Int8Array, $OP, Int8Type)
+            }
+            (DataType::Int16, DataType::Int16) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, Int16Array, $OP, Int16Type)
+            }
+            (DataType::Int32, DataType::Int32) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, Int32Array, $OP, Int32Type)
+            }
+            (DataType::Int64, DataType::Int64) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, Int64Array, $OP, Int64Type)
+            }
+            (DataType::UInt8, DataType::UInt8) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, UInt8Array, $OP, UInt8Type)
+            }
+            (DataType::UInt16, DataType::UInt16) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, UInt16Array, $OP, UInt16Type)
+            }
+            (DataType::UInt32, DataType::UInt32) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, UInt32Array, $OP, UInt32Type)
+            }
+            (DataType::UInt64, DataType::UInt64) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, UInt64Array, $OP, UInt64Type)
+            }
+            (DataType::Float32, DataType::Float32) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, Float32Array, $OP, Float32Type)
+            }
+            (DataType::Float64, DataType::Float64) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, Float64Array, $OP, Float64Type)
+            }
+            (DataType::Utf8, DataType::Utf8) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, StringArray, $OP, i32)
+            }
+            (DataType::LargeUtf8, DataType::LargeUtf8) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, LargeStringArray, $OP, i64)
+            }
+            (DataType::Dictionary(DataType::UInt8, DataType::UInt8), DataType::UInt8) => {
+                let values_comp =
+                    typed_compare_scalar!($LEFT.values(), $RIGHT, eq_scalar);
+                unpack_dict_comparison($LEFT, values_comp)
+            }
+            (t1, t2) if t1 == t2 => Err(ArrowError::NotYetImplemented(format!(
+                "Comparing arrays of type {} is not yet implemented",
+                t1
+            ))),
+            (t1, t2) => Err(ArrowError::CastError(format!(
+                "Cannot compare an array with a scalar of different type ({} and {})",
+                t1, t2
+            ))),
+        }
+    }};
+}
+
+/// Perform `left == right` operation on an array and a numeric scalar
+/// value. Supports PrimtiveArrays, and DictionaryArrays that have primitive values

Review comment:
       ```suggestion
   /// value. Supports PrimitiveArrays, and DictionaryArrays that have primitive values
   ```

##########
File path: arrow/src/compute/kernels/comparison.rs
##########
@@ -898,6 +898,126 @@ pub fn gt_eq_utf8_scalar<OffsetSize: StringOffsetSizeTrait>(
     compare_op_scalar!(left, right, |a, b| a >= b)
 }
 
+macro_rules! dyn_cmp_scalar {
+    ($LEFT: expr, $RIGHT: expr, $T: ident, $OP: ident, $TT: tt) => {{
+        let left = $LEFT.as_any().downcast_ref::<$T>().ok_or_else(|| {
+            ArrowError::CastError(format!(
+                "Left array cannot be cast to {}",
+                type_name::<$T>()
+            ))
+        })?;
+        let right = $RIGHT.as_any().downcast_ref::<$T>().ok_or_else(|| {
+            ArrowError::CastError(format!(
+                "Right array cannot be cast to {}",
+                type_name::<$T>(),
+            ))
+        })?;
+        $OP::<$TT>(left, right)
+    }};
+}
+
+macro_rules! dyn_compare_scalar {
+    ($LEFT: expr, $RIGHT: expr, $OP: ident) => {{
+        let right = $RIGHT.try_into().map_err(|_| {
+            ArrowError::ComputeError(format!(
+                "Can not convert scalar {:?} to i128",
+                $RIGHT
+            ))
+        });
+        match ($LEFT.data_type(), $RIGHT::Arrow) {
+            // (DataType::Boolean, DataType::Boolean) => {
+            //     typed_cmp_scalar!($LEFT, $RIGHT, BooleanArray, $OP_BOOL)
+            // }
+            (DataType::Int8, DataType::Int8) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, Int8Array, $OP, Int8Type)
+            }
+            (DataType::Int16, DataType::Int16) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, Int16Array, $OP, Int16Type)
+            }
+            (DataType::Int32, DataType::Int32) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, Int32Array, $OP, Int32Type)
+            }
+            (DataType::Int64, DataType::Int64) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, Int64Array, $OP, Int64Type)
+            }
+            (DataType::UInt8, DataType::UInt8) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, UInt8Array, $OP, UInt8Type)
+            }
+            (DataType::UInt16, DataType::UInt16) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, UInt16Array, $OP, UInt16Type)
+            }
+            (DataType::UInt32, DataType::UInt32) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, UInt32Array, $OP, UInt32Type)
+            }
+            (DataType::UInt64, DataType::UInt64) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, UInt64Array, $OP, UInt64Type)
+            }
+            (DataType::Float32, DataType::Float32) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, Float32Array, $OP, Float32Type)
+            }
+            (DataType::Float64, DataType::Float64) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, Float64Array, $OP, Float64Type)
+            }
+            (DataType::Utf8, DataType::Utf8) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, StringArray, $OP, i32)
+            }
+            (DataType::LargeUtf8, DataType::LargeUtf8) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, LargeStringArray, $OP, i64)
+            }
+            (DataType::Dictionary(DataType::UInt8, DataType::UInt8), DataType::UInt8) => {
+                let values_comp =
+                    typed_compare_scalar!($LEFT.values(), $RIGHT, eq_scalar);
+                unpack_dict_comparison($LEFT, values_comp)
+            }
+            (t1, t2) if t1 == t2 => Err(ArrowError::NotYetImplemented(format!(
+                "Comparing arrays of type {} is not yet implemented",
+                t1
+            ))),
+            (t1, t2) => Err(ArrowError::CastError(format!(
+                "Cannot compare an array with a scalar of different type ({} and {})",
+                t1, t2
+            ))),
+        }
+    }};
+}
+
+/// Perform `left == right` operation on an array and a numeric scalar
+/// value. Supports PrimtiveArrays, and DictionaryArrays that have primitive values
+pub fn eq_dyn_scalar<T>(left: &dyn Array, right: T) -> Result<BooleanArray>
+where
+    T: IntoArrowNumericType + TryInto<i128> + Copy + std::fmt::Debug,
+{
+    dyn_compare_scalar!(left, right, eq_scalar)

Review comment:
       I admit this is a drive-by review, but I'm not seeing the benefit of the macros here yet. They don't do any repetition reduction. It looks like `dyn_compare_scalar` could be inlined and `dyn_cmp_scalar` could be a regular function.

##########
File path: arrow/src/compute/kernels/comparison.rs
##########
@@ -898,6 +898,126 @@ pub fn gt_eq_utf8_scalar<OffsetSize: StringOffsetSizeTrait>(
     compare_op_scalar!(left, right, |a, b| a >= b)
 }
 
+macro_rules! dyn_cmp_scalar {
+    ($LEFT: expr, $RIGHT: expr, $T: ident, $OP: ident, $TT: tt) => {{
+        let left = $LEFT.as_any().downcast_ref::<$T>().ok_or_else(|| {
+            ArrowError::CastError(format!(
+                "Left array cannot be cast to {}",
+                type_name::<$T>()
+            ))
+        })?;
+        let right = $RIGHT.as_any().downcast_ref::<$T>().ok_or_else(|| {
+            ArrowError::CastError(format!(
+                "Right array cannot be cast to {}",
+                type_name::<$T>(),
+            ))
+        })?;
+        $OP::<$TT>(left, right)
+    }};
+}
+
+macro_rules! dyn_compare_scalar {
+    ($LEFT: expr, $RIGHT: expr, $OP: ident) => {{
+        let right = $RIGHT.try_into().map_err(|_| {

Review comment:
       I'm missing where this `right` value is used...

##########
File path: arrow/src/compute/kernels/comparison.rs
##########
@@ -898,6 +898,126 @@ pub fn gt_eq_utf8_scalar<OffsetSize: StringOffsetSizeTrait>(
     compare_op_scalar!(left, right, |a, b| a >= b)
 }
 
+macro_rules! dyn_cmp_scalar {
+    ($LEFT: expr, $RIGHT: expr, $T: ident, $OP: ident, $TT: tt) => {{
+        let left = $LEFT.as_any().downcast_ref::<$T>().ok_or_else(|| {
+            ArrowError::CastError(format!(
+                "Left array cannot be cast to {}",
+                type_name::<$T>()
+            ))
+        })?;
+        let right = $RIGHT.as_any().downcast_ref::<$T>().ok_or_else(|| {
+            ArrowError::CastError(format!(
+                "Right array cannot be cast to {}",
+                type_name::<$T>(),
+            ))
+        })?;
+        $OP::<$TT>(left, right)
+    }};
+}
+
+macro_rules! dyn_compare_scalar {
+    ($LEFT: expr, $RIGHT: expr, $OP: ident) => {{
+        let right = $RIGHT.try_into().map_err(|_| {
+            ArrowError::ComputeError(format!(
+                "Can not convert scalar {:?} to i128",
+                $RIGHT
+            ))
+        });
+        match ($LEFT.data_type(), $RIGHT::Arrow) {
+            // (DataType::Boolean, DataType::Boolean) => {
+            //     typed_cmp_scalar!($LEFT, $RIGHT, BooleanArray, $OP_BOOL)
+            // }
+            (DataType::Int8, DataType::Int8) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, Int8Array, $OP, Int8Type)
+            }
+            (DataType::Int16, DataType::Int16) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, Int16Array, $OP, Int16Type)
+            }
+            (DataType::Int32, DataType::Int32) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, Int32Array, $OP, Int32Type)
+            }
+            (DataType::Int64, DataType::Int64) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, Int64Array, $OP, Int64Type)
+            }
+            (DataType::UInt8, DataType::UInt8) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, UInt8Array, $OP, UInt8Type)
+            }
+            (DataType::UInt16, DataType::UInt16) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, UInt16Array, $OP, UInt16Type)
+            }
+            (DataType::UInt32, DataType::UInt32) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, UInt32Array, $OP, UInt32Type)
+            }
+            (DataType::UInt64, DataType::UInt64) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, UInt64Array, $OP, UInt64Type)
+            }
+            (DataType::Float32, DataType::Float32) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, Float32Array, $OP, Float32Type)
+            }
+            (DataType::Float64, DataType::Float64) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, Float64Array, $OP, Float64Type)
+            }
+            (DataType::Utf8, DataType::Utf8) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, StringArray, $OP, i32)
+            }
+            (DataType::LargeUtf8, DataType::LargeUtf8) => {
+                dyn_cmp_scalar!($LEFT, $RIGHT, LargeStringArray, $OP, i64)
+            }
+            (DataType::Dictionary(DataType::UInt8, DataType::UInt8), DataType::UInt8) => {
+                let values_comp =
+                    typed_compare_scalar!($LEFT.values(), $RIGHT, eq_scalar);
+                unpack_dict_comparison($LEFT, values_comp)
+            }
+            (t1, t2) if t1 == t2 => Err(ArrowError::NotYetImplemented(format!(
+                "Comparing arrays of type {} is not yet implemented",
+                t1
+            ))),
+            (t1, t2) => Err(ArrowError::CastError(format!(
+                "Cannot compare an array with a scalar of different type ({} and {})",
+                t1, t2
+            ))),
+        }
+    }};
+}
+
+/// Perform `left == right` operation on an array and a numeric scalar
+/// value. Supports PrimtiveArrays, and DictionaryArrays that have primitive values
+pub fn eq_dyn_scalar<T>(left: &dyn Array, right: T) -> Result<BooleanArray>
+where
+    T: IntoArrowNumericType + TryInto<i128> + Copy + std::fmt::Debug,
+{
+    dyn_compare_scalar!(left, right, eq_scalar)
+}
+
+/// unpacks the results of comparing left.values (as a boolean)
+///
+/// TODO add example
+///
+fn unpack_dict_comparison<K>(
+    left: &DictionaryArray<K>,
+    dict_comparison: BooleanArray,
+) -> Result<BooleanArray>
+where
+    K: ArrowNumericType,
+{
+    assert_eq!(dict_comparison.len(), left.values().len());

Review comment:
       Why an assertion as opposed to an error or a "no this does not match?"




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org