You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by tu...@apache.org on 2023/05/12 17:25:10 UTC

[arrow-rs] branch master updated: Add RecordBatchWriter trait and implement it for CSV, JSON, IPC and Parquet (#4206)

This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 019040814 Add RecordBatchWriter trait and implement it for CSV, JSON, IPC and Parquet (#4206)
019040814 is described below

commit 0190408147a34c6c08fcc9ba57443c629c678ca6
Author: Alexandre Crayssac <al...@gmail.com>
AuthorDate: Fri May 12 19:25:04 2023 +0200

    Add RecordBatchWriter trait and implement it for CSV, JSON, IPC and Parquet (#4206)
    
    Co-authored-by: alexandreyc <al...@crayssac.net>
---
 arrow-array/src/lib.rs                |  1 +
 arrow-array/src/record_batch.rs       |  6 ++++
 arrow-csv/src/writer.rs               |  6 ++++
 arrow-ipc/src/writer.rs               | 12 ++++++++
 arrow-json/src/writer.rs              | 56 +++++++++++++++++++++--------------
 arrow/benches/json_reader.rs          |  2 +-
 parquet/src/arrow/arrow_writer/mod.rs | 10 +++++--
 7 files changed, 67 insertions(+), 26 deletions(-)

diff --git a/arrow-array/src/lib.rs b/arrow-array/src/lib.rs
index 6ee9f7f1d..46de381c3 100644
--- a/arrow-array/src/lib.rs
+++ b/arrow-array/src/lib.rs
@@ -183,6 +183,7 @@ pub use array::*;
 mod record_batch;
 pub use record_batch::{
     RecordBatch, RecordBatchIterator, RecordBatchOptions, RecordBatchReader,
+    RecordBatchWriter,
 };
 
 mod arithmetic;
diff --git a/arrow-array/src/record_batch.rs b/arrow-array/src/record_batch.rs
index bd1cc65c7..aea49c047 100644
--- a/arrow-array/src/record_batch.rs
+++ b/arrow-array/src/record_batch.rs
@@ -43,6 +43,12 @@ pub trait RecordBatchReader: Iterator<Item = Result<RecordBatch, ArrowError>> {
     }
 }
 
+/// Trait for types that can write `RecordBatch`'s.
+pub trait RecordBatchWriter {
+    /// Write a single batch to the writer.
+    fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError>;
+}
+
 /// A two-dimensional batch of column-oriented data with a defined
 /// [schema](arrow_schema::Schema).
 ///
diff --git a/arrow-csv/src/writer.rs b/arrow-csv/src/writer.rs
index 5f542be30..ba2123a09 100644
--- a/arrow-csv/src/writer.rs
+++ b/arrow-csv/src/writer.rs
@@ -193,6 +193,12 @@ impl<W: Write> Writer<W> {
     }
 }
 
+impl<W: Write> RecordBatchWriter for Writer<W> {
+    fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> {
+        self.write(batch)
+    }
+}
+
 /// A CSV writer builder
 #[derive(Clone, Debug)]
 pub struct WriterBuilder {
diff --git a/arrow-ipc/src/writer.rs b/arrow-ipc/src/writer.rs
index b2fcec08d..fcfd4d97a 100644
--- a/arrow-ipc/src/writer.rs
+++ b/arrow-ipc/src/writer.rs
@@ -857,6 +857,12 @@ impl<W: Write> FileWriter<W> {
     }
 }
 
+impl<W: Write> RecordBatchWriter for FileWriter<W> {
+    fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> {
+        self.write(batch)
+    }
+}
+
 pub struct StreamWriter<W: Write> {
     /// The object to write to
     writer: BufWriter<W>,
@@ -991,6 +997,12 @@ impl<W: Write> StreamWriter<W> {
     }
 }
 
+impl<W: Write> RecordBatchWriter for StreamWriter<W> {
+    fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> {
+        self.write(batch)
+    }
+}
+
 /// Stores the encoded data, which is an crate::Message, and optional Arrow data
 pub struct EncodedData {
     /// An encoded crate::Message
diff --git a/arrow-json/src/writer.rs b/arrow-json/src/writer.rs
index d610dd9a3..6f241be40 100644
--- a/arrow-json/src/writer.rs
+++ b/arrow-json/src/writer.rs
@@ -35,7 +35,7 @@
 //! let a = Int32Array::from(vec![1, 2, 3]);
 //! let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]).unwrap();
 //!
-//! let json_rows = arrow_json::writer::record_batches_to_json_rows(&[batch]).unwrap();
+//! let json_rows = arrow_json::writer::record_batches_to_json_rows(&[&batch]).unwrap();
 //! assert_eq!(
 //!     serde_json::Value::Object(json_rows[1].clone()),
 //!     serde_json::json!({"a": 2}),
@@ -59,7 +59,7 @@
 //! // Write the record batch out as JSON
 //! let buf = Vec::new();
 //! let mut writer = arrow_json::LineDelimitedWriter::new(buf);
-//! writer.write_batches(&vec![batch]).unwrap();
+//! writer.write_batches(&vec![&batch]).unwrap();
 //! writer.finish().unwrap();
 //!
 //! // Get the underlying buffer back,
@@ -85,7 +85,7 @@
 //! // Write the record batch out as a JSON array
 //! let buf = Vec::new();
 //! let mut writer = arrow_json::ArrayWriter::new(buf);
-//! writer.write_batches(&vec![batch]).unwrap();
+//! writer.write_batches(&vec![&batch]).unwrap();
 //! writer.finish().unwrap();
 //!
 //! // Get the underlying buffer back,
@@ -390,7 +390,7 @@ fn set_column_for_json_rows(
 /// Converts an arrow [`RecordBatch`] into a `Vec` of Serde JSON
 /// [`JsonMap`]s (objects)
 pub fn record_batches_to_json_rows(
-    batches: &[RecordBatch],
+    batches: &[&RecordBatch],
 ) -> Result<Vec<JsonMap<String, Value>>, ArrowError> {
     let mut rows: Vec<JsonMap<String, Value>> = iter::repeat(JsonMap::new())
         .take(batches.iter().map(|b| b.num_rows()).sum())
@@ -554,7 +554,7 @@ where
     }
 
     /// Convert the `RecordBatch` into JSON rows, and write them to the output
-    pub fn write(&mut self, batch: RecordBatch) -> Result<(), ArrowError> {
+    pub fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> {
         for row in record_batches_to_json_rows(&[batch])? {
             self.write_row(&Value::Object(row))?;
         }
@@ -562,7 +562,7 @@ where
     }
 
     /// Convert the [`RecordBatch`] into JSON rows, and write them to the output
-    pub fn write_batches(&mut self, batches: &[RecordBatch]) -> Result<(), ArrowError> {
+    pub fn write_batches(&mut self, batches: &[&RecordBatch]) -> Result<(), ArrowError> {
         for row in record_batches_to_json_rows(batches)? {
             self.write_row(&Value::Object(row))?;
         }
@@ -586,6 +586,16 @@ where
     }
 }
 
+impl<W, F> RecordBatchWriter for Writer<W, F>
+where
+    W: Write,
+    F: JsonFormat,
+{
+    fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> {
+        self.write(batch)
+    }
+}
+
 #[cfg(test)]
 mod tests {
     use std::fs::{read_to_string, File};
@@ -631,7 +641,7 @@ mod tests {
         let mut buf = Vec::new();
         {
             let mut writer = LineDelimitedWriter::new(&mut buf);
-            writer.write_batches(&[batch]).unwrap();
+            writer.write_batches(&[&batch]).unwrap();
         }
 
         assert_json_eq(
@@ -662,7 +672,7 @@ mod tests {
         let mut buf = Vec::new();
         {
             let mut writer = LineDelimitedWriter::new(&mut buf);
-            writer.write_batches(&[batch]).unwrap();
+            writer.write_batches(&[&batch]).unwrap();
         }
 
         assert_json_eq(
@@ -704,7 +714,7 @@ mod tests {
         let mut buf = Vec::new();
         {
             let mut writer = LineDelimitedWriter::new(&mut buf);
-            writer.write_batches(&[batch]).unwrap();
+            writer.write_batches(&[&batch]).unwrap();
         }
 
         assert_json_eq(
@@ -759,7 +769,7 @@ mod tests {
         let mut buf = Vec::new();
         {
             let mut writer = LineDelimitedWriter::new(&mut buf);
-            writer.write_batches(&[batch]).unwrap();
+            writer.write_batches(&[&batch]).unwrap();
         }
 
         assert_json_eq(
@@ -818,7 +828,7 @@ mod tests {
         let mut buf = Vec::new();
         {
             let mut writer = LineDelimitedWriter::new(&mut buf);
-            writer.write_batches(&[batch]).unwrap();
+            writer.write_batches(&[&batch]).unwrap();
         }
 
         assert_json_eq(
@@ -864,7 +874,7 @@ mod tests {
         let mut buf = Vec::new();
         {
             let mut writer = LineDelimitedWriter::new(&mut buf);
-            writer.write_batches(&[batch]).unwrap();
+            writer.write_batches(&[&batch]).unwrap();
         }
 
         assert_json_eq(
@@ -907,7 +917,7 @@ mod tests {
         let mut buf = Vec::new();
         {
             let mut writer = LineDelimitedWriter::new(&mut buf);
-            writer.write_batches(&[batch]).unwrap();
+            writer.write_batches(&[&batch]).unwrap();
         }
 
         assert_json_eq(
@@ -950,7 +960,7 @@ mod tests {
         let mut buf = Vec::new();
         {
             let mut writer = LineDelimitedWriter::new(&mut buf);
-            writer.write_batches(&[batch]).unwrap();
+            writer.write_batches(&[&batch]).unwrap();
         }
 
         assert_json_eq(
@@ -1010,7 +1020,7 @@ mod tests {
         let mut buf = Vec::new();
         {
             let mut writer = LineDelimitedWriter::new(&mut buf);
-            writer.write_batches(&[batch]).unwrap();
+            writer.write_batches(&[&batch]).unwrap();
         }
 
         assert_json_eq(
@@ -1053,7 +1063,7 @@ mod tests {
         let mut buf = Vec::new();
         {
             let mut writer = LineDelimitedWriter::new(&mut buf);
-            writer.write_batches(&[batch]).unwrap();
+            writer.write_batches(&[&batch]).unwrap();
         }
 
         assert_json_eq(
@@ -1113,7 +1123,7 @@ mod tests {
         let mut buf = Vec::new();
         {
             let mut writer = LineDelimitedWriter::new(&mut buf);
-            writer.write_batches(&[batch]).unwrap();
+            writer.write_batches(&[&batch]).unwrap();
         }
 
         assert_json_eq(
@@ -1192,7 +1202,7 @@ mod tests {
         let mut buf = Vec::new();
         {
             let mut writer = LineDelimitedWriter::new(&mut buf);
-            writer.write_batches(&[batch]).unwrap();
+            writer.write_batches(&[&batch]).unwrap();
         }
 
         assert_json_eq(
@@ -1217,7 +1227,7 @@ mod tests {
         let mut buf = Vec::new();
         {
             let mut writer = LineDelimitedWriter::new(&mut buf);
-            writer.write_batches(&[batch]).unwrap();
+            writer.write_batches(&[&batch]).unwrap();
         }
 
         let result = String::from_utf8(buf).unwrap();
@@ -1315,7 +1325,7 @@ mod tests {
         let mut buf = Vec::new();
         {
             let mut writer = LineDelimitedWriter::new(&mut buf);
-            writer.write_batches(&[batch]).unwrap();
+            writer.write_batches(&[&batch]).unwrap();
         }
 
         // NOTE: The last value should technically be {"list": [null]} but it appears
@@ -1378,7 +1388,7 @@ mod tests {
         let mut buf = Vec::new();
         {
             let mut writer = LineDelimitedWriter::new(&mut buf);
-            writer.write_batches(&[batch]).unwrap();
+            writer.write_batches(&[&batch]).unwrap();
         }
 
         assert_json_eq(
@@ -1408,7 +1418,7 @@ mod tests {
         let mut buf = Vec::new();
         {
             let mut writer = LineDelimitedWriter::new(&mut buf);
-            writer.write(batch).unwrap();
+            writer.write(&batch).unwrap();
         }
 
         let result = String::from_utf8(buf).unwrap();
@@ -1445,7 +1455,7 @@ mod tests {
         let batch = reader.next().unwrap().unwrap();
 
         // test batches = an empty batch + 2 same batches, finally result should be eq to 2 same batches
-        let batches = [RecordBatch::new_empty(schema), batch.clone(), batch];
+        let batches = [&RecordBatch::new_empty(schema), &batch, &batch];
 
         let mut buf = Vec::new();
         {
diff --git a/arrow/benches/json_reader.rs b/arrow/benches/json_reader.rs
index 8cebc42e4..8f3898c51 100644
--- a/arrow/benches/json_reader.rs
+++ b/arrow/benches/json_reader.rs
@@ -92,7 +92,7 @@ fn large_bench_primitive(c: &mut Criterion) {
     .unwrap();
 
     let mut out = Vec::with_capacity(1024);
-    LineDelimitedWriter::new(&mut out).write(batch).unwrap();
+    LineDelimitedWriter::new(&mut out).write(&batch).unwrap();
 
     let json = std::str::from_utf8(&out).unwrap();
     do_bench(c, "large_bench_primitive", json, schema)
diff --git a/parquet/src/arrow/arrow_writer/mod.rs b/parquet/src/arrow/arrow_writer/mod.rs
index 14eb30f0b..075ecc034 100644
--- a/parquet/src/arrow/arrow_writer/mod.rs
+++ b/parquet/src/arrow/arrow_writer/mod.rs
@@ -23,8 +23,8 @@ use std::sync::Arc;
 
 use arrow_array::cast::AsArray;
 use arrow_array::types::{Decimal128Type, Int32Type, Int64Type, UInt32Type, UInt64Type};
-use arrow_array::{types, Array, ArrayRef, RecordBatch};
-use arrow_schema::{DataType as ArrowDataType, IntervalUnit, SchemaRef};
+use arrow_array::{types, Array, ArrayRef, RecordBatch, RecordBatchWriter};
+use arrow_schema::{ArrowError, DataType as ArrowDataType, IntervalUnit, SchemaRef};
 
 use super::schema::{
     add_encoded_arrow_schema_to_metadata, arrow_to_parquet_schema,
@@ -246,6 +246,12 @@ impl<W: Write> ArrowWriter<W> {
     }
 }
 
+impl<W: Write> RecordBatchWriter for ArrowWriter<W> {
+    fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> {
+        self.write(batch).map_err(|e| e.into())
+    }
+}
+
 fn write_leaves<W: Write>(
     row_group_writer: &mut SerializedRowGroupWriter<'_, W>,
     arrays: &[ArrayRef],