You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by tu...@apache.org on 2022/10/13 20:08:28 UTC

[arrow-rs] branch master updated: Validate ArrayData type when converting to Array (#2834) (#2835)

This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 8adebca35 Validate ArrayData type when converting to Array (#2834) (#2835)
8adebca35 is described below

commit 8adebca35253943fffb0653e7521eaf7a25b0153
Author: Raphael Taylor-Davies <17...@users.noreply.github.com>
AuthorDate: Thu Oct 13 21:08:22 2022 +0100

    Validate ArrayData type when converting to Array (#2834) (#2835)
    
    * Validate ArrayData type when converting to Array (#2834)
    
    * Fix cast kernel and take kernel tests
    
    * Clippy
    
    * Fix parquet
    
    * Clippy
---
 arrow-array/src/array/binary_array.rs             |  9 ++++++++
 arrow-array/src/array/boolean_array.rs            | 17 ++++++++++++++
 arrow-array/src/array/decimal_array.rs            | 25 +++++++++++++++++----
 arrow-array/src/array/dictionary_array.rs         | 22 +++++++++++++++---
 arrow-array/src/array/list_array.rs               | 15 +++++++++++++
 arrow-array/src/array/map_array.rs                | 23 +++++++++++++++++++
 arrow-array/src/array/primitive_array.rs          | 19 ++++++++++++++++
 arrow/src/compute/kernels/cast.rs                 | 11 ++++-----
 arrow/src/compute/kernels/take.rs                 |  4 ++--
 parquet/src/arrow/array_reader/primitive_array.rs | 27 ++++++++++++++---------
 10 files changed, 148 insertions(+), 24 deletions(-)

diff --git a/arrow-array/src/array/binary_array.rs b/arrow-array/src/array/binary_array.rs
index 851fb60c0..c8407b252 100644
--- a/arrow-array/src/array/binary_array.rs
+++ b/arrow-array/src/array/binary_array.rs
@@ -297,6 +297,8 @@ impl<OffsetSize: OffsetSizeTrait> From<ArrayData> for GenericBinaryArray<OffsetS
         let values = data.buffers()[1].as_ptr();
         Self {
             data,
+            // SAFETY:
+            // ArrayData must be valid, and validated data type above
             value_offsets: unsafe { RawPtrBox::new(offsets) },
             value_data: unsafe { RawPtrBox::new(values) },
         }
@@ -833,6 +835,13 @@ mod tests {
         binary_array.value(4);
     }
 
+    #[test]
+    #[should_panic(expected = "[Large]BinaryArray expects Datatype::[Large]Binary")]
+    fn test_binary_array_validation() {
+        let array = BinaryArray::from_iter_values(&[&[1, 2]]);
+        let _ = LargeBinaryArray::from(array.into_data());
+    }
+
     #[test]
     fn test_binary_array_all_null() {
         let data = vec![None];
diff --git a/arrow-array/src/array/boolean_array.rs b/arrow-array/src/array/boolean_array.rs
index 24be122c9..c7a44c7d5 100644
--- a/arrow-array/src/array/boolean_array.rs
+++ b/arrow-array/src/array/boolean_array.rs
@@ -201,6 +201,13 @@ impl From<Vec<Option<bool>>> for BooleanArray {
 
 impl From<ArrayData> for BooleanArray {
     fn from(data: ArrayData) -> Self {
+        assert_eq!(
+            data.data_type(),
+            &DataType::Boolean,
+            "BooleanArray expected ArrayData with type {} got {}",
+            DataType::Boolean,
+            data.data_type()
+        );
         assert_eq!(
             data.buffers().len(),
             1,
@@ -209,6 +216,8 @@ impl From<ArrayData> for BooleanArray {
         let ptr = data.buffers()[0].as_ptr();
         Self {
             data,
+            // SAFETY:
+            // ArrayData must be valid, and validated data type above
             raw_values: unsafe { RawPtrBox::new(ptr) },
         }
     }
@@ -414,4 +423,12 @@ mod tests {
         };
         drop(BooleanArray::from(data));
     }
+
+    #[test]
+    #[should_panic(
+        expected = "BooleanArray expected ArrayData with type Boolean got Int32"
+    )]
+    fn test_from_array_data_validation() {
+        let _ = BooleanArray::from(ArrayData::new_empty(&DataType::Int32));
+    }
 }
diff --git a/arrow-array/src/array/decimal_array.rs b/arrow-array/src/array/decimal_array.rs
index 34b424092..5ca9b0715 100644
--- a/arrow-array/src/array/decimal_array.rs
+++ b/arrow-array/src/array/decimal_array.rs
@@ -407,13 +407,21 @@ impl<T: DecimalType> From<ArrayData> for DecimalArray<T> {
             "DecimalArray data should contain 1 buffer only (values)"
         );
         let values = data.buffers()[0].as_ptr();
-        let (precision, scale) = match (data.data_type(), Self::VALUE_LENGTH) {
-            (DataType::Decimal128(precision, scale), 16)
-            | (DataType::Decimal256(precision, scale), 32) => (*precision, *scale),
-            _ => panic!("Expected data type to be Decimal"),
+        let (precision, scale) = match (data.data_type(), Self::DEFAULT_TYPE) {
+            (DataType::Decimal128(precision, scale), DataType::Decimal128(_, _))
+            | (DataType::Decimal256(precision, scale), DataType::Decimal256(_, _)) => {
+                (*precision, *scale)
+            }
+            _ => panic!(
+                "Expected data type to match {} got {}",
+                Self::DEFAULT_TYPE,
+                data.data_type()
+            ),
         };
         Self {
             data,
+            // SAFETY:
+            // ArrayData must be valid, and verified data type above
             value_data: unsafe { RawPtrBox::new(values) },
             precision,
             scale,
@@ -977,4 +985,13 @@ mod tests {
 
         array.value(4);
     }
+
+    #[test]
+    #[should_panic(
+        expected = "Expected data type to match Decimal256(76, 10) got Decimal128(38, 10)"
+    )]
+    fn test_from_array_data_validation() {
+        let array = Decimal128Array::from_iter_values(vec![-100, 0, 101].into_iter());
+        let _ = Decimal256Array::from(array.into_data());
+    }
 }
diff --git a/arrow-array/src/array/dictionary_array.rs b/arrow-array/src/array/dictionary_array.rs
index 96e91f729..002ee6f47 100644
--- a/arrow-array/src/array/dictionary_array.rs
+++ b/arrow-array/src/array/dictionary_array.rs
@@ -408,10 +408,17 @@ impl<T: ArrowPrimitiveType> From<ArrayData> for DictionaryArray<T> {
         );
 
         if let DataType::Dictionary(key_data_type, _) = data.data_type() {
-            if key_data_type.as_ref() != &T::DATA_TYPE {
-                panic!("DictionaryArray's data type must match.")
-            };
+            assert_eq!(
+                &T::DATA_TYPE,
+                key_data_type.as_ref(),
+                "DictionaryArray's data type must match, expected {} got {}",
+                T::DATA_TYPE,
+                key_data_type
+            );
+
             // create a zero-copy of the keys' data
+            // SAFETY:
+            // ArrayData is valid and verified type above
             let keys = PrimitiveArray::<T>::from(unsafe {
                 ArrayData::new_unchecked(
                     T::DATA_TYPE,
@@ -925,4 +932,13 @@ mod tests {
         let keys: Float32Array = [Some(0_f32), None, Some(3_f32)].into_iter().collect();
         DictionaryArray::<Float32Type>::try_new(&keys, &values).unwrap();
     }
+
+    #[test]
+    #[should_panic(
+        expected = "DictionaryArray's data type must match, expected Int64 got Int32"
+    )]
+    fn test_from_array_data_validation() {
+        let a = DictionaryArray::<Int32Type>::from_iter(["32"]);
+        let _ = DictionaryArray::<Int64Type>::from(a.into_data());
+    }
 }
diff --git a/arrow-array/src/array/list_array.rs b/arrow-array/src/array/list_array.rs
index 3022db023..cdc7531d9 100644
--- a/arrow-array/src/array/list_array.rs
+++ b/arrow-array/src/array/list_array.rs
@@ -257,6 +257,8 @@ impl<OffsetSize: OffsetSizeTrait> GenericListArray<OffsetSize> {
             false => data.buffers()[0].as_ptr(),
         };
 
+        // SAFETY:
+        // Verified list type in call to `Self::get_type`
         let value_offsets = unsafe { RawPtrBox::new(offsets) };
         Ok(Self {
             data,
@@ -362,6 +364,7 @@ pub type LargeListArray = GenericListArray<i64>;
 #[cfg(test)]
 mod tests {
     use super::*;
+    use crate::builder::{Int32Builder, ListBuilder};
     use crate::types::Int32Type;
     use crate::Int32Array;
     use arrow_buffer::{bit_util, Buffer, ToByteSlice};
@@ -820,6 +823,18 @@ mod tests {
         drop(ListArray::from(list_data));
     }
 
+    #[test]
+    #[should_panic(
+        expected = "[Large]ListArray's datatype must be [Large]ListArray(). It is List"
+    )]
+    fn test_from_array_data_validation() {
+        let mut builder = ListBuilder::new(Int32Builder::new());
+        builder.values().append_value(1);
+        builder.append(true);
+        let array = builder.finish();
+        let _ = LargeListArray::from(array.into_data());
+    }
+
     #[test]
     fn test_list_array_offsets_need_not_start_at_zero() {
         let value_data = ArrayData::builder(DataType::Int32)
diff --git a/arrow-array/src/array/map_array.rs b/arrow-array/src/array/map_array.rs
index bfe8d4072..0f3ae2e68 100644
--- a/arrow-array/src/array/map_array.rs
+++ b/arrow-array/src/array/map_array.rs
@@ -109,6 +109,12 @@ impl From<MapArray> for ArrayData {
 
 impl MapArray {
     fn try_new_from_array_data(data: ArrayData) -> Result<Self, ArrowError> {
+        assert!(
+            matches!(data.data_type(), DataType::Map(_, _)),
+            "MapArray expected ArrayData with DataType::Map got {}",
+            data.data_type()
+        );
+
         if data.buffers().len() != 1 {
             return Err(ArrowError::InvalidArgumentError(
                 format!("MapArray data should contain a single buffer only (value offsets), had {}",
@@ -141,6 +147,8 @@ impl MapArray {
         let values = make_array(entries);
         let value_offsets = data.buffers()[0].as_ptr();
 
+        // SAFETY:
+        // ArrayData is valid, and verified type above
         let value_offsets = unsafe { RawPtrBox::<i32>::new(value_offsets) };
         unsafe {
             if (*value_offsets.as_ptr().offset(0)) != 0 {
@@ -467,6 +475,21 @@ mod tests {
         map_array.value(map_array.len());
     }
 
+    #[test]
+    #[should_panic(
+        expected = "MapArray expected ArrayData with DataType::Map got Dictionary"
+    )]
+    fn test_from_array_data_validation() {
+        // A DictionaryArray has similar buffer layout to a MapArray
+        // but the meaning of the values differs
+        let struct_t = DataType::Struct(vec![
+            Field::new("keys", DataType::Int32, true),
+            Field::new("values", DataType::UInt32, true),
+        ]);
+        let dict_t = DataType::Dictionary(Box::new(DataType::Int32), Box::new(struct_t));
+        let _ = MapArray::from(ArrayData::new_empty(&dict_t));
+    }
+
     #[test]
     fn test_new_from_strings() {
         let keys = vec!["a", "b", "c", "d", "e", "f", "g", "h"];
diff --git a/arrow-array/src/array/primitive_array.rs b/arrow-array/src/array/primitive_array.rs
index 928135463..895c80b07 100644
--- a/arrow-array/src/array/primitive_array.rs
+++ b/arrow-array/src/array/primitive_array.rs
@@ -818,6 +818,14 @@ impl<T: ArrowTimestampType> PrimitiveArray<T> {
 /// Constructs a `PrimitiveArray` from an array data reference.
 impl<T: ArrowPrimitiveType> From<ArrayData> for PrimitiveArray<T> {
     fn from(data: ArrayData) -> Self {
+        // Use discriminant to allow for decimals
+        assert_eq!(
+            std::mem::discriminant(&T::DATA_TYPE),
+            std::mem::discriminant(data.data_type()),
+            "PrimitiveArray expected ArrayData with type {} got {}",
+            T::DATA_TYPE,
+            data.data_type()
+        );
         assert_eq!(
             data.buffers().len(),
             1,
@@ -827,6 +835,8 @@ impl<T: ArrowPrimitiveType> From<ArrayData> for PrimitiveArray<T> {
         let ptr = data.buffers()[0].as_ptr();
         Self {
             data,
+            // SAFETY:
+            // ArrayData must be valid, and validated data type above
             raw_values: unsafe { RawPtrBox::new(ptr) },
         }
     }
@@ -1352,6 +1362,15 @@ mod tests {
         array.value(4);
     }
 
+    #[test]
+    #[should_panic(
+        expected = "PrimitiveArray expected ArrayData with type Int64 got Int32"
+    )]
+    fn test_from_array_data_validation() {
+        let foo = PrimitiveArray::<Int32Type>::from_iter([1, 2, 3]);
+        let _ = PrimitiveArray::<Int64Type>::from(foo.into_data());
+    }
+
     #[test]
     fn test_decimal128() {
         let values: Vec<_> = vec![0, 1, -1, i128::MIN, i128::MAX];
diff --git a/arrow/src/compute/kernels/cast.rs b/arrow/src/compute/kernels/cast.rs
index b573c65d0..49a9b18d8 100644
--- a/arrow/src/compute/kernels/cast.rs
+++ b/arrow/src/compute/kernels/cast.rs
@@ -1312,15 +1312,16 @@ pub fn cast_with_options(
         )),
 
         (Timestamp(from_unit, _), Timestamp(to_unit, to_tz)) => {
-            let time_array = Int64Array::from(array.data().clone());
+            let array = cast_with_options(array, &Int64, cast_options)?;
+            let time_array = as_primitive_array::<Int64Type>(array.as_ref());
             let from_size = time_unit_multiple(from_unit);
             let to_size = time_unit_multiple(to_unit);
             // we either divide or multiply, depending on size of each unit
             // units are never the same when the types are the same
             let converted = if from_size >= to_size {
-                divide_scalar(&time_array, from_size / to_size)?
+                divide_scalar(time_array, from_size / to_size)?
             } else {
-                multiply_scalar(&time_array, to_size / from_size)?
+                multiply_scalar(time_array, to_size / from_size)?
             };
             Ok(make_timestamp_array(
                 &converted,
@@ -1329,10 +1330,10 @@ pub fn cast_with_options(
             ))
         }
         (Timestamp(from_unit, _), Date32) => {
-            let time_array = Int64Array::from(array.data().clone());
+            let array = cast_with_options(array, &Int64, cast_options)?;
+            let time_array = as_primitive_array::<Int64Type>(array.as_ref());
             let from_size = time_unit_multiple(from_unit) * SECONDS_IN_DAY;
 
-            // Int32Array::from_iter(tim.iter)
             let mut b = Date32Builder::with_capacity(array.len());
 
             for i in 0..array.len() {
diff --git a/arrow/src/compute/kernels/take.rs b/arrow/src/compute/kernels/take.rs
index 1aa4473c0..b9cfae516 100644
--- a/arrow/src/compute/kernels/take.rs
+++ b/arrow/src/compute/kernels/take.rs
@@ -1398,7 +1398,7 @@ mod tests {
     fn test_take_bool_nullable_index() {
         // indices where the masked invalid elements would be out of bounds
         let index_data = ArrayData::try_new(
-            DataType::Int32,
+            DataType::UInt32,
             6,
             Some(Buffer::from_iter(vec![
                 false, true, false, true, false, true,
@@ -1421,7 +1421,7 @@ mod tests {
     fn test_take_bool_nullable_index_nonnull_values() {
         // indices where the masked invalid elements would be out of bounds
         let index_data = ArrayData::try_new(
-            DataType::Int32,
+            DataType::UInt32,
             6,
             Some(Buffer::from_iter(vec![
                 false, true, false, true, false, true,
diff --git a/parquet/src/arrow/array_reader/primitive_array.rs b/parquet/src/arrow/array_reader/primitive_array.rs
index d4f96e6a8..5fc5e639d 100644
--- a/parquet/src/arrow/array_reader/primitive_array.rs
+++ b/parquet/src/arrow/array_reader/primitive_array.rs
@@ -26,7 +26,8 @@ use crate::errors::{ParquetError, Result};
 use crate::schema::types::ColumnDescPtr;
 use arrow::array::{
     ArrayDataBuilder, ArrayRef, BooleanArray, BooleanBufferBuilder, Decimal128Array,
-    Float32Array, Float64Array, Int32Array, Int64Array,TimestampNanosecondArray, TimestampNanosecondBufferBuilder,
+    Float32Array, Float64Array, Int32Array, Int64Array, TimestampNanosecondArray,
+    TimestampNanosecondBufferBuilder, UInt32Array, UInt64Array,
 };
 use arrow::buffer::Buffer;
 use arrow::datatypes::{DataType as ArrowType, TimeUnit};
@@ -169,15 +170,21 @@ where
             .null_bit_buffer(self.record_reader.consume_bitmap_buffer());
 
         let array_data = unsafe { array_data.build_unchecked() };
-        let array = match T::get_physical_type() {
-            PhysicalType::BOOLEAN => Arc::new(BooleanArray::from(array_data)) as ArrayRef,
-            PhysicalType::INT32 => Arc::new(Int32Array::from(array_data)) as ArrayRef,
-            PhysicalType::INT64 => Arc::new(Int64Array::from(array_data)) as ArrayRef,
-            PhysicalType::FLOAT => Arc::new(Float32Array::from(array_data)) as ArrayRef,
-            PhysicalType::DOUBLE => Arc::new(Float64Array::from(array_data)) as ArrayRef,
-            PhysicalType::INT96 => {
-                Arc::new(TimestampNanosecondArray::from(array_data)) as ArrayRef
-            }
+        let array: ArrayRef = match T::get_physical_type() {
+            PhysicalType::BOOLEAN => Arc::new(BooleanArray::from(array_data)),
+            PhysicalType::INT32 => match array_data.data_type() {
+                ArrowType::UInt32 => Arc::new(UInt32Array::from(array_data)),
+                ArrowType::Int32 => Arc::new(Int32Array::from(array_data)),
+                _ => unreachable!(),
+            },
+            PhysicalType::INT64 => match array_data.data_type() {
+                ArrowType::UInt64 => Arc::new(UInt64Array::from(array_data)),
+                ArrowType::Int64 => Arc::new(Int64Array::from(array_data)),
+                _ => unreachable!(),
+            },
+            PhysicalType::FLOAT => Arc::new(Float32Array::from(array_data)),
+            PhysicalType::DOUBLE => Arc::new(Float64Array::from(array_data)),
+            PhysicalType::INT96 => Arc::new(TimestampNanosecondArray::from(array_data)),
             PhysicalType::BYTE_ARRAY | PhysicalType::FIXED_LEN_BYTE_ARRAY => {
                 unreachable!(
                     "PrimitiveArrayReaders don't support complex physical types"