You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by cu...@apache.org on 2018/05/31 22:45:41 UTC

[arrow] branch master updated: ARROW-2645: [Java] Refactor ArrowWriter to remove all ArrowFileWriter specifc logic

This is an automated email from the ASF dual-hosted git repository.

cutlerb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 0894d97  ARROW-2645: [Java] Refactor ArrowWriter to remove all ArrowFileWriter specifc logic
0894d97 is described below

commit 0894d97244951696ce880dfc3affdbae7a6c035c
Author: Bryan Cutler <cu...@gmail.com>
AuthorDate: Thu May 31 15:45:25 2018 -0700

    ARROW-2645: [Java] Refactor ArrowWriter to remove all ArrowFileWriter specifc logic
    
    Related to #2079 , the DictionaryBatch `ArrowBlock`s were being accumulated in the base class and used by `ArrowFileWriter` but not `ArrowStreamWriter`.  This refactors the `ArrowWriter` to move Lists of ArrowBlocks from the base class to only `ArrowFileWriter`.
    
    Moved tests counting ArrowBlocks written from ArrowStreamWriter tests to ArrowFileWriter.
    
    Author: Bryan Cutler <cu...@gmail.com>
    
    Closes #2090 from BryanCutler/java-ArrowStreamWriter-accum-DictionaryBlocks-ARROW-2645 and squashes the following commits:
    
    fc7f061 <Bryan Cutler> added comment about saving ArrowBlocks
    5e4711a <Bryan Cutler> Moved lists of ArrowBlocks in ArrowWriter base class to ArrowFileWriter where they are used
---
 .../org/apache/arrow/tools/EchoServerTest.java     |  1 -
 .../apache/arrow/vector/ipc/ArrowFileWriter.java   | 41 ++++++++++++++++++----
 .../apache/arrow/vector/ipc/ArrowStreamWriter.java | 19 +---------
 .../org/apache/arrow/vector/ipc/ArrowWriter.java   | 36 ++++++++-----------
 .../org/apache/arrow/vector/ipc/TestArrowFile.java |  6 ++++
 .../apache/arrow/vector/ipc/TestArrowStream.java   |  4 ---
 6 files changed, 56 insertions(+), 51 deletions(-)

diff --git a/java/tools/src/test/java/org/apache/arrow/tools/EchoServerTest.java b/java/tools/src/test/java/org/apache/arrow/tools/EchoServerTest.java
index 2674c7b..47b5541 100644
--- a/java/tools/src/test/java/org/apache/arrow/tools/EchoServerTest.java
+++ b/java/tools/src/test/java/org/apache/arrow/tools/EchoServerTest.java
@@ -112,7 +112,6 @@ public class EchoServerTest {
         writer.writeBatch();
       }
       writer.end();
-      assertTrue(writer.getRecordBlocks().isEmpty());
 
       assertEquals(new Schema(asList(field)), reader.getVectorSchemaRoot().getSchema());
 
diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowFileWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowFileWriter.java
index 1b687c9..832608a 100644
--- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowFileWriter.java
+++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowFileWriter.java
@@ -20,13 +20,17 @@ package org.apache.arrow.vector.ipc;
 
 import java.io.IOException;
 import java.nio.channels.WritableByteChannel;
+import java.util.ArrayList;
 import java.util.List;
 
+import com.google.common.annotations.VisibleForTesting;
+
 import org.apache.arrow.vector.VectorSchemaRoot;
 import org.apache.arrow.vector.dictionary.DictionaryProvider;
 import org.apache.arrow.vector.ipc.message.ArrowBlock;
+import org.apache.arrow.vector.ipc.message.ArrowDictionaryBatch;
 import org.apache.arrow.vector.ipc.message.ArrowFooter;
