You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by li...@apache.org on 2023/06/09 15:16:13 UTC
[arrow] branch main updated: GH-18547: [Java] Support re-emitting dictionaries in ArrowStreamWriter (#35920)
This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new 8b2ab4d820 GH-18547: [Java] Support re-emitting dictionaries in ArrowStreamWriter (#35920)
8b2ab4d820 is described below
commit 8b2ab4d8200fdd414b85d531f1cef4b58a3ce351
Author: Adam Reeve <ad...@gmail.com>
AuthorDate: Sat Jun 10 03:16:00 2023 +1200
GH-18547: [Java] Support re-emitting dictionaries in ArrowStreamWriter (#35920)
### Rationale for this change
This allows writing IPC streams where dictionary values change between record batches.
### What changes are included in this PR?
* Add new abstract `void ensureDictionariesWritten(DictionaryProvider provider, Set<Long> dictionaryIdsUsed)` to the base `ArrowWriter` class
* Move existing logic that only writes dictionaries once into the `ArrowFileWriter` class
* Implement replacement dictionary writing in `ArrowStreamWriter` by keeping copies of previously written dictionaries
### Are these changes tested?
Yes, I've added a new unit test for this
### Are there any user-facing changes?
Yes, `ArrowStreamWriter` will now write replacement dictionaries when dictionary values change between batches.
**This PR includes breaking changes to public APIs.**
`ArrowWriter` has a new abstract `ensureDictionariesWritten` method. This will only affect users directly inheriting from `ArrowWriter` rather than `ArrowFileWriter` or `ArrowStreamWriter`.
There's also a behaviour change to `ArrowWriter`, where previously dictionaries were read from a `DictionaryProvider` on construction, but this is now delayed until the first batch is written.
* Closes: #18547
Authored-by: Adam Reeve <ad...@gmail.com>
Signed-off-by: David Li <li...@gmail.com>
---
.../java/org/apache/arrow/tools/EchoServer.java | 1 +
.../apache/arrow/vector/ipc/ArrowFileWriter.java | 18 +++++
.../apache/arrow/vector/ipc/ArrowStreamWriter.java | 49 +++++++++++++
.../org/apache/arrow/vector/ipc/ArrowWriter.java | 69 ++++++++----------
.../arrow/vector/ipc/TestArrowReaderWriter.java | 83 ++++++++++++++++++++++
5 files changed, 179 insertions(+), 41 deletions(-)
diff --git a/java/tools/src/main/java/org/apache/arrow/tools/EchoServer.java b/java/tools/src/main/java/org/apache/arrow/tools/EchoServer.java
index 0ddd1e9464..36ba24dbee 100644
--- a/java/tools/src/main/java/org/apache/arrow/tools/EchoServer.java
+++ b/java/tools/src/main/java/org/apache/arrow/tools/EchoServer.java
@@ -135,6 +135,7 @@ public class EchoServer {
Preconditions.checkState(reader.bytesRead() == writer.bytesWritten());
LOGGER.debug(String.format("Echoed %d records", echoed));
reader.close(false);
+ writer.close();
}
}
diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowFileWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowFileWriter.java
index 0b0931f7bb..71db79087a 100644
--- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowFileWriter.java
+++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowFileWriter.java
@@ -23,11 +23,13 @@ import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
+import java.util.Set;
import org.apache.arrow.util.VisibleForTesting;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.compression.CompressionCodec;
import org.apache.arrow.vector.compression.CompressionUtil;
+import org.apache.arrow.vector.dictionary.Dictionary;
import org.apache.arrow.vector.dictionary.DictionaryProvider;
import org.apache.arrow.vector.ipc.message.ArrowBlock;
import org.apache.arrow.vector.ipc.message.ArrowDictionaryBatch;
@@ -50,6 +52,7 @@ public class ArrowFileWriter extends ArrowWriter {
private final List<ArrowBlock> recordBlocks = new ArrayList<>();
private Map<String, String> metaData;
+ private boolean dictionariesWritten = false;
public ArrowFileWriter(VectorSchemaRoot root, DictionaryProvider provider, WritableByteChannel out) {
super(root, provider, out);
@@ -123,6 +126,21 @@ public class ArrowFileWriter extends ArrowWriter {
LOGGER.debug("magic written, now at {}", out.getCurrentPosition());
}
+ @Override
+ protected void ensureDictionariesWritten(DictionaryProvider provider, Set<Long> dictionaryIdsUsed)
+ throws IOException {
+ if (dictionariesWritten) {
+ return;
+ }
+ dictionariesWritten = true;
+ // Write out all dictionaries required.
+ // Replacement dictionaries are not supported in the IPC file format.
+ for (long id : dictionaryIdsUsed) {
+ Dictionary dictionary = provider.lookup(id);
+ writeDictionaryBatch(dictionary);
+ }
+ }
+
@VisibleForTesting
public List<ArrowBlock> getRecordBlocks() {
return recordBlocks;
diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowStreamWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowStreamWriter.java
index 7200851620..928e1de4c5 100644
--- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowStreamWriter.java
+++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowStreamWriter.java
@@ -21,11 +21,18 @@ import java.io.IOException;
import java.io.OutputStream;
import java.nio.channels.Channels;
import java.nio.channels.WritableByteChannel;
+import java.util.HashMap;
+import java.util.Map;
import java.util.Optional;
+import java.util.Set;
+import org.apache.arrow.util.AutoCloseables;
+import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.compare.VectorEqualsVisitor;
import org.apache.arrow.vector.compression.CompressionCodec;
import org.apache.arrow.vector.compression.CompressionUtil;
+import org.apache.arrow.vector.dictionary.Dictionary;
import org.apache.arrow.vector.dictionary.DictionaryProvider;
import org.apache.arrow.vector.ipc.message.IpcOption;
import org.apache.arrow.vector.ipc.message.MessageSerializer;
@@ -34,6 +41,7 @@ import org.apache.arrow.vector.ipc.message.MessageSerializer;
* Writer for the Arrow stream format to send ArrowRecordBatches over a WriteChannel.
*/
public class ArrowStreamWriter extends ArrowWriter {
+ private final Map<Long, FieldVector> previousDictionaries = new HashMap<>();
/**
* Construct an ArrowStreamWriter with an optional DictionaryProvider for the OutputStream.
@@ -121,4 +129,45 @@ public class ArrowStreamWriter extends ArrowWriter {
protected void endInternal(WriteChannel out) throws IOException {
writeEndOfStream(out, option);
}
+
+ @Override
+ protected void ensureDictionariesWritten(DictionaryProvider provider, Set<Long> dictionaryIdsUsed)
+ throws IOException {
+ // write out any dictionaries that have changes
+ for (long id : dictionaryIdsUsed) {
+ Dictionary dictionary = provider.lookup(id);
+ FieldVector vector = dictionary.getVector();
+ if (previousDictionaries.containsKey(id) &&
+ VectorEqualsVisitor.vectorEquals(vector, previousDictionaries.get(id))) {
+ // Dictionary was previously written and hasn't changed
+ continue;
+ }
+ writeDictionaryBatch(dictionary);
+ // Store a copy of the vector in case it is later mutated
+ if (previousDictionaries.containsKey(id)) {
+ previousDictionaries.get(id).close();
+ }
+ previousDictionaries.put(id, copyVector(vector));
+ }
+ }
+
+ @Override
+ public void close() {
+ super.close();
+ try {
+ AutoCloseables.close(previousDictionaries.values());
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ private static FieldVector copyVector(FieldVector source) {
+ FieldVector copy = source.getField().createVector(source.getAllocator());
+ copy.allocateNew();
+ for (int i = 0; i < source.getValueCount(); i++) {
+ copy.copyFromSafe(i, i, source);
+ }
+ copy.setValueCount(source.getValueCount());
+ return copy;
+ }
}
diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowWriter.java
index 2c524b81b7..a33c55de53 100644
--- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowWriter.java
+++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/ArrowWriter.java
@@ -26,7 +26,6 @@ import java.util.List;
import java.util.Optional;
import java.util.Set;
-import org.apache.arrow.util.AutoCloseables;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.VectorUnloader;
@@ -59,13 +58,12 @@ public abstract class ArrowWriter implements AutoCloseable {
protected final WriteChannel out;
private final VectorUnloader unloader;
- private final List<ArrowDictionaryBatch> dictionaries;
+ private final DictionaryProvider dictionaryProvider;
+ private final Set<Long> dictionaryIdsUsed = new HashSet<>();
private boolean started = false;
private boolean ended = false;
- private boolean dictWritten = false;
-
protected IpcOption option;
protected ArrowWriter(VectorSchemaRoot root, DictionaryProvider provider, WritableByteChannel out) {
@@ -99,9 +97,9 @@ public abstract class ArrowWriter implements AutoCloseable {
/*alignBuffers*/ true);
this.out = new WriteChannel(out);
this.option = option;
+ this.dictionaryProvider = provider;
List<Field> fields = new ArrayList<>(root.getSchema().getFields().size());
- Set<Long> dictionaryIdsUsed = new HashSet<>();
MetadataV4UnionChecker.checkForUnion(root.getSchema().getFields().iterator(), option.metadataVersion);
// Convert fields with dictionaries to have dictionary type
@@ -109,21 +107,6 @@ public abstract class ArrowWriter implements AutoCloseable {
fields.add(DictionaryUtility.toMessageFormat(field, provider, dictionaryIdsUsed));
}
- // Create a record batch for each dictionary
- this.dictionaries = new ArrayList<>(dictionaryIdsUsed.size());
- for (long id : dictionaryIdsUsed) {
- Dictionary dictionary = provider.lookup(id);
- FieldVector vector = dictionary.getVector();
- int count = vector.getValueCount();
- VectorSchemaRoot dictRoot = new VectorSchemaRoot(
- Collections.singletonList(vector.getField()),
- Collections.singletonList(vector),
- count);
- VectorUnloader unloader = new VectorUnloader(dictRoot);
- ArrowRecordBatch batch = unloader.getRecordBatch();
- this.dictionaries.add(new ArrowDictionaryBatch(id, batch));
- }
-
this.schema = new Schema(fields, root.getSchema().getCustomMetadata());
}
@@ -136,12 +119,34 @@ public abstract class ArrowWriter implements AutoCloseable {
*/
public void writeBatch() throws IOException {
ensureStarted();
- ensureDictionariesWritten();
+ ensureDictionariesWritten(dictionaryProvider, dictionaryIdsUsed);
try (ArrowRecordBatch batch = unloader.getRecordBatch()) {
writeRecordBatch(batch);
}
}
+ protected void writeDictionaryBatch(Dictionary dictionary) throws IOException {
+ FieldVector vector = dictionary.getVector();
+ long id = dictionary.getEncoding().getId();
+ int count = vector.getValueCount();
+ VectorSchemaRoot dictRoot = new VectorSchemaRoot(
+ Collections.singletonList(vector.getField()),
+ Collections.singletonList(vector),
+ count);
+ VectorUnloader unloader = new VectorUnloader(dictRoot);
+ ArrowRecordBatch batch = unloader.getRecordBatch();
+ ArrowDictionaryBatch dictionaryBatch = new ArrowDictionaryBatch(id, batch, false);
+ try {
+ writeDictionaryBatch(dictionaryBatch);
+ } finally {
+ try {
+ dictionaryBatch.close();
+ } catch (Exception e) {
+ throw new RuntimeException("Error occurred while closing dictionary.", e);
+ }
+ }
+ }
+
protected ArrowBlock writeDictionaryBatch(ArrowDictionaryBatch batch) throws IOException {
ArrowBlock block = MessageSerializer.serialize(out, batch, option);
if (LOGGER.isDebugEnabled()) {
@@ -183,23 +188,8 @@ public abstract class ArrowWriter implements AutoCloseable {
* Write dictionaries after schema and before recordBatches, dictionaries won't be
* written if empty stream (only has schema data in IPC).
*/
- private void ensureDictionariesWritten() throws IOException {
- if (!dictWritten) {
- dictWritten = true;
- // write out any dictionaries
- try {
- for (ArrowDictionaryBatch batch : dictionaries) {
- writeDictionaryBatch(batch);
- }
- } finally {
- try {
- AutoCloseables.close(dictionaries);
- } catch (Exception e) {
- throw new RuntimeException("Error occurred while closing dictionaries.", e);
- }
- }
- }
- }
+ protected abstract void ensureDictionariesWritten(DictionaryProvider provider, Set<Long> dictionaryIdsUsed)
+ throws IOException;
private void ensureEnded() throws IOException {
if (!ended) {
@@ -219,9 +209,6 @@ public abstract class ArrowWriter implements AutoCloseable {
try {
end();
out.close();
- if (!dictWritten) {
- AutoCloseables.close(dictionaries);
- }
} catch (Exception e) {
throw new RuntimeException(e);
}
diff --git a/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowReaderWriter.java b/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowReaderWriter.java
index 08c4d34732..07875b2502 100644
--- a/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowReaderWriter.java
+++ b/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestArrowReaderWriter.java
@@ -86,6 +86,7 @@ import org.apache.arrow.vector.types.pojo.FieldType;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel;
import org.apache.arrow.vector.util.DictionaryUtility;
+import org.apache.arrow.vector.util.TransferPair;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
@@ -612,6 +613,88 @@ public class TestArrowReaderWriter {
}
+ // Tests that the ArrowStreamWriter re-emits dictionaries when they change
+ @Test
+ public void testWriteReadStreamWithDictionaryReplacement() throws Exception {
+ DictionaryProvider.MapDictionaryProvider provider = new DictionaryProvider.MapDictionaryProvider();
+ provider.put(dictionary1);
+
+ String[] batch0 = {"foo", "bar", "baz", "bar", "baz"};
+ String[] batch1 = {"foo", "aa", "bar", "bb", "baz", "cc"};
+
+ VarCharVector vector = newVarCharVector("varchar", allocator);
+ vector.allocateNewSafe();
+ for (int i = 0; i < batch0.length; ++i) {
+ vector.set(i, batch0[i].getBytes(StandardCharsets.UTF_8));
+ }
+ vector.setValueCount(batch0.length);
+ FieldVector encodedVector1 = (FieldVector) DictionaryEncoder.encode(vector, dictionary1);
+
+ List<Field> fields = Arrays.asList(encodedVector1.getField());
+ try (ByteArrayOutputStream out = new ByteArrayOutputStream()) {
+ try (VectorSchemaRoot root =
+ new VectorSchemaRoot(fields, Arrays.asList(encodedVector1), encodedVector1.getValueCount());
+ ArrowStreamWriter writer = new ArrowStreamWriter(root, provider, newChannel(out))) {
+ writer.start();
+
+ // Write batch with initial data and dictionary
+ writer.writeBatch();
+
+ // Create data for the next batch, using an extended dictionary with the same id
+ vector.reset();
+ for (int i = 0; i < batch1.length; ++i) {
+ vector.set(i, batch1[i].getBytes(StandardCharsets.UTF_8));
+ }
+ vector.setValueCount(batch1.length);
+
+ // Re-encode and move encoded data into the vector schema root
+ provider.put(dictionary3);
+ FieldVector encodedVector2 = (FieldVector) DictionaryEncoder.encode(vector, dictionary3);
+ TransferPair transferPair = encodedVector2.makeTransferPair(root.getVector(0));
+ transferPair.transfer();
+
+ // Write second batch
+ root.setRowCount(batch1.length);
+ writer.writeBatch();
+
+ writer.end();
+ }
+
+ try (ArrowStreamReader reader = new ArrowStreamReader(
+ new ByteArrayReadableSeekableByteChannel(out.toByteArray()), allocator)) {
+ VectorSchemaRoot root = reader.getVectorSchemaRoot();
+
+ // Read and verify first batch
+ assertTrue(reader.loadNextBatch());
+ assertEquals(batch0.length, root.getRowCount());
+ FieldVector readEncoded1 = root.getVector(0);
+ long dictionaryId = readEncoded1.getField().getDictionary().getId();
+ try (VarCharVector decodedValues =
+ (VarCharVector) DictionaryEncoder.decode(readEncoded1, reader.lookup(dictionaryId))) {
+ for (int i = 0; i < batch0.length; ++i) {
+ assertEquals(batch0[i], new String(decodedValues.get(i), StandardCharsets.UTF_8));
+ }
+ }
+
+ // Read and verify second batch
+ assertTrue(reader.loadNextBatch());
+ assertEquals(batch1.length, root.getRowCount());
+ FieldVector readEncoded2 = root.getVector(0);
+ dictionaryId = readEncoded2.getField().getDictionary().getId();
+ try (VarCharVector decodedValues =
+ (VarCharVector) DictionaryEncoder.decode(readEncoded2, reader.lookup(dictionaryId))) {
+ for (int i = 0; i < batch1.length; ++i) {
+ assertEquals(batch1[i], new String(decodedValues.get(i), StandardCharsets.UTF_8));
+ }
+ }
+
+ assertFalse(reader.loadNextBatch());
+ }
+ }
+
+ vector.close();
+ }
+
private void serializeDictionaryBatch(
WriteChannel out,
Dictionary dictionary,