You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by vi...@apache.org on 2023/01/22 01:03:03 UTC

[arrow-rs] branch master updated: Remove unwrap on datetime cast for CSV writer (#3570)

This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 24e5daef3 Remove unwrap on datetime cast for CSV writer (#3570)
24e5daef3 is described below

commit 24e5daef3248c38a6fb354c8427c9ba653e2b3e9
Author: comphead <co...@users.noreply.github.com>
AuthorDate: Sat Jan 21 17:02:56 2023 -0800

    Remove unwrap on datetime cast for CSV writer (#3570)
    
    * avoid unwrap on casting
    
    * avoid unwrap on cast
    
    * fmt
    
    * fixes
---
 arrow-csv/src/writer.rs | 143 +++++++++++++++++++++++++++++++++++-------------
 1 file changed, 104 insertions(+), 39 deletions(-)

diff --git a/arrow-csv/src/writer.rs b/arrow-csv/src/writer.rs
index c5eed7f1e..3ab28c2df 100644
--- a/arrow-csv/src/writer.rs
+++ b/arrow-csv/src/writer.rs
@@ -88,6 +88,26 @@ where
     lexical_to_string(c.value(i))
 }
 
+fn invalid_cast_error(dt: &str, col_index: usize, row_index: usize) -> ArrowError {
+    ArrowError::CastError(format!(
+        "Cannot cast to {} at col index: {} row index: {}",
+        dt, col_index, row_index
+    ))
+}
+
+macro_rules! write_temporal_value {
+    ($array:expr, $tpe: ident, $format: expr, $col_index: expr, $row_index: expr, $cast_func: ident, $tpe_name: expr) => {{
+        $array
+            .as_any()
+            .downcast_ref::<$tpe>()
+            .ok_or_else(|| invalid_cast_error($tpe_name, $col_index, $row_index))?
+            .$cast_func($row_index)
+            .ok_or_else(|| invalid_cast_error($tpe_name, $col_index, $row_index))?
+            .format($format)
+            .to_string()
+    }};
+}
+
 /// A CSV writer
 #[derive(Debug)]
 pub struct Writer<W: Write> {
@@ -171,55 +191,70 @@ impl<W: Write> Writer<W> {
                     c.value(row_index).to_owned()
                 }
                 DataType::Date32 => {
-                    let c = col.as_any().downcast_ref::<Date32Array>().unwrap();
-                    c.value_as_date(row_index)
-                        .unwrap()
-                        .format(&self.date_format)
-                        .to_string()
+                    write_temporal_value!(
+                        col,
+                        Date32Array,
+                        &self.date_format,
+                        col_index,
+                        row_index,
+                        value_as_date,
+                        "Date32"
+                    )
                 }
                 DataType::Date64 => {
-                    let c = col.as_any().downcast_ref::<Date64Array>().unwrap();
-                    c.value_as_datetime(row_index)
-                        .unwrap()
-                        .format(&self.datetime_format)
-                        .to_string()
+                    write_temporal_value!(
+                        col,
+                        Date64Array,
+                        &self.datetime_format,
+                        col_index,
+                        row_index,
+                        value_as_datetime,
+                        "Date64"
+                    )
                 }
                 DataType::Time32(TimeUnit::Second) => {
-                    let c = col.as_any().downcast_ref::<Time32SecondArray>().unwrap();
-                    c.value_as_time(row_index)
-                        .unwrap()
-                        .format(&self.time_format)
-                        .to_string()
+                    write_temporal_value!(
+                        col,
+                        Time32SecondArray,
+                        &self.time_format,
+                        col_index,
+                        row_index,
+                        value_as_time,
+                        "Time32"
+                    )
                 }
                 DataType::Time32(TimeUnit::Millisecond) => {
-                    let c = col
-                        .as_any()
-                        .downcast_ref::<Time32MillisecondArray>()
-                        .unwrap();
-                    c.value_as_time(row_index)
-                        .unwrap()
-                        .format(&self.time_format)
-                        .to_string()
+                    write_temporal_value!(
+                        col,
+                        Time32MillisecondArray,
+                        &self.time_format,
+                        col_index,
+                        row_index,
+                        value_as_time,
+                        "Time32"
+                    )
                 }
                 DataType::Time64(TimeUnit::Microsecond) => {
-                    let c = col
-                        .as_any()
-                        .downcast_ref::<Time64MicrosecondArray>()
-                        .unwrap();
-                    c.value_as_time(row_index)
-                        .unwrap()
-                        .format(&self.time_format)
-                        .to_string()
+                    write_temporal_value!(
+                        col,
+                        Time64MicrosecondArray,
+                        &self.time_format,
+                        col_index,
+                        row_index,
+                        value_as_time,
+                        "Time64"
+                    )
                 }
                 DataType::Time64(TimeUnit::Nanosecond) => {
-                    let c = col
-                        .as_any()
-                        .downcast_ref::<Time64NanosecondArray>()
-                        .unwrap();
-                    c.value_as_time(row_index)
-                        .unwrap()
-                        .format(&self.time_format)
-                        .to_string()
+                    write_temporal_value!(
+                        col,
+                        Time64NanosecondArray,
+                        &self.time_format,
+                        col_index,
+                        row_index,
+                        value_as_time,
+                        "Time64"
+                    )
                 }
                 DataType::Timestamp(time_unit, time_zone) => {
                     self.handle_timestamp(time_unit, time_zone.as_ref(), row_index, col)?
@@ -672,4 +707,34 @@ sed do eiusmod tempor,-556132.25,1,,2019-04-18T02:45:55.555000000,23:46:03,foo
         let expected = nanoseconds.into_iter().map(Some).collect::<Vec<_>>();
         assert_eq!(actual, expected);
     }
+
+    #[test]
+    fn test_write_csv_invalid_cast() {
+        let schema = Schema::new(vec![
+            Field::new("c0", DataType::UInt32, false),
+            Field::new("c1", DataType::Date64, false),
+        ]);
+
+        let c0 = UInt32Array::from(vec![Some(123), Some(234)]);
+        let c1 = Date64Array::from(vec![Some(1926632005177), Some(1926632005177685347)]);
+        let batch =
+            RecordBatch::try_new(Arc::new(schema), vec![Arc::new(c0), Arc::new(c1)])
+                .unwrap();
+
+        let mut file = tempfile::tempfile().unwrap();
+        let mut writer = Writer::new(&mut file);
+        let batches = vec![&batch, &batch];
+        for batch in batches {
+            writer
+                .write(batch)
+                .map_err(|e| {
+                    dbg!(e.to_string());
+                    assert!(e.to_string().ends_with(
+                        invalid_cast_error("Date64", 1, 1).to_string().as_str()
+                    ))
+                })
+                .unwrap_err();
+        }
+        drop(writer);
+    }
 }