-import org.apache.arrow.vector.types.pojo.Schema;
+import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -34,6 +38,10 @@ public class ArrowFileWriter extends ArrowWriter {
 
   private static final Logger LOGGER = LoggerFactory.getLogger(ArrowFileWriter.class);
 
+  // All ArrowBlocks written are saved in these lists to be passed to ArrowFooter in endInternal.
+  private final List<ArrowBlock> dictionaryBlocks = new ArrayList<>();
+  private final List<ArrowBlock> recordBlocks = new ArrayList<>();
+
   public ArrowFileWriter(VectorSchemaRoot root, DictionaryProvider provider, WritableByteChannel out) {
     super(root, provider, out);
   }
@@ -44,12 +52,23 @@ public class ArrowFileWriter extends ArrowWriter {
   }
 
   @Override
-  protected void endInternal(WriteChannel out,
-                             Schema schema,
-                             List<ArrowBlock> dictionaries,
-                             List<ArrowBlock> records) throws IOException {
+  protected ArrowBlock writeDictionaryBatch(ArrowDictionaryBatch batch) throws IOException {
+    ArrowBlock block = super.writeDictionaryBatch(batch);
+    dictionaryBlocks.add(block);
+    return block;
+  }
+
+  @Override
+  protected ArrowBlock writeRecordBatch(ArrowRecordBatch batch) throws IOException {
+    ArrowBlock block = super.writeRecordBatch(batch);
+    recordBlocks.add(block);
+    return block;
+  }
+
+  @Override
+  protected void endInternal(WriteChannel out) throws IOException {
     long footerStart = out.getCurrentPosition();
-    out.write(new ArrowFooter(schema, dictionaries, records), false);
+    out.write(new ArrowFooter(schema, dictionaryBlocks, recordBlocks), false);
     int footerLength = (int) (out.getCurrentPosition() - footerStart);
     if (footerLength <= 0) {
       throw new InvalidArrowFileException("invalid footer");
@@ -59,4 +78,14 @@ public class ArrowFileWriter extends ArrowWriter {
     ArrowMagic.writeMagic(out, false);
     LOGGER.debug(String.format("magic written, now at %d", out.getCurrentPosition()));
   }
+
+  @VisibleForTesting
+  public List<ArrowBlock> getRecordBlocks() {
+    return recordBlocks;
+  }
+
+  @VisibleForTesting
+  public List<ArrowBlock> getDictionaryBlocks() {
+    return dictionaryBlocks;
+  }
 }
diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowStreamWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowStreamWriter.java
index 14e6add..784ce08 100644
--- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowStreamWriter.java
+++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowStreamWriter.java
@@ -20,18 +20,11 @@ package org.apache.arrow.vector.ipc;
 
 import org.apache.arrow.vector.VectorSchemaRoot;
 import org.apache.arrow.vector.dictionary.DictionaryProvider;
-import org.apache.arrow.vector.ipc.message.ArrowBlock;
-import org.apache.arrow.vector.ipc.ArrowWriter;
-import org.apache.arrow.vector.ipc.WriteChannel;
-import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
-import org.apache.arrow.vector.ipc.message.MessageSerializer;
-import org.apache.arrow.vector.types.pojo.Schema;
 
 import java.io.IOException;
 import java.io.OutputStream;
 import java.nio.channels.Channels;
 import java.nio.channels.WritableByteChannel;
-import java.util.List;
 
 public class ArrowStreamWriter extends ArrowWriter {
 
@@ -48,17 +41,7 @@ public class ArrowStreamWriter extends ArrowWriter {
   }
 
   @Override
-  protected void endInternal(WriteChannel out,
-                             Schema schema,
-                             List<ArrowBlock> dictionaries,
-                             List<ArrowBlock> records) throws IOException {
+  protected void endInternal(WriteChannel out) throws IOException {
     out.writeIntLittleEndian(0);
   }
-
-  @Override
-  protected void writeRecordBatch(ArrowRecordBatch batch) throws IOException {
-    ArrowBlock block = MessageSerializer.serialize(out, batch);
-    LOGGER.debug(String.format("RecordBatch at %d, metadata: %d, body: %d",
-        block.getOffset(), block.getMetadataLength(), block.getBodyLength()));
-  }
 }
diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowWriter.java
index d9d0534..8bc6402 100644
--- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowWriter.java
+++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowWriter.java
@@ -25,7 +25,6 @@ import java.util.HashSet;
 import java.util.List;
 import java.util.Set;
 
-import com.google.common.annotations.VisibleForTesting;
 import org.apache.arrow.vector.FieldVector;
 import org.apache.arrow.vector.VectorSchemaRoot;
 import org.apache.arrow.vector.VectorUnloader;
@@ -48,15 +47,12 @@ public abstract class ArrowWriter implements AutoCloseable {
   protected static final Logger LOGGER = LoggerFactory.getLogger(ArrowWriter.class);
 
   // schema with fields in message format, not memory format
-  private final Schema schema;
+  protected final Schema schema;
   protected final WriteChannel out;
 
   private final VectorUnloader unloader;
   private final List<ArrowDictionaryBatch> dictionaries;
 
-  private final List<ArrowBlock> dictionaryBlocks = new ArrayList<>();
-  private final List<ArrowBlock> recordBlocks = new ArrayList<>();
-
   private boolean started = false;
   private boolean ended = false;
 
@@ -105,11 +101,18 @@ public abstract class ArrowWriter implements AutoCloseable {
     }
   }
 
-  protected void writeRecordBatch(ArrowRecordBatch batch) throws IOException {
+  protected ArrowBlock writeDictionaryBatch(ArrowDictionaryBatch batch) throws IOException {
+    ArrowBlock block = MessageSerializer.serialize(out, batch);
+    LOGGER.debug(String.format("DictionaryRecordBatch at %d, metadata: %d, body: %d",
+      block.getOffset(), block.getMetadataLength(), block.getBodyLength()));
+    return block;
+  }
+
+  protected ArrowBlock writeRecordBatch(ArrowRecordBatch batch) throws IOException {
     ArrowBlock block = MessageSerializer.serialize(out, batch);
     LOGGER.debug(String.format("RecordBatch at %d, metadata: %d, body: %d",
-        block.getOffset(), block.getMetadataLength(), block.getBodyLength()));
-    recordBlocks.add(block);
+      block.getOffset(), block.getMetadataLength(), block.getBodyLength()));
+    return block;
   }
 
   public void end() throws IOException {
@@ -131,10 +134,7 @@ public abstract class ArrowWriter implements AutoCloseable {
       // write out any dictionaries
       for (ArrowDictionaryBatch batch : dictionaries) {
         try {
-          ArrowBlock block = MessageSerializer.serialize(out, batch);
-          LOGGER.debug(String.format("DictionaryRecordBatch at %d, metadata: %d, body: %d",
-              block.getOffset(), block.getMetadataLength(), block.getBodyLength()));
-          dictionaryBlocks.add(block);
+          writeDictionaryBatch(batch);
         } finally {
           batch.close();
         }
@@ -145,16 +145,13 @@ public abstract class ArrowWriter implements AutoCloseable {
   private void ensureEnded() throws IOException {
     if (!ended) {
       ended = true;
-      endInternal(out, schema, dictionaryBlocks, recordBlocks);
+      endInternal(out);
     }
   }
 
   protected abstract void startInternal(WriteChannel out) throws IOException;
 
-  protected abstract void endInternal(WriteChannel out,
-                                      Schema schema,
-                                      List<ArrowBlock> dictionaries,
-                                      List<ArrowBlock> records) throws IOException;
+  protected abstract void endInternal(WriteChannel out) throws IOException;
 
   @Override
   public void close() {
@@ -165,9 +162,4 @@ public abstract class ArrowWriter implements AutoCloseable {
       throw new RuntimeException(e);
     }
   }
-
-  @VisibleForTesting
-  public List<ArrowBlock> getRecordBlocks() {
-    return recordBlocks;
-  }
 }
diff --git a/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowFile.java b/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowFile.java
index 6ddd14b..fcf738f 100644
--- a/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowFile.java
+++ b/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowFile.java
@@ -445,6 +445,7 @@ public class TestArrowFile extends BaseFileTest {
   public void testWriteReadDictionary() throws IOException {
     File file = new File("target/mytest_dict.arrow");
     ByteArrayOutputStream stream = new ByteArrayOutputStream();
+    int numDictionaryBlocksWritten = 0;
 
     // write
     try (BufferAllocator originalVectorAllocator = allocator.newChildAllocator("original vectors", 0, Integer.MAX_VALUE)) {
@@ -462,6 +463,7 @@ public class TestArrowFile extends BaseFileTest {
         streamWriter.writeBatch();
         fileWriter.end();
         streamWriter.end();
+        numDictionaryBlocksWritten = fileWriter.getDictionaryBlocks().size();
       }
 
       // Need to close dictionary vectors
@@ -479,6 +481,7 @@ public class TestArrowFile extends BaseFileTest {
       LOGGER.debug("reading schema: " + schema);
       Assert.assertTrue(arrowReader.loadNextBatch());
       validateFlatDictionary(root, arrowReader);
+      Assert.assertEquals(numDictionaryBlocksWritten, arrowReader.getDictionaryBlocks().size());
     }
 
     // Read from stream
@@ -712,6 +715,7 @@ public class TestArrowFile extends BaseFileTest {
   @Test
   public void testReadWriteMultipleBatches() throws IOException {
     File file = new File("target/mytest_nulls_multibatch.arrow");
+    int numBlocksWritten = 0;
 
     try (IntVector vector = new IntVector("foo", allocator);) {
       Schema schema = new Schema(Collections.singletonList(vector.getField()), null);
@@ -719,6 +723,7 @@ public class TestArrowFile extends BaseFileTest {
            VectorSchemaRoot root = new VectorSchemaRoot(schema, Collections.singletonList((FieldVector) vector), vector.getValueCount());
            ArrowFileWriter writer = new ArrowFileWriter(root, null, fileOutputStream.getChannel());) {
         writeBatchData(writer, vector, root);
+        numBlocksWritten = writer.getRecordBlocks().size();
       }
     }
 
@@ -726,6 +731,7 @@ public class TestArrowFile extends BaseFileTest {
          ArrowFileReader reader = new ArrowFileReader(fileInputStream.getChannel(), allocator);) {
       IntVector vector = (IntVector) reader.getVectorSchemaRoot().getFieldVectors().get(0);
       validateBatchData(reader, vector);
+      Assert.assertEquals(numBlocksWritten, reader.getRecordBlocks().size());
     }
   }
 
diff --git a/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowStream.java b/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowStream.java
index bed4e63..431ebf0 100644
--- a/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowStream.java
+++ b/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowStream.java
@@ -82,8 +82,6 @@ public class TestArrowStream extends BaseFileTest {
         }
         writer.end();
         bytesWritten = writer.bytesWritten();
-
-        assertTrue(writer.getRecordBlocks().isEmpty());
       }
 
       ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray());
@@ -110,8 +108,6 @@ public class TestArrowStream extends BaseFileTest {
       try (VectorSchemaRoot root = new VectorSchemaRoot(schema, Collections.singletonList(vector), vector.getValueCount());
            ArrowStreamWriter writer = new ArrowStreamWriter(root, null, Channels.newChannel(os));) {
         writeBatchData(writer, vector, root);
-
-        assertTrue(writer.getRecordBlocks().isEmpty());
       }
     }
 

-- 
To stop receiving notification emails like this one, please contact
cutlerb@apache.org.