You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by we...@apache.org on 2018/07/24 19:23:11 UTC
[arrow] branch master updated: ARROW-2859: [Python] Accept
buffer-like objects as sources in open_file, open_stream APIs
This is an automated email from the ASF dual-hosted git repository.
wesm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 2a0128d ARROW-2859: [Python] Accept buffer-like objects as sources in open_file, open_stream APIs
2a0128d is described below
commit 2a0128dec863ebf14139372cca8b618d54e7e7dc
Author: Wes McKinney <we...@apache.org>
AuthorDate: Tue Jul 24 15:23:06 2018 -0400
ARROW-2859: [Python] Accept buffer-like objects as sources in open_file, open_stream APIs
The behavior had been to treat a string-like object like a file name; we didn't have any APIs that made use of this fact, and I think that being able to read a stream from an object importing the buffer protocol is much more convenient and natural as `pa.open_stream(buf)` than `pa.open_stream(pa.BufferReader(buf))`.
I may look at quickly adding support for pathlib.Path objects here.
I also added the precursor for addressing ARROW-2807
Author: Wes McKinney <we...@apache.org>
Closes #2314 from wesm/ARROW-2859 and squashes the following commits:
b64a828c <Wes McKinney> Fix docstrings
5cc363f8 <Wes McKinney> Amend usages of get_result, add FutureWarning
b11e5328 <Wes McKinney> Add pathlib test. Refactor to use pytest
53f32e84 <Wes McKinney> Add test for stream from buffer protocol
a6fc8f1c <Wes McKinney> Do not try to open file from buffer input, add use_memory_map flag
---
python/doc/source/ipc.rst | 4 +-
python/doc/source/memory.rst | 2 +-
python/pyarrow/_orc.pyx | 4 +-
python/pyarrow/_parquet.pyx | 5 +-
python/pyarrow/feather.pxi | 4 +-
python/pyarrow/io.pxi | 30 +-
python/pyarrow/ipc.pxi | 17 +-
python/pyarrow/ipc.py | 16 +-
python/pyarrow/lib.pxd | 3 +-
python/pyarrow/serialization.pxi | 2 +-
python/pyarrow/tests/test_io.py | 10 +-
python/pyarrow/tests/test_ipc.py | 543 +++++++++++++++++++++--------------
python/pyarrow/tests/test_parquet.py | 12 +-
13 files changed, 395 insertions(+), 257 deletions(-)
diff --git a/python/doc/source/ipc.rst b/python/doc/source/ipc.rst
index 51d523c..738ae1d 100644
--- a/python/doc/source/ipc.rst
+++ b/python/doc/source/ipc.rst
@@ -79,7 +79,7 @@ particular stream. Now we can do:
writer.write_batch(batch)
writer.close()
- buf = sink.get_result()
+ buf = sink.getvalue()
buf.size
Now ``buf`` contains the complete stream as an in-memory byte buffer. We can
@@ -119,7 +119,7 @@ The :class:`~pyarrow.RecordBatchFileWriter` has the same API as
writer.write_batch(batch)
writer.close()
- buf = sink.get_result()
+ buf = sink.getvalue()
buf.size
The difference between :class:`~pyarrow.RecordBatchFileReader` and
diff --git a/python/doc/source/memory.rst b/python/doc/source/memory.rst
index cd8983a..8fcf5f5 100644
--- a/python/doc/source/memory.rst
+++ b/python/doc/source/memory.rst
@@ -206,7 +206,7 @@ file interfaces that can read and write to Arrow Buffers.
writer = pa.BufferOutputStream()
writer.write(b'hello, friends')
- buf = writer.get_result()
+ buf = writer.getvalue()
buf
buf.size
reader = pa.BufferReader(buf)
diff --git a/python/pyarrow/_orc.pyx b/python/pyarrow/_orc.pyx
index cf04f48..c95bea2 100644
--- a/python/pyarrow/_orc.pyx
+++ b/python/pyarrow/_orc.pyx
@@ -41,13 +41,13 @@ cdef class ORCReader:
def __cinit__(self, MemoryPool memory_pool=None):
self.allocator = maybe_unbox_memory_pool(memory_pool)
- def open(self, object source):
+ def open(self, object source, c_bool use_memory_map=True):
cdef:
shared_ptr[RandomAccessFile] rd_handle
self.source = source
- get_reader(source, &rd_handle)
+ get_reader(source, use_memory_map, &rd_handle)
with nogil:
check_status(ORCFileReader.Open(rd_handle, self.allocator,
&self.reader))
diff --git a/python/pyarrow/_parquet.pyx b/python/pyarrow/_parquet.pyx
index e40a57c..983ff8d 100644
--- a/python/pyarrow/_parquet.pyx
+++ b/python/pyarrow/_parquet.pyx
@@ -636,7 +636,8 @@ cdef class ParquetReader:
self.allocator = maybe_unbox_memory_pool(memory_pool)
self._metadata = None
- def open(self, object source, FileMetaData metadata=None):
+ def open(self, object source, c_bool use_memory_map=True,
+ FileMetaData metadata=None):
cdef:
shared_ptr[RandomAccessFile] rd_handle
shared_ptr[CFileMetaData] c_metadata
@@ -648,7 +649,7 @@ cdef class ParquetReader:
self.source = source
- get_reader(source, &rd_handle)
+ get_reader(source, use_memory_map, &rd_handle)
with nogil:
check_status(OpenFile(rd_handle, self.allocator, properties,
c_metadata, &self.reader))
diff --git a/python/pyarrow/feather.pxi b/python/pyarrow/feather.pxi
index 37c7f92..937e275 100644
--- a/python/pyarrow/feather.pxi
+++ b/python/pyarrow/feather.pxi
@@ -75,9 +75,9 @@ cdef class FeatherReader:
def __cinit__(self):
pass
- def open(self, source):
+ def open(self, source, c_bool use_memory_map=True):
cdef shared_ptr[RandomAccessFile] reader
- get_reader(source, &reader)
+ get_reader(source, use_memory_map, &reader)
with nogil:
check_status(CFeatherReader.Open(reader, &self.reader))
diff --git a/python/pyarrow/io.pxi b/python/pyarrow/io.pxi
index ad9bf0e..8d35d0d 100644
--- a/python/pyarrow/io.pxi
+++ b/python/pyarrow/io.pxi
@@ -930,6 +930,22 @@ cdef class BufferOutputStream(NativeFile):
self.closed = False
def get_result(self):
+ """
+ Deprecated as of 0.10.0. Alias for getvalue()
+ """
+ warnings.warn("BufferOutputStream.get_result() has been renamed "
+ "to getvalue(), will be removed in 0.11.0",
+ FutureWarning)
+ return self.getvalue()
+
+ def getvalue(self):
+ """
+ Finalize output stream and return result as pyarrow.Buffer.
+
+ Returns
+ -------
+ value : Buffer
+ """
with nogil:
check_status(self.wr_file.get().Close())
self.closed = True
@@ -994,7 +1010,14 @@ def foreign_buffer(address, size, base):
return pyarrow_wrap_buffer(buf)
-cdef get_reader(object source, shared_ptr[RandomAccessFile]* reader):
+def as_buffer(object o):
+ if isinstance(o, Buffer):
+ return o
+ return py_buffer(o)
+
+
+cdef get_reader(object source, c_bool use_memory_map,
+ shared_ptr[RandomAccessFile]* reader):
cdef NativeFile nf
try:
@@ -1006,7 +1029,10 @@ cdef get_reader(object source, shared_ptr[RandomAccessFile]* reader):
# Optimistically hope this is file-like
source = PythonFile(source, mode='r')
else:
- source = memory_map(source_path, mode='r')
+ if use_memory_map:
+ source = memory_map(source_path, mode='r')
+ else:
+ source = OSFile(source_path, mode='r')
if isinstance(source, NativeFile):
nf = source
diff --git a/python/pyarrow/ipc.pxi b/python/pyarrow/ipc.pxi
index ccc2f64..2f51142 100644
--- a/python/pyarrow/ipc.pxi
+++ b/python/pyarrow/ipc.pxi
@@ -239,7 +239,13 @@ cdef get_input_stream(object source, shared_ptr[InputStream]* out):
cdef:
shared_ptr[RandomAccessFile] file_handle
- get_reader(source, &file_handle)
+ try:
+ source = as_buffer(source)
+ except TypeError:
+ # Non-buffer-like
+ pass
+
+ get_reader(source, True, &file_handle)
out[0] = <shared_ptr[InputStream]> file_handle
@@ -334,7 +340,12 @@ cdef class _RecordBatchFileReader:
pass
def _open(self, source, footer_offset=None):
- get_reader(source, &self.file)
+ try:
+ source = as_buffer(source)
+ except TypeError:
+ pass
+
+ get_reader(source, True, &self.file)
cdef int64_t offset = 0
if footer_offset is not None:
@@ -522,7 +533,7 @@ def read_schema(obj):
if isinstance(obj, Message):
raise NotImplementedError(type(obj))
- get_reader(obj, &cpp_file)
+ get_reader(obj, True, &cpp_file)
with nogil:
check_status(ReadSchema(cpp_file.get(), &result))
diff --git a/python/pyarrow/ipc.py b/python/pyarrow/ipc.py
index bed2dd6..989e976 100644
--- a/python/pyarrow/ipc.py
+++ b/python/pyarrow/ipc.py
@@ -52,8 +52,8 @@ class RecordBatchStreamReader(lib._RecordBatchReader, _ReadPandasOption):
Parameters
----------
- source : str, pyarrow.NativeFile, or file-like Python object
- Either a file path, or a readable file object
+ source : bytes/buffer-like, pyarrow.NativeFile, or file-like Python object
+ Either an in-memory buffer, or a readable file object
"""
def __init__(self, source):
self._open(source)
@@ -80,8 +80,8 @@ class RecordBatchFileReader(lib._RecordBatchFileReader, _ReadPandasOption):
Parameters
----------
- source : str, pyarrow.NativeFile, or file-like Python object
- Either a file path, or a readable file object
+ source : bytes/buffer-like, pyarrow.NativeFile, or file-like Python object
+ Either an in-memory buffer, or a readable file object
footer_offset : int, default None
If the file is embedded in some larger file, this is the byte offset to
the very end of the file data
@@ -111,8 +111,8 @@ def open_stream(source):
Parameters
----------
- source : str, pyarrow.NativeFile, or file-like Python object
- Either a file path, or a readable file object
+ source : bytes/buffer-like, pyarrow.NativeFile, or file-like Python object
+ Either an in-memory buffer, or a readable file object
footer_offset : int, default None
If the file is embedded in some larger file, this is the byte offset to
the very end of the file data
@@ -130,8 +130,8 @@ def open_file(source, footer_offset=None):
Parameters
----------
- source : str, pyarrow.NativeFile, or file-like Python object
- Either a file path, or a readable file object
+ source : bytes/buffer-like, pyarrow.NativeFile, or file-like Python object
+ Either an in-memory buffer, or a readable file object
footer_offset : int, default None
If the file is embedded in some larger file, this is the byte offset to
the very end of the file data
diff --git a/python/pyarrow/lib.pxd b/python/pyarrow/lib.pxd
index 29e3e3a..e392361 100644
--- a/python/pyarrow/lib.pxd
+++ b/python/pyarrow/lib.pxd
@@ -357,7 +357,8 @@ cdef class NativeFile:
cdef read_handle(self, shared_ptr[RandomAccessFile]* file)
cdef write_handle(self, shared_ptr[OutputStream]* file)
-cdef get_reader(object source, shared_ptr[RandomAccessFile]* reader)
+cdef get_reader(object source, c_bool use_memory_map,
+ shared_ptr[RandomAccessFile]* reader)
cdef get_writer(object source, shared_ptr[OutputStream]* writer)
cdef dict box_metadata(const CKeyValueMetadata* sp_metadata)
diff --git a/python/pyarrow/serialization.pxi b/python/pyarrow/serialization.pxi
index 1ec6073..6407347 100644
--- a/python/pyarrow/serialization.pxi
+++ b/python/pyarrow/serialization.pxi
@@ -372,7 +372,7 @@ def read_serialized(source, base=None):
serialized : the serialized data
"""
cdef shared_ptr[RandomAccessFile] stream
- get_reader(source, &stream)
+ get_reader(source, True, &stream)
cdef SerializedPyObject serialized = SerializedPyObject()
serialized.base = base
diff --git a/python/pyarrow/tests/test_io.py b/python/pyarrow/tests/test_io.py
index f1994b3..eafa40c 100644
--- a/python/pyarrow/tests/test_io.py
+++ b/python/pyarrow/tests/test_io.py
@@ -545,7 +545,7 @@ def test_memory_output_stream():
for i in range(K):
f.write(val)
- buf = f.get_result()
+ buf = f.getvalue()
assert len(buf) == len(val) * K
assert buf.to_pybytes() == val * K
@@ -554,7 +554,7 @@ def test_memory_output_stream():
def test_inmemory_write_after_closed():
f = pa.BufferOutputStream()
f.write(b'ok')
- f.get_result()
+ f.getvalue()
with pytest.raises(ValueError):
f.write(b'not ok')
@@ -586,7 +586,7 @@ def test_nativefile_write_memoryview():
f.write(arr)
f.write(bytearray(data))
- buf = f.get_result()
+ buf = f.getvalue()
assert buf.to_pybytes() == data * 2
@@ -610,7 +610,7 @@ def test_mock_output_stream():
f1.write(val)
f2.write(val)
- assert f1.size() == len(f2.get_result())
+ assert f1.size() == len(f2.getvalue())
# Do the same test with a pandas DataFrame
val = pd.DataFrame({'a': [1, 2, 3]})
@@ -627,7 +627,7 @@ def test_mock_output_stream():
stream_writer1.close()
stream_writer2.close()
- assert f1.size() == len(f2.get_result())
+ assert f1.size() == len(f2.getvalue())
# ----------------------------------------------------------------------
diff --git a/python/pyarrow/tests/test_ipc.py b/python/pyarrow/tests/test_ipc.py
index a779fb2..115d6bd 100644
--- a/python/pyarrow/tests/test_ipc.py
+++ b/python/pyarrow/tests/test_ipc.py
@@ -18,6 +18,7 @@
import io
import pytest
import socket
+import sys
import threading
import numpy as np
@@ -26,19 +27,19 @@ from pandas.util.testing import (assert_frame_equal,
assert_series_equal)
import pandas as pd
-from pyarrow.compat import unittest
import pyarrow as pa
-class MessagingTest(object):
+class IpcFixture(object):
- def setUp(self):
- self.sink = self._get_sink()
+ def __init__(self, sink_factory=lambda: io.BytesIO()):
+ self._sink_factory = sink_factory
+ self.sink = self.get_sink()
- def _get_sink(self):
- return io.BytesIO()
+ def get_sink(self):
+ return self._sink_factory()
- def _get_source(self):
+ def get_source(self):
return self.sink.getvalue()
def write_batches(self, num_batches=5, as_table=False):
@@ -70,20 +71,14 @@ class MessagingTest(object):
return frames, batches
-class TestFile(MessagingTest, unittest.TestCase):
- # Also tests writing zero-copy NumPy array with additional padding
+class FileFormatFixture(IpcFixture):
def _get_writer(self, sink, schema):
return pa.RecordBatchFileWriter(sink, schema)
- def test_empty_file(self):
- buf = io.BytesIO(b'')
- with pytest.raises(pa.ArrowInvalid):
- pa.open_file(buf)
-
def _check_roundtrip(self, as_table=False):
_, batches = self.write_batches(as_table=as_table)
- file_contents = pa.BufferReader(self._get_source())
+ file_contents = pa.BufferReader(self.get_source())
reader = pa.open_file(file_contents)
@@ -95,231 +90,333 @@ class TestFile(MessagingTest, unittest.TestCase):
assert batches[i].equals(batch)
assert reader.schema.equals(batches[0].schema)
- def test_simple_roundtrip(self):
- self._check_roundtrip(as_table=False)
- def test_write_table(self):
- self._check_roundtrip(as_table=True)
+class StreamFormatFixture(IpcFixture):
- def test_read_all(self):
- _, batches = self.write_batches()
- file_contents = pa.BufferReader(self._get_source())
+ def _get_writer(self, sink, schema):
+ return pa.RecordBatchStreamWriter(sink, schema)
- reader = pa.open_file(file_contents)
- result = reader.read_all()
- expected = pa.Table.from_batches(batches)
- assert result.equals(expected)
+class MessageFixture(IpcFixture):
- def test_read_pandas(self):
- frames, _ = self.write_batches()
+ def _get_writer(self, sink, schema):
+ return pa.RecordBatchStreamWriter(sink, schema)
- file_contents = pa.BufferReader(self._get_source())
- reader = pa.open_file(file_contents)
- result = reader.read_pandas()
- expected = pd.concat(frames)
- assert_frame_equal(result, expected)
+@pytest.fixture
+def ipc_fixture():
+ return IpcFixture()
-class TestStream(MessagingTest, unittest.TestCase):
+@pytest.fixture
+def file_fixture():
+ return FileFormatFixture()
- def _get_writer(self, sink, schema):
- return pa.RecordBatchStreamWriter(sink, schema)
- def test_empty_stream(self):
- buf = io.BytesIO(b'')
- with pytest.raises(pa.ArrowInvalid):
- pa.open_stream(buf)
+@pytest.fixture
+def stream_fixture():
+ return StreamFormatFixture()
- def test_categorical_roundtrip(self):
- df = pd.DataFrame({
- 'one': np.random.randn(5),
- 'two': pd.Categorical(['foo', np.nan, 'bar', 'foo', 'foo'],
- categories=['foo', 'bar'],
- ordered=True)
- })
- batch = pa.RecordBatch.from_pandas(df)
- writer = self._get_writer(self.sink, batch.schema)
- writer.write_batch(pa.RecordBatch.from_pandas(df))
- writer.close()
- table = (pa.open_stream(pa.BufferReader(self._get_source()))
- .read_all())
- assert_frame_equal(table.to_pandas(), df)
+def test_empty_file():
+ buf = b''
+ with pytest.raises(pa.ArrowInvalid):
+ pa.open_file(pa.BufferReader(buf))
- def test_stream_write_dispatch(self):
- # ARROW-1616
- df = pd.DataFrame({
- 'one': np.random.randn(5),
- 'two': pd.Categorical(['foo', np.nan, 'bar', 'foo', 'foo'],
- categories=['foo', 'bar'],
- ordered=True)
- })
- table = pa.Table.from_pandas(df, preserve_index=False)
- batch = pa.RecordBatch.from_pandas(df, preserve_index=False)
- writer = self._get_writer(self.sink, table.schema)
- writer.write(table)
- writer.write(batch)
- writer.close()
- table = (pa.open_stream(pa.BufferReader(self._get_source()))
- .read_all())
- assert_frame_equal(table.to_pandas(),
- pd.concat([df, df], ignore_index=True))
+def test_file_simple_roundtrip(file_fixture):
+ file_fixture._check_roundtrip(as_table=False)
- def test_stream_write_table_batches(self):
- # ARROW-504
- df = pd.DataFrame({
- 'one': np.random.randn(20),
- })
- b1 = pa.RecordBatch.from_pandas(df[:10], preserve_index=False)
- b2 = pa.RecordBatch.from_pandas(df, preserve_index=False)
+def test_file_write_table(file_fixture):
+ file_fixture._check_roundtrip(as_table=True)
- table = pa.Table.from_batches([b1, b2, b1])
- writer = self._get_writer(self.sink, table.schema)
- writer.write_table(table, chunksize=15)
- writer.close()
+@pytest.mark.parametrize("sink_factory", [
+ lambda: io.BytesIO(),
+ lambda: pa.BufferOutputStream()
+])
+def test_file_read_all(sink_factory):
+ fixture = FileFormatFixture(sink_factory)
- batches = list(pa.open_stream(pa.BufferReader(self._get_source())))
+ _, batches = fixture.write_batches()
+ file_contents = pa.BufferReader(fixture.get_source())
- assert list(map(len, batches)) == [10, 15, 5, 10]
- result_table = pa.Table.from_batches(batches)
- assert_frame_equal(result_table.to_pandas(),
- pd.concat([df[:10], df, df[:10]],
- ignore_index=True))
+ reader = pa.open_file(file_contents)
- def test_simple_roundtrip(self):
- _, batches = self.write_batches()
- file_contents = pa.BufferReader(self._get_source())
- reader = pa.open_stream(file_contents)
+ result = reader.read_all()
+ expected = pa.Table.from_batches(batches)
+ assert result.equals(expected)
- assert reader.schema.equals(batches[0].schema)
- total = 0
- for i, next_batch in enumerate(reader):
- assert next_batch.equals(batches[i])
- total += 1
+def test_open_file_from_buffer(file_fixture):
+ # ARROW-2859; APIs accept the buffer protocol
+ _, batches = file_fixture.write_batches()
+ source = file_fixture.get_source()
- assert total == len(batches)
+ reader1 = pa.open_file(source)
+ reader2 = pa.open_file(pa.BufferReader(source))
+ reader3 = pa.RecordBatchFileReader(source)
- with pytest.raises(StopIteration):
- reader.get_next_batch()
+ result1 = reader1.read_all()
+ result2 = reader2.read_all()
+ result3 = reader3.read_all()
- def test_read_all(self):
- _, batches = self.write_batches()
- file_contents = pa.BufferReader(self._get_source())
- reader = pa.open_stream(file_contents)
+ assert result1.equals(result2)
+ assert result1.equals(result3)
- result = reader.read_all()
- expected = pa.Table.from_batches(batches)
- assert result.equals(expected)
+def test_file_read_pandas(file_fixture):
+ frames, _ = file_fixture.write_batches()
-class TestMessageReader(MessagingTest, unittest.TestCase):
+ file_contents = pa.BufferReader(file_fixture.get_source())
+ reader = pa.open_file(file_contents)
+ result = reader.read_pandas()
- def _get_example_messages(self):
- _, batches = self.write_batches()
- file_contents = self._get_source()
- buf_reader = pa.BufferReader(file_contents)
- reader = pa.MessageReader.open_stream(buf_reader)
- return batches, list(reader)
+ expected = pd.concat(frames)
+ assert_frame_equal(result, expected)
+
+
+@pytest.mark.skipif(sys.version_info < (3, 6),
+ reason="need Python 3.6")
+def test_file_pathlib(file_fixture, tmpdir):
+ import pathlib
+
+ _, batches = file_fixture.write_batches()
+ source = file_fixture.get_source()
+
+ path = tmpdir.join('file.arrow').strpath
+ with open(path, 'wb') as f:
+ f.write(source)
+
+ t1 = pa.open_file(pathlib.Path(path)).read_all()
+ t2 = pa.open_file(pa.OSFile(path)).read_all()
+
+ assert t1.equals(t2)
+
+
+def test_empty_stream():
+ buf = io.BytesIO(b'')
+ with pytest.raises(pa.ArrowInvalid):
+ pa.open_stream(buf)
+
+
+def test_stream_categorical_roundtrip(stream_fixture):
+ df = pd.DataFrame({
+ 'one': np.random.randn(5),
+ 'two': pd.Categorical(['foo', np.nan, 'bar', 'foo', 'foo'],
+ categories=['foo', 'bar'],
+ ordered=True)
+ })
+ batch = pa.RecordBatch.from_pandas(df)
+ writer = stream_fixture._get_writer(stream_fixture.sink, batch.schema)
+ writer.write_batch(pa.RecordBatch.from_pandas(df))
+ writer.close()
+
+ table = (pa.open_stream(pa.BufferReader(stream_fixture.get_source()))
+ .read_all())
+ assert_frame_equal(table.to_pandas(), df)
+
+
+def test_open_stream_from_buffer(stream_fixture):
+ # ARROW-2859
+ _, batches = stream_fixture.write_batches()
+ source = stream_fixture.get_source()
+
+ reader1 = pa.open_stream(source)
+ reader2 = pa.open_stream(pa.BufferReader(source))
+ reader3 = pa.RecordBatchStreamReader(source)
+
+ result1 = reader1.read_all()
+ result2 = reader2.read_all()
+ result3 = reader3.read_all()
+
+ assert result1.equals(result2)
+ assert result1.equals(result3)
+
+
+def test_stream_write_dispatch(stream_fixture):
+ # ARROW-1616
+ df = pd.DataFrame({
+ 'one': np.random.randn(5),
+ 'two': pd.Categorical(['foo', np.nan, 'bar', 'foo', 'foo'],
+ categories=['foo', 'bar'],
+ ordered=True)
+ })
+ table = pa.Table.from_pandas(df, preserve_index=False)
+ batch = pa.RecordBatch.from_pandas(df, preserve_index=False)
+ writer = stream_fixture._get_writer(stream_fixture.sink, table.schema)
+ writer.write(table)
+ writer.write(batch)
+ writer.close()
+
+ table = (pa.open_stream(pa.BufferReader(stream_fixture.get_source()))
+ .read_all())
+ assert_frame_equal(table.to_pandas(),
+ pd.concat([df, df], ignore_index=True))
- def _get_writer(self, sink, schema):
- return pa.RecordBatchStreamWriter(sink, schema)
- def test_ctors_no_segfault(self):
- with pytest.raises(TypeError):
- repr(pa.Message())
-
- with pytest.raises(TypeError):
- repr(pa.MessageReader())
-
- def test_message_reader(self):
- _, messages = self._get_example_messages()
-
- assert len(messages) == 6
- assert messages[0].type == 'schema'
- for msg in messages[1:]:
- assert msg.type == 'record batch'
-
- def test_serialize_read_message(self):
- _, messages = self._get_example_messages()
-
- msg = messages[0]
- buf = msg.serialize()
-
- restored = pa.read_message(buf)
- restored2 = pa.read_message(pa.BufferReader(buf))
- restored3 = pa.read_message(buf.to_pybytes())
-
- assert msg.equals(restored)
- assert msg.equals(restored2)
- assert msg.equals(restored3)
-
- def test_read_record_batch(self):
- batches, messages = self._get_example_messages()
-
- for batch, message in zip(batches, messages[1:]):
- read_batch = pa.read_record_batch(message, batch.schema)
- assert read_batch.equals(batch)
-
- def test_read_pandas(self):
- frames, _ = self.write_batches()
- file_contents = pa.BufferReader(self._get_source())
- reader = pa.open_stream(file_contents)
- result = reader.read_pandas()
-
- expected = pd.concat(frames)
- assert_frame_equal(result, expected)
-
-
-class TestSocket(MessagingTest, unittest.TestCase):
-
- class StreamReaderServer(threading.Thread):
-
- def init(self, do_read_all):
- self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- self._sock.bind(('127.0.0.1', 0))
- self._sock.listen(1)
- host, port = self._sock.getsockname()
- self._do_read_all = do_read_all
- self._schema = None
- self._batches = []
- self._table = None
- return port
-
- def run(self):
- connection, client_address = self._sock.accept()
- try:
- source = connection.makefile(mode='rb')
- reader = pa.open_stream(source)
- self._schema = reader.schema
- if self._do_read_all:
- self._table = reader.read_all()
- else:
- for i, batch in enumerate(reader):
- self._batches.append(batch)
- finally:
- connection.close()
-
- def get_result(self):
- return(self._schema, self._table if self._do_read_all
- else self._batches)
-
- def setUp(self):
- # NOTE: must start and stop server in test
+def test_stream_write_table_batches(stream_fixture):
+ # ARROW-504
+ df = pd.DataFrame({
+ 'one': np.random.randn(20),
+ })
+
+ b1 = pa.RecordBatch.from_pandas(df[:10], preserve_index=False)
+ b2 = pa.RecordBatch.from_pandas(df, preserve_index=False)
+
+ table = pa.Table.from_batches([b1, b2, b1])
+
+ writer = stream_fixture._get_writer(stream_fixture.sink, table.schema)
+ writer.write_table(table, chunksize=15)
+ writer.close()
+
+ batches = list(pa.open_stream(stream_fixture.get_source()))
+
+ assert list(map(len, batches)) == [10, 15, 5, 10]
+ result_table = pa.Table.from_batches(batches)
+ assert_frame_equal(result_table.to_pandas(),
+ pd.concat([df[:10], df, df[:10]],
+ ignore_index=True))
+
+
+def test_stream_simple_roundtrip(stream_fixture):
+ _, batches = stream_fixture.write_batches()
+ file_contents = pa.BufferReader(stream_fixture.get_source())
+ reader = pa.open_stream(file_contents)
+
+ assert reader.schema.equals(batches[0].schema)
+
+ total = 0
+ for i, next_batch in enumerate(reader):
+ assert next_batch.equals(batches[i])
+ total += 1
+
+ assert total == len(batches)
+
+ with pytest.raises(StopIteration):
+ reader.get_next_batch()
+
+
+def test_stream_read_all(stream_fixture):
+ _, batches = stream_fixture.write_batches()
+ file_contents = pa.BufferReader(stream_fixture.get_source())
+ reader = pa.open_stream(file_contents)
+
+ result = reader.read_all()
+ expected = pa.Table.from_batches(batches)
+ assert result.equals(expected)
+
+
+def test_stream_read_pandas(stream_fixture):
+ frames, _ = stream_fixture.write_batches()
+ file_contents = stream_fixture.get_source()
+ reader = pa.open_stream(file_contents)
+ result = reader.read_pandas()
+
+ expected = pd.concat(frames)
+ assert_frame_equal(result, expected)
+
+
+@pytest.fixture
+def example_messages(stream_fixture):
+ _, batches = stream_fixture.write_batches()
+ file_contents = stream_fixture.get_source()
+ buf_reader = pa.BufferReader(file_contents)
+ reader = pa.MessageReader.open_stream(buf_reader)
+ return batches, list(reader)
+
+
+def test_message_ctors_no_segfault():
+ with pytest.raises(TypeError):
+ repr(pa.Message())
+
+ with pytest.raises(TypeError):
+ repr(pa.MessageReader())
+
+
+def test_message_reader(example_messages):
+ _, messages = example_messages
+
+ assert len(messages) == 6
+ assert messages[0].type == 'schema'
+ for msg in messages[1:]:
+ assert msg.type == 'record batch'
+
+
+def test_message_serialize_read_message(example_messages):
+ _, messages = example_messages
+
+ msg = messages[0]
+ buf = msg.serialize()
+
+ restored = pa.read_message(buf)
+ restored2 = pa.read_message(pa.BufferReader(buf))
+ restored3 = pa.read_message(buf.to_pybytes())
+
+ assert msg.equals(restored)
+ assert msg.equals(restored2)
+ assert msg.equals(restored3)
+
+
+def test_message_read_record_batch(example_messages):
+ batches, messages = example_messages
+
+ for batch, message in zip(batches, messages[1:]):
+ read_batch = pa.read_record_batch(message, batch.schema)
+ assert read_batch.equals(batch)
+
+
+# ----------------------------------------------------------------------
+# Socket streaming testa
+
+
+class StreamReaderServer(threading.Thread):
+
+ def init(self, do_read_all):
+ self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ self._sock.bind(('127.0.0.1', 0))
+ self._sock.listen(1)
+ host, port = self._sock.getsockname()
+ self._do_read_all = do_read_all
+ self._schema = None
+ self._batches = []
+ self._table = None
+ return port
+
+ def run(self):
+ connection, client_address = self._sock.accept()
+ try:
+ source = connection.makefile(mode='rb')
+ reader = pa.open_stream(source)
+ self._schema = reader.schema
+ if self._do_read_all:
+ self._table = reader.read_all()
+ else:
+ for i, batch in enumerate(reader):
+ self._batches.append(batch)
+ finally:
+ connection.close()
+
+ def get_result(self):
+ return(self._schema, self._table if self._do_read_all
+ else self._batches)
+
+
+class SocketStreamFixture(IpcFixture):
+
+ def __init__(self):
+ # XXX(wesm): test will decide when to start socket server. This should
+ # probably be refactored
pass
def start_server(self, do_read_all):
- self._server = TestSocket.StreamReaderServer()
+ self._server = StreamReaderServer()
port = self._server.init(do_read_all)
self._server.start()
self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._sock.connect(('127.0.0.1', port))
- self.sink = self._get_sink()
+ self.sink = self.get_sink()
def stop_and_get_result(self):
import struct
@@ -329,38 +426,40 @@ class TestSocket(MessagingTest, unittest.TestCase):
self._server.join()
return self._server.get_result()
- def _get_sink(self):
+ def get_sink(self):
return self._sock.makefile(mode='wb')
def _get_writer(self, sink, schema):
return pa.RecordBatchStreamWriter(sink, schema)
- def test_simple_roundtrip(self):
- self.start_server(do_read_all=False)
- _, writer_batches = self.write_batches()
- reader_schema, reader_batches = self.stop_and_get_result()
- assert reader_schema.equals(writer_batches[0].schema)
- assert len(reader_batches) == len(writer_batches)
- for i, batch in enumerate(writer_batches):
- assert reader_batches[i].equals(batch)
+@pytest.fixture
+def socket_fixture():
+ return SocketStreamFixture()
+
+
+def test_socket_simple_roundtrip(socket_fixture):
+ socket_fixture.start_server(do_read_all=False)
+ _, writer_batches = socket_fixture.write_batches()
+ reader_schema, reader_batches = socket_fixture.stop_and_get_result()
- def test_read_all(self):
- self.start_server(do_read_all=True)
- _, writer_batches = self.write_batches()
- _, result = self.stop_and_get_result()
+ assert reader_schema.equals(writer_batches[0].schema)
+ assert len(reader_batches) == len(writer_batches)
+ for i, batch in enumerate(writer_batches):
+ assert reader_batches[i].equals(batch)
- expected = pa.Table.from_batches(writer_batches)
- assert result.equals(expected)
+def test_socket_read_all(socket_fixture):
+ socket_fixture.start_server(do_read_all=True)
+ _, writer_batches = socket_fixture.write_batches()
+ _, result = socket_fixture.stop_and_get_result()
-class TestInMemoryFile(TestFile):
+ expected = pa.Table.from_batches(writer_batches)
+ assert result.equals(expected)
- def _get_sink(self):
- return pa.BufferOutputStream()
- def _get_source(self):
- return self.sink.get_result()
+# ----------------------------------------------------------------------
+# Miscellaneous IPC tests
def test_ipc_zero_copy_numpy():
@@ -369,7 +468,7 @@ def test_ipc_zero_copy_numpy():
batch = pa.RecordBatch.from_pandas(df)
sink = pa.BufferOutputStream()
write_file(batch, sink)
- buffer = sink.get_result()
+ buffer = sink.getvalue()
reader = pa.BufferReader(buffer)
batches = read_file(reader)
@@ -389,7 +488,7 @@ def test_ipc_stream_no_batches():
writer = pa.RecordBatchStreamWriter(sink, table.schema)
writer.close()
- source = sink.get_result()
+ source = sink.getvalue()
reader = pa.open_stream(source)
result = reader.read_all()
diff --git a/python/pyarrow/tests/test_parquet.py b/python/pyarrow/tests/test_parquet.py
index d7473e9..1d30737 100644
--- a/python/pyarrow/tests/test_parquet.py
+++ b/python/pyarrow/tests/test_parquet.py
@@ -396,7 +396,7 @@ def test_pandas_parquet_native_file_roundtrip(tmpdir):
arrow_table = pa.Table.from_pandas(df)
imos = pa.BufferOutputStream()
_write_table(arrow_table, imos, version="2.0")
- buf = imos.get_result()
+ buf = imos.getvalue()
reader = pa.BufferReader(buf)
df_read = _read_table(reader).to_pandas()
tm.assert_frame_equal(df, df_read)
@@ -424,7 +424,7 @@ def test_parquet_incremental_file_build(tmpdir):
writer.close()
- buf = out.get_result()
+ buf = out.getvalue()
result = _read_table(pa.BufferReader(buf))
expected = pd.concat(frames, ignore_index=True)
@@ -439,7 +439,7 @@ def test_read_pandas_column_subset(tmpdir):
arrow_table = pa.Table.from_pandas(df)
imos = pa.BufferOutputStream()
_write_table(arrow_table, imos, version="2.0")
- buf = imos.get_result()
+ buf = imos.getvalue()
reader = pa.BufferReader(buf)
df_read = pq.read_pandas(reader, columns=['strings', 'uint8']).to_pandas()
tm.assert_frame_equal(df[['strings', 'uint8']], df_read)
@@ -451,7 +451,7 @@ def test_pandas_parquet_empty_roundtrip(tmpdir):
arrow_table = pa.Table.from_pandas(df)
imos = pa.BufferOutputStream()
_write_table(arrow_table, imos, version="2.0")
- buf = imos.get_result()
+ buf = imos.getvalue()
reader = pa.BufferReader(buf)
df_read = _read_table(reader).to_pandas()
tm.assert_frame_equal(df, df_read)
@@ -2108,7 +2108,7 @@ def test_parquet_writer_context_obj(tmpdir):
frames.append(df.copy())
- buf = out.get_result()
+ buf = out.getvalue()
result = _read_table(pa.BufferReader(buf))
expected = pd.concat(frames, ignore_index=True)
@@ -2143,7 +2143,7 @@ def test_parquet_writer_context_obj_with_exception(tmpdir):
except Exception as e:
assert str(e) == error_text
- buf = out.get_result()
+ buf = out.getvalue()
result = _read_table(pa.BufferReader(buf))
expected = pd.concat(frames, ignore_index=True)