You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by vi...@apache.org on 2023/01/22 01:03:03 UTC
[arrow-rs] branch master updated: Remove unwrap on datetime cast for CSV writer (#3570)
This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 24e5daef3 Remove unwrap on datetime cast for CSV writer (#3570)
24e5daef3 is described below
commit 24e5daef3248c38a6fb354c8427c9ba653e2b3e9
Author: comphead <co...@users.noreply.github.com>
AuthorDate: Sat Jan 21 17:02:56 2023 -0800
Remove unwrap on datetime cast for CSV writer (#3570)
* avoid unwrap on casting
* avoid unwrap on cast
* fmt
* fixes
---
arrow-csv/src/writer.rs | 143 +++++++++++++++++++++++++++++++++++-------------
1 file changed, 104 insertions(+), 39 deletions(-)
diff --git a/arrow-csv/src/writer.rs b/arrow-csv/src/writer.rs
index c5eed7f1e..3ab28c2df 100644
--- a/arrow-csv/src/writer.rs
+++ b/arrow-csv/src/writer.rs
@@ -88,6 +88,26 @@ where
lexical_to_string(c.value(i))
}
+fn invalid_cast_error(dt: &str, col_index: usize, row_index: usize) -> ArrowError {
+ ArrowError::CastError(format!(
+ "Cannot cast to {} at col index: {} row index: {}",
+ dt, col_index, row_index
+ ))
+}
+
+macro_rules! write_temporal_value {
+ ($array:expr, $tpe: ident, $format: expr, $col_index: expr, $row_index: expr, $cast_func: ident, $tpe_name: expr) => {{
+ $array
+ .as_any()
+ .downcast_ref::<$tpe>()
+ .ok_or_else(|| invalid_cast_error($tpe_name, $col_index, $row_index))?
+ .$cast_func($row_index)
+ .ok_or_else(|| invalid_cast_error($tpe_name, $col_index, $row_index))?
+ .format($format)
+ .to_string()
+ }};
+}
+
/// A CSV writer
#[derive(Debug)]
pub struct Writer<W: Write> {
@@ -171,55 +191,70 @@ impl<W: Write> Writer<W> {
c.value(row_index).to_owned()
}
DataType::Date32 => {
- let c = col.as_any().downcast_ref::<Date32Array>().unwrap();
- c.value_as_date(row_index)
- .unwrap()
- .format(&self.date_format)
- .to_string()
+ write_temporal_value!(
+ col,
+ Date32Array,
+ &self.date_format,
+ col_index,
+ row_index,
+ value_as_date,
+ "Date32"
+ )
}
DataType::Date64 => {
- let c = col.as_any().downcast_ref::<Date64Array>().unwrap();
- c.value_as_datetime(row_index)
- .unwrap()
- .format(&self.datetime_format)
- .to_string()
+ write_temporal_value!(
+ col,
+ Date64Array,
+ &self.datetime_format,
+ col_index,
+ row_index,
+ value_as_datetime,
+ "Date64"
+ )
}
DataType::Time32(TimeUnit::Second) => {
- let c = col.as_any().downcast_ref::<Time32SecondArray>().unwrap();
- c.value_as_time(row_index)
- .unwrap()
- .format(&self.time_format)
- .to_string()
+ write_temporal_value!(
+ col,
+ Time32SecondArray,
+ &self.time_format,
+ col_index,
+ row_index,
+ value_as_time,
+ "Time32"
+ )
}
DataType::Time32(TimeUnit::Millisecond) => {
- let c = col
- .as_any()
- .downcast_ref::<Time32MillisecondArray>()
- .unwrap();
- c.value_as_time(row_index)
- .unwrap()
- .format(&self.time_format)
- .to_string()
+ write_temporal_value!(
+ col,
+ Time32MillisecondArray,
+ &self.time_format,
+ col_index,
+ row_index,
+ value_as_time,
+ "Time32"
+ )
}
DataType::Time64(TimeUnit::Microsecond) => {
- let c = col
- .as_any()
- .downcast_ref::<Time64MicrosecondArray>()
- .unwrap();
- c.value_as_time(row_index)
- .unwrap()
- .format(&self.time_format)
- .to_string()
+ write_temporal_value!(
+ col,
+ Time64MicrosecondArray,
+ &self.time_format,
+ col_index,
+ row_index,
+ value_as_time,
+ "Time64"
+ )
}
DataType::Time64(TimeUnit::Nanosecond) => {
- let c = col
- .as_any()
- .downcast_ref::<Time64NanosecondArray>()
- .unwrap();
- c.value_as_time(row_index)
- .unwrap()
- .format(&self.time_format)
- .to_string()
+ write_temporal_value!(
+ col,
+ Time64NanosecondArray,
+ &self.time_format,
+ col_index,
+ row_index,
+ value_as_time,
+ "Time64"
+ )
}
DataType::Timestamp(time_unit, time_zone) => {
self.handle_timestamp(time_unit, time_zone.as_ref(), row_index, col)?
@@ -672,4 +707,34 @@ sed do eiusmod tempor,-556132.25,1,,2019-04-18T02:45:55.555000000,23:46:03,foo
let expected = nanoseconds.into_iter().map(Some).collect::<Vec<_>>();
assert_eq!(actual, expected);
}
+
+ #[test]
+ fn test_write_csv_invalid_cast() {
+ let schema = Schema::new(vec![
+ Field::new("c0", DataType::UInt32, false),
+ Field::new("c1", DataType::Date64, false),
+ ]);
+
+ let c0 = UInt32Array::from(vec![Some(123), Some(234)]);
+ let c1 = Date64Array::from(vec![Some(1926632005177), Some(1926632005177685347)]);
+ let batch =
+ RecordBatch::try_new(Arc::new(schema), vec![Arc::new(c0), Arc::new(c1)])
+ .unwrap();
+
+ let mut file = tempfile::tempfile().unwrap();
+ let mut writer = Writer::new(&mut file);
+ let batches = vec![&batch, &batch];
+ for batch in batches {
+ writer
+ .write(batch)
+ .map_err(|e| {
+ dbg!(e.to_string());
+ assert!(e.to_string().ends_with(
+ invalid_cast_error("Date64", 1, 1).to_string().as_str()
+ ))
+ })
+ .unwrap_err();
+ }
+ drop(writer);
+ }
}