You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by li...@apache.org on 2023/06/09 15:16:13 UTC

[arrow] branch main updated: GH-18547: [Java] Support re-emitting dictionaries in ArrowStreamWriter (#35920)

This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new 8b2ab4d820 GH-18547: [Java] Support re-emitting dictionaries in ArrowStreamWriter (#35920)
8b2ab4d820 is described below

commit 8b2ab4d8200fdd414b85d531f1cef4b58a3ce351
Author: Adam Reeve <ad...@gmail.com>
AuthorDate: Sat Jun 10 03:16:00 2023 +1200

    GH-18547: [Java] Support re-emitting dictionaries in ArrowStreamWriter (#35920)
    
    ### Rationale for this change
    
    This allows writing IPC streams where dictionary values change between record batches.
    
    ### What changes are included in this PR?
    
    * Add new abstract `void ensureDictionariesWritten(DictionaryProvider provider, Set<Long> dictionaryIdsUsed)` to the base `ArrowWriter` class
    * Move existing logic that only writes dictionaries once into the `ArrowFileWriter` class
    * Implement replacement dictionary writing in `ArrowStreamWriter` by keeping copies of previously written dictionaries
    
    ### Are these changes tested?
    
    Yes, I've added a new unit test for this
    
    ### Are there any user-facing changes?
    
    Yes, `ArrowStreamWriter` will now write replacement dictionaries when dictionary values change between batches.
    
    **This PR includes breaking changes to public APIs.**
    
    `ArrowWriter` has a new abstract `ensureDictionariesWritten` method. This will only affect users directly inheriting from  `ArrowWriter` rather than `ArrowFileWriter` or `ArrowStreamWriter`.
    
    There's also a behaviour change to `ArrowWriter`, where previously dictionaries were read from a `DictionaryProvider` on construction, but this is now delayed until the first batch is written.
    * Closes: #18547
    
    Authored-by: Adam Reeve <ad...@gmail.com>
    Signed-off-by: David Li <li...@gmail.com>
---
 .../java/org/apache/arrow/tools/EchoServer.java    |  1 +
 .../apache/arrow/vector/ipc/ArrowFileWriter.java   | 18 +++++
 .../apache/arrow/vector/ipc/ArrowStreamWriter.java | 49 +++++++++++++
 .../org/apache/arrow/vector/ipc/ArrowWriter.java   | 69 ++++++++----------
 .../arrow/vector/ipc/TestArrowReaderWriter.java    | 83 ++++++++++++++++++++++
 5 files changed, 179 insertions(+), 41 deletions(-)

diff --git a/java/tools/src/main/java/org/apache/arrow/tools/EchoServer.java b/java/tools/src/main/java/org/apache/arrow/tools/EchoServer.java
index 0ddd1e9464..36ba24dbee 100644
--- a/java/tools/src/main/java/org/apache/arrow/tools/EchoServer.java
+++ b/java/tools/src/main/java/org/apache/arrow/tools/EchoServer.java
@@ -135,6 +135,7 @@ public class EchoServer {
         Preconditions.checkState(reader.bytesRead() == writer.bytesWritten());
         LOGGER.debug(String.format("Echoed %d records", echoed));
         reader.close(false);
+        writer.close();
       }
     }
 
diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowFileWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowFileWriter.java
index 0b0931f7bb..71db79087a 100644
--- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowFileWriter.java
+++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowFileWriter.java
@@ -23,11 +23,13 @@ import java.util.ArrayList;
 import java.util.List;
 import java.util.Map;
 import java.util.Optional;
+import java.util.Set;
 
 import org.apache.arrow.util.VisibleForTesting;
 import org.apache.arrow.vector.VectorSchemaRoot;
 import org.apache.arrow.vector.compression.CompressionCodec;
 import org.apache.arrow.vector.compression.CompressionUtil;
+import org.apache.arrow.vector.dictionary.Dictionary;
 import org.apache.arrow.vector.dictionary.DictionaryProvider;
 import org.apache.arrow.vector.ipc.message.ArrowBlock;
 import org.apache.arrow.vector.ipc.message.ArrowDictionaryBatch;
@@ -50,6 +52,7 @@ public class ArrowFileWriter extends ArrowWriter {
   private final List<ArrowBlock> recordBlocks = new ArrayList<>();
 
   private Map<String, String> metaData;
+  private boolean dictionariesWritten = false;
 
   public ArrowFileWriter(VectorSchemaRoot root, DictionaryProvider provider, WritableByteChannel out) {
     super(root, provider, out);
@@ -123,6 +126,21 @@ public class ArrowFileWriter extends ArrowWriter {
     LOGGER.debug("magic written, now at {}", out.getCurrentPosition());
   }
 
+  @Override
+  protected void ensureDictionariesWritten(DictionaryProvider provider, Set<Long> dictionaryIdsUsed)
+      throws IOException {
+    if (dictionariesWritten) {
+      return;
+    }
+    dictionariesWritten = true;
+    // Write out all dictionaries required.
+    // Replacement dictionaries are not supported in the IPC file format.
+    for (long id : dictionaryIdsUsed) {
+      Dictionary dictionary = provider.lookup(id);
+      writeDictionaryBatch(dictionary);
+    }
+  }
+
   @VisibleForTesting
   public List<ArrowBlock> getRecordBlocks() {
     return recordBlocks;
diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowStreamWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowStreamWriter.java
index 7200851620..928e1de4c5 100644
--- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowStreamWriter.java
+++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowStreamWriter.java
@@ -21,11 +21,18 @@ import java.io.IOException;
 import java.io.OutputStream;
 import java.nio.channels.Channels;
 import java.nio.channels.WritableByteChannel;
+import java.util.HashMap;
+import java.util.Map;
 import java.util.Optional;
+import java.util.Set;
 
+import org.apache.arrow.util.AutoCloseables;
+import org.apache.arrow.vector.FieldVector;
 import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.compare.VectorEqualsVisitor;
 import org.apache.arrow.vector.compression.CompressionCodec;
 import org.apache.arrow.vector.compression.CompressionUtil;
+import org.apache.arrow.vector.dictionary.Dictionary;
 import org.apache.arrow.vector.dictionary.DictionaryProvider;
 import org.apache.arrow.vector.ipc.message.IpcOption;
 import org.apache.arrow.vector.ipc.message.MessageSerializer;
@@ -34,6 +41,7 @@ import org.apache.arrow.vector.ipc.message.MessageSerializer;
  * Writer for the Arrow stream format to send ArrowRecordBatches over a WriteChannel.
  */
 public class ArrowStreamWriter extends ArrowWriter {
+  private final Map<Long, FieldVector> previousDictionaries = new HashMap<>();
 
   /**
    * Construct an ArrowStreamWriter with an optional DictionaryProvider for the OutputStream.
@@ -121,4 +129,45 @@ public class ArrowStreamWriter extends ArrowWriter {
   protected void endInternal(WriteChannel out) throws IOException {
     writeEndOfStream(out, option);
   }
+
+  @Override
+  protected void ensureDictionariesWritten(DictionaryProvider provider, Set<Long> dictionaryIdsUsed)
+      throws IOException {
+    // write out any dictionaries that have changes
+    for (long id : dictionaryIdsUsed) {
+      Dictionary dictionary = provider.lookup(id);
+      FieldVector vector = dictionary.getVector();
+      if (previousDictionaries.containsKey(id) &&
+          VectorEqualsVisitor.vectorEquals(vector, previousDictionaries.get(id))) {
+        // Dictionary was previously written and hasn't changed
+        continue;
+      }
+      writeDictionaryBatch(dictionary);
+      // Store a copy of the vector in case it is later mutated
+      if (previousDictionaries.containsKey(id)) {
+        previousDictionaries.get(id).close();
+      }
+      previousDictionaries.put(id, copyVector(vector));
+    }
+  }
+
+  @Override
+  public void close() {
+    super.close();
+    try {
+      AutoCloseables.close(previousDictionaries.values());
+    } catch (Exception e) {
+      throw new RuntimeException(e);
+    }
+  }
+
+  private static FieldVector copyVector(FieldVector source) {
+    FieldVector copy = source.getField().createVector(source.getAllocator());
+    copy.allocateNew();
+    for (int i = 0; i < source.getValueCount(); i++) {
+      copy.copyFromSafe(i, i, source);
+    }
+    copy.setValueCount(source.getValueCount());
+    return copy;
+  }
 }
diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowWriter.java
index 2c524b81b7..a33c55de53 100644
--- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowWriter.java
+++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowWriter.java
@@ -26,7 +26,6 @@ import java.util.List;
 import java.util.Optional;
 import java.util.Set;
 
-import org.apache.arrow.util.AutoCloseables;
 import org.apache.arrow.vector.FieldVector;
 import org.apache.arrow.vector.VectorSchemaRoot;
 import org.apache.arrow.vector.VectorUnloader;
@@ -59,13 +58,12 @@ public abstract class ArrowWriter implements AutoCloseable {
   protected final WriteChannel out;
 
   private final VectorUnloader unloader;
-  private final List<ArrowDictionaryBatch> dictionaries;
+  private final DictionaryProvider dictionaryProvider;
+  private final Set<Long> dictionaryIdsUsed = new HashSet<>();
 
   private boolean started = false;
   private boolean ended = false;
 
-  private boolean dictWritten = false;
-
   protected IpcOption option;
 
   protected ArrowWriter(VectorSchemaRoot root, DictionaryProvider provider, WritableByteChannel out) {
@@ -99,9 +97,9 @@ public abstract class ArrowWriter implements AutoCloseable {
         /*alignBuffers*/ true);
     this.out = new WriteChannel(out);
     this.option = option;
+    this.dictionaryProvider = provider;
 
     List<Field> fields = new ArrayList<>(root.getSchema().getFields().size());
-    Set<Long> dictionaryIdsUsed = new HashSet<>();
 
     MetadataV4UnionChecker.checkForUnion(root.getSchema().getFields().iterator(), option.metadataVersion);
     // Convert fields with dictionaries to have dictionary type
@@ -109,21 +107,6 @@ public abstract class ArrowWriter implements AutoCloseable {
       fields.add(DictionaryUtility.toMessageFormat(field, provider, dictionaryIdsUsed));
     }
 
-    // Create a record batch for each dictionary
-    this.dictionaries = new ArrayList<>(dictionaryIdsUsed.size());
-    for (long id : dictionaryIdsUsed) {
-      Dictionary dictionary = provider.lookup(id);
-      FieldVector vector = dictionary.getVector();
-      int count = vector.getValueCount();
-      VectorSchemaRoot dictRoot = new VectorSchemaRoot(
-          Collections.singletonList(vector.getField()),
-          Collections.singletonList(vector),
-          count);
-      VectorUnloader unloader = new VectorUnloader(dictRoot);
-      ArrowRecordBatch batch = unloader.getRecordBatch();
-      this.dictionaries.add(new ArrowDictionaryBatch(id, batch));
-    }
-
     this.schema = new Schema(fields, root.getSchema().getCustomMetadata());
   }
 
@@ -136,12 +119,34 @@ public abstract class ArrowWriter implements AutoCloseable {
    */
   public void writeBatch() throws IOException {
     ensureStarted();
-    ensureDictionariesWritten();
+    ensureDictionariesWritten(dictionaryProvider, dictionaryIdsUsed);
     try (ArrowRecordBatch batch = unloader.getRecordBatch()) {
       writeRecordBatch(batch);
     }
   }
 
+  protected void writeDictionaryBatch(Dictionary dictionary) throws IOException {
+    FieldVector vector = dictionary.getVector();
+    long id = dictionary.getEncoding().getId();
+    int count = vector.getValueCount();
+    VectorSchemaRoot dictRoot = new VectorSchemaRoot(
+        Collections.singletonList(vector.getField()),
+        Collections.singletonList(vector),
+        count);
+    VectorUnloader unloader = new VectorUnloader(dictRoot);
+    ArrowRecordBatch batch = unloader.getRecordBatch();
+    ArrowDictionaryBatch dictionaryBatch = new ArrowDictionaryBatch(id, batch, false);
+    try {
+      writeDictionaryBatch(dictionaryBatch);
+    } finally {
+      try {
+        dictionaryBatch.close();
+      } catch (Exception e) {
+        throw new RuntimeException("Error occurred while closing dictionary.", e);
+      }
+    }
+  }
+
   protected ArrowBlock writeDictionaryBatch(ArrowDictionaryBatch batch) throws IOException {
     ArrowBlock block = MessageSerializer.serialize(out, batch, option);
     if (LOGGER.isDebugEnabled()) {
@@ -183,23 +188,8 @@ public abstract class ArrowWriter implements AutoCloseable {
    * Write dictionaries after schema and before recordBatches, dictionaries won't be
    * written if empty stream (only has schema data in IPC).
    */
-  private void ensureDictionariesWritten() throws IOException {
-    if (!dictWritten) {
-      dictWritten = true;
-      // write out any dictionaries
-      try {
-        for (ArrowDictionaryBatch batch : dictionaries) {
-          writeDictionaryBatch(batch);
-        }
-      } finally {
-        try {
-          AutoCloseables.close(dictionaries);
-        } catch (Exception e) {
-          throw new RuntimeException("Error occurred while closing dictionaries.", e);
-        }
-      }
-    }
-  }
+  protected abstract void ensureDictionariesWritten(DictionaryProvider provider, Set<Long> dictionaryIdsUsed)
+      throws IOException;
 
   private void ensureEnded() throws IOException {
     if (!ended) {
@@ -219,9 +209,6 @@ public abstract class ArrowWriter implements AutoCloseable {
     try {
       end();
       out.close();
-      if (!dictWritten) {
-        AutoCloseables.close(dictionaries);
-      }
     } catch (Exception e) {
       throw new RuntimeException(e);
     }
diff --git a/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowReaderWriter.java b/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowReaderWriter.java
index 08c4d34732..07875b2502 100644
--- a/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowReaderWriter.java
+++ b/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowReaderWriter.java
@@ -86,6 +86,7 @@ import org.apache.arrow.vector.types.pojo.FieldType;
 import org.apache.arrow.vector.types.pojo.Schema;
 import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel;
 import org.apache.arrow.vector.util.DictionaryUtility;
+import org.apache.arrow.vector.util.TransferPair;
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
@@ -612,6 +613,88 @@ public class TestArrowReaderWriter {
 
   }
 
+  // Tests that the ArrowStreamWriter re-emits dictionaries when they change
+  @Test
+  public void testWriteReadStreamWithDictionaryReplacement() throws Exception {
+    DictionaryProvider.MapDictionaryProvider provider = new DictionaryProvider.MapDictionaryProvider();
+    provider.put(dictionary1);
+
+    String[] batch0 = {"foo", "bar", "baz", "bar", "baz"};
+    String[] batch1 = {"foo", "aa", "bar", "bb", "baz", "cc"};
+
+    VarCharVector vector = newVarCharVector("varchar", allocator);
+    vector.allocateNewSafe();
+    for (int i = 0; i < batch0.length; ++i) {
+      vector.set(i, batch0[i].getBytes(StandardCharsets.UTF_8));
+    }
+    vector.setValueCount(batch0.length);
+    FieldVector encodedVector1 = (FieldVector) DictionaryEncoder.encode(vector, dictionary1);
+
+    List<Field> fields = Arrays.asList(encodedVector1.getField());
+    try (ByteArrayOutputStream out = new ByteArrayOutputStream()) {
+      try (VectorSchemaRoot root =
+               new VectorSchemaRoot(fields, Arrays.asList(encodedVector1), encodedVector1.getValueCount());
+           ArrowStreamWriter writer = new ArrowStreamWriter(root, provider, newChannel(out))) {
+        writer.start();
+
+        // Write batch with initial data and dictionary
+        writer.writeBatch();
+
+        // Create data for the next batch, using an extended dictionary with the same id
+        vector.reset();
+        for (int i = 0; i < batch1.length; ++i) {
+          vector.set(i, batch1[i].getBytes(StandardCharsets.UTF_8));
+        }
+        vector.setValueCount(batch1.length);
+
+        // Re-encode and move encoded data into the vector schema root
+        provider.put(dictionary3);
+        FieldVector encodedVector2 = (FieldVector) DictionaryEncoder.encode(vector, dictionary3);
+        TransferPair transferPair = encodedVector2.makeTransferPair(root.getVector(0));
+        transferPair.transfer();
+
+        // Write second batch
+        root.setRowCount(batch1.length);
+        writer.writeBatch();
+
+        writer.end();
+      }
+
+      try (ArrowStreamReader reader = new ArrowStreamReader(
+          new ByteArrayReadableSeekableByteChannel(out.toByteArray()), allocator)) {
+        VectorSchemaRoot root = reader.getVectorSchemaRoot();
+
+        // Read and verify first batch
+        assertTrue(reader.loadNextBatch());
+        assertEquals(batch0.length, root.getRowCount());
+        FieldVector readEncoded1 = root.getVector(0);
+        long dictionaryId = readEncoded1.getField().getDictionary().getId();
+        try (VarCharVector decodedValues =
+                 (VarCharVector) DictionaryEncoder.decode(readEncoded1, reader.lookup(dictionaryId))) {
+          for (int i = 0; i < batch0.length; ++i) {
+            assertEquals(batch0[i], new String(decodedValues.get(i), StandardCharsets.UTF_8));
+          }
+        }
+
+        // Read and verify second batch
+        assertTrue(reader.loadNextBatch());
+        assertEquals(batch1.length, root.getRowCount());
+        FieldVector readEncoded2 = root.getVector(0);
+        dictionaryId = readEncoded2.getField().getDictionary().getId();
+        try (VarCharVector decodedValues =
+                 (VarCharVector) DictionaryEncoder.decode(readEncoded2, reader.lookup(dictionaryId))) {
+          for (int i = 0; i < batch1.length; ++i) {
+            assertEquals(batch1[i], new String(decodedValues.get(i), StandardCharsets.UTF_8));
+          }
+        }
+
+        assertFalse(reader.loadNextBatch());
+      }
+    }
+
+    vector.close();
+  }
+
   private void serializeDictionaryBatch(
       WriteChannel out,
       Dictionary dictionary,