You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/04/28 17:39:44 UTC
[arrow-rs] branch master updated: Fix Null Mask Handling in ArrayData And UnionArray (#1589)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 37085d2f7 Fix Null Mask Handling in ArrayData And UnionArray (#1589)
37085d2f7 is described below
commit 37085d2f73661e9b7b1683c8569ec783b396a08e
Author: Raphael Taylor-Davies <17...@users.noreply.github.com>
AuthorDate: Thu Apr 28 18:39:37 2022 +0100
Fix Null Mask Handling in ArrayData And UnionArray (#1589)
* Fix ListArray and StructArray equality (#626)
* Simplify null masking in equality comparisons
Various UnionArray fixes (#1598) (#1596) (#1591) (#1590)
Fix handling of null masks in ArrayData equality (#1599)
* Miscellaneous fixes
* Fix structure null equality
* Review feedback
---
arrow/src/array/array_union.rs | 72 +++++-----
arrow/src/array/builder.rs | 137 +++++++-----------
arrow/src/array/data.rs | 24 +++-
arrow/src/array/equal/boolean.rs | 11 +-
arrow/src/array/equal/decimal.rs | 11 +-
arrow/src/array/equal/dictionary.rs | 15 +-
arrow/src/array/equal/fixed_binary.rs | 11 +-
arrow/src/array/equal/fixed_list.rs | 23 ++--
arrow/src/array/equal/list.rs | 205 ++++++++++-----------------
arrow/src/array/equal/mod.rs | 203 +++++++++++----------------
arrow/src/array/equal/primitive.rs | 11 +-
arrow/src/array/equal/structure.rs | 51 +++----
arrow/src/array/equal/union.rs | 33 +----
arrow/src/array/equal/utils.rs | 245 ++-------------------------------
arrow/src/array/equal/variable_size.rs | 14 +-
arrow/src/array/transform/union.rs | 156 +++++----------------
arrow/src/compute/kernels/filter.rs | 47 +++++--
arrow/src/ipc/reader.rs | 15 +-
arrow/src/ipc/writer.rs | 9 +-
19 files changed, 412 insertions(+), 881 deletions(-)
diff --git a/arrow/src/array/array_union.rs b/arrow/src/array/array_union.rs
index 2a4a42d95..63cf5c2a0 100644
--- a/arrow/src/array/array_union.rs
+++ b/arrow/src/array/array_union.rs
@@ -61,7 +61,6 @@ use std::any::Any;
/// type_id_buffer,
/// Some(value_offsets_buffer),
/// children,
-/// None,
/// ).unwrap();
///
/// let value = array.value(0).as_any().downcast_ref::<Int32Array>().unwrap().value(0);
@@ -94,7 +93,6 @@ use std::any::Any;
/// type_id_buffer,
/// None,
/// children,
-/// None,
/// ).unwrap();
///
/// let value = array.value(0).as_any().downcast_ref::<Int32Array>().unwrap().value(0);
@@ -140,7 +138,6 @@ impl UnionArray {
type_ids: Buffer,
value_offsets: Option<Buffer>,
child_arrays: Vec<(Field, ArrayRef)>,
- bitmap_data: Option<Buffer>,
) -> Self {
let (field_types, field_values): (Vec<_>, Vec<_>) =
child_arrays.into_iter().unzip();
@@ -152,13 +149,11 @@ impl UnionArray {
UnionMode::Sparse
};
- let mut builder = ArrayData::builder(DataType::Union(field_types, mode))
+ let builder = ArrayData::builder(DataType::Union(field_types, mode))
.add_buffer(type_ids)
.child_data(field_values.into_iter().map(|a| a.data().clone()).collect())
.len(len);
- if let Some(bitmap) = bitmap_data {
- builder = builder.null_bit_buffer(bitmap)
- }
+
let data = match value_offsets {
Some(b) => builder.add_buffer(b).build_unchecked(),
None => builder.build_unchecked(),
@@ -171,7 +166,6 @@ impl UnionArray {
type_ids: Buffer,
value_offsets: Option<Buffer>,
child_arrays: Vec<(Field, ArrayRef)>,
- bitmap: Option<Buffer>,
) -> Result<Self> {
if let Some(b) = &value_offsets {
if ((type_ids.len()) * 4) != b.len() {
@@ -216,7 +210,7 @@ impl UnionArray {
// Unsafe Justification: arguments were validated above (and
// re-revalidated as part of data().validate() below)
let new_self =
- unsafe { Self::new_unchecked(type_ids, value_offsets, child_arrays, bitmap) };
+ unsafe { Self::new_unchecked(type_ids, value_offsets, child_arrays) };
new_self.data().validate()?;
Ok(new_self)
@@ -512,7 +506,7 @@ mod tests {
builder.append::<Int32Type>("a", 1).unwrap();
builder.append::<Int64Type>("c", 3).unwrap();
builder.append::<Int32Type>("a", 10).unwrap();
- builder.append_null().unwrap();
+ builder.append_null::<Int32Type>("a").unwrap();
builder.append::<Int32Type>("a", 6).unwrap();
let union = builder.build().unwrap();
@@ -522,29 +516,29 @@ mod tests {
match i {
0 => {
let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
- assert!(!union.is_null(i));
+ assert!(!slot.is_null(0));
assert_eq!(slot.len(), 1);
let value = slot.value(0);
assert_eq!(1_i32, value);
}
1 => {
let slot = slot.as_any().downcast_ref::<Int64Array>().unwrap();
- assert!(!union.is_null(i));
+ assert!(!slot.is_null(0));
assert_eq!(slot.len(), 1);
let value = slot.value(0);
assert_eq!(3_i64, value);
}
2 => {
let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
- assert!(!union.is_null(i));
+ assert!(!slot.is_null(0));
assert_eq!(slot.len(), 1);
let value = slot.value(0);
assert_eq!(10_i32, value);
}
- 3 => assert!(union.is_null(i)),
+ 3 => assert!(slot.is_null(0)),
4 => {
let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
- assert!(!union.is_null(i));
+ assert!(!slot.is_null(0));
assert_eq!(slot.len(), 1);
let value = slot.value(0);
assert_eq!(6_i32, value);
@@ -560,7 +554,7 @@ mod tests {
builder.append::<Int32Type>("a", 1).unwrap();
builder.append::<Int64Type>("c", 3).unwrap();
builder.append::<Int32Type>("a", 10).unwrap();
- builder.append_null().unwrap();
+ builder.append_null::<Int32Type>("a").unwrap();
builder.append::<Int32Type>("a", 6).unwrap();
let union = builder.build().unwrap();
@@ -573,15 +567,15 @@ mod tests {
match i {
0 => {
let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
- assert!(!union.is_null(i));
+ assert!(!slot.is_null(0));
assert_eq!(slot.len(), 1);
let value = slot.value(0);
assert_eq!(10_i32, value);
}
- 1 => assert!(new_union.is_null(i)),
+ 1 => assert!(slot.is_null(0)),
2 => {
let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
- assert!(!union.is_null(i));
+ assert!(!slot.is_null(0));
assert_eq!(slot.len(), 1);
let value = slot.value(0);
assert_eq!(6_i32, value);
@@ -614,13 +608,9 @@ mod tests {
Arc::new(float_array),
),
];
- let array = UnionArray::try_new(
- type_id_buffer,
- Some(value_offsets_buffer),
- children,
- None,
- )
- .unwrap();
+ let array =
+ UnionArray::try_new(type_id_buffer, Some(value_offsets_buffer), children)
+ .unwrap();
// Check type ids
assert_eq!(Buffer::from_slice_ref(&type_ids), array.data().buffers()[0]);
@@ -800,7 +790,7 @@ mod tests {
fn test_sparse_mixed_with_nulls() {
let mut builder = UnionBuilder::new_sparse(5);
builder.append::<Int32Type>("a", 1).unwrap();
- builder.append_null().unwrap();
+ builder.append_null::<Int32Type>("a").unwrap();
builder.append::<Float64Type>("c", 3.0).unwrap();
builder.append::<Int32Type>("a", 4).unwrap();
let union = builder.build().unwrap();
@@ -824,22 +814,22 @@ mod tests {
match i {
0 => {
let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
- assert!(!union.is_null(i));
+ assert!(!slot.is_null(0));
assert_eq!(slot.len(), 1);
let value = slot.value(0);
assert_eq!(1_i32, value);
}
- 1 => assert!(union.is_null(i)),
+ 1 => assert!(slot.is_null(0)),
2 => {
let slot = slot.as_any().downcast_ref::<Float64Array>().unwrap();
- assert!(!union.is_null(i));
+ assert!(!slot.is_null(0));
assert_eq!(slot.len(), 1);
let value = slot.value(0);
assert_eq!(value, 3_f64);
}
3 => {
let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
- assert!(!union.is_null(i));
+ assert!(!slot.is_null(0));
assert_eq!(slot.len(), 1);
let value = slot.value(0);
assert_eq!(4_i32, value);
@@ -853,9 +843,9 @@ mod tests {
fn test_sparse_mixed_with_nulls_and_offset() {
let mut builder = UnionBuilder::new_sparse(5);
builder.append::<Int32Type>("a", 1).unwrap();
- builder.append_null().unwrap();
+ builder.append_null::<Int32Type>("a").unwrap();
builder.append::<Float64Type>("c", 3.0).unwrap();
- builder.append_null().unwrap();
+ builder.append_null::<Float64Type>("c").unwrap();
builder.append::<Int32Type>("a", 4).unwrap();
let union = builder.build().unwrap();
@@ -866,18 +856,18 @@ mod tests {
for i in 0..new_union.len() {
let slot = new_union.value(i);
match i {
- 0 => assert!(new_union.is_null(i)),
+ 0 => assert!(slot.is_null(0)),
1 => {
let slot = slot.as_any().downcast_ref::<Float64Array>().unwrap();
- assert!(!new_union.is_null(i));
+ assert!(!slot.is_null(0));
assert_eq!(slot.len(), 1);
let value = slot.value(0);
assert_eq!(value, 3_f64);
}
- 2 => assert!(new_union.is_null(i)),
+ 2 => assert!(slot.is_null(0)),
3 => {
let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
- assert!(!new_union.is_null(i));
+ assert!(!slot.is_null(0));
assert_eq!(slot.len(), 1);
let value = slot.value(0);
assert_eq!(4_i32, value);
@@ -886,4 +876,12 @@ mod tests {
}
}
}
+
+ #[test]
+ fn test_type_check() {
+ let mut builder = UnionBuilder::new_sparse(2);
+ builder.append::<Float32Type>("a", 1.0).unwrap();
+ let err = builder.append::<Int32Type>("a", 1).unwrap_err().to_string();
+ assert!(err.contains("Attempt to write col \"a\" with type Int32 doesn't match existing type Float32"), "{}", err);
+ }
}
diff --git a/arrow/src/array/builder.rs b/arrow/src/array/builder.rs
index e98627bae..1c64b5062 100644
--- a/arrow/src/array/builder.rs
+++ b/arrow/src/array/builder.rs
@@ -1894,23 +1894,19 @@ struct FieldData {
values_buffer: Option<MutableBuffer>,
/// The number of array slots represented by the buffer
slots: usize,
- /// A builder for the bitmap if required (for Sparse Unions)
- bitmap_builder: Option<BooleanBufferBuilder>,
+ /// A builder for the null bitmap
+ bitmap_builder: BooleanBufferBuilder,
}
impl FieldData {
/// Creates a new `FieldData`.
- fn new(
- type_id: i8,
- data_type: DataType,
- bitmap_builder: Option<BooleanBufferBuilder>,
- ) -> Self {
+ fn new(type_id: i8, data_type: DataType) -> Self {
Self {
type_id,
data_type,
values_buffer: Some(MutableBuffer::new(1)),
slots: 0,
- bitmap_builder,
+ bitmap_builder: BooleanBufferBuilder::new(1),
}
}
@@ -1931,28 +1927,26 @@ impl FieldData {
self.values_buffer = Some(mutable_buffer);
self.slots += 1;
- if let Some(b) = &mut self.bitmap_builder {
- b.append(true)
- };
+ self.bitmap_builder.append(true);
Ok(())
}
/// Appends a null to this `FieldData`.
#[allow(clippy::unnecessary_wraps)]
fn append_null<T: ArrowPrimitiveType>(&mut self) -> Result<()> {
- if let Some(b) = &mut self.bitmap_builder {
- let values_buffer = self
- .values_buffer
- .take()
- .expect("Values buffer was never created");
- let mut builder: BufferBuilder<T::Native> =
- mutable_buffer_to_builder(values_buffer, self.slots);
- builder.advance(1);
- let mutable_buffer = builder_to_mutable_buffer(builder);
- self.values_buffer = Some(mutable_buffer);
- self.slots += 1;
- b.append(false);
- };
+ let values_buffer = self
+ .values_buffer
+ .take()
+ .expect("Values buffer was never created");
+
+ let mut builder: BufferBuilder<T::Native> =
+ mutable_buffer_to_builder(values_buffer, self.slots);
+
+ builder.advance(1);
+ let mutable_buffer = builder_to_mutable_buffer(builder);
+ self.values_buffer = Some(mutable_buffer);
+ self.slots += 1;
+ self.bitmap_builder.append(false);
Ok(())
}
@@ -2047,8 +2041,6 @@ pub struct UnionBuilder {
type_id_builder: Int8BufferBuilder,
/// Builder to keep track of offsets (`None` for sparse unions)
value_offset_builder: Option<Int32BufferBuilder>,
- /// Optional builder for null slots
- bitmap_builder: Option<BooleanBufferBuilder>,
}
impl UnionBuilder {
@@ -2059,7 +2051,6 @@ impl UnionBuilder {
fields: HashMap::default(),
type_id_builder: Int8BufferBuilder::new(capacity),
value_offset_builder: Some(Int32BufferBuilder::new(capacity)),
- bitmap_builder: None,
}
}
@@ -2070,39 +2061,13 @@ impl UnionBuilder {
fields: HashMap::default(),
type_id_builder: Int8BufferBuilder::new(capacity),
value_offset_builder: None,
- bitmap_builder: None,
}
}
/// Appends a null to this builder.
#[inline]
- pub fn append_null(&mut self) -> Result<()> {
- if self.bitmap_builder.is_none() {
- let mut builder = BooleanBufferBuilder::new(self.len + 1);
- for _ in 0..self.len {
- builder.append(true);
- }
- self.bitmap_builder = Some(builder)
- }
- self.bitmap_builder
- .as_mut()
- .expect("Cannot be None")
- .append(false);
-
- self.type_id_builder.append(i8::default());
-
- match &mut self.value_offset_builder {
- // Handle dense union
- Some(value_offset_builder) => value_offset_builder.append(i32::default()),
- // Handle sparse union
- None => {
- for (_, fd) in self.fields.iter_mut() {
- fd.append_null_dynamic()?;
- }
- }
- };
- self.len += 1;
- Ok(())
+ pub fn append_null<T: ArrowPrimitiveType>(&mut self, type_name: &str) -> Result<()> {
+ self.append_option::<T>(type_name, None)
}
/// Appends a value to this builder.
@@ -2111,22 +2076,28 @@ impl UnionBuilder {
&mut self,
type_name: &str,
v: T::Native,
+ ) -> Result<()> {
+ self.append_option::<T>(type_name, Some(v))
+ }
+
+ fn append_option<T: ArrowPrimitiveType>(
+ &mut self,
+ type_name: &str,
+ v: Option<T::Native>,
) -> Result<()> {
let type_name = type_name.to_string();
let mut field_data = match self.fields.remove(&type_name) {
- Some(data) => data,
- None => match self.value_offset_builder {
- Some(_) => {
- // For Dense Union, we don't build bitmap in individual field
- FieldData::new(self.fields.len() as i8, T::DATA_TYPE, None)
+ Some(data) => {
+ if data.data_type != T::DATA_TYPE {
+ return Err(ArrowError::InvalidArgumentError(format!("Attempt to write col \"{}\" with type {} doesn't match existing type {}", type_name, T::DATA_TYPE, data.data_type)));
}
+ data
+ }
+ None => match self.value_offset_builder {
+ Some(_) => FieldData::new(self.fields.len() as i8, T::DATA_TYPE),
None => {
- let mut fd = FieldData::new(
- self.fields.len() as i8,
- T::DATA_TYPE,
- Some(BooleanBufferBuilder::new(1)),
- );
+ let mut fd = FieldData::new(self.fields.len() as i8, T::DATA_TYPE);
for _ in 0..self.len {
fd.append_null::<T>()?;
}
@@ -2143,20 +2114,19 @@ impl UnionBuilder {
}
// Sparse Union
None => {
- for (name, fd) in self.fields.iter_mut() {
- if name != &type_name {
- fd.append_null_dynamic()?;
- }
+ for (_, fd) in self.fields.iter_mut() {
+ // Append to all bar the FieldData currently being appended to
+ fd.append_null_dynamic()?;
}
}
}
- field_data.append_to_values_buffer::<T>(v)?;
- self.fields.insert(type_name, field_data);
- // Update the bitmap builder if it exists
- if let Some(b) = &mut self.bitmap_builder {
- b.append(true);
+ match v {
+ Some(v) => field_data.append_to_values_buffer::<T>(v)?,
+ None => field_data.append_null::<T>()?,
}
+
+ self.fields.insert(type_name, field_data);
self.len += 1;
Ok(())
}
@@ -2173,7 +2143,7 @@ impl UnionBuilder {
data_type,
values_buffer,
slots,
- bitmap_builder,
+ mut bitmap_builder,
},
) in self.fields.into_iter()
{
@@ -2182,16 +2152,10 @@ impl UnionBuilder {
.into();
let arr_data_builder = ArrayDataBuilder::new(data_type.clone())
.add_buffer(buffer)
- .len(slots);
- // .build();
- let arr_data_ref = unsafe {
- match bitmap_builder {
- Some(mut bb) => arr_data_builder
- .null_bit_buffer(bb.finish())
- .build_unchecked(),
- None => arr_data_builder.build_unchecked(),
- }
- };
+ .len(slots)
+ .null_bit_buffer(bitmap_builder.finish());
+
+ let arr_data_ref = unsafe { arr_data_builder.build_unchecked() };
let array_ref = make_array(arr_data_ref);
children.push((type_id, (Field::new(&name, data_type, false), array_ref)))
}
@@ -2201,9 +2165,8 @@ impl UnionBuilder {
.expect("This will never be None as type ids are always i8 values.")
});
let children: Vec<_> = children.into_iter().map(|(_, b)| b).collect();
- let bitmap = self.bitmap_builder.map(|mut b| b.finish());
- UnionArray::try_new(type_id_buffer, value_offsets_buffer, children, bitmap)
+ UnionArray::try_new(type_id_buffer, value_offsets_buffer, children)
}
}
diff --git a/arrow/src/array/data.rs b/arrow/src/array/data.rs
index 2afc00b58..c0ecef75d 100644
--- a/arrow/src/array/data.rs
+++ b/arrow/src/array/data.rs
@@ -621,6 +621,13 @@ impl ArrayData {
// Check that the data layout conforms to the spec
let layout = layout(&self.data_type);
+ if !layout.can_contain_null_mask && self.null_bitmap.is_some() {
+ return Err(ArrowError::InvalidArgumentError(format!(
+ "Arrays of type {:?} cannot contain a null bitmask",
+ self.data_type,
+ )));
+ }
+
if self.buffers.len() != layout.buffers.len() {
return Err(ArrowError::InvalidArgumentError(format!(
"Expected {} buffers in array of type {:?}, got {}",
@@ -1224,9 +1231,13 @@ fn layout(data_type: &DataType) -> DataTypeLayout {
// https://github.com/apache/arrow/blob/661c7d749150905a63dd3b52e0a04dac39030d95/cpp/src/arrow/type.h (and .cc)
use std::mem::size_of;
match data_type {
- DataType::Null => DataTypeLayout::new_empty(),
+ DataType::Null => DataTypeLayout {
+ buffers: vec![],
+ can_contain_null_mask: false,
+ },
DataType::Boolean => DataTypeLayout {
buffers: vec![BufferSpec::BitMap],
+ can_contain_null_mask: true,
},
DataType::Int8 => DataTypeLayout::new_fixed_width(size_of::<i8>()),
DataType::Int16 => DataTypeLayout::new_fixed_width(size_of::<i16>()),
@@ -1287,6 +1298,7 @@ fn layout(data_type: &DataType) -> DataTypeLayout {
]
}
},
+ can_contain_null_mask: false,
}
}
DataType::Dictionary(key_type, _value_type) => layout(key_type),
@@ -1308,6 +1320,9 @@ fn layout(data_type: &DataType) -> DataTypeLayout {
struct DataTypeLayout {
/// A vector of buffer layout specifications, one for each expected buffer
pub buffers: Vec<BufferSpec>,
+
+ /// Can contain a null bitmask
+ pub can_contain_null_mask: bool,
}
impl DataTypeLayout {
@@ -1315,6 +1330,7 @@ impl DataTypeLayout {
pub fn new_fixed_width(byte_width: usize) -> Self {
Self {
buffers: vec![BufferSpec::FixedWidth { byte_width }],
+ can_contain_null_mask: true,
}
}
@@ -1322,7 +1338,10 @@ impl DataTypeLayout {
/// (e.g. FixedSizeList). Note such arrays may still have a Null
/// Bitmap
pub fn new_empty() -> Self {
- Self { buffers: vec![] }
+ Self {
+ buffers: vec![],
+ can_contain_null_mask: true,
+ }
}
/// Describes a basic numeric array where each element has a fixed
@@ -1338,6 +1357,7 @@ impl DataTypeLayout {
// values
BufferSpec::VariableWidth,
],
+ can_contain_null_mask: true,
}
}
}
diff --git a/arrow/src/array/equal/boolean.rs b/arrow/src/array/equal/boolean.rs
index 35c9786e4..de34d7fab 100644
--- a/arrow/src/array/equal/boolean.rs
+++ b/arrow/src/array/equal/boolean.rs
@@ -16,7 +16,6 @@
// under the License.
use crate::array::{data::count_nulls, ArrayData};
-use crate::buffer::Buffer;
use crate::util::bit_util::get_bit;
use super::utils::{equal_bits, equal_len};
@@ -24,8 +23,6 @@ use super::utils::{equal_bits, equal_len};
pub(super) fn boolean_equal(
lhs: &ArrayData,
rhs: &ArrayData,
- lhs_nulls: Option<&Buffer>,
- rhs_nulls: Option<&Buffer>,
mut lhs_start: usize,
mut rhs_start: usize,
mut len: usize,
@@ -33,8 +30,8 @@ pub(super) fn boolean_equal(
let lhs_values = lhs.buffers()[0].as_slice();
let rhs_values = rhs.buffers()[0].as_slice();
- let lhs_null_count = count_nulls(lhs_nulls, lhs_start, len);
- let rhs_null_count = count_nulls(rhs_nulls, rhs_start, len);
+ let lhs_null_count = count_nulls(lhs.null_buffer(), lhs_start + lhs.offset(), len);
+ let rhs_null_count = count_nulls(rhs.null_buffer(), rhs_start + rhs.offset(), len);
if lhs_null_count == 0 && rhs_null_count == 0 {
// Optimize performance for starting offset at u8 boundary.
@@ -73,8 +70,8 @@ pub(super) fn boolean_equal(
)
} else {
// get a ref of the null buffer bytes, to use in testing for nullness
- let lhs_null_bytes = lhs_nulls.as_ref().unwrap().as_slice();
- let rhs_null_bytes = rhs_nulls.as_ref().unwrap().as_slice();
+ let lhs_null_bytes = lhs.null_buffer().as_ref().unwrap().as_slice();
+ let rhs_null_bytes = rhs.null_buffer().as_ref().unwrap().as_slice();
let lhs_start = lhs.offset() + lhs_start;
let rhs_start = rhs.offset() + rhs_start;
diff --git a/arrow/src/array/equal/decimal.rs b/arrow/src/array/equal/decimal.rs
index 1ee6ec9b5..e9879f3f2 100644
--- a/arrow/src/array/equal/decimal.rs
+++ b/arrow/src/array/equal/decimal.rs
@@ -16,7 +16,6 @@
// under the License.
use crate::array::{data::count_nulls, ArrayData};
-use crate::buffer::Buffer;
use crate::datatypes::DataType;
use crate::util::bit_util::get_bit;
@@ -25,8 +24,6 @@ use super::utils::equal_len;
pub(super) fn decimal_equal(
lhs: &ArrayData,
rhs: &ArrayData,
- lhs_nulls: Option<&Buffer>,
- rhs_nulls: Option<&Buffer>,
lhs_start: usize,
rhs_start: usize,
len: usize,
@@ -39,8 +36,8 @@ pub(super) fn decimal_equal(
let lhs_values = &lhs.buffers()[0].as_slice()[lhs.offset() * size..];
let rhs_values = &rhs.buffers()[0].as_slice()[rhs.offset() * size..];
- let lhs_null_count = count_nulls(lhs_nulls, lhs_start, len);
- let rhs_null_count = count_nulls(rhs_nulls, rhs_start, len);
+ let lhs_null_count = count_nulls(lhs.null_buffer(), lhs_start + lhs.offset(), len);
+ let rhs_null_count = count_nulls(rhs.null_buffer(), rhs_start + rhs.offset(), len);
if lhs_null_count == 0 && rhs_null_count == 0 {
equal_len(
@@ -52,8 +49,8 @@ pub(super) fn decimal_equal(
)
} else {
// get a ref of the null buffer bytes, to use in testing for nullness
- let lhs_null_bytes = lhs_nulls.as_ref().unwrap().as_slice();
- let rhs_null_bytes = rhs_nulls.as_ref().unwrap().as_slice();
+ let lhs_null_bytes = lhs.null_buffer().as_ref().unwrap().as_slice();
+ let rhs_null_bytes = rhs.null_buffer().as_ref().unwrap().as_slice();
// with nulls, we need to compare item by item whenever it is not null
(0..len).all(|i| {
let lhs_pos = lhs_start + i;
diff --git a/arrow/src/array/equal/dictionary.rs b/arrow/src/array/equal/dictionary.rs
index 22add2494..4c9bcf798 100644
--- a/arrow/src/array/equal/dictionary.rs
+++ b/arrow/src/array/equal/dictionary.rs
@@ -16,7 +16,6 @@
// under the License.
use crate::array::{data::count_nulls, ArrayData};
-use crate::buffer::Buffer;
use crate::datatypes::ArrowNativeType;
use crate::util::bit_util::get_bit;
@@ -25,8 +24,6 @@ use super::equal_range;
pub(super) fn dictionary_equal<T: ArrowNativeType>(
lhs: &ArrayData,
rhs: &ArrayData,
- lhs_nulls: Option<&Buffer>,
- rhs_nulls: Option<&Buffer>,
lhs_start: usize,
rhs_start: usize,
len: usize,
@@ -37,8 +34,8 @@ pub(super) fn dictionary_equal<T: ArrowNativeType>(
let lhs_values = &lhs.child_data()[0];
let rhs_values = &rhs.child_data()[0];
- let lhs_null_count = count_nulls(lhs_nulls, lhs_start, len);
- let rhs_null_count = count_nulls(rhs_nulls, rhs_start, len);
+ let lhs_null_count = count_nulls(lhs.null_buffer(), lhs_start + lhs.offset(), len);
+ let rhs_null_count = count_nulls(rhs.null_buffer(), rhs_start + rhs.offset(), len);
if lhs_null_count == 0 && rhs_null_count == 0 {
(0..len).all(|i| {
@@ -48,8 +45,6 @@ pub(super) fn dictionary_equal<T: ArrowNativeType>(
equal_range(
lhs_values,
rhs_values,
- lhs_values.null_buffer(),
- rhs_values.null_buffer(),
lhs_keys[lhs_pos].to_usize().unwrap(),
rhs_keys[rhs_pos].to_usize().unwrap(),
1,
@@ -57,8 +52,8 @@ pub(super) fn dictionary_equal<T: ArrowNativeType>(
})
} else {
// get a ref of the null buffer bytes, to use in testing for nullness
- let lhs_null_bytes = lhs_nulls.as_ref().unwrap().as_slice();
- let rhs_null_bytes = rhs_nulls.as_ref().unwrap().as_slice();
+ let lhs_null_bytes = lhs.null_buffer().as_ref().unwrap().as_slice();
+ let rhs_null_bytes = rhs.null_buffer().as_ref().unwrap().as_slice();
(0..len).all(|i| {
let lhs_pos = lhs_start + i;
let rhs_pos = rhs_start + i;
@@ -71,8 +66,6 @@ pub(super) fn dictionary_equal<T: ArrowNativeType>(
&& equal_range(
lhs_values,
rhs_values,
- lhs_values.null_buffer(),
- rhs_values.null_buffer(),
lhs_keys[lhs_pos].to_usize().unwrap(),
rhs_keys[rhs_pos].to_usize().unwrap(),
1,
diff --git a/arrow/src/array/equal/fixed_binary.rs b/arrow/src/array/equal/fixed_binary.rs
index 5f8f93232..aea0e08a9 100644
--- a/arrow/src/array/equal/fixed_binary.rs
+++ b/arrow/src/array/equal/fixed_binary.rs
@@ -16,7 +16,6 @@
// under the License.
use crate::array::{data::count_nulls, ArrayData};
-use crate::buffer::Buffer;
use crate::datatypes::DataType;
use crate::util::bit_util::get_bit;
@@ -25,8 +24,6 @@ use super::utils::equal_len;
pub(super) fn fixed_binary_equal(
lhs: &ArrayData,
rhs: &ArrayData,
- lhs_nulls: Option<&Buffer>,
- rhs_nulls: Option<&Buffer>,
lhs_start: usize,
rhs_start: usize,
len: usize,
@@ -39,8 +36,8 @@ pub(super) fn fixed_binary_equal(
let lhs_values = &lhs.buffers()[0].as_slice()[lhs.offset() * size..];
let rhs_values = &rhs.buffers()[0].as_slice()[rhs.offset() * size..];
- let lhs_null_count = count_nulls(lhs_nulls, lhs_start, len);
- let rhs_null_count = count_nulls(rhs_nulls, rhs_start, len);
+ let lhs_null_count = count_nulls(lhs.null_buffer(), lhs_start + lhs.offset(), len);
+ let rhs_null_count = count_nulls(rhs.null_buffer(), rhs_start + rhs.offset(), len);
if lhs_null_count == 0 && rhs_null_count == 0 {
equal_len(
@@ -52,8 +49,8 @@ pub(super) fn fixed_binary_equal(
)
} else {
// get a ref of the null buffer bytes, to use in testing for nullness
- let lhs_null_bytes = lhs_nulls.as_ref().unwrap().as_slice();
- let rhs_null_bytes = rhs_nulls.as_ref().unwrap().as_slice();
+ let lhs_null_bytes = lhs.null_buffer().as_ref().unwrap().as_slice();
+ let rhs_null_bytes = rhs.null_buffer().as_ref().unwrap().as_slice();
// with nulls, we need to compare item by item whenever it is not null
(0..len).all(|i| {
let lhs_pos = lhs_start + i;
diff --git a/arrow/src/array/equal/fixed_list.rs b/arrow/src/array/equal/fixed_list.rs
index e708a06ef..82a347c86 100644
--- a/arrow/src/array/equal/fixed_list.rs
+++ b/arrow/src/array/equal/fixed_list.rs
@@ -16,7 +16,6 @@
// under the License.
use crate::array::{data::count_nulls, ArrayData};
-use crate::buffer::Buffer;
use crate::datatypes::DataType;
use crate::util::bit_util::get_bit;
@@ -25,8 +24,6 @@ use super::equal_range;
pub(super) fn fixed_list_equal(
lhs: &ArrayData,
rhs: &ArrayData,
- lhs_nulls: Option<&Buffer>,
- rhs_nulls: Option<&Buffer>,
lhs_start: usize,
rhs_start: usize,
len: usize,
@@ -39,23 +36,21 @@ pub(super) fn fixed_list_equal(
let lhs_values = &lhs.child_data()[0];
let rhs_values = &rhs.child_data()[0];
- let lhs_null_count = count_nulls(lhs_nulls, lhs_start, len);
- let rhs_null_count = count_nulls(rhs_nulls, rhs_start, len);
+ let lhs_null_count = count_nulls(lhs.null_buffer(), lhs_start + lhs.offset(), len);
+ let rhs_null_count = count_nulls(rhs.null_buffer(), rhs_start + rhs.offset(), len);
if lhs_null_count == 0 && rhs_null_count == 0 {
equal_range(
lhs_values,
rhs_values,
- lhs_values.null_buffer(),
- rhs_values.null_buffer(),
- size * lhs_start,
- size * rhs_start,
+ (lhs_start + lhs.offset()) * size,
+ (rhs_start + rhs.offset()) * size,
size * len,
)
} else {
// get a ref of the null buffer bytes, to use in testing for nullness
- let lhs_null_bytes = lhs_nulls.as_ref().unwrap().as_slice();
- let rhs_null_bytes = rhs_nulls.as_ref().unwrap().as_slice();
+ let lhs_null_bytes = lhs.null_buffer().as_ref().unwrap().as_slice();
+ let rhs_null_bytes = rhs.null_buffer().as_ref().unwrap().as_slice();
// with nulls, we need to compare item by item whenever it is not null
(0..len).all(|i| {
let lhs_pos = lhs_start + i;
@@ -69,10 +64,8 @@ pub(super) fn fixed_list_equal(
&& equal_range(
lhs_values,
rhs_values,
- lhs_values.null_buffer(),
- rhs_values.null_buffer(),
- lhs_pos * size,
- rhs_pos * size,
+ (lhs_pos + lhs.offset()) * size,
+ (rhs_pos + rhs.offset()) * size,
size, // 1 * size since we are comparing a single entry
)
})
diff --git a/arrow/src/array/equal/list.rs b/arrow/src/array/equal/list.rs
index 000b31a1f..09ad896f4 100644
--- a/arrow/src/array/equal/list.rs
+++ b/arrow/src/array/equal/list.rs
@@ -15,17 +15,13 @@
// specific language governing permissions and limitations
// under the License.
-use crate::datatypes::DataType;
use crate::{
array::ArrayData,
array::{data::count_nulls, OffsetSizeTrait},
- buffer::Buffer,
util::bit_util::get_bit,
};
-use super::{
- equal_range, equal_values, utils::child_logical_null_buffer, utils::equal_nulls,
-};
+use super::equal_range;
fn lengths_equal<T: OffsetSizeTrait>(lhs: &[T], rhs: &[T]) -> bool {
// invariant from `base_equal`
@@ -49,66 +45,9 @@ fn lengths_equal<T: OffsetSizeTrait>(lhs: &[T], rhs: &[T]) -> bool {
})
}
-#[allow(clippy::too_many_arguments)]
-#[inline]
-fn offset_value_equal<T: OffsetSizeTrait>(
- lhs_values: &ArrayData,
- rhs_values: &ArrayData,
- lhs_nulls: Option<&Buffer>,
- rhs_nulls: Option<&Buffer>,
- lhs_offsets: &[T],
- rhs_offsets: &[T],
- lhs_pos: usize,
- rhs_pos: usize,
- len: usize,
- data_type: &DataType,
-) -> bool {
- let lhs_start = lhs_offsets[lhs_pos].to_usize().unwrap();
- let rhs_start = rhs_offsets[rhs_pos].to_usize().unwrap();
- let lhs_len = lhs_offsets[lhs_pos + len] - lhs_offsets[lhs_pos];
- let rhs_len = rhs_offsets[rhs_pos + len] - rhs_offsets[rhs_pos];
-
- lhs_len == rhs_len && {
- match data_type {
- DataType::Map(_, _) => {
- // Don't use `equal_range` which calls `utils::base_equal` that checks
- // struct fields, but we don't enforce struct field names.
- equal_nulls(
- lhs_values,
- rhs_values,
- lhs_nulls,
- rhs_nulls,
- lhs_start,
- rhs_start,
- lhs_len.to_usize().unwrap(),
- ) && equal_values(
- lhs_values,
- rhs_values,
- lhs_nulls,
- rhs_nulls,
- lhs_start,
- rhs_start,
- lhs_len.to_usize().unwrap(),
- )
- }
- _ => equal_range(
- lhs_values,
- rhs_values,
- lhs_nulls,
- rhs_nulls,
- lhs_start,
- rhs_start,
- lhs_len.to_usize().unwrap(),
- ),
- }
- }
-}
-
pub(super) fn list_equal<T: OffsetSizeTrait>(
lhs: &ArrayData,
rhs: &ArrayData,
- lhs_nulls: Option<&Buffer>,
- rhs_nulls: Option<&Buffer>,
lhs_start: usize,
rhs_start: usize,
len: usize,
@@ -123,7 +62,7 @@ pub(super) fn list_equal<T: OffsetSizeTrait>(
// no child values. This causes panics when trying to count set bits.
//
// We caught this by chance from an accidental test-case, but due to the nature of this
- // crash only occuring on list equality checks, we are adding a check here, instead of
+ // crash only occurring on list equality checks, we are adding a check here, instead of
// on the buffer/bitmap utilities, as a length check would incur a penalty for almost all
// other use-cases.
//
@@ -134,10 +73,11 @@ pub(super) fn list_equal<T: OffsetSizeTrait>(
// however, one is more likely to slice into a list array and get a region that has 0
// child values.
// The test that triggered this behaviour had [4, 4] as a slice of 1 value slot.
- let lhs_child_length = lhs_offsets.get(len).unwrap().to_usize().unwrap()
- - lhs_offsets.first().unwrap().to_usize().unwrap();
- let rhs_child_length = rhs_offsets.get(len).unwrap().to_usize().unwrap()
- - rhs_offsets.first().unwrap().to_usize().unwrap();
+ let lhs_child_length = lhs_offsets[lhs_start + len].to_usize().unwrap()
+ - lhs_offsets[lhs_start].to_usize().unwrap();
+
+ let rhs_child_length = rhs_offsets[rhs_start + len].to_usize().unwrap()
+ - rhs_offsets[rhs_start].to_usize().unwrap();
if lhs_child_length == 0 && lhs_child_length == rhs_child_length {
return true;
@@ -146,64 +86,33 @@ pub(super) fn list_equal<T: OffsetSizeTrait>(
let lhs_values = &lhs.child_data()[0];
let rhs_values = &rhs.child_data()[0];
- let lhs_null_count = count_nulls(lhs_nulls, lhs_start, len);
- let rhs_null_count = count_nulls(rhs_nulls, rhs_start, len);
+ let lhs_null_count = count_nulls(lhs.null_buffer(), lhs_start + lhs.offset(), len);
+ let rhs_null_count = count_nulls(rhs.null_buffer(), rhs_start + rhs.offset(), len);
- // compute the child logical bitmap
- let child_lhs_nulls =
- child_logical_null_buffer(lhs, lhs_nulls, lhs.child_data().get(0).unwrap());
- let child_rhs_nulls =
- child_logical_null_buffer(rhs, rhs_nulls, rhs.child_data().get(0).unwrap());
+ if lhs_null_count != rhs_null_count {
+ return false;
+ }
if lhs_null_count == 0 && rhs_null_count == 0 {
- lengths_equal(
- &lhs_offsets[lhs_start..lhs_start + len],
- &rhs_offsets[rhs_start..rhs_start + len],
- ) && {
- match lhs.data_type() {
- DataType::Map(_, _) => {
- // Don't use `equal_range` which calls `utils::base_equal` that checks
- // struct fields, but we don't enforce struct field names.
- equal_nulls(
- lhs_values,
- rhs_values,
- child_lhs_nulls.as_ref(),
- child_rhs_nulls.as_ref(),
- lhs_offsets[lhs_start].to_usize().unwrap(),
- rhs_offsets[rhs_start].to_usize().unwrap(),
- (lhs_offsets[lhs_start + len] - lhs_offsets[lhs_start])
- .to_usize()
- .unwrap(),
- ) && equal_values(
- lhs_values,
- rhs_values,
- child_lhs_nulls.as_ref(),
- child_rhs_nulls.as_ref(),
- lhs_offsets[lhs_start].to_usize().unwrap(),
- rhs_offsets[rhs_start].to_usize().unwrap(),
- (lhs_offsets[lhs_start + len] - lhs_offsets[lhs_start])
- .to_usize()
- .unwrap(),
- )
- }
- _ => equal_range(
- lhs_values,
- rhs_values,
- child_lhs_nulls.as_ref(),
- child_rhs_nulls.as_ref(),
- lhs_offsets[lhs_start].to_usize().unwrap(),
- rhs_offsets[rhs_start].to_usize().unwrap(),
- (lhs_offsets[lhs_start + len] - lhs_offsets[lhs_start])
- .to_usize()
- .unwrap(),
- ),
- }
- }
+ lhs_child_length == rhs_child_length
+ && lengths_equal(
+ &lhs_offsets[lhs_start..lhs_start + len],
+ &rhs_offsets[rhs_start..rhs_start + len],
+ )
+ && equal_range(
+ lhs_values,
+ rhs_values,
+ lhs_offsets[lhs_start].to_usize().unwrap(),
+ rhs_offsets[rhs_start].to_usize().unwrap(),
+ lhs_child_length,
+ )
} else {
// get a ref of the parent null buffer bytes, to use in testing for nullness
- let lhs_null_bytes = lhs_nulls.unwrap().as_slice();
- let rhs_null_bytes = rhs_nulls.unwrap().as_slice();
+ let lhs_null_bytes = lhs.null_buffer().unwrap().as_slice();
+ let rhs_null_bytes = rhs.null_buffer().unwrap().as_slice();
+
// with nulls, we need to compare item by item whenever it is not null
+ // TODO: Could potentially compare runs of not NULL values
(0..len).all(|i| {
let lhs_pos = lhs_start + i;
let rhs_pos = rhs_start + i;
@@ -211,20 +120,56 @@ pub(super) fn list_equal<T: OffsetSizeTrait>(
let lhs_is_null = !get_bit(lhs_null_bytes, lhs_pos + lhs.offset());
let rhs_is_null = !get_bit(rhs_null_bytes, rhs_pos + rhs.offset());
+ if lhs_is_null != rhs_is_null {
+ return false;
+ }
+
+ let lhs_offset_start = lhs_offsets[lhs_pos].to_usize().unwrap();
+ let lhs_offset_end = lhs_offsets[lhs_pos + 1].to_usize().unwrap();
+ let rhs_offset_start = rhs_offsets[rhs_pos].to_usize().unwrap();
+ let rhs_offset_end = rhs_offsets[rhs_pos + 1].to_usize().unwrap();
+
+ let lhs_len = lhs_offset_end - lhs_offset_start;
+ let rhs_len = rhs_offset_end - rhs_offset_start;
+
lhs_is_null
- || (lhs_is_null == rhs_is_null)
- && offset_value_equal::<T>(
+ || (lhs_len == rhs_len
+ && equal_range(
lhs_values,
rhs_values,
- child_lhs_nulls.as_ref(),
- child_rhs_nulls.as_ref(),
- lhs_offsets,
- rhs_offsets,
- lhs_pos,
- rhs_pos,
- 1,
- lhs.data_type(),
- )
+ lhs_offset_start,
+ rhs_offset_start,
+ lhs_len,
+ ))
})
}
}
+
+#[cfg(test)]
+mod tests {
+ use crate::array::{Int64Builder, ListBuilder};
+
+ #[test]
+ fn list_array_non_zero_nulls() {
+ // Tests handling of list arrays with non-empty null ranges
+ let mut builder = ListBuilder::new(Int64Builder::new(10));
+ builder.values().append_value(1).unwrap();
+ builder.values().append_value(2).unwrap();
+ builder.values().append_value(3).unwrap();
+ builder.append(true).unwrap();
+ builder.append(false).unwrap();
+ let array1 = builder.finish();
+
+ let mut builder = ListBuilder::new(Int64Builder::new(10));
+ builder.values().append_value(1).unwrap();
+ builder.values().append_value(2).unwrap();
+ builder.values().append_value(3).unwrap();
+ builder.append(true).unwrap();
+ builder.values().append_null().unwrap();
+ builder.values().append_null().unwrap();
+ builder.append(false).unwrap();
+ let array2 = builder.finish();
+
+ assert_eq!(array1, array2);
+ }
+}
diff --git a/arrow/src/array/equal/mod.rs b/arrow/src/array/equal/mod.rs
index 07c173b13..f5f0d60c7 100644
--- a/arrow/src/array/equal/mod.rs
+++ b/arrow/src/array/equal/mod.rs
@@ -25,10 +25,7 @@ use super::{
GenericStringArray, MapArray, NullArray, OffsetSizeTrait, PrimitiveArray,
StringOffsetSizeTrait, StructArray,
};
-use crate::{
- buffer::Buffer,
- datatypes::{ArrowPrimitiveType, DataType, IntervalUnit},
-};
+use crate::datatypes::{ArrowPrimitiveType, DataType, IntervalUnit};
use half::f16;
mod boolean;
@@ -144,147 +141,99 @@ impl PartialEq for StructArray {
}
/// Compares the values of two [ArrayData] starting at `lhs_start` and `rhs_start` respectively
-/// for `len` slots. The null buffers `lhs_nulls` and `rhs_nulls` inherit parent nullability.
-///
-/// If an array is a child of a struct or list, the array's nulls have to be merged with the parent.
-/// This then affects the null count of the array, thus the merged nulls are passed separately
-/// as `lhs_nulls` and `rhs_nulls` variables to functions.
-/// The nulls are merged with a bitwise AND, and null counts are recomputed where necessary.
+/// for `len` slots.
#[inline]
fn equal_values(
lhs: &ArrayData,
rhs: &ArrayData,
- lhs_nulls: Option<&Buffer>,
- rhs_nulls: Option<&Buffer>,
lhs_start: usize,
rhs_start: usize,
len: usize,
) -> bool {
match lhs.data_type() {
DataType::Null => null_equal(lhs, rhs, lhs_start, rhs_start, len),
- DataType::Boolean => {
- boolean_equal(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len)
- }
- DataType::UInt8 => primitive_equal::<u8>(
- lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len,
- ),
- DataType::UInt16 => primitive_equal::<u16>(
- lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len,
- ),
- DataType::UInt32 => primitive_equal::<u32>(
- lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len,
- ),
- DataType::UInt64 => primitive_equal::<u64>(
- lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len,
- ),
- DataType::Int8 => primitive_equal::<i8>(
- lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len,
- ),
- DataType::Int16 => primitive_equal::<i16>(
- lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len,
- ),
- DataType::Int32 => primitive_equal::<i32>(
- lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len,
- ),
- DataType::Int64 => primitive_equal::<i64>(
- lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len,
- ),
- DataType::Float32 => primitive_equal::<f32>(
- lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len,
- ),
- DataType::Float64 => primitive_equal::<f64>(
- lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len,
- ),
+ DataType::Boolean => boolean_equal(lhs, rhs, lhs_start, rhs_start, len),
+ DataType::UInt8 => primitive_equal::<u8>(lhs, rhs, lhs_start, rhs_start, len),
+ DataType::UInt16 => primitive_equal::<u16>(lhs, rhs, lhs_start, rhs_start, len),
+ DataType::UInt32 => primitive_equal::<u32>(lhs, rhs, lhs_start, rhs_start, len),
+ DataType::UInt64 => primitive_equal::<u64>(lhs, rhs, lhs_start, rhs_start, len),
+ DataType::Int8 => primitive_equal::<i8>(lhs, rhs, lhs_start, rhs_start, len),
+ DataType::Int16 => primitive_equal::<i16>(lhs, rhs, lhs_start, rhs_start, len),
+ DataType::Int32 => primitive_equal::<i32>(lhs, rhs, lhs_start, rhs_start, len),
+ DataType::Int64 => primitive_equal::<i64>(lhs, rhs, lhs_start, rhs_start, len),
+ DataType::Float32 => primitive_equal::<f32>(lhs, rhs, lhs_start, rhs_start, len),
+ DataType::Float64 => primitive_equal::<f64>(lhs, rhs, lhs_start, rhs_start, len),
DataType::Date32
| DataType::Time32(_)
- | DataType::Interval(IntervalUnit::YearMonth) => primitive_equal::<i32>(
- lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len,
- ),
+ | DataType::Interval(IntervalUnit::YearMonth) => {
+ primitive_equal::<i32>(lhs, rhs, lhs_start, rhs_start, len)
+ }
DataType::Date64
| DataType::Interval(IntervalUnit::DayTime)
| DataType::Time64(_)
| DataType::Timestamp(_, _)
- | DataType::Duration(_) => primitive_equal::<i64>(
- lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len,
- ),
- DataType::Interval(IntervalUnit::MonthDayNano) => primitive_equal::<i128>(
- lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len,
- ),
- DataType::Utf8 | DataType::Binary => variable_sized_equal::<i32>(
- lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len,
- ),
- DataType::LargeUtf8 | DataType::LargeBinary => variable_sized_equal::<i64>(
- lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len,
- ),
- DataType::FixedSizeBinary(_) => {
- fixed_binary_equal(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len)
+ | DataType::Duration(_) => {
+ primitive_equal::<i64>(lhs, rhs, lhs_start, rhs_start, len)
}
- DataType::Decimal(_, _) => {
- decimal_equal(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len)
+ DataType::Interval(IntervalUnit::MonthDayNano) => {
+ primitive_equal::<i128>(lhs, rhs, lhs_start, rhs_start, len)
}
- DataType::List(_) => {
- list_equal::<i32>(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len)
+ DataType::Utf8 | DataType::Binary => {
+ variable_sized_equal::<i32>(lhs, rhs, lhs_start, rhs_start, len)
}
- DataType::LargeList(_) => {
- list_equal::<i64>(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len)
+ DataType::LargeUtf8 | DataType::LargeBinary => {
+ variable_sized_equal::<i64>(lhs, rhs, lhs_start, rhs_start, len)
}
- DataType::FixedSizeList(_, _) => {
- fixed_list_equal(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len)
- }
- DataType::Struct(_) => {
- struct_equal(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len)
+ DataType::FixedSizeBinary(_) => {
+ fixed_binary_equal(lhs, rhs, lhs_start, rhs_start, len)
}
- DataType::Union(_, _) => {
- union_equal(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len)
+ DataType::Decimal(_, _) => decimal_equal(lhs, rhs, lhs_start, rhs_start, len),
+ DataType::List(_) => list_equal::<i32>(lhs, rhs, lhs_start, rhs_start, len),
+ DataType::LargeList(_) => list_equal::<i64>(lhs, rhs, lhs_start, rhs_start, len),
+ DataType::FixedSizeList(_, _) => {
+ fixed_list_equal(lhs, rhs, lhs_start, rhs_start, len)
}
+ DataType::Struct(_) => struct_equal(lhs, rhs, lhs_start, rhs_start, len),
+ DataType::Union(_, _) => union_equal(lhs, rhs, lhs_start, rhs_start, len),
DataType::Dictionary(data_type, _) => match data_type.as_ref() {
- DataType::Int8 => dictionary_equal::<i8>(
- lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len,
- ),
- DataType::Int16 => dictionary_equal::<i16>(
- lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len,
- ),
- DataType::Int32 => dictionary_equal::<i32>(
- lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len,
- ),
- DataType::Int64 => dictionary_equal::<i64>(
- lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len,
- ),
- DataType::UInt8 => dictionary_equal::<u8>(
- lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len,
- ),
- DataType::UInt16 => dictionary_equal::<u16>(
- lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len,
- ),
- DataType::UInt32 => dictionary_equal::<u32>(
- lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len,
- ),
- DataType::UInt64 => dictionary_equal::<u64>(
- lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len,
- ),
+ DataType::Int8 => dictionary_equal::<i8>(lhs, rhs, lhs_start, rhs_start, len),
+ DataType::Int16 => {
+ dictionary_equal::<i16>(lhs, rhs, lhs_start, rhs_start, len)
+ }
+ DataType::Int32 => {
+ dictionary_equal::<i32>(lhs, rhs, lhs_start, rhs_start, len)
+ }
+ DataType::Int64 => {
+ dictionary_equal::<i64>(lhs, rhs, lhs_start, rhs_start, len)
+ }
+ DataType::UInt8 => {
+ dictionary_equal::<u8>(lhs, rhs, lhs_start, rhs_start, len)
+ }
+ DataType::UInt16 => {
+ dictionary_equal::<u16>(lhs, rhs, lhs_start, rhs_start, len)
+ }
+ DataType::UInt32 => {
+ dictionary_equal::<u32>(lhs, rhs, lhs_start, rhs_start, len)
+ }
+ DataType::UInt64 => {
+ dictionary_equal::<u64>(lhs, rhs, lhs_start, rhs_start, len)
+ }
_ => unreachable!(),
},
- DataType::Float16 => primitive_equal::<f16>(
- lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len,
- ),
- DataType::Map(_, _) => {
- list_equal::<i32>(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len)
- }
+ DataType::Float16 => primitive_equal::<f16>(lhs, rhs, lhs_start, rhs_start, len),
+ DataType::Map(_, _) => list_equal::<i32>(lhs, rhs, lhs_start, rhs_start, len),
}
}
fn equal_range(
lhs: &ArrayData,
rhs: &ArrayData,
- lhs_nulls: Option<&Buffer>,
- rhs_nulls: Option<&Buffer>,
lhs_start: usize,
rhs_start: usize,
len: usize,
) -> bool {
- utils::base_equal(lhs, rhs)
- && utils::equal_nulls(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len)
- && equal_values(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len)
+ utils::equal_nulls(lhs, rhs, lhs_start, rhs_start, len)
+ && equal_values(lhs, rhs, lhs_start, rhs_start, len)
}
/// Logically compares two [ArrayData].
@@ -300,12 +249,10 @@ fn equal_range(
/// This function may panic whenever any of the [ArrayData] does not follow the Arrow specification.
/// (e.g. wrong number of buffers, buffer `len` does not correspond to the declared `len`)
pub fn equal(lhs: &ArrayData, rhs: &ArrayData) -> bool {
- let lhs_nulls = lhs.null_buffer();
- let rhs_nulls = rhs.null_buffer();
utils::base_equal(lhs, rhs)
&& lhs.null_count() == rhs.null_count()
- && utils::equal_nulls(lhs, rhs, lhs_nulls, rhs_nulls, 0, 0, lhs.len())
- && equal_values(lhs, rhs, lhs_nulls, rhs_nulls, 0, 0, lhs.len())
+ && utils::equal_nulls(lhs, rhs, 0, 0, lhs.len())
+ && equal_values(lhs, rhs, 0, 0, lhs.len())
}
#[cfg(test)]
@@ -494,6 +441,13 @@ mod tests {
(1, 2),
true,
),
+ (
+ vec![Some(1), Some(2), None, Some(0)],
+ (2, 2),
+ vec![Some(4), Some(5), Some(0), None],
+ (2, 2),
+ false,
+ ),
];
for (lhs, slice_lhs, rhs, slice_rhs, expected) in cases {
@@ -990,6 +944,11 @@ mod tests {
None,
]);
test_equal(&a, &b, false);
+
+ let b = create_fixed_size_list_array(&[None, Some(&[4, 5, 6]), None, None]);
+
+ test_equal(&a.slice(2, 4), &b, true);
+ test_equal(&a.slice(3, 3), &b.slice(1, 3), true);
}
#[test]
@@ -1359,7 +1318,7 @@ mod tests {
builder.append::<Int32Type>("b", 2).unwrap();
builder.append::<Int32Type>("c", 3).unwrap();
builder.append::<Int32Type>("a", 4).unwrap();
- builder.append_null().unwrap();
+ builder.append_null::<Int32Type>("a").unwrap();
builder.append::<Int32Type>("a", 6).unwrap();
builder.append::<Int32Type>("b", 7).unwrap();
let union1 = builder.build().unwrap();
@@ -1369,7 +1328,7 @@ mod tests {
builder.append::<Int32Type>("b", 2).unwrap();
builder.append::<Int32Type>("c", 3).unwrap();
builder.append::<Int32Type>("a", 4).unwrap();
- builder.append_null().unwrap();
+ builder.append_null::<Int32Type>("a").unwrap();
builder.append::<Int32Type>("a", 6).unwrap();
builder.append::<Int32Type>("b", 7).unwrap();
let union2 = builder.build().unwrap();
@@ -1389,8 +1348,8 @@ mod tests {
builder.append::<Int32Type>("b", 2).unwrap();
builder.append::<Int32Type>("c", 3).unwrap();
builder.append::<Int32Type>("a", 4).unwrap();
- builder.append_null().unwrap();
- builder.append_null().unwrap();
+ builder.append_null::<Int32Type>("c").unwrap();
+ builder.append_null::<Int32Type>("b").unwrap();
builder.append::<Int32Type>("b", 7).unwrap();
let union4 = builder.build().unwrap();
@@ -1406,7 +1365,7 @@ mod tests {
builder.append::<Int32Type>("b", 2).unwrap();
builder.append::<Int32Type>("c", 3).unwrap();
builder.append::<Int32Type>("a", 4).unwrap();
- builder.append_null().unwrap();
+ builder.append_null::<Int32Type>("a").unwrap();
builder.append::<Int32Type>("a", 6).unwrap();
builder.append::<Int32Type>("b", 7).unwrap();
let union1 = builder.build().unwrap();
@@ -1416,7 +1375,7 @@ mod tests {
builder.append::<Int32Type>("b", 2).unwrap();
builder.append::<Int32Type>("c", 3).unwrap();
builder.append::<Int32Type>("a", 4).unwrap();
- builder.append_null().unwrap();
+ builder.append_null::<Int32Type>("a").unwrap();
builder.append::<Int32Type>("a", 6).unwrap();
builder.append::<Int32Type>("b", 7).unwrap();
let union2 = builder.build().unwrap();
@@ -1436,8 +1395,8 @@ mod tests {
builder.append::<Int32Type>("b", 2).unwrap();
builder.append::<Int32Type>("c", 3).unwrap();
builder.append::<Int32Type>("a", 4).unwrap();
- builder.append_null().unwrap();
- builder.append_null().unwrap();
+ builder.append_null::<Int32Type>("a").unwrap();
+ builder.append_null::<Int32Type>("a").unwrap();
builder.append::<Int32Type>("b", 7).unwrap();
let union4 = builder.build().unwrap();
diff --git a/arrow/src/array/equal/primitive.rs b/arrow/src/array/equal/primitive.rs
index db7587915..09882cd78 100644
--- a/arrow/src/array/equal/primitive.rs
+++ b/arrow/src/array/equal/primitive.rs
@@ -18,7 +18,6 @@
use std::mem::size_of;
use crate::array::{data::count_nulls, ArrayData};
-use crate::buffer::Buffer;
use crate::util::bit_util::get_bit;
use super::utils::equal_len;
@@ -26,8 +25,6 @@ use super::utils::equal_len;
pub(super) fn primitive_equal<T>(
lhs: &ArrayData,
rhs: &ArrayData,
- lhs_nulls: Option<&Buffer>,
- rhs_nulls: Option<&Buffer>,
lhs_start: usize,
rhs_start: usize,
len: usize,
@@ -36,8 +33,8 @@ pub(super) fn primitive_equal<T>(
let lhs_values = &lhs.buffers()[0].as_slice()[lhs.offset() * byte_width..];
let rhs_values = &rhs.buffers()[0].as_slice()[rhs.offset() * byte_width..];
- let lhs_null_count = count_nulls(lhs_nulls, lhs_start, len);
- let rhs_null_count = count_nulls(rhs_nulls, rhs_start, len);
+ let lhs_null_count = count_nulls(lhs.null_buffer(), lhs_start + lhs.offset(), len);
+ let rhs_null_count = count_nulls(rhs.null_buffer(), rhs_start + rhs.offset(), len);
if lhs_null_count == 0 && rhs_null_count == 0 {
// without nulls, we just need to compare slices
@@ -50,8 +47,8 @@ pub(super) fn primitive_equal<T>(
)
} else {
// get a ref of the null buffer bytes, to use in testing for nullness
- let lhs_null_bytes = lhs_nulls.as_ref().unwrap().as_slice();
- let rhs_null_bytes = rhs_nulls.as_ref().unwrap().as_slice();
+ let lhs_null_bytes = lhs.null_buffer().as_ref().unwrap().as_slice();
+ let rhs_null_bytes = rhs.null_buffer().as_ref().unwrap().as_slice();
// with nulls, we need to compare item by item whenever it is not null
(0..len).all(|i| {
let lhs_pos = lhs_start + i;
diff --git a/arrow/src/array/equal/structure.rs b/arrow/src/array/equal/structure.rs
index b3cc4029e..0f943e40c 100644
--- a/arrow/src/array/equal/structure.rs
+++ b/arrow/src/array/equal/structure.rs
@@ -15,24 +15,15 @@
// specific language governing permissions and limitations
// under the License.
-use crate::{
- array::data::count_nulls, array::ArrayData, buffer::Buffer, util::bit_util::get_bit,
-};
+use crate::{array::data::count_nulls, array::ArrayData, util::bit_util::get_bit};
-use super::{equal_range, utils::child_logical_null_buffer};
+use super::equal_range;
/// Compares the values of two [ArrayData] starting at `lhs_start` and `rhs_start` respectively
-/// for `len` slots. The null buffers `lhs_nulls` and `rhs_nulls` inherit parent nullability.
-///
-/// If an array is a child of a struct or list, the array's nulls have to be merged with the parent.
-/// This then affects the null count of the array, thus the merged nulls are passed separately
-/// as `lhs_nulls` and `rhs_nulls` variables to functions.
-/// The nulls are merged with a bitwise AND, and null counts are recomputed where necessary.
-fn equal_values(
+/// for `len` slots.
+fn equal_child_values(
lhs: &ArrayData,
rhs: &ArrayData,
- lhs_nulls: Option<&Buffer>,
- rhs_nulls: Option<&Buffer>,
lhs_start: usize,
rhs_start: usize,
len: usize,
@@ -41,39 +32,27 @@ fn equal_values(
.iter()
.zip(rhs.child_data())
.all(|(lhs_values, rhs_values)| {
- // merge the null data
- let lhs_merged_nulls = child_logical_null_buffer(lhs, lhs_nulls, lhs_values);
- let rhs_merged_nulls = child_logical_null_buffer(rhs, rhs_nulls, rhs_values);
- equal_range(
- lhs_values,
- rhs_values,
- lhs_merged_nulls.as_ref(),
- rhs_merged_nulls.as_ref(),
- lhs_start,
- rhs_start,
- len,
- )
+ equal_range(lhs_values, rhs_values, lhs_start, rhs_start, len)
})
}
pub(super) fn struct_equal(
lhs: &ArrayData,
rhs: &ArrayData,
- lhs_nulls: Option<&Buffer>,
- rhs_nulls: Option<&Buffer>,
lhs_start: usize,
rhs_start: usize,
len: usize,
) -> bool {
// we have to recalculate null counts from the null buffers
- let lhs_null_count = count_nulls(lhs_nulls, lhs_start, len);
- let rhs_null_count = count_nulls(rhs_nulls, rhs_start, len);
+ let lhs_null_count = count_nulls(lhs.null_buffer(), lhs_start + lhs.offset(), len);
+ let rhs_null_count = count_nulls(rhs.null_buffer(), rhs_start + rhs.offset(), len);
+
if lhs_null_count == 0 && rhs_null_count == 0 {
- equal_values(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len)
+ equal_child_values(lhs, rhs, lhs_start, rhs_start, len)
} else {
// get a ref of the null buffer bytes, to use in testing for nullness
- let lhs_null_bytes = lhs_nulls.as_ref().unwrap().as_slice();
- let rhs_null_bytes = rhs_nulls.as_ref().unwrap().as_slice();
+ let lhs_null_bytes = lhs.null_buffer().as_ref().unwrap().as_slice();
+ let rhs_null_bytes = rhs.null_buffer().as_ref().unwrap().as_slice();
// with nulls, we need to compare item by item whenever it is not null
(0..len).all(|i| {
let lhs_pos = lhs_start + i;
@@ -82,9 +61,11 @@ pub(super) fn struct_equal(
let lhs_is_null = !get_bit(lhs_null_bytes, lhs_pos + lhs.offset());
let rhs_is_null = !get_bit(rhs_null_bytes, rhs_pos + rhs.offset());
- lhs_is_null
- || (lhs_is_null == rhs_is_null)
- && equal_values(lhs, rhs, lhs_nulls, rhs_nulls, lhs_pos, rhs_pos, 1)
+ if lhs_is_null != rhs_is_null {
+ return false;
+ }
+
+ lhs_is_null || equal_child_values(lhs, rhs, lhs_pos, rhs_pos, 1)
})
}
}
diff --git a/arrow/src/array/equal/union.rs b/arrow/src/array/equal/union.rs
index 36cd19725..021b0a3b7 100644
--- a/arrow/src/array/equal/union.rs
+++ b/arrow/src/array/equal/union.rs
@@ -15,13 +15,9 @@
// specific language governing permissions and limitations
// under the License.
-use crate::{
- array::ArrayData, buffer::Buffer, datatypes::DataType, datatypes::UnionMode,
-};
+use crate::{array::ArrayData, datatypes::DataType, datatypes::UnionMode};
-use super::{
- equal_range, equal_values, utils::child_logical_null_buffer, utils::equal_nulls,
-};
+use super::equal_range;
fn equal_dense(
lhs: &ArrayData,
@@ -41,11 +37,9 @@ fn equal_dense(
let lhs_values = &lhs.child_data()[*l_type_id as usize];
let rhs_values = &rhs.child_data()[*r_type_id as usize];
- equal_values(
+ equal_range(
lhs_values,
rhs_values,
- None,
- None,
*l_offset as usize,
*r_offset as usize,
1,
@@ -56,8 +50,6 @@ fn equal_dense(
fn equal_sparse(
lhs: &ArrayData,
rhs: &ArrayData,
- lhs_nulls: Option<&Buffer>,
- rhs_nulls: Option<&Buffer>,
lhs_start: usize,
rhs_start: usize,
len: usize,
@@ -66,26 +58,13 @@ fn equal_sparse(
.iter()
.zip(rhs.child_data())
.all(|(lhs_values, rhs_values)| {
- // merge the null data
- let lhs_merged_nulls = child_logical_null_buffer(lhs, lhs_nulls, lhs_values);
- let rhs_merged_nulls = child_logical_null_buffer(rhs, rhs_nulls, rhs_values);
- equal_range(
- lhs_values,
- rhs_values,
- lhs_merged_nulls.as_ref(),
- rhs_merged_nulls.as_ref(),
- lhs_start,
- rhs_start,
- len,
- )
+ equal_range(lhs_values, rhs_values, lhs_start, rhs_start, len)
})
}
pub(super) fn union_equal(
lhs: &ArrayData,
rhs: &ArrayData,
- lhs_nulls: Option<&Buffer>,
- rhs_nulls: Option<&Buffer>,
lhs_start: usize,
rhs_start: usize,
len: usize,
@@ -104,9 +83,7 @@ pub(super) fn union_equal(
let lhs_offsets_range = &lhs_offsets[lhs_start..lhs_start + len];
let rhs_offsets_range = &rhs_offsets[rhs_start..rhs_start + len];
- // nullness is kept in the parent UnionArray, so we compare its nulls here
lhs_type_id_range == rhs_type_id_range
- && equal_nulls(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len)
&& equal_dense(
lhs,
rhs,
@@ -121,7 +98,7 @@ pub(super) fn union_equal(
DataType::Union(_, UnionMode::Sparse),
) => {
lhs_type_id_range == rhs_type_id_range
- && equal_sparse(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len)
+ && equal_sparse(lhs, rhs, lhs_start, rhs_start, len)
}
_ => unimplemented!(
"Logical equality not yet implemented between dense and sparse union arrays"
diff --git a/arrow/src/array/equal/utils.rs b/arrow/src/array/equal/utils.rs
index b6690f936..8875239ca 100644
--- a/arrow/src/array/equal/utils.rs
+++ b/arrow/src/array/equal/utils.rs
@@ -15,10 +15,8 @@
// specific language governing permissions and limitations
// under the License.
-use crate::array::{data::count_nulls, ArrayData, OffsetSizeTrait};
-use crate::bitmap::Bitmap;
-use crate::buffer::{Buffer, MutableBuffer};
-use crate::datatypes::{DataType, UnionMode};
+use crate::array::{data::count_nulls, ArrayData};
+use crate::datatypes::DataType;
use crate::util::bit_util;
// whether bits along the positions are equal
@@ -41,17 +39,20 @@ pub(super) fn equal_bits(
pub(super) fn equal_nulls(
lhs: &ArrayData,
rhs: &ArrayData,
- lhs_nulls: Option<&Buffer>,
- rhs_nulls: Option<&Buffer>,
lhs_start: usize,
rhs_start: usize,
len: usize,
) -> bool {
- let lhs_null_count = count_nulls(lhs_nulls, lhs_start, len);
- let rhs_null_count = count_nulls(rhs_nulls, rhs_start, len);
+ let lhs_null_count = count_nulls(lhs.null_buffer(), lhs_start + lhs.offset(), len);
+ let rhs_null_count = count_nulls(rhs.null_buffer(), rhs_start + rhs.offset(), len);
+
+ if lhs_null_count != rhs_null_count {
+ return false;
+ }
+
if lhs_null_count > 0 || rhs_null_count > 0 {
- let lhs_values = lhs_nulls.unwrap().as_slice();
- let rhs_values = rhs_nulls.unwrap().as_slice();
+ let lhs_values = lhs.null_buffer().unwrap().as_slice();
+ let rhs_values = rhs.null_buffer().unwrap().as_slice();
equal_bits(
lhs_values,
rhs_values,
@@ -111,227 +112,3 @@ pub(super) fn equal_len(
) -> bool {
lhs_values[lhs_start..(lhs_start + len)] == rhs_values[rhs_start..(rhs_start + len)]
}
-
-/// Computes the logical validity bitmap of the array data using the
-/// parent's array data. The parent should be a list or struct, else
-/// the logical bitmap of the array is returned unaltered.
-///
-/// Parent data is passed along with the parent's logical bitmap, as
-/// nested arrays could have a logical bitmap different to the physical
-/// one on the `ArrayData`.
-pub(super) fn child_logical_null_buffer(
- parent_data: &ArrayData,
- logical_null_buffer: Option<&Buffer>,
- child_data: &ArrayData,
-) -> Option<Buffer> {
- let parent_len = parent_data.len();
- let parent_bitmap = logical_null_buffer
- .cloned()
- .map(Bitmap::from)
- .unwrap_or_else(|| {
- let ceil = bit_util::ceil(parent_len, 8);
- Bitmap::from(Buffer::from(vec![0b11111111; ceil]))
- });
- let self_null_bitmap = child_data.null_bitmap().cloned().unwrap_or_else(|| {
- let ceil = bit_util::ceil(child_data.len(), 8);
- Bitmap::from(Buffer::from(vec![0b11111111; ceil]))
- });
- match parent_data.data_type() {
- DataType::List(_) | DataType::Map(_, _) => Some(logical_list_bitmap::<i32>(
- parent_data,
- parent_bitmap,
- self_null_bitmap,
- )),
- DataType::LargeList(_) => Some(logical_list_bitmap::<i64>(
- parent_data,
- parent_bitmap,
- self_null_bitmap,
- )),
- DataType::FixedSizeList(_, len) => {
- let len = *len as usize;
- let array_offset = parent_data.offset();
- let bitmap_len = bit_util::ceil(parent_len * len, 8);
- let mut buffer = MutableBuffer::from_len_zeroed(bitmap_len);
- let null_slice = buffer.as_slice_mut();
- (array_offset..parent_len + array_offset).for_each(|index| {
- let start = index * len;
- let end = start + len;
- let mask = parent_bitmap.is_set(index);
- (start..end).for_each(|child_index| {
- if mask && self_null_bitmap.is_set(child_index) {
- bit_util::set_bit(null_slice, child_index);
- }
- });
- });
- Some(buffer.into())
- }
- DataType::Struct(_) => {
- // Arrow implementations are free to pad data, which can result in null buffers not
- // having the same length.
- // Rust bitwise comparisons will return an error if left AND right is performed on
- // buffers of different length.
- // This might be a valid case during integration testing, where we read Arrow arrays
- // from IPC data, which has padding.
- //
- // We first perform a bitwise comparison, and if there is an error, we revert to a
- // slower method that indexes into the buffers one-by-one.
- let result = &parent_bitmap & &self_null_bitmap;
- if let Ok(bitmap) = result {
- return Some(bitmap.bits);
- }
- // slow path
- let array_offset = parent_data.offset();
- let mut buffer = MutableBuffer::new_null(parent_len);
- let null_slice = buffer.as_slice_mut();
- (0..parent_len).for_each(|index| {
- if parent_bitmap.is_set(index + array_offset)
- && self_null_bitmap.is_set(index + array_offset)
- {
- bit_util::set_bit(null_slice, index);
- }
- });
- Some(buffer.into())
- }
- DataType::Union(_, mode) => union_child_logical_null_buffer(
- parent_data,
- parent_len,
- &parent_bitmap,
- &self_null_bitmap,
- mode,
- ),
- DataType::Dictionary(_, _) => {
- unimplemented!("Logical equality not yet implemented for nested dictionaries")
- }
- data_type => panic!("Data type {:?} is not a supported nested type", data_type),
- }
-}
-
-pub(super) fn union_child_logical_null_buffer(
- parent_data: &ArrayData,
- parent_len: usize,
- parent_bitmap: &Bitmap,
- self_null_bitmap: &Bitmap,
- mode: &UnionMode,
-) -> Option<Buffer> {
- match mode {
- UnionMode::Sparse => {
- // See the logic of `DataType::Struct` in `child_logical_null_buffer`.
- let result = parent_bitmap & self_null_bitmap;
- if let Ok(bitmap) = result {
- return Some(bitmap.bits);
- }
-
- // slow path
- let array_offset = parent_data.offset();
- let mut buffer = MutableBuffer::new_null(parent_len);
- let null_slice = buffer.as_slice_mut();
- (0..parent_len).for_each(|index| {
- if parent_bitmap.is_set(index + array_offset)
- && self_null_bitmap.is_set(index + array_offset)
- {
- bit_util::set_bit(null_slice, index);
- }
- });
- Some(buffer.into())
- }
- UnionMode::Dense => {
- // We don't keep bitmap in child data of Dense UnionArray
- unimplemented!("Logical equality not yet implemented for dense union arrays")
- }
- }
-}
-
-// Calculate a list child's logical bitmap/buffer
-#[inline]
-fn logical_list_bitmap<OffsetSize: OffsetSizeTrait>(
- parent_data: &ArrayData,
- parent_bitmap: Bitmap,
- child_bitmap: Bitmap,
-) -> Buffer {
- let offsets = parent_data.buffer::<OffsetSize>(0);
- let offset_start = offsets.first().unwrap().to_usize().unwrap();
- let offset_len = offsets.get(parent_data.len()).unwrap().to_usize().unwrap();
- let mut buffer = MutableBuffer::new_null(offset_len - offset_start);
- let null_slice = buffer.as_slice_mut();
-
- offsets
- .windows(2)
- .enumerate()
- .take(parent_data.len())
- .for_each(|(index, window)| {
- let start = window[0].to_usize().unwrap();
- let end = window[1].to_usize().unwrap();
- let mask = parent_bitmap.is_set(index);
- (start..end).for_each(|child_index| {
- if mask && child_bitmap.is_set(child_index) {
- bit_util::set_bit(null_slice, child_index - offset_start);
- }
- });
- });
- buffer.into()
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- use crate::datatypes::{Field, ToByteSlice};
-
- #[test]
- fn test_logical_null_buffer() {
- let child_data = ArrayData::builder(DataType::Int32)
- .len(11)
- .add_buffer(Buffer::from(
- vec![1i32, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11].to_byte_slice(),
- ))
- .build()
- .unwrap();
-
- let data = ArrayData::builder(DataType::List(Box::new(Field::new(
- "item",
- DataType::Int32,
- false,
- ))))
- .len(7)
- .add_buffer(Buffer::from(vec![0, 0, 3, 5, 6, 9, 10, 11].to_byte_slice()))
- .null_bit_buffer(Buffer::from(vec![0b01011010]))
- .add_child_data(child_data.clone())
- .build()
- .unwrap();
-
- // Get the child logical null buffer. The child is non-nullable, but because the list has nulls,
- // we expect the child to logically have some nulls, inherited from the parent:
- // [1, 2, 3, null, null, 6, 7, 8, 9, null, 11]
- let nulls = child_logical_null_buffer(
- &data,
- data.null_buffer(),
- data.child_data().get(0).unwrap(),
- );
- let expected = Some(Buffer::from(vec![0b11100111, 0b00000101]));
- assert_eq!(nulls, expected);
-
- // test with offset
- let data = ArrayData::builder(DataType::List(Box::new(Field::new(
- "item",
- DataType::Int32,
- false,
- ))))
- .len(4)
- .offset(3)
- .add_buffer(Buffer::from(vec![0, 0, 3, 5, 6, 9, 10, 11].to_byte_slice()))
- // the null_bit_buffer doesn't have an offset, i.e. cleared the 3 offset bits 0b[---]01011[010]
- .null_bit_buffer(Buffer::from(vec![0b00001011]))
- .add_child_data(child_data)
- .build()
- .unwrap();
-
- let nulls = child_logical_null_buffer(
- &data,
- data.null_buffer(),
- data.child_data().get(0).unwrap(),
- );
-
- let expected = Some(Buffer::from(vec![0b00101111]));
- assert_eq!(nulls, expected);
- }
-}
diff --git a/arrow/src/array/equal/variable_size.rs b/arrow/src/array/equal/variable_size.rs
index 946f107f3..f40f79e40 100644
--- a/arrow/src/array/equal/variable_size.rs
+++ b/arrow/src/array/equal/variable_size.rs
@@ -15,7 +15,6 @@
// specific language governing permissions and limitations
// under the License.
-use crate::buffer::Buffer;
use crate::util::bit_util::get_bit;
use crate::{
array::data::count_nulls,
@@ -51,8 +50,6 @@ fn offset_value_equal<T: OffsetSizeTrait>(
pub(super) fn variable_sized_equal<T: OffsetSizeTrait>(
lhs: &ArrayData,
rhs: &ArrayData,
- lhs_nulls: Option<&Buffer>,
- rhs_nulls: Option<&Buffer>,
lhs_start: usize,
rhs_start: usize,
len: usize,
@@ -64,8 +61,8 @@ pub(super) fn variable_sized_equal<T: OffsetSizeTrait>(
let lhs_values = lhs.buffers()[1].as_slice();
let rhs_values = rhs.buffers()[1].as_slice();
- let lhs_null_count = count_nulls(lhs_nulls, lhs_start, len);
- let rhs_null_count = count_nulls(rhs_nulls, rhs_start, len);
+ let lhs_null_count = count_nulls(lhs.null_buffer(), lhs_start + lhs.offset(), len);
+ let rhs_null_count = count_nulls(rhs.null_buffer(), rhs_start + rhs.offset(), len);
if lhs_null_count == 0
&& rhs_null_count == 0
@@ -87,10 +84,13 @@ pub(super) fn variable_sized_equal<T: OffsetSizeTrait>(
let rhs_pos = rhs_start + i;
// the null bits can still be `None`, indicating that the value is valid.
- let lhs_is_null = !lhs_nulls
+ let lhs_is_null = !lhs
+ .null_buffer()
.map(|v| get_bit(v.as_slice(), lhs.offset() + lhs_pos))
.unwrap_or(true);
- let rhs_is_null = !rhs_nulls
+
+ let rhs_is_null = !rhs
+ .null_buffer()
.map(|v| get_bit(v.as_slice(), rhs.offset() + rhs_pos))
.unwrap_or(true);
diff --git a/arrow/src/array/transform/union.rs b/arrow/src/array/transform/union.rs
index ec672daf4..bbea50821 100644
--- a/arrow/src/array/transform/union.rs
+++ b/arrow/src/array/transform/union.rs
@@ -22,140 +22,50 @@ use super::{Extend, _MutableArrayData};
pub(super) fn build_extend_sparse(array: &ArrayData) -> Extend {
let type_ids = array.buffer::<i8>(0);
- if array.null_count() == 0 {
- Box::new(
- move |mutable: &mut _MutableArrayData,
- index: usize,
- start: usize,
- len: usize| {
- // extends type_ids
- mutable
- .buffer1
- .extend_from_slice(&type_ids[start..start + len]);
-
- mutable
- .child_data
- .iter_mut()
- .for_each(|child| child.extend(index, start, start + len))
- },
- )
- } else {
- Box::new(
- move |mutable: &mut _MutableArrayData,
- index: usize,
- start: usize,
- len: usize| {
- // extends type_ids
- mutable
- .buffer1
- .extend_from_slice(&type_ids[start..start + len]);
+ Box::new(
+ move |mutable: &mut _MutableArrayData, index: usize, start: usize, len: usize| {
+ // extends type_ids
+ mutable
+ .buffer1
+ .extend_from_slice(&type_ids[start..start + len]);
- (start..start + len).for_each(|i| {
- if array.is_valid(i) {
- mutable
- .child_data
- .iter_mut()
- .for_each(|child| child.extend(index, i, i + 1))
- } else {
- mutable
- .child_data
- .iter_mut()
- .for_each(|child| child.extend_nulls(1))
- }
- })
- },
- )
- }
+ mutable
+ .child_data
+ .iter_mut()
+ .for_each(|child| child.extend(index, start, start + len))
+ },
+ )
}
pub(super) fn build_extend_dense(array: &ArrayData) -> Extend {
let type_ids = array.buffer::<i8>(0);
let offsets = array.buffer::<i32>(1);
- if array.null_count() == 0 {
- Box::new(
- move |mutable: &mut _MutableArrayData,
- index: usize,
- start: usize,
- len: usize| {
- // extends type_ids
- mutable
- .buffer1
- .extend_from_slice(&type_ids[start..start + len]);
- // extends offsets
- mutable
- .buffer2
- .extend_from_slice(&offsets[start..start + len]);
-
- (start..start + len).for_each(|i| {
- let type_id = type_ids[i] as usize;
- let offset_start = offsets[start] as usize;
-
- mutable.child_data[type_id].extend(
- index,
- offset_start,
- offset_start + 1,
- )
- })
- },
- )
- } else {
- Box::new(
- move |mutable: &mut _MutableArrayData,
- index: usize,
- start: usize,
- len: usize| {
- // extends type_ids
- mutable
- .buffer1
- .extend_from_slice(&type_ids[start..start + len]);
- // extends offsets
- mutable
- .buffer2
- .extend_from_slice(&offsets[start..start + len]);
+ Box::new(
+ move |mutable: &mut _MutableArrayData, index: usize, start: usize, len: usize| {
+ // extends type_ids
+ mutable
+ .buffer1
+ .extend_from_slice(&type_ids[start..start + len]);
- (start..start + len).for_each(|i| {
- let type_id = type_ids[i] as usize;
- let offset_start = offsets[start] as usize;
+ (start..start + len).for_each(|i| {
+ let type_id = type_ids[i] as usize;
+ let src_offset = offsets[i] as usize;
+ let child_data = &mut mutable.child_data[type_id];
+ let dst_offset = child_data.len();
- if array.is_valid(i) {
- mutable.child_data[type_id].extend(
- index,
- offset_start,
- offset_start + 1,
- )
- } else {
- mutable.child_data[type_id].extend_nulls(1)
- }
- })
- },
- )
- }
+ // Extend offsets
+ mutable.buffer2.push(dst_offset as i32);
+ mutable.child_data[type_id].extend(index, src_offset, src_offset + 1)
+ })
+ },
+ )
}
-pub(super) fn extend_nulls_dense(mutable: &mut _MutableArrayData, len: usize) {
- let mut count: usize = 0;
- let num = len / mutable.child_data.len();
- mutable
- .child_data
- .iter_mut()
- .enumerate()
- .for_each(|(idx, child)| {
- let n = if count + num > len { len - count } else { num };
- count += n;
- mutable
- .buffer1
- .extend_from_slice(vec![idx as i8; n].as_slice());
- mutable
- .buffer2
- .extend_from_slice(vec![child.len() as i32; n].as_slice());
- child.extend_nulls(n)
- })
+pub(super) fn extend_nulls_dense(_mutable: &mut _MutableArrayData, _len: usize) {
+ panic!("cannot call extend_nulls on UnionArray as cannot infer type");
}
-pub(super) fn extend_nulls_sparse(mutable: &mut _MutableArrayData, len: usize) {
- mutable
- .child_data
- .iter_mut()
- .for_each(|child| child.extend_nulls(len))
+pub(super) fn extend_nulls_sparse(_mutable: &mut _MutableArrayData, _len: usize) {
+ panic!("cannot call extend_nulls on UnionArray as cannot infer type");
}
diff --git a/arrow/src/compute/kernels/filter.rs b/arrow/src/compute/kernels/filter.rs
index df59ba63c..b4abcd5a4 100644
--- a/arrow/src/compute/kernels/filter.rs
+++ b/arrow/src/compute/kernels/filter.rs
@@ -1670,22 +1670,53 @@ mod tests {
test_filter_union_array(array);
}
+ #[test]
+ fn test_filter_run_union_array_dense() {
+ let mut builder = UnionBuilder::new_dense(3);
+ builder.append::<Int32Type>("A", 1).unwrap();
+ builder.append::<Int32Type>("A", 3).unwrap();
+ builder.append::<Int32Type>("A", 34).unwrap();
+ let array = builder.build().unwrap();
+
+ let filter_array = BooleanArray::from(vec![true, true, false]);
+ let c = filter(&array, &filter_array).unwrap();
+ let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
+
+ let mut builder = UnionBuilder::new_dense(3);
+ builder.append::<Int32Type>("A", 1).unwrap();
+ builder.append::<Int32Type>("A", 3).unwrap();
+ let expected = builder.build().unwrap();
+
+ assert_eq!(filtered.data(), expected.data());
+ }
+
#[test]
fn test_filter_union_array_dense_with_nulls() {
let mut builder = UnionBuilder::new_dense(4);
builder.append::<Int32Type>("A", 1).unwrap();
builder.append::<Float64Type>("B", 3.2).unwrap();
- builder.append_null().unwrap();
+ builder.append_null::<Float64Type>("B").unwrap();
builder.append::<Int32Type>("A", 34).unwrap();
let array = builder.build().unwrap();
+ let filter_array = BooleanArray::from(vec![true, true, false, false]);
+ let c = filter(&array, &filter_array).unwrap();
+ let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
+
+ let mut builder = UnionBuilder::new_dense(2);
+ builder.append::<Int32Type>("A", 1).unwrap();
+ builder.append::<Float64Type>("B", 3.2).unwrap();
+ let expected_array = builder.build().unwrap();
+
+ compare_union_arrays(filtered, &expected_array);
+
let filter_array = BooleanArray::from(vec![true, false, true, false]);
let c = filter(&array, &filter_array).unwrap();
let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
- let mut builder = UnionBuilder::new_dense(1);
+ let mut builder = UnionBuilder::new_dense(2);
builder.append::<Int32Type>("A", 1).unwrap();
- builder.append_null().unwrap();
+ builder.append_null::<Float64Type>("B").unwrap();
let expected_array = builder.build().unwrap();
compare_union_arrays(filtered, &expected_array);
@@ -1707,7 +1738,7 @@ mod tests {
let mut builder = UnionBuilder::new_sparse(4);
builder.append::<Int32Type>("A", 1).unwrap();
builder.append::<Float64Type>("B", 3.2).unwrap();
- builder.append_null().unwrap();
+ builder.append_null::<Float64Type>("B").unwrap();
builder.append::<Int32Type>("A", 34).unwrap();
let array = builder.build().unwrap();
@@ -1715,9 +1746,9 @@ mod tests {
let c = filter(&array, &filter_array).unwrap();
let filtered = c.as_any().downcast_ref::<UnionArray>().unwrap();
- let mut builder = UnionBuilder::new_dense(1);
+ let mut builder = UnionBuilder::new_sparse(2);
builder.append::<Int32Type>("A", 1).unwrap();
- builder.append_null().unwrap();
+ builder.append_null::<Float64Type>("B").unwrap();
let expected_array = builder.build().unwrap();
compare_union_arrays(filtered, &expected_array);
@@ -1732,9 +1763,9 @@ mod tests {
let slot1 = union1.value(i);
let slot2 = union2.value(i);
- assert_eq!(union1.is_null(i), union2.is_null(i));
+ assert_eq!(slot1.is_null(0), slot2.is_null(0));
- if !union1.is_null(i) && !union2.is_null(i) {
+ if !slot1.is_null(0) && !slot2.is_null(0) {
match type_id {
0 => {
let slot1 = slot1.as_any().downcast_ref::<Int32Array>().unwrap();
diff --git a/arrow/src/ipc/reader.rs b/arrow/src/ipc/reader.rs
index 8a26167db..f3e46e27f 100644
--- a/arrow/src/ipc/reader.rs
+++ b/arrow/src/ipc/reader.rs
@@ -190,11 +190,10 @@ fn create_array(
let len = union_node.length() as usize;
- let null_buffer: Buffer = read_buffer(&buffers[buffer_index], data);
let type_ids: Buffer =
- read_buffer(&buffers[buffer_index + 1], data)[..len].into();
+ read_buffer(&buffers[buffer_index], data)[..len].into();
- buffer_index += 2;
+ buffer_index += 1;
let value_offsets = match mode {
UnionMode::Dense => {
@@ -224,13 +223,7 @@ fn create_array(
children.push((field.clone(), triple.0));
}
- let array = UnionArray::try_new(
- type_ids,
- value_offsets,
- children,
- Some(null_buffer),
- )?;
-
+ let array = UnionArray::try_new(type_ids, value_offsets, children)?;
Arc::new(array)
}
Null => {
@@ -1359,7 +1352,7 @@ mod tests {
fn check_union_with_builder(mut builder: UnionBuilder) {
builder.append::<datatypes::Int32Type>("a", 1).unwrap();
- builder.append_null().unwrap();
+ builder.append_null::<datatypes::Int32Type>("a").unwrap();
builder.append::<datatypes::Float64Type>("c", 3.0).unwrap();
builder.append::<datatypes::Int32Type>("a", 4).unwrap();
builder.append::<datatypes::Int64Type>("d", 11).unwrap();
diff --git a/arrow/src/ipc/writer.rs b/arrow/src/ipc/writer.rs
index efc878a12..c03d5e449 100644
--- a/arrow/src/ipc/writer.rs
+++ b/arrow/src/ipc/writer.rs
@@ -862,7 +862,11 @@ fn write_array_data(
let mut offset = offset;
nodes.push(ipc::FieldNode::new(num_rows as i64, null_count as i64));
// NullArray does not have any buffers, thus the null buffer is not generated
- if array_data.data_type() != &DataType::Null {
+ // UnionArray does not have a validity buffer
+ if !matches!(
+ array_data.data_type(),
+ DataType::Null | DataType::Union(_, _)
+ ) {
// write null buffer if exists
let null_buffer = match array_data.null_buffer() {
None => {
@@ -1324,8 +1328,7 @@ mod tests {
let offsets = Buffer::from_slice_ref(&[0_i32, 1, 2]);
let union =
- UnionArray::try_new(types, Some(offsets), vec![(dctfield, array)], None)
- .unwrap();
+ UnionArray::try_new(types, Some(offsets), vec![(dctfield, array)]).unwrap();
let schema = Arc::new(Schema::new(vec![Field::new(
"union",