You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by "alamb (via GitHub)" <gi...@apache.org> on 2023/06/13 16:05:12 UTC

[GitHub] [arrow-rs] alamb commented on a diff in pull request #4405: Improve `take` kernel performance on primitive arrays, fix bad null index handling (#4404)

alamb commented on code in PR #4405:
URL: https://github.com/apache/arrow-rs/pull/4405#discussion_r1228144351


##########
arrow-select/src/take.rs:
##########
@@ -374,148 +243,91 @@ fn take_primitive<T, I>(
 where
     T: ArrowPrimitiveType,
     I: ArrowPrimitiveType,
-    I::Native: ToPrimitive,
 {
-    let indices_nulls = indices.nulls().filter(|x| x.null_count() > 0);
-    let values_nulls = values.nulls().filter(|x| x.null_count() > 0);
-
-    // note: this function should only panic when "an index is not null and out of bounds".
-    // if the index is null, its value is undefined and therefore we should not read from it.
-    let (buffer, nulls) = match (values_nulls, indices_nulls) {
-        (None, None) => {
-            // * no nulls
-            // * all `indices.values()` are valid
-            take_no_nulls(values.values(), indices.values())?
-        }
-        (Some(values_nulls), None) => {
-            // * nulls come from `values` alone
-            // * all `indices.values()` are valid
-            take_values_nulls(values.values(), values_nulls, indices.values())?
-        }
-        (None, Some(indices_nulls)) => {
-            // in this branch it is unsound to read and use `index.values()`,
-            // as doing so is UB when they come from a null slot.
-            take_indices_nulls(values.values(), indices.values(), indices_nulls)?
-        }
-        (Some(values_nulls), Some(indices_nulls)) => {
-            // in this branch it is unsound to read and use `index.values()`,
-            // as doing so is UB when they come from a null slot.
-            take_values_indices_nulls(
-                values.values(),
-                values_nulls,
-                indices.values(),
-                indices_nulls,
-            )?
-        }
-    };
-
-    let data = unsafe {
-        ArrayData::new_unchecked(
-            values.data_type().clone(),
-            indices.len(),
-            None,
-            nulls,
-            0,
-            vec![buffer],
-            vec![],
-        )
-    };
-    Ok(PrimitiveArray::<T>::from(data))
+    let values_buf = take_native(values.values(), indices);
+    let nulls = take_nulls(values.nulls(), indices);
+    Ok(PrimitiveArray::new(values_buf, nulls).with_data_type(values.data_type().clone()))
 }
 
-fn take_bits<IndexType>(
-    values: &Buffer,
-    values_offset: usize,
-    indices: &PrimitiveArray<IndexType>,
-) -> Result<Buffer, ArrowError>
-where
-    IndexType: ArrowPrimitiveType,
-    IndexType::Native: ToPrimitive,
-{
-    let len = indices.len();
-    let values_slice = values.as_slice();
-    let mut output_buffer = MutableBuffer::new_null(len);
-    let output_slice = output_buffer.as_slice_mut();
-
-    let indices_has_nulls = indices.null_count() > 0;
+#[inline(never)]
+fn take_nulls<I: ArrowPrimitiveType>(
+    values: Option<&NullBuffer>,
+    indices: &PrimitiveArray<I>,
+) -> Option<NullBuffer> {
+    match values.filter(|n| n.null_count() > 0) {

Review Comment:
   it is certainly neat to see nice Rust code like this and then know `rustc` / `LLVM` did the right thing to make it fast



##########
arrow-buffer/src/buffer/scalar.rs:
##########
@@ -140,6 +140,12 @@ impl<T: ArrowNativeType> From<Vec<T>> for ScalarBuffer<T> {
     }
 }
 
+impl<T: ArrowNativeType> FromIterator<T> for ScalarBuffer<T> {
+    fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
+        iter.into_iter().collect::<Vec<_>>().into()

Review Comment:
   Can you please add this (very interesting) information as a comment inline?



##########
arrow-select/src/take.rs:
##########
@@ -376,134 +253,87 @@ where
     I: ArrowPrimitiveType,
     I::Native: ToPrimitive,
 {
-    let indices_nulls = indices.nulls().filter(|x| x.null_count() > 0);
-    let values_nulls = values.nulls().filter(|x| x.null_count() > 0);
-
-    // note: this function should only panic when "an index is not null and out of bounds".
-    // if the index is null, its value is undefined and therefore we should not read from it.
-    let (buffer, nulls) = match (values_nulls, indices_nulls) {
-        (None, None) => {
-            // * no nulls
-            // * all `indices.values()` are valid
-            take_no_nulls(values.values(), indices.values())?
-        }
-        (Some(values_nulls), None) => {
-            // * nulls come from `values` alone
-            // * all `indices.values()` are valid
-            take_values_nulls(values.values(), values_nulls, indices.values())?
-        }
-        (None, Some(indices_nulls)) => {
-            // in this branch it is unsound to read and use `index.values()`,
-            // as doing so is UB when they come from a null slot.
-            take_indices_nulls(values.values(), indices.values(), indices_nulls)?
-        }
-        (Some(values_nulls), Some(indices_nulls)) => {
-            // in this branch it is unsound to read and use `index.values()`,
-            // as doing so is UB when they come from a null slot.
-            take_values_indices_nulls(
-                values.values(),
-                values_nulls,
-                indices.values(),
-                indices_nulls,
-            )?
-        }
-    };
-
-    let data = unsafe {
-        ArrayData::new_unchecked(
-            values.data_type().clone(),
-            indices.len(),
-            None,
-            nulls,
-            0,
-            vec![buffer],
-            vec![],
-        )
-    };
-    Ok(PrimitiveArray::<T>::from(data))
+    let values_buf = take_native(values.values(), indices);
+    let nulls = take_nulls(values.nulls(), indices);
+    Ok(PrimitiveArray::new(values_buf, nulls).with_data_type(values.data_type().clone()))
 }
 
-fn take_bits<IndexType>(
-    values: &Buffer,
-    values_offset: usize,
-    indices: &PrimitiveArray<IndexType>,
-) -> Result<Buffer, ArrowError>
-where
-    IndexType: ArrowPrimitiveType,
-    IndexType::Native: ToPrimitive,
-{
-    let len = indices.len();
-    let values_slice = values.as_slice();
-    let mut output_buffer = MutableBuffer::new_null(len);
-    let output_slice = output_buffer.as_slice_mut();
-
-    let indices_has_nulls = indices.null_count() > 0;
+#[inline(never)]

Review Comment:
   Likewise this may be worth a comment in the code as well



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org