You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by tu...@apache.org on 2023/01/27 14:37:14 UTC
[arrow-rs] branch master updated: Add Push-Based CSV Decoder (#3604)
This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new d9c2681a8 Add Push-Based CSV Decoder (#3604)
d9c2681a8 is described below
commit d9c2681a8a477aa19feb492e536f9e1a034d2c8d
Author: Raphael Taylor-Davies <17...@users.noreply.github.com>
AuthorDate: Fri Jan 27 14:37:07 2023 +0000
Add Push-Based CSV Decoder (#3604)
* Add Push-Based CSV Decoder
* Clippy
* More tests
* Clippy
---
arrow-csv/src/reader/mod.rs | 370 +++++++++++++++++++++++++++-------------
arrow-csv/src/reader/records.rs | 303 ++++++++++++++++++--------------
2 files changed, 427 insertions(+), 246 deletions(-)
diff --git a/arrow-csv/src/reader/mod.rs b/arrow-csv/src/reader/mod.rs
index 82b033f80..cff1337dd 100644
--- a/arrow-csv/src/reader/mod.rs
+++ b/arrow-csv/src/reader/mod.rs
@@ -57,7 +57,7 @@ use arrow_cast::parse::Parser;
use arrow_schema::*;
use crate::map_csv_error;
-use crate::reader::records::{RecordReader, StringRecords};
+use crate::reader::records::{RecordDecoder, StringRecords};
use arrow_data::decimal::validate_decimal_precision;
use csv::StringRecord;
use std::ops::Neg;
@@ -330,24 +330,11 @@ pub type Reader<R> = BufReader<StdBufReader<R>>;
/// CSV file reader
pub struct BufReader<R> {
- /// Explicit schema for the CSV file
- schema: SchemaRef,
- /// Optional projection for which columns to load (zero-based column indices)
- projection: Option<Vec<usize>>,
/// File reader
- reader: RecordReader<R>,
- /// Rows to skip
- to_skip: usize,
- /// Current line number
- line_number: usize,
- /// End line number
- end: usize,
- /// Number of records per batch
- batch_size: usize,
- /// datetime format used to parse datetime values, (format understood by chrono)
- ///
- /// For format refer to [chrono docs](https://docs.rs/chrono/0.4.19/chrono/format/strftime/index.html)
- datetime_format: Option<String>,
+ reader: R,
+
+ /// The decoder
+ decoder: Decoder,
}
impl<R> fmt::Debug for BufReader<R>
@@ -356,10 +343,7 @@ where
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Reader")
- .field("schema", &self.schema)
- .field("projection", &self.projection)
- .field("line_number", &self.line_number)
- .field("datetime_format", &self.datetime_format)
+ .field("decoder", &self.decoder)
.finish()
}
}
@@ -383,7 +367,8 @@ impl<R: Read> Reader<R> {
) -> Self {
let mut builder = ReaderBuilder::new()
.has_header(has_header)
- .with_batch_size(batch_size);
+ .with_batch_size(batch_size)
+ .with_schema(schema);
if let Some(delimiter) = delimiter {
builder = builder.with_delimiter(delimiter);
@@ -397,21 +382,25 @@ impl<R: Read> Reader<R> {
if let Some(format) = datetime_format {
builder = builder.with_datetime_format(format)
}
- builder.build_with_schema(StdBufReader::new(reader), schema)
+
+ Self {
+ decoder: builder.build_decoder(),
+ reader: StdBufReader::new(reader),
+ }
}
/// Returns the schema of the reader, useful for getting the schema without reading
/// record batches
pub fn schema(&self) -> SchemaRef {
- match &self.projection {
+ match &self.decoder.projection {
Some(projection) => {
- let fields = self.schema.fields();
+ let fields = self.decoder.schema.fields();
let projected_fields: Vec<Field> =
projection.iter().map(|i| fields[*i].clone()).collect();
Arc::new(Schema::new(projected_fields))
}
- None => self.schema.clone(),
+ None => self.decoder.schema.clone(),
}
}
@@ -444,38 +433,146 @@ impl<R: Read> Reader<R> {
}
}
+impl<R: BufRead> BufReader<R> {
+ fn read(&mut self) -> Result<Option<RecordBatch>, ArrowError> {
+ loop {
+ let buf = self.reader.fill_buf()?;
+ let decoded = self.decoder.decode(buf)?;
+ if decoded == 0 {
+ break;
+ }
+ self.reader.consume(decoded);
+ }
+
+ self.decoder.flush()
+ }
+}
+
impl<R: BufRead> Iterator for BufReader<R> {
type Item = Result<RecordBatch, ArrowError>;
fn next(&mut self) -> Option<Self::Item> {
+ self.read().transpose()
+ }
+}
+
+/// A push-based interface for decoding CSV data from an arbitrary byte stream
+///
+/// See [`Reader`] for a higher-level interface for interface with [`Read`]
+///
+/// The push-based interface facilitates integration with sources that yield arbitrarily
+/// delimited bytes ranges, such as [`BufRead`], or a chunked byte stream received from
+/// object storage
+///
+/// ```
+/// # use std::io::BufRead;
+/// # use arrow_array::RecordBatch;
+/// # use arrow_csv::ReaderBuilder;
+/// # use arrow_schema::{ArrowError, SchemaRef};
+/// #
+/// fn read_from_csv<R: BufRead>(
+/// mut reader: R,
+/// schema: SchemaRef,
+/// batch_size: usize,
+/// ) -> Result<impl Iterator<Item = Result<RecordBatch, ArrowError>>, ArrowError> {
+/// let mut decoder = ReaderBuilder::new()
+/// .with_schema(schema)
+/// .with_batch_size(batch_size)
+/// .build_decoder();
+///
+/// let mut next = move || {
+/// loop {
+/// let buf = reader.fill_buf()?;
+/// let decoded = decoder.decode(buf)?;
+/// if decoded == 0 {
+/// break;
+/// }
+///
+/// // Consume the number of bytes read
+/// reader.consume(decoded);
+/// }
+/// decoder.flush()
+/// };
+/// Ok(std::iter::from_fn(move || next().transpose()))
+/// }
+/// ```
+#[derive(Debug)]
+pub struct Decoder {
+ /// Explicit schema for the CSV file
+ schema: SchemaRef,
+
+ /// Optional projection for which columns to load (zero-based column indices)
+ projection: Option<Vec<usize>>,
+
+ /// Number of records per batch
+ batch_size: usize,
+
+ /// Rows to skip
+ to_skip: usize,
+
+ /// Current line number
+ line_number: usize,
+
+ /// End line number
+ end: usize,
+
+ /// A decoder for [`StringRecords`]
+ record_decoder: RecordDecoder,
+
+ /// datetime format used to parse datetime values, (format understood by chrono)
+ ///
+ /// For format refer to [chrono docs](https://docs.rs/chrono/0.4.19/chrono/format/strftime/index.html)
+ datetime_format: Option<String>,
+}
+
+impl Decoder {
+ /// Decode records from `buf` returning the number of bytes read
+ ///
+ /// This method returns once `batch_size` objects have been parsed since the
+ /// last call to [`Self::flush`], or `buf` is exhausted. Any remaining bytes
+ /// should be included in the next call to [`Self::decode`]
+ ///
+ /// There is no requirement that `buf` contains a whole number of records, facilitating
+ /// integration with arbitrary byte streams, such as that yielded by [`BufRead`] or
+ /// network sources such as object storage
+ pub fn decode(&mut self, buf: &[u8]) -> Result<usize, ArrowError> {
if self.to_skip != 0 {
- if let Err(e) = self.reader.skip(std::mem::take(&mut self.to_skip)) {
- return Some(Err(e));
- }
+ // Skip in units of `to_read` to avoid over-allocating buffers
+ let to_skip = self.to_skip.min(self.batch_size);
+ let (skipped, bytes) = self.record_decoder.decode(buf, to_skip)?;
+ self.to_skip -= skipped;
+ self.record_decoder.clear();
+ return Ok(bytes);
}
- let remaining = self.end - self.line_number;
- let to_read = self.batch_size.min(remaining);
+ let to_read =
+ self.batch_size.min(self.end - self.line_number) - self.record_decoder.len();
+ let (_, bytes) = self.record_decoder.decode(buf, to_read)?;
+ Ok(bytes)
+ }
- let batch = match self.reader.read(to_read) {
- Ok(b) if b.is_empty() => return None,
- Ok(b) => b,
- Err(e) => return Some(Err(e)),
- };
+ /// Flushes the currently buffered data to a [`RecordBatch`]
+ ///
+ /// This should only be called after [`Self::decode`] has returned `Ok(0)`,
+ /// otherwise may return an error if part way through decoding a record
+ ///
+ /// Returns `Ok(None)` if no buffered data
+ pub fn flush(&mut self) -> Result<Option<RecordBatch>, ArrowError> {
+ if self.record_decoder.is_empty() {
+ return Ok(None);
+ }
- // parse the batches into a RecordBatch
- let result = parse(
- &batch,
+ let rows = self.record_decoder.flush()?;
+ let batch = parse(
+ &rows,
self.schema.fields(),
Some(self.schema.metadata.clone()),
self.projection.as_ref(),
self.line_number,
self.datetime_format.as_deref(),
- );
-
- self.line_number += batch.len();
-
- Some(result)
+ )?;
+ self.line_number += rows.len();
+ Ok(Some(batch))
}
}
@@ -1055,29 +1152,35 @@ impl ReaderBuilder {
mut reader: R,
) -> Result<BufReader<R>, ArrowError> {
// check if schema should be inferred
- let delimiter = self.delimiter.unwrap_or(b',');
- let schema = match self.schema.take() {
- Some(schema) => schema,
- None => {
- let roptions = ReaderOptions {
- delimiter: Some(delimiter),
- max_read_records: self.max_records,
- has_header: self.has_header,
- escape: self.escape,
- quote: self.quote,
- terminator: self.terminator,
- datetime_re: self.datetime_re.take(),
- };
- let (inferred_schema, _) =
- infer_file_schema_with_csv_options(&mut reader, roptions)?;
-
- Arc::new(inferred_schema)
- }
- };
- Ok(self.build_with_schema(reader, schema))
+ if self.schema.is_none() {
+ let delimiter = self.delimiter.unwrap_or(b',');
+ let roptions = ReaderOptions {
+ delimiter: Some(delimiter),
+ max_read_records: self.max_records,
+ has_header: self.has_header,
+ escape: self.escape,
+ quote: self.quote,
+ terminator: self.terminator,
+ datetime_re: self.datetime_re.take(),
+ };
+ let (inferred_schema, _) =
+ infer_file_schema_with_csv_options(&mut reader, roptions)?;
+ self.schema = Some(Arc::new(inferred_schema))
+ }
+
+ Ok(BufReader {
+ reader,
+ decoder: self.build_decoder(),
+ })
}
- fn build_with_schema<R: BufRead>(self, reader: R, schema: SchemaRef) -> BufReader<R> {
+ /// Builds a decoder that can be used to decode CSV from an arbitrary byte stream
+ ///
+ /// # Panics
+ ///
+ /// This method panics if no schema provided
+ pub fn build_decoder(self) -> Decoder {
+ let schema = self.schema.expect("schema should be provided");
let mut reader_builder = csv_core::ReaderBuilder::new();
reader_builder.escape(self.escape);
@@ -1091,7 +1194,7 @@ impl ReaderBuilder {
reader_builder.terminator(csv_core::Terminator::Any(t));
}
let delimiter = reader_builder.build();
- let reader = RecordReader::new(reader, delimiter, schema.fields().len());
+ let record_decoder = RecordDecoder::new(delimiter, schema.fields().len());
let header = self.has_header as usize;
@@ -1100,15 +1203,15 @@ impl ReaderBuilder {
None => (header, usize::MAX),
};
- BufReader {
+ Decoder {
schema,
- projection: self.projection,
- reader,
to_skip: start,
+ record_decoder,
line_number: start,
end,
- batch_size: self.batch_size,
+ projection: self.projection,
datetime_format: self.datetime_format,
+ batch_size: self.batch_size,
}
}
}
@@ -1125,49 +1228,46 @@ mod tests {
#[test]
fn test_csv() {
- let _: Vec<()> = vec![None, Some("%Y-%m-%dT%H:%M:%S%.f%:z".to_string())]
- .into_iter()
- .map(|format| {
- let schema = Schema::new(vec![
- Field::new("city", DataType::Utf8, false),
- Field::new("lat", DataType::Float64, false),
- Field::new("lng", DataType::Float64, false),
- ]);
-
- let file = File::open("test/data/uk_cities.csv").unwrap();
- let mut csv = Reader::new(
- file,
- Arc::new(schema.clone()),
- false,
- None,
- 1024,
- None,
- None,
- format,
- );
- assert_eq!(Arc::new(schema), csv.schema());
- let batch = csv.next().unwrap().unwrap();
- assert_eq!(37, batch.num_rows());
- assert_eq!(3, batch.num_columns());
-
- // access data from a primitive array
- let lat = batch
- .column(1)
- .as_any()
- .downcast_ref::<Float64Array>()
- .unwrap();
- assert_eq!(57.653484, lat.value(0));
-
- // access data from a string array (ListArray<u8>)
- let city = batch
- .column(0)
- .as_any()
- .downcast_ref::<StringArray>()
- .unwrap();
-
- assert_eq!("Aberdeen, Aberdeen City, UK", city.value(13));
- })
- .collect();
+ for format in [None, Some("%Y-%m-%dT%H:%M:%S%.f%:z".to_string())] {
+ let schema = Schema::new(vec![
+ Field::new("city", DataType::Utf8, false),
+ Field::new("lat", DataType::Float64, false),
+ Field::new("lng", DataType::Float64, false),
+ ]);
+
+ let file = File::open("test/data/uk_cities.csv").unwrap();
+ let mut csv = Reader::new(
+ file,
+ Arc::new(schema.clone()),
+ false,
+ None,
+ 1024,
+ None,
+ None,
+ format,
+ );
+ assert_eq!(Arc::new(schema), csv.schema());
+ let batch = csv.next().unwrap().unwrap();
+ assert_eq!(37, batch.num_rows());
+ assert_eq!(3, batch.num_columns());
+
+ // access data from a primitive array
+ let lat = batch
+ .column(1)
+ .as_any()
+ .downcast_ref::<Float64Array>()
+ .unwrap();
+ assert_eq!(57.653484, lat.value(0));
+
+ // access data from a string array (ListArray<u8>)
+ let city = batch
+ .column(0)
+ .as_any()
+ .downcast_ref::<StringArray>()
+ .unwrap();
+
+ assert_eq!("Aberdeen, Aberdeen City, UK", city.value(13));
+ }
}
#[test]
@@ -2089,4 +2189,46 @@ mod tests {
assert!(c.value(2));
assert!(c.is_null(3));
}
+
+ #[test]
+ fn test_buffered() {
+ let tests = [
+ ("test/data/uk_cities.csv", false, 37),
+ ("test/data/various_types.csv", true, 7),
+ ("test/data/decimal_test.csv", false, 10),
+ ];
+
+ for (path, has_header, expected_rows) in tests {
+ for batch_size in [1, 4] {
+ for capacity in [1, 3, 7, 100] {
+ let reader = ReaderBuilder::new()
+ .with_batch_size(batch_size)
+ .has_header(has_header)
+ .build(File::open(path).unwrap())
+ .unwrap();
+
+ let expected = reader.collect::<Result<Vec<_>, _>>().unwrap();
+
+ assert_eq!(
+ expected.iter().map(|x| x.num_rows()).sum::<usize>(),
+ expected_rows
+ );
+
+ let buffered = std::io::BufReader::with_capacity(
+ capacity,
+ File::open(path).unwrap(),
+ );
+
+ let reader = ReaderBuilder::new()
+ .with_batch_size(batch_size)
+ .has_header(has_header)
+ .build_buffered(buffered)
+ .unwrap();
+
+ let actual = reader.collect::<Result<Vec<_>, _>>().unwrap();
+ assert_eq!(expected, actual)
+ }
+ }
+ }
+ }
}
diff --git a/arrow-csv/src/reader/records.rs b/arrow-csv/src/reader/records.rs
index 76adb719e..c4da36ca4 100644
--- a/arrow-csv/src/reader/records.rs
+++ b/arrow-csv/src/reader/records.rs
@@ -17,7 +17,6 @@
use arrow_schema::ArrowError;
use csv_core::{ReadRecordResult, Reader};
-use std::io::BufRead;
/// The estimated length of a field in bytes
const AVERAGE_FIELD_SIZE: usize = 8;
@@ -25,112 +24,165 @@ const AVERAGE_FIELD_SIZE: usize = 8;
/// The minimum amount of data in a single read
const MIN_CAPACITY: usize = 1024;
-pub struct RecordReader<R> {
- reader: R,
+/// [`RecordDecoder`] provides a push-based interface to decoder [`StringRecords`]
+#[derive(Debug)]
+pub struct RecordDecoder {
delimiter: Reader,
+ /// The expected number of fields per row
num_columns: usize,
+ /// The current line number
line_number: usize,
+
+ /// Offsets delimiting field start positions
offsets: Vec<usize>,
+
+ /// The current offset into `self.offsets`
+ ///
+ /// We track this independently of Vec to avoid re-zeroing memory
+ offsets_len: usize,
+
+ /// The number of fields read for the current record
+ current_field: usize,
+
+ /// The number of rows buffered
+ num_rows: usize,
+
+ /// Decoded field data
data: Vec<u8>,
+
+ /// Offsets into data
+ ///
+ /// We track this independently of Vec to avoid re-zeroing memory
+ data_len: usize,
}
-impl<R: BufRead> RecordReader<R> {
- pub fn new(reader: R, delimiter: Reader, num_columns: usize) -> Self {
+impl RecordDecoder {
+ pub fn new(delimiter: Reader, num_columns: usize) -> Self {
Self {
- reader,
delimiter,
num_columns,
line_number: 1,
offsets: vec![],
+ offsets_len: 1, // The first offset is always 0
+ current_field: 0,
+ data_len: 0,
data: vec![],
+ num_rows: 0,
}
}
- /// Clears and then fills the buffers on this [`RecordReader`]
- /// returning the number of records read
- fn fill_buf(&mut self, to_read: usize) -> Result<usize, ArrowError> {
- // Reserve sufficient capacity in offsets
- self.offsets.resize(to_read * self.num_columns + 1, 0);
-
- let mut read = 0;
+ /// Decodes records from `input` returning the number of records and bytes read
+ ///
+ /// Note: this expects to be called with an empty `input` to signal EOF
+ pub fn decode(
+ &mut self,
+ input: &[u8],
+ to_read: usize,
+ ) -> Result<(usize, usize), ArrowError> {
if to_read == 0 {
- return Ok(0);
+ return Ok((0, 0));
}
- // The current offset into `self.data`
- let mut output_offset = 0;
+ // Reserve sufficient capacity in offsets
+ self.offsets
+ .resize(self.offsets_len + to_read * self.num_columns, 0);
+
// The current offset into `input`
let mut input_offset = 0;
- // The current offset into `self.offsets`
- let mut field_offset = 1;
- // The number of fields read for the current row
- let mut field_count = 0;
-
- 'outer: loop {
- let input = self.reader.fill_buf()?;
-
- 'input: loop {
- // Reserve necessary space in output data based on best estimate
- let remaining_rows = to_read - read;
- let capacity = remaining_rows * self.num_columns * AVERAGE_FIELD_SIZE;
- let estimated_data = capacity.max(MIN_CAPACITY);
- self.data.resize(output_offset + estimated_data, 0);
-
- loop {
- let (result, bytes_read, bytes_written, end_positions) =
- self.delimiter.read_record(
- &input[input_offset..],
- &mut self.data[output_offset..],
- &mut self.offsets[field_offset..],
- );
-
- field_count += end_positions;
- field_offset += end_positions;
- input_offset += bytes_read;
- output_offset += bytes_written;
-
- match result {
- ReadRecordResult::End => break 'outer, // Reached end of file
- ReadRecordResult::InputEmpty => break 'input, // Input exhausted, need to read more
- ReadRecordResult::OutputFull => break, // Need to allocate more capacity
- ReadRecordResult::OutputEndsFull => {
- let line_number = self.line_number + read;
- return Err(ArrowError::CsvError(format!("incorrect number of fields for line {}, expected {} got more than {}", line_number, self.num_columns, field_count)));
+
+ // The number of rows decoded in this pass
+ let mut read = 0;
+
+ loop {
+ // Reserve necessary space in output data based on best estimate
+ let remaining_rows = to_read - read;
+ let capacity = remaining_rows * self.num_columns * AVERAGE_FIELD_SIZE;
+ let estimated_data = capacity.max(MIN_CAPACITY);
+ self.data.resize(self.data_len + estimated_data, 0);
+
+ // Try to read a record
+ loop {
+ let (result, bytes_read, bytes_written, end_positions) =
+ self.delimiter.read_record(
+ &input[input_offset..],
+ &mut self.data[self.data_len..],
+ &mut self.offsets[self.offsets_len..],
+ );
+
+ self.current_field += end_positions;
+ self.offsets_len += end_positions;
+ input_offset += bytes_read;
+ self.data_len += bytes_written;
+
+ match result {
+ ReadRecordResult::End | ReadRecordResult::InputEmpty => {
+ // Reached end of input
+ return Ok((read, input_offset));
+ }
+ // Need to allocate more capacity
+ ReadRecordResult::OutputFull => break,
+ ReadRecordResult::OutputEndsFull => {
+ return Err(ArrowError::CsvError(format!("incorrect number of fields for line {}, expected {} got more than {}", self.line_number, self.num_columns, self.current_field)));
+ }
+ ReadRecordResult::Record => {
+ if self.current_field != self.num_columns {
+ return Err(ArrowError::CsvError(format!("incorrect number of fields for line {}, expected {} got {}", self.line_number, self.num_columns, self.current_field)));
}
- ReadRecordResult::Record => {
- if field_count != self.num_columns {
- let line_number = self.line_number + read;
- return Err(ArrowError::CsvError(format!("incorrect number of fields for line {}, expected {} got {}", line_number, self.num_columns, field_count)));
- }
- read += 1;
- field_count = 0;
-
- if read == to_read {
- break 'outer; // Read sufficient rows
- }
-
- if input.len() == input_offset {
- // Input exhausted, need to read more
- // Without this read_record will interpret the empty input
- // byte array as indicating the end of the file
- break 'input;
- }
+ read += 1;
+ self.current_field = 0;
+ self.line_number += 1;
+ self.num_rows += 1;
+
+ if read == to_read {
+ // Read sufficient rows
+ return Ok((read, input_offset));
+ }
+
+ if input.len() == input_offset {
+ // Input exhausted, need to read more
+ // Without this read_record will interpret the empty input
+ // byte array as indicating the end of the file
+ return Ok((read, input_offset));
}
}
}
}
- self.reader.consume(input_offset);
- input_offset = 0;
}
- self.reader.consume(input_offset);
+ }
+
+ /// Returns the current number of buffered records
+ pub fn len(&self) -> usize {
+ self.num_rows
+ }
+
+ /// Returns true if the decoder is empty
+ pub fn is_empty(&self) -> bool {
+ self.num_rows == 0
+ }
+
+ /// Clears the current contents of the decoder
+ pub fn clear(&mut self) {
+ // This does not reset current_field to allow clearing part way through a record
+ self.offsets_len = 1;
+ self.data_len = 0;
+ self.num_rows = 0;
+ }
+
+ /// Flushes the current contents of the reader
+ pub fn flush(&mut self) -> Result<StringRecords<'_>, ArrowError> {
+ if self.current_field != 0 {
+ return Err(ArrowError::CsvError(
+ "Cannot flush part way through record".to_string(),
+ ));
+ }
// csv_core::Reader writes end offsets relative to the start of the row
// Therefore scan through and offset these based on the cumulative row offsets
let mut row_offset = 0;
- self.offsets[1..]
- .chunks_mut(self.num_columns)
+ self.offsets[1..self.offsets_len]
+ .chunks_exact_mut(self.num_columns)
.for_each(|row| {
let offset = row_offset;
row.iter_mut().for_each(|x| {
@@ -139,48 +191,23 @@ impl<R: BufRead> RecordReader<R> {
});
});
- self.line_number += read;
-
- Ok(read)
- }
-
- /// Skips forward `to_skip` rows, returning an error if insufficient lines in source
- pub fn skip(&mut self, to_skip: usize) -> Result<(), ArrowError> {
- // TODO: This could be done by scanning for unquoted newline delimiters
- let mut skipped = 0;
- while to_skip > skipped {
- let read = self.fill_buf(to_skip.min(1024))?;
- if read == 0 {
- return Err(ArrowError::CsvError(format!(
- "Failed to skip {to_skip} rows only found {skipped}"
- )));
- }
-
- skipped += read;
- }
- Ok(())
- }
-
- /// Reads up to `to_read` rows from the reader
- pub fn read(&mut self, to_read: usize) -> Result<StringRecords<'_>, ArrowError> {
- let num_rows = self.fill_buf(to_read)?;
-
- // Need to slice fields to the actual number of rows read
- //
- // We intentionally avoid using `Vec::truncate` to avoid having
- // to re-initialize the data again
- let num_fields = num_rows * self.num_columns;
- let last_offset = self.offsets[num_fields];
-
- // Need to truncate data to the actual amount of data read
- let data = std::str::from_utf8(&self.data[..last_offset]).map_err(|e| {
+ // Need to truncate data t1o the actual amount of data read
+ let data = std::str::from_utf8(&self.data[..self.data_len]).map_err(|e| {
ArrowError::CsvError(format!("Encountered invalid UTF-8 data: {e}"))
})?;
+ let offsets = &self.offsets[..self.offsets_len];
+ let num_rows = self.num_rows;
+
+ // Reset state
+ self.offsets_len = 1;
+ self.data_len = 0;
+ self.num_rows = 0;
+
Ok(StringRecords {
num_rows,
num_columns: self.num_columns,
- offsets: &self.offsets[..num_fields + 1],
+ offsets,
data,
})
}
@@ -208,10 +235,6 @@ impl<'a> StringRecords<'a> {
self.num_rows
}
- pub fn is_empty(&self) -> bool {
- self.num_rows == 0
- }
-
pub fn iter(&self) -> impl Iterator<Item = StringRecord<'a>> + '_ {
(0..self.num_rows).map(|x| self.get(x))
}
@@ -237,9 +260,9 @@ impl<'a> StringRecord<'a> {
#[cfg(test)]
mod tests {
- use crate::reader::records::RecordReader;
+ use crate::reader::records::RecordDecoder;
use csv_core::Reader;
- use std::io::Cursor;
+ use std::io::{BufRead, BufReader, Cursor};
#[test]
fn test_basic() {
@@ -259,30 +282,43 @@ mod tests {
]
.into_iter();
- let cursor = Cursor::new(csv.as_bytes());
- let mut reader = RecordReader::new(cursor, Reader::new(), 3);
+ let mut reader = BufReader::with_capacity(3, Cursor::new(csv.as_bytes()));
+ let mut decoder = RecordDecoder::new(Reader::new(), 3);
loop {
- let b = reader.read(3).unwrap();
- if b.is_empty() {
+ let to_read = 3;
+ let mut read = 0;
+ loop {
+ let buf = reader.fill_buf().unwrap();
+ let (records, bytes) = decoder.decode(buf, to_read - read).unwrap();
+
+ reader.consume(bytes);
+ read += records;
+
+ if read == to_read || bytes == 0 {
+ break;
+ }
+ }
+ if read == 0 {
break;
}
+ let b = decoder.flush().unwrap();
b.iter().zip(&mut expected).for_each(|(record, expected)| {
let actual = (0..3)
.map(|field_idx| record.get(field_idx))
.collect::<Vec<_>>();
assert_eq!(actual, expected)
- })
+ });
}
+ assert!(expected.next().is_none());
}
#[test]
fn test_invalid_fields() {
let csv = "a,b\nb,c\na\n";
- let cursor = Cursor::new(csv.as_bytes());
- let mut reader = RecordReader::new(cursor, Reader::new(), 2);
- let err = reader.read(4).unwrap_err().to_string();
+ let mut decoder = RecordDecoder::new(Reader::new(), 2);
+ let err = decoder.decode(csv.as_bytes(), 4).unwrap_err().to_string();
let expected =
"Csv error: incorrect number of fields for line 3, expected 2 got 1";
@@ -290,19 +326,22 @@ mod tests {
assert_eq!(err, expected);
// Test with initial skip
- let cursor = Cursor::new(csv.as_bytes());
- let mut reader = RecordReader::new(cursor, Reader::new(), 2);
- reader.skip(1).unwrap();
- let err = reader.read(4).unwrap_err().to_string();
+ let mut decoder = RecordDecoder::new(Reader::new(), 2);
+ let (skipped, bytes) = decoder.decode(csv.as_bytes(), 1).unwrap();
+ assert_eq!(skipped, 1);
+ decoder.clear();
+
+ let remaining = &csv.as_bytes()[bytes..];
+ let err = decoder.decode(remaining, 3).unwrap_err().to_string();
assert_eq!(err, expected);
}
#[test]
fn test_skip_insufficient_rows() {
let csv = "a\nv\n";
- let cursor = Cursor::new(csv.as_bytes());
- let mut reader = RecordReader::new(cursor, Reader::new(), 1);
- let err = reader.skip(3).unwrap_err().to_string();
- assert_eq!(err, "Csv error: Failed to skip 3 rows only found 2");
+ let mut decoder = RecordDecoder::new(Reader::new(), 1);
+ let (read, bytes) = decoder.decode(csv.as_bytes(), 3).unwrap();
+ assert_eq!(read, 2);
+ assert_eq!(bytes, csv.len());
}
}