You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ko...@apache.org on 2019/04/04 02:35:21 UTC

[arrow] branch master updated: ARROW-5019: [C#] ArrowStreamWriter doesn't work on a non-seekable stream

This is an automated email from the ASF dual-hosted git repository.

kou pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 384a3b0  ARROW-5019: [C#] ArrowStreamWriter doesn't work on a non-seekable stream
384a3b0 is described below

commit 384a3b070bab958ae0a7d0a45f148d2ef1f86da6
Author: Eric Erhardt <er...@microsoft.com>
AuthorDate: Thu Apr 4 11:35:09 2019 +0900

    ARROW-5019: [C#] ArrowStreamWriter doesn't work on a non-seekable stream
    
    Allow ArrowStreamWriter to write to a non-seekable stream, like a network stream.
    
    @chutchinson @stephentoub @pgovind
    
    Author: Eric Erhardt <er...@microsoft.com>
    
    Closes #4052 from eerhardt/WriteToNetworkStream and squashes the following commits:
    
    e7125bf8 <Eric Erhardt> PR feedback
    7333b8cd <Eric Erhardt> ArrowStreamWriter doesn't work on a non-seekable stream
---
 csharp/src/Apache.Arrow/ArrowBuffer.Builder.cs     |  7 +-
 csharp/src/Apache.Arrow/BitUtility.cs              |  4 +-
 csharp/src/Apache.Arrow/Ipc/ArrowFileWriter.cs     | 43 +++++++++--
 csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs   | 86 ++++++++++------------
 csharp/src/Apache.Arrow/Ipc/Block.cs               | 20 ++---
 .../Apache.Arrow.Tests/ArrowStreamWriterTests.cs   | 57 ++++++++++++++
 6 files changed, 153 insertions(+), 64 deletions(-)

diff --git a/csharp/src/Apache.Arrow/ArrowBuffer.Builder.cs b/csharp/src/Apache.Arrow/ArrowBuffer.Builder.cs
index 7ab26fa..36c2bc2 100644
--- a/csharp/src/Apache.Arrow/ArrowBuffer.Builder.cs
+++ b/csharp/src/Apache.Arrow/ArrowBuffer.Builder.cs
@@ -111,7 +111,12 @@ namespace Apache.Arrow
 
             public ArrowBuffer Build(MemoryPool pool = default)
             {
-                var length = BitUtility.RoundUpToMultipleOf64(_buffer.Length);
+                int length;
+                checked
+                {
+                    length = (int)BitUtility.RoundUpToMultipleOf64(_buffer.Length);
+                }
+
                 var memoryPool = pool ?? MemoryPool.Default.Value;
                 var memory = memoryPool.Allocate(length);
 
diff --git a/csharp/src/Apache.Arrow/BitUtility.cs b/csharp/src/Apache.Arrow/BitUtility.cs
index fccdfe0..efb426c 100644
--- a/csharp/src/Apache.Arrow/BitUtility.cs
+++ b/csharp/src/Apache.Arrow/BitUtility.cs
@@ -99,7 +99,7 @@ namespace Apache.Arrow
         /// </summary>
         /// <param name="n">Integer to round.</param>
         /// <returns>Integer rounded to the nearest multiple of 64.</returns>
-        public static int RoundUpToMultipleOf64(int n) =>
+        public static long RoundUpToMultipleOf64(long n) =>
             RoundUpToMultiplePowerOfTwo(n, 64);
 
         /// <summary>
@@ -111,7 +111,7 @@ namespace Apache.Arrow
         /// <param name="n">Integer to round up.</param>
         /// <param name="factor">Power of two factor to round up to.</param>
         /// <returns>Integer rounded up to the nearest power of two.</returns>
-        public static int RoundUpToMultiplePowerOfTwo(int n, int factor)
+        public static long RoundUpToMultiplePowerOfTwo(long n, int factor)
         {
             // Assert that factor is a power of two.
             Debug.Assert(factor > 0 && (factor & (factor - 1)) == 0);
diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowFileWriter.cs b/csharp/src/Apache.Arrow/Ipc/ArrowFileWriter.cs
index b74bcc4..d6e124a 100644
--- a/csharp/src/Apache.Arrow/Ipc/ArrowFileWriter.cs
+++ b/csharp/src/Apache.Arrow/Ipc/ArrowFileWriter.cs
@@ -16,6 +16,7 @@
 using System;
 using System.Buffers.Binary;
 using System.Collections.Generic;
+using System.Diagnostics;
 using System.IO;
 using System.Threading;
 using System.Threading.Tasks;
@@ -23,7 +24,9 @@ using System.Threading.Tasks;
 namespace Apache.Arrow.Ipc
 {
     public class ArrowFileWriter: ArrowStreamWriter
-    { 
+    {
+        private long _currentRecordBatchOffset = -1;
+
         private bool HasWrittenHeader { get; set; }
         private bool HasWrittenFooter { get; set; }
 
@@ -67,10 +70,35 @@ namespace Apache.Arrow.Ipc
 
             cancellationToken.ThrowIfCancellationRequested();
 
-            var block = await WriteRecordBatchInternalAsync(recordBatch, cancellationToken)
+            await WriteRecordBatchInternalAsync(recordBatch, cancellationToken)
                 .ConfigureAwait(false);
+        }
+
+        private protected override void StartingWritingRecordBatch()
+        {
+            _currentRecordBatchOffset = BaseStream.Position;
+        }
+
+        private protected override void FinishedWritingRecordBatch(long bodyLength, long metadataLength)
+        {
+            // Record batches only appear after a Schema is written, so the record batch offsets must
+            // always be greater than 0.
+            Debug.Assert(_currentRecordBatchOffset > 0, "_currentRecordBatchOffset must be positive.");
+
+            int metadataLengthInt;
+            checked
+            {
+                metadataLengthInt = (int)metadataLength;
+            }
+
+            var block = new Block(
+                offset: _currentRecordBatchOffset,
+                length: bodyLength,
+                metadataLength: metadataLengthInt);
 
             RecordBatchBlocks.Add(block);
+
+            _currentRecordBatchOffset = -1;
         }
 
         public async Task WriteFooterAsync(CancellationToken cancellationToken = default)
@@ -112,7 +140,7 @@ namespace Apache.Arrow.Ipc
             foreach (var recordBatch in RecordBatchBlocks)
             {
                 Flatbuf.Block.CreateBlock(
-                    Builder, recordBatch.Offset, recordBatch.MetadataLength, recordBatch.Length);
+                    Builder, recordBatch.Offset, recordBatch.MetadataLength, recordBatch.BodyLength);
             }
 
             var recordBatchesVectorOffset = Builder.EndVector();
@@ -141,8 +169,13 @@ namespace Apache.Arrow.Ipc
 
             await Buffers.RentReturnAsync(4, async (buffer) =>
             {
-                BinaryPrimitives.WriteInt32LittleEndian(buffer.Span,
-                    Convert.ToInt32(BaseStream.Position - offset));
+                int footerLength;
+                checked
+                {
+                    footerLength = (int)(BaseStream.Position - offset);
+                }
+
+                BinaryPrimitives.WriteInt32LittleEndian(buffer.Span, footerLength);
 
                 await BaseStream.WriteAsync(buffer, cancellationToken).ConfigureAwait(false);
             }).ConfigureAwait(false);
diff --git a/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs b/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs
index c1a6646..90cbfd3 100644
--- a/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs
+++ b/csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs
@@ -125,20 +125,6 @@ namespace Apache.Arrow.Ipc
             }
         }
 
-        protected struct Block
-        {
-            public readonly int Offset;
-            public readonly int Length;
-            public readonly int MetadataLength;
-
-            public Block(int offset, int length, int metadataLength)
-            {
-                Offset = offset;
-                Length = length;
-                MetadataLength = metadataLength;
-            }
-        }
-
         protected Stream BaseStream { get; }
 
         protected ArrayPool<byte> Buffers { get; }
@@ -174,7 +160,7 @@ namespace Apache.Arrow.Ipc
             _fieldTypeBuilder = new ArrowTypeFlatbufferBuilder(Builder);
         }
 
-        protected virtual async Task<Block> WriteRecordBatchInternalAsync(RecordBatch recordBatch,
+        private protected async Task WriteRecordBatchInternalAsync(RecordBatch recordBatch,
             CancellationToken cancellationToken = default)
         {
             // TODO: Truncate buffers with extraneous padding / unused capacity
@@ -228,63 +214,55 @@ namespace Apache.Arrow.Ipc
 
             // Serialize record batch
 
+            StartingWritingRecordBatch();
+
             var recordBatchOffset = Flatbuf.RecordBatch.CreateRecordBatch(Builder, recordBatch.Length,
                 fieldNodesVectorOffset,
                 buffersVectorOffset);
 
-            var metadataOffset = BaseStream.Position;
-
-            await WriteMessageAsync(Flatbuf.MessageHeader.RecordBatch,
+            long metadataLength = await WriteMessageAsync(Flatbuf.MessageHeader.RecordBatch,
                 recordBatchOffset, recordBatchBuilder.TotalLength,
                 cancellationToken).ConfigureAwait(false);
 
-            var metadataLength = BaseStream.Position - metadataOffset;
-
             // Write buffer data
 
-            var lengthOffset = BaseStream.Position;
+            long bodyLength = 0;
 
             for (var i = 0; i < buffers.Count; i++)
             {
                 if (buffers[i].DataBuffer.IsEmpty)
                     continue;
 
-                
                 await WriteBufferAsync(buffers[i].DataBuffer, cancellationToken).ConfigureAwait(false);
+                bodyLength += buffers[i].DataBuffer.Length;
             }
 
             // Write padding so the record batch message body length is a multiple of 8 bytes
 
-            var bodyLength = Convert.ToInt32(BaseStream.Position - lengthOffset);
-            var bodyPaddingLength = CalculatePadding(bodyLength);
+            int bodyPaddingLength = CalculatePadding(bodyLength);
 
             await WritePaddingAsync(bodyPaddingLength).ConfigureAwait(false);
 
-            return new Block(
-                offset: Convert.ToInt32(metadataOffset),
-                length: bodyLength + bodyPaddingLength, 
-                metadataLength: Convert.ToInt32(metadataLength));
+            FinishedWritingRecordBatch(bodyLength + bodyPaddingLength, metadataLength);
+        }
+
+        private protected virtual void StartingWritingRecordBatch()
+        {
+        }
+
+        private protected virtual void FinishedWritingRecordBatch(long bodyLength, long metadataLength)
+        {
         }
 
         public virtual Task WriteRecordBatchAsync(RecordBatch recordBatch, CancellationToken cancellationToken = default)
         {
             return WriteRecordBatchInternalAsync(recordBatch, cancellationToken);
         }
-    
-        public Task WriteBufferAsync(ArrowBuffer arrowBuffer, CancellationToken cancellationToken = default)
+
+        public async Task WriteBufferAsync(ArrowBuffer arrowBuffer, CancellationToken cancellationToken = default)
         {
-            byte[] buffer = null;
-            try
-            {
-                var span = arrowBuffer.Span;
-                buffer = ArrayPool<byte>.Shared.Rent(span.Length);
-                span.CopyTo(buffer);
-                return BaseStream.WriteAsync(buffer, 0, span.Length, cancellationToken);
-            }
-            finally
-            {
-                ArrayPool<byte>.Shared.Return(buffer);
-            }
+            await BaseStream.WriteAsync(arrowBuffer.Memory, cancellationToken)
+                .ConfigureAwait(false);
         }
 
         private protected Offset<Flatbuf.Schema> SerializeSchema(Schema schema)
@@ -319,7 +297,6 @@ namespace Apache.Arrow.Ipc
                 Builder, endianness, fieldsVectorOffset);
         }
 
-
         private async ValueTask<Offset<Flatbuf.Schema>> WriteSchemaAsync(Schema schema, CancellationToken cancellationToken)
         {
             Builder.Clear();
@@ -336,7 +313,13 @@ namespace Apache.Arrow.Ipc
             return schemaOffset;
         }
 
-        private async ValueTask WriteMessageAsync<T>(
+        /// <summary>
+        /// Writes the message to the <see cref="BaseStream"/>.
+        /// </summary>
+        /// <returns>
+        /// The number of bytes written to the stream.
+        /// </returns>
+        private async ValueTask<long> WriteMessageAsync<T>(
             Flatbuf.MessageHeader headerType, Offset<T> headerOffset, int bodyLength,
             CancellationToken cancellationToken)
             where T: struct
@@ -359,6 +342,11 @@ namespace Apache.Arrow.Ipc
 
             await BaseStream.WriteAsync(messageData, cancellationToken).ConfigureAwait(false);
             await WritePaddingAsync(messagePaddingLength).ConfigureAwait(false);
+
+            checked
+            {
+                return 4 + messageData.Length + messagePaddingLength;
+            }
         }
 
         private protected async ValueTask WriteFlatBufferAsync(CancellationToken cancellationToken = default)
@@ -368,8 +356,14 @@ namespace Apache.Arrow.Ipc
             await BaseStream.WriteAsync(segment, cancellationToken).ConfigureAwait(false);
         }
 
-        protected int CalculatePadding(int offset, int alignment = 8) =>
-            BitUtility.RoundUpToMultiplePowerOfTwo(offset, alignment) - offset;
+        protected int CalculatePadding(long offset, int alignment = 8)
+        {
+            long result = BitUtility.RoundUpToMultiplePowerOfTwo(offset, alignment) - offset;
+            checked
+            {
+                return (int)result;
+            }
+        }
 
         protected Task WritePaddingAsync(int length)
         {
diff --git a/csharp/src/Apache.Arrow/Ipc/Block.cs b/csharp/src/Apache.Arrow/Ipc/Block.cs
index f12e7e5..4aaa3b4 100644
--- a/csharp/src/Apache.Arrow/Ipc/Block.cs
+++ b/csharp/src/Apache.Arrow/Ipc/Block.cs
@@ -17,24 +17,24 @@ using System;
 
 namespace Apache.Arrow.Ipc
 {
-    internal class Block
+    internal readonly struct Block
     {
-        public long Offset { get; }
-        public int MetaDataLength { get; }
-        public long BodyLength { get; }
+        public readonly long Offset;
+        public readonly long BodyLength;
+        public readonly int MetadataLength;
 
-        public Block(long offset, int metadataLength, long bodyLength)
+        public Block(long offset, long length, int metadataLength)
         {
             Offset = offset;
-            MetaDataLength = metadataLength;
-            BodyLength = bodyLength;
+            BodyLength = length;
+            MetadataLength = metadataLength;
         }
 
         public Block(Flatbuf.Block block)
         {
-            Offset = Convert.ToInt32(block.Offset);
-            MetaDataLength = Convert.ToInt32(block.MetaDataLength);
-            BodyLength = Convert.ToInt32(block.BodyLength);
+            Offset = block.Offset;
+            BodyLength = block.BodyLength;
+            MetadataLength = block.MetaDataLength;
         }
     }
 }
diff --git a/csharp/test/Apache.Arrow.Tests/ArrowStreamWriterTests.cs b/csharp/test/Apache.Arrow.Tests/ArrowStreamWriterTests.cs
index b4cfef0..3ef747d 100644
--- a/csharp/test/Apache.Arrow.Tests/ArrowStreamWriterTests.cs
+++ b/csharp/test/Apache.Arrow.Tests/ArrowStreamWriterTests.cs
@@ -16,6 +16,9 @@
 using Apache.Arrow.Ipc;
 using System;
 using System.IO;
+using System.Net;
+using System.Net.Sockets;
+using System.Threading.Tasks;
 using Xunit;
 
 namespace Apache.Arrow.Tests
@@ -48,5 +51,59 @@ namespace Apache.Arrow.Tests
             new ArrowStreamWriter(stream, originalBatch.Schema, leaveOpen: true).Dispose();
             Assert.Equal(0, stream.Position);
         }
+
+        [Fact]
+        public async Task CanWriteToNetworkStream()
+        {
+            RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 100);
+
+            const int port = 32154;
+            TcpListener listener = new TcpListener(IPAddress.Loopback, port);
+            listener.Start();
+
+            using (TcpClient sender = new TcpClient())
+            {
+                sender.Connect(IPAddress.Loopback, port);
+                NetworkStream stream = sender.GetStream();
+
+                using (var writer = new ArrowStreamWriter(stream, originalBatch.Schema))
+                {
+                    await writer.WriteRecordBatchAsync(originalBatch);
+                    stream.Flush();
+                }
+            }
+
+            using (TcpClient receiver = listener.AcceptTcpClient())
+            {
+                NetworkStream stream = receiver.GetStream();
+                using (var reader = new ArrowStreamReader(stream))
+                {
+                    RecordBatch newBatch = reader.ReadNextRecordBatch();
+                    ArrowReaderVerifier.CompareBatches(originalBatch, newBatch);
+                }
+            }
+        }
+
+        [Fact]
+        public async Task WriteEmptyBatch()
+        {
+            RecordBatch originalBatch = TestData.CreateSampleRecordBatch(length: 0);
+
+            using (MemoryStream stream = new MemoryStream())
+            {
+                using (var writer = new ArrowStreamWriter(stream, originalBatch.Schema, leaveOpen: true))
+                {
+                    await writer.WriteRecordBatchAsync(originalBatch);
+                }
+
+                stream.Position = 0;
+
+                using (var reader = new ArrowStreamReader(stream))
+                {
+                    RecordBatch newBatch = reader.ReadNextRecordBatch();
+                    ArrowReaderVerifier.CompareBatches(originalBatch, newBatch);
+                }
+            }
+        }
     }
 }