You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by we...@apache.org on 2017/04/21 01:42:29 UTC

arrow git commit: ARROW-822: [Python] StreamWriter Wrapper for Socket and File-like Objects without tell()

Repository: arrow
Updated Branches:
  refs/heads/master 7c1fef51c -> 6c352e205


ARROW-822: [Python] StreamWriter Wrapper for Socket and File-like Objects without tell()

Added a wrapper for StreamWriter to implement the required tell() method so that python sockets and file-like objects can be used as sinks.  The tell() method will report the position by starting at 0 when the StreamWriter is created and incrementing by number of bytes after each write.

Added unittests that use local socket as the source/sink for streaming.

Author: Bryan Cutler <cu...@gmail.com>

Closes #569 from BryanCutler/pyarrow-stream-writer-socket-ARROW-822 and squashes the following commits:

6cdec4f [Bryan Cutler] Removed StreamWriter wrapper and put position handling in PyStreamWriter instead
2bd669f [Bryan Cutler] Added StreamSinkWrapper to ensure stream sink has tell() method, added unittest for StreamWriter and StreamReader over local socket


Project: http://git-wip-us.apache.org/repos/asf/arrow/repo
Commit: http://git-wip-us.apache.org/repos/asf/arrow/commit/6c352e20
Tree: http://git-wip-us.apache.org/repos/asf/arrow/tree/6c352e20
Diff: http://git-wip-us.apache.org/repos/asf/arrow/diff/6c352e20

Branch: refs/heads/master
Commit: 6c352e2057d5f9a442c1ebf0d35c716f475fd343
Parents: 7c1fef5
Author: Bryan Cutler <cu...@gmail.com>
Authored: Thu Apr 20 21:42:20 2017 -0400
Committer: Wes McKinney <we...@twosigma.com>
Committed: Thu Apr 20 21:42:20 2017 -0400

----------------------------------------------------------------------
 cpp/src/arrow/python/io.cc       |  7 ++--
 cpp/src/arrow/python/io.h        |  1 +
 python/pyarrow/tests/test_ipc.py | 79 +++++++++++++++++++++++++++++++++++
 3 files changed, 84 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/arrow/blob/6c352e20/cpp/src/arrow/python/io.cc
----------------------------------------------------------------------
diff --git a/cpp/src/arrow/python/io.cc b/cpp/src/arrow/python/io.cc
index ba82a45..327e8fe 100644
--- a/cpp/src/arrow/python/io.cc
+++ b/cpp/src/arrow/python/io.cc
@@ -189,7 +189,7 @@ bool PyReadableFile::supports_zero_copy() const {
 // ----------------------------------------------------------------------
 // Output stream
 
-PyOutputStream::PyOutputStream(PyObject* file) {
+PyOutputStream::PyOutputStream(PyObject* file) : position_(0) {
   file_.reset(new PythonFile(file));
 }
 
@@ -201,12 +201,13 @@ Status PyOutputStream::Close() {
 }
 
 Status PyOutputStream::Tell(int64_t* position) {
-  PyAcquireGIL lock;
-  return file_->Tell(position);
+  *position = position_;
+  return Status::OK();
 }
 
 Status PyOutputStream::Write(const uint8_t* data, int64_t nbytes) {
   PyAcquireGIL lock;
+  position_ += nbytes;
   return file_->Write(data, nbytes);
 }
 

http://git-wip-us.apache.org/repos/asf/arrow/blob/6c352e20/cpp/src/arrow/python/io.h
----------------------------------------------------------------------
diff --git a/cpp/src/arrow/python/io.h b/cpp/src/arrow/python/io.h
index bf14cd6..ebd4c5a 100644
--- a/cpp/src/arrow/python/io.h
+++ b/cpp/src/arrow/python/io.h
@@ -82,6 +82,7 @@ class ARROW_EXPORT PyOutputStream : public io::OutputStream {
 
  private:
   std::unique_ptr<PythonFile> file_;
+  int64_t position_;
 };
 
 // A zero-copy reader backed by a PyBuffer object

http://git-wip-us.apache.org/repos/asf/arrow/blob/6c352e20/python/pyarrow/tests/test_ipc.py
----------------------------------------------------------------------
diff --git a/python/pyarrow/tests/test_ipc.py b/python/pyarrow/tests/test_ipc.py
index 31d418d..81213ed 100644
--- a/python/pyarrow/tests/test_ipc.py
+++ b/python/pyarrow/tests/test_ipc.py
@@ -17,6 +17,8 @@
 
 import io
 import pytest
+import socket
+import threading
 
 import numpy as np
 
@@ -126,6 +128,83 @@ class TestStream(MessagingTest, unittest.TestCase):
         assert result.equals(expected)
 
 
+class TestSocket(MessagingTest, unittest.TestCase):
+
+    class StreamReaderServer(threading.Thread):
+
+        def init(self, do_read_all):
+            self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+            self._sock.bind(('127.0.0.1', 0))
+            self._sock.listen(1)
+            host, port = self._sock.getsockname()
+            self._do_read_all = do_read_all
+            self._schema = None
+            self._batches = []
+            self._table = None
+            return port
+
+        def run(self):
+            connection, client_address = self._sock.accept()
+            try:
+                source = connection.makefile(mode='rb')
+                reader = pa.StreamReader(source)
+                self._schema = reader.schema
+                if self._do_read_all:
+                    self._table = reader.read_all()
+                else:
+                    for i, batch in enumerate(reader):
+                        self._batches.append(batch)
+            finally:
+                connection.close()
+
+        def get_result(self):
+            return(self._schema, self._table if self._do_read_all else self._batches)
+
+    def setUp(self):
+        # NOTE: must start and stop server in test
+        pass
+
+    def start_server(self, do_read_all):
+        self._server = TestSocket.StreamReaderServer()
+        port = self._server.init(do_read_all)
+        self._server.start()
+        self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        self._sock.connect(('127.0.0.1', port))
+        self.sink = self._get_sink()
+
+    def stop_and_get_result(self):
+        import struct
+        self.sink.write(struct.pack('i', 0))
+        self.sink.flush()
+        self._sock.close()
+        self._server.join()
+        return self._server.get_result()
+
+    def _get_sink(self):
+        return self._sock.makefile(mode='wb')
+
+    def _get_writer(self, sink, schema):
+        return pa.StreamWriter(sink, schema)
+
+    def test_simple_roundtrip(self):
+        self.start_server(do_read_all=False)
+        writer_batches = self.write_batches()
+        reader_schema, reader_batches = self.stop_and_get_result()
+
+        assert reader_schema.equals(writer_batches[0].schema)
+        assert len(reader_batches) == len(writer_batches)
+        for i, batch in enumerate(writer_batches):
+            assert reader_batches[i].equals(batch)
+
+    def test_read_all(self):
+        self.start_server(do_read_all=True)
+        writer_batches = self.write_batches()
+        _, result = self.stop_and_get_result()
+
+        expected = pa.Table.from_batches(writer_batches)
+        assert result.equals(expected)
+
+
 class TestInMemoryFile(TestFile):
 
     def _get_sink(self):