You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by lc...@apache.org on 2022/07/06 18:12:01 UTC
[beam] branch master updated: [BEAM-14545] Optimize copies in dataflow v1 shuffle reader. (#17802)
This is an automated email from the ASF dual-hosted git repository.
lcwik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 5465f38c750 [BEAM-14545] Optimize copies in dataflow v1 shuffle reader. (#17802)
5465f38c750 is described below
commit 5465f38c750d9c87024c82a9e694c3b34282fc37
Author: Steven Niemitz <st...@gmail.com>
AuthorDate: Wed Jul 6 14:11:52 2022 -0400
[BEAM-14545] Optimize copies in dataflow v1 shuffle reader. (#17802)
* shuffle tuning
* review cleanup
---
.../runners/dataflow/worker/ByteArrayReader.java | 53 ++++++++++++++++++++++
.../worker/ChunkingShuffleBatchReader.java | 31 ++++++-------
.../dataflow/worker/GroupingShuffleReader.java | 10 ++--
.../dataflow/worker/PartitioningShuffleReader.java | 4 +-
.../runners/dataflow/worker/ShuffleReader.java | 7 +--
.../beam/runners/dataflow/worker/ShuffleSink.java | 2 +-
.../dataflow/worker/UngroupedShuffleReader.java | 5 +-
.../common/worker/ByteArrayShufflePosition.java | 31 ++++++++-----
.../worker/GroupingShuffleEntryIterator.java | 17 +++----
.../common/worker/KeyGroupedShuffleEntries.java | 5 +-
.../util/common/worker/ShuffleBatchReader.java | 4 +-
.../worker/util/common/worker/ShuffleEntry.java | 47 +++++++++----------
.../util/common/worker/ShuffleEntryReader.java | 2 +-
.../dataflow/worker/GroupingShuffleReaderTest.java | 40 ++++++++++------
.../runners/dataflow/worker/ShuffleSinkTest.java | 33 +++++++-------
.../runners/dataflow/worker/TestShuffleReader.java | 43 ++++++++++--------
.../dataflow/worker/TestShuffleReaderTest.java | 7 +--
.../runners/dataflow/worker/TestShuffleWriter.java | 5 +-
.../worker/BatchingShuffleEntryReaderTest.java | 20 +++++---
.../worker/GroupingShuffleEntryIteratorTest.java | 7 +--
.../util/common/worker/ShuffleEntryTest.java | 17 +++----
.../java/org/apache/beam/sdk/util/CoderUtils.java | 24 ++++++++++
.../org/apache/beam/sdk/util/CoderUtilsTest.java | 25 ++++++++++
23 files changed, 290 insertions(+), 149 deletions(-)
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/ByteArrayReader.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/ByteArrayReader.java
new file mode 100644
index 00000000000..3c19c4ec661
--- /dev/null
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/ByteArrayReader.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker;
+
+import org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.ByteString;
+import org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.UnsafeByteOperations;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.primitives.Ints;
+
+class ByteArrayReader {
+
+ private final byte[] arr;
+ private int pos;
+
+ public ByteArrayReader(byte[] arr) {
+ this.arr = arr;
+ this.pos = 0;
+ }
+
+ public int available() {
+ return arr.length - pos;
+ }
+
+ public int readInt() {
+ int ret = Ints.fromBytes(arr[pos], arr[pos + 1], arr[pos + 2], arr[pos + 3]);
+ pos += 4;
+ return ret;
+ }
+
+ public ByteString read(int size) {
+ if (size == 0) {
+ return ByteString.EMPTY;
+ }
+
+ ByteString ret = UnsafeByteOperations.unsafeWrap(arr, pos, size);
+ pos += size;
+ return ret;
+ }
+}
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/ChunkingShuffleBatchReader.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/ChunkingShuffleBatchReader.java
index 3a8cbc8041e..2157a63d0ac 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/ChunkingShuffleBatchReader.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/ChunkingShuffleBatchReader.java
@@ -17,9 +17,7 @@
*/
package org.apache.beam.runners.dataflow.worker;
-import java.io.ByteArrayInputStream;
import java.io.Closeable;
-import java.io.DataInputStream;
import java.io.IOException;
import java.util.ArrayList;
import org.apache.beam.runners.core.metrics.ExecutionStateTracker;
@@ -28,7 +26,7 @@ import org.apache.beam.runners.dataflow.worker.util.common.worker.ByteArrayShuff
import org.apache.beam.runners.dataflow.worker.util.common.worker.ShuffleBatchReader;
import org.apache.beam.runners.dataflow.worker.util.common.worker.ShuffleEntry;
import org.apache.beam.runners.dataflow.worker.util.common.worker.ShufflePosition;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.ByteStreams;
+import org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.ByteString;
import org.checkerframework.checker.nullness.qual.Nullable;
/** ChunkingShuffleBatchReader reads data from a shuffle dataset using a ShuffleReader. */
@@ -57,11 +55,13 @@ final class ChunkingShuffleBatchReader implements ShuffleBatchReader {
try (Closeable trackedReadState = tracker.enterState(readState)) {
result = reader.readIncludingPosition(startPosition, endPosition);
}
- DataInputStream input = new DataInputStream(new ByteArrayInputStream(result.chunk));
+ ByteArrayReader input = new ByteArrayReader(result.chunk);
ArrayList<ShuffleEntry> entries = new ArrayList<>();
+
while (input.available() > 0) {
entries.add(getShuffleEntry(input));
}
+
return new Batch(
entries,
result.nextStartPosition == null
@@ -72,31 +72,30 @@ final class ChunkingShuffleBatchReader implements ShuffleBatchReader {
/**
* Extracts a ShuffleEntry by parsing bytes from a given InputStream.
*
- * @param input stream to read from
+ * @param chunk chunk to read from
* @return parsed ShuffleEntry
*/
- static ShuffleEntry getShuffleEntry(DataInputStream input) throws IOException {
- byte[] position = getFixedLengthPrefixedByteArray(input);
- byte[] key = getFixedLengthPrefixedByteArray(input);
- byte[] skey = getFixedLengthPrefixedByteArray(input);
- byte[] value = getFixedLengthPrefixedByteArray(input);
+ private ShuffleEntry getShuffleEntry(ByteArrayReader chunk) throws IOException {
+ ByteString position = getFixedLengthPrefixedByteArray(chunk);
+ ByteString key = getFixedLengthPrefixedByteArray(chunk);
+ ByteString skey = getFixedLengthPrefixedByteArray(chunk);
+ ByteString value = getFixedLengthPrefixedByteArray(chunk);
+
return new ShuffleEntry(ByteArrayShufflePosition.of(position), key, skey, value);
}
/**
* Extracts a length-prefix-encoded byte array from a given InputStream.
*
- * @param dataInputStream stream to read from
+ * @param chunk chunk to read from
* @return parsed byte array
*/
- static byte[] getFixedLengthPrefixedByteArray(DataInputStream dataInputStream)
+ private static ByteString getFixedLengthPrefixedByteArray(ByteArrayReader chunk)
throws IOException {
- int length = dataInputStream.readInt();
+ int length = chunk.readInt();
if (length < 0) {
throw new IOException("invalid length: " + length);
}
- byte[] data = new byte[length];
- ByteStreams.readFully(dataInputStream, data);
- return data;
+ return chunk.read(length);
}
}
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/GroupingShuffleReader.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/GroupingShuffleReader.java
index c0084d87b41..2fefc9c6084 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/GroupingShuffleReader.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/GroupingShuffleReader.java
@@ -25,9 +25,9 @@ import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Prec
import com.google.api.services.dataflow.model.ApproximateReportedProgress;
import com.google.api.services.dataflow.model.ApproximateSplitRequest;
-import java.io.ByteArrayInputStream;
import java.io.Closeable;
import java.io.IOException;
+import java.io.InputStream;
import java.util.NoSuchElementException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
@@ -288,7 +288,7 @@ public class GroupingShuffleReader<K, V> extends NativeReader<WindowedValue<KV<K
}
}
- K key = CoderUtils.decodeFromByteArray(parentReader.keyCoder, groups.getCurrent().key);
+ K key = CoderUtils.decodeFromByteString(parentReader.keyCoder, groups.getCurrent().key);
parentReader.executionContext.setKey(key);
current =
new ValueInEmptyWindows<>(
@@ -452,20 +452,20 @@ public class GroupingShuffleReader<K, V> extends NativeReader<WindowedValue<KV<K
notifyValueReturned(currentGroupSize.getAndSet(0L));
try {
if (parentReader.secondaryKeyCoder != null) {
- ByteArrayInputStream bais = new ByteArrayInputStream(entry.getSecondaryKey());
+ InputStream bais = entry.getSecondaryKey().newInput();
@SuppressWarnings("unchecked")
V value =
(V)
KV.of(
// We ignore decoding the timestamp.
parentReader.secondaryKeyCoder.decode(bais),
- CoderUtils.decodeFromByteArray(
+ CoderUtils.decodeFromByteString(
parentReader.valueCoder, entry.getValue()));
return value;
} else {
@SuppressWarnings("unchecked")
V value =
- (V) CoderUtils.decodeFromByteArray(parentReader.valueCoder, entry.getValue());
+ (V) CoderUtils.decodeFromByteString(parentReader.valueCoder, entry.getValue());
return value;
}
} catch (IOException exn) {
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/PartitioningShuffleReader.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/PartitioningShuffleReader.java
index 39c015c13d4..ed12052cb1b 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/PartitioningShuffleReader.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/PartitioningShuffleReader.java
@@ -144,9 +144,9 @@ public class PartitioningShuffleReader<K, V> extends NativeReader<WindowedValue<
return false;
}
ShuffleEntry record = iterator.next();
- K key = CoderUtils.decodeFromByteArray(shuffleReader.keyCoder, record.getKey());
+ K key = CoderUtils.decodeFromByteString(shuffleReader.keyCoder, record.getKey());
WindowedValue<V> windowedValue =
- CoderUtils.decodeFromByteArray(shuffleReader.windowedValueCoder, record.getValue());
+ CoderUtils.decodeFromByteString(shuffleReader.windowedValueCoder, record.getValue());
shuffleReader.notifyElementRead(record.length());
current = windowedValue.withValue(KV.of(key, windowedValue.getValue()));
return true;
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/ShuffleReader.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/ShuffleReader.java
index f663abad85f..b894fe43964 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/ShuffleReader.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/ShuffleReader.java
@@ -17,12 +17,13 @@
*/
package org.apache.beam.runners.dataflow.worker;
+import java.io.Closeable;
import java.io.IOException;
/** ShuffleReader reads chunks of data from a shuffle dataset for a given position range. */
-interface ShuffleReader {
+interface ShuffleReader extends Closeable {
/** Represents a chunk of data read from a shuffle dataset. */
- public static class ReadChunkResult {
+ class ReadChunkResult {
public final byte[] chunk;
public final byte[] nextStartPosition;
@@ -41,6 +42,6 @@ interface ShuffleReader {
* @param startPosition the start of the requested range (inclusive)
* @param endPosition the end of the requested range (exclusive)
*/
- public ReadChunkResult readIncludingPosition(byte[] startPosition, byte[] endPosition)
+ ReadChunkResult readIncludingPosition(byte[] startPosition, byte[] endPosition)
throws IOException;
}
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/ShuffleSink.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/ShuffleSink.java
index 955902c9a4a..7f1dd75ad0f 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/ShuffleSink.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/ShuffleSink.java
@@ -328,9 +328,9 @@ public class ShuffleSink<T> extends Sink<WindowedValue<T>> {
// Move forward enough bytes so we can prefix the size on after performing the write
int initialChunkSize = chunk.size();
chunk.resetTo(initialChunkSize + Ints.BYTES);
+
coder.encode(value, chunk.asOutputStream(), Context.OUTER);
int elementSize = chunk.size() - initialChunkSize - Ints.BYTES;
-
byte[] internalBytes = chunk.array();
internalBytes[initialChunkSize] = (byte) ((elementSize >>> 24) & 0xFF);
internalBytes[initialChunkSize + 1] = (byte) ((elementSize >>> 16) & 0xFF);
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/UngroupedShuffleReader.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/UngroupedShuffleReader.java
index 77b5c00042a..2947aae9751 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/UngroupedShuffleReader.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/UngroupedShuffleReader.java
@@ -27,6 +27,7 @@ import org.apache.beam.runners.dataflow.worker.util.common.worker.ShuffleEntryRe
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.util.CoderUtils;
+import org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.ByteString;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
import org.checkerframework.checker.nullness.qual.Nullable;
@@ -114,9 +115,9 @@ public class UngroupedShuffleReader<T> extends NativeReader<T> {
}
ShuffleEntry record = iterator.next();
// Throw away the primary and the secondary keys.
- byte[] value = record.getValue();
+ ByteString value = record.getValue();
shuffleReader.notifyElementRead(record.length());
- current = CoderUtils.decodeFromByteArray(shuffleReader.coder, value);
+ current = CoderUtils.decodeFromByteString(shuffleReader.coder, value);
return true;
}
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/ByteArrayShufflePosition.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/ByteArrayShufflePosition.java
index 1ed816ced63..36c0c611bdb 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/ByteArrayShufflePosition.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/ByteArrayShufflePosition.java
@@ -20,10 +20,9 @@ package org.apache.beam.runners.dataflow.worker.util.common.worker;
import static com.google.api.client.util.Base64.decodeBase64;
import static com.google.api.client.util.Base64.encodeBase64URLSafeString;
-import java.util.Arrays;
+import org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.ByteString;
+import org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.UnsafeByteOperations;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.primitives.Bytes;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.primitives.UnsignedBytes;
import org.checkerframework.checker.nullness.qual.Nullable;
/**
@@ -35,9 +34,10 @@ import org.checkerframework.checker.nullness.qual.Nullable;
"nullness" // TODO(https://github.com/apache/beam/issues/20497)
})
public class ByteArrayShufflePosition implements Comparable<ShufflePosition>, ShufflePosition {
- private final byte[] position;
+ private static final ByteString ZERO = ByteString.copyFrom(new byte[] {0});
+ private final ByteString position;
- public ByteArrayShufflePosition(byte[] position) {
+ public ByteArrayShufflePosition(ByteString position) {
this.position = position;
}
@@ -46,6 +46,13 @@ public class ByteArrayShufflePosition implements Comparable<ShufflePosition>, Sh
}
public static ByteArrayShufflePosition of(byte[] position) {
+ if (position == null) {
+ return null;
+ }
+ return new ByteArrayShufflePosition(UnsafeByteOperations.unsafeWrap(position));
+ }
+
+ public static ByteArrayShufflePosition of(ByteString position) {
if (position == null) {
return null;
}
@@ -58,15 +65,15 @@ public class ByteArrayShufflePosition implements Comparable<ShufflePosition>, Sh
}
Preconditions.checkArgument(shufflePosition instanceof ByteArrayShufflePosition);
ByteArrayShufflePosition adapter = (ByteArrayShufflePosition) shufflePosition;
- return adapter.getPosition();
+ return adapter.getPosition().toByteArray();
}
- public byte[] getPosition() {
+ public ByteString getPosition() {
return position;
}
public String encodeBase64() {
- return encodeBase64URLSafeString(position);
+ return encodeBase64URLSafeString(position.toByteArray());
}
/**
@@ -75,7 +82,7 @@ public class ByteArrayShufflePosition implements Comparable<ShufflePosition>, Sh
* successor.
*/
public ByteArrayShufflePosition immediateSuccessor() {
- return new ByteArrayShufflePosition(Bytes.concat(position, new byte[] {0}));
+ return new ByteArrayShufflePosition(position.concat(ZERO));
}
@Override
@@ -85,14 +92,14 @@ public class ByteArrayShufflePosition implements Comparable<ShufflePosition>, Sh
}
if (o instanceof ByteArrayShufflePosition) {
ByteArrayShufflePosition that = (ByteArrayShufflePosition) o;
- return Arrays.equals(this.position, that.position);
+ return this.position.equals(that.position);
}
return false;
}
@Override
public int hashCode() {
- return Arrays.hashCode(position);
+ return position.hashCode();
}
@Override
@@ -107,6 +114,6 @@ public class ByteArrayShufflePosition implements Comparable<ShufflePosition>, Sh
return 0;
}
ByteArrayShufflePosition other = (ByteArrayShufflePosition) o;
- return UnsignedBytes.lexicographicalComparator().compare(position, other.position);
+ return ByteString.unsignedLexicographicalComparator().compare(position, other.position);
}
}
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/GroupingShuffleEntryIterator.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/GroupingShuffleEntryIterator.java
index 064d6b232d0..e5d385f7a77 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/GroupingShuffleEntryIterator.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/GroupingShuffleEntryIterator.java
@@ -20,12 +20,12 @@ package org.apache.beam.runners.dataflow.worker.util.common.worker;
import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkNotNull;
import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState;
-import java.util.Arrays;
import java.util.NoSuchElementException;
import org.apache.beam.sdk.util.common.ElementByteSizeObservableIterable;
import org.apache.beam.sdk.util.common.ElementByteSizeObservableIterator;
import org.apache.beam.sdk.util.common.Reiterable;
import org.apache.beam.sdk.util.common.Reiterator;
+import org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.ByteString;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -59,7 +59,7 @@ public abstract class GroupingShuffleEntryIterator {
* shuffleIterator.next() is the key of the next KeyGroupedShuffleEntries to return via {@link
* #advance}/{@link #getCurrent}.
*/
- private byte @Nullable [] currentKeyBytes = null;
+ private @Nullable ByteString currentKeyBytes = null;
private ShufflePosition lastGroupStart;
@@ -118,7 +118,7 @@ public abstract class GroupingShuffleEntryIterator {
// start ValuesIterable below from this entry.
atCurrentEntry = shuffleIterator.copy();
entry = shuffleIterator.next();
- if (!Arrays.equals(entry.getKey(), currentKeyBytes)) {
+ if (!entry.getKey().equals(currentKeyBytes)) {
break;
}
// Note: we can get here only if the ValuesIterable of the preceding key has NOT been
@@ -190,12 +190,12 @@ public abstract class GroupingShuffleEntryIterator {
extends ElementByteSizeObservableIterable<ShuffleEntry, ValuesIterator>
implements Reiterable<ShuffleEntry> {
private final GroupingShuffleEntryIterator parent;
- private final byte[] currentKeyBytes;
+ private final ByteString currentKeyBytes;
private Reiterator<ShuffleEntry> baseValuesIterator;
public ValuesIterable(
GroupingShuffleEntryIterator parent,
- byte[] keyBytes,
+ ByteString keyBytes,
Reiterator<ShuffleEntry> baseValuesIterator) {
this.parent = parent;
this.currentKeyBytes = keyBytes;
@@ -217,7 +217,7 @@ public abstract class GroupingShuffleEntryIterator {
private final GroupingShuffleEntryIterator parent;
private final Reiterator<ShuffleEntry> valuesIterator;
private final ProgressTracker<ShuffleEntry> tracker;
- private final byte[] expectedKeyBytes;
+ private final ByteString expectedKeyBytes;
private Boolean cachedHasNext;
private long byteSizeRead = 0L;
@@ -226,7 +226,7 @@ public abstract class GroupingShuffleEntryIterator {
public ValuesIterator(
GroupingShuffleEntryIterator parent,
Reiterator<ShuffleEntry> valuesIterator,
- byte[] expectedKeyBytes) {
+ ByteString expectedKeyBytes) {
this.parent = parent;
this.valuesIterator = valuesIterator;
this.expectedKeyBytes = checkNotNull(expectedKeyBytes);
@@ -267,13 +267,14 @@ public abstract class GroupingShuffleEntryIterator {
return cachedHasNext;
}
+ @SuppressWarnings("ReferenceEquality")
private boolean advance() {
// Save a copy of the iterator pointing at the next entry, to use below in case we're right
// before a key boundary (or end of stream).
Reiterator<ShuffleEntry> possibleStartOfNextKey = valuesIterator.copy();
if (valuesIterator.hasNext()) {
ShuffleEntry entry = valuesIterator.next();
- if (Arrays.equals(entry.getKey(), expectedKeyBytes)) {
+ if (entry.getKey().equals(expectedKeyBytes)) {
byteSizeRead += entry.length();
tracker.saw(entry);
current = entry;
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/KeyGroupedShuffleEntries.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/KeyGroupedShuffleEntries.java
index 6ec7109aa9b..0ff3d0b13bb 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/KeyGroupedShuffleEntries.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/KeyGroupedShuffleEntries.java
@@ -18,15 +18,16 @@
package org.apache.beam.runners.dataflow.worker.util.common.worker;
import org.apache.beam.sdk.util.common.Reiterable;
+import org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.ByteString;
/** A collection of ShuffleEntries, all with the same key. */
public class KeyGroupedShuffleEntries {
public final ShufflePosition position;
- public final byte[] key;
+ public final ByteString key;
public final Reiterable<ShuffleEntry> values;
public KeyGroupedShuffleEntries(
- ShufflePosition position, byte[] key, Reiterable<ShuffleEntry> values) {
+ ShufflePosition position, ByteString key, Reiterable<ShuffleEntry> values) {
this.position = position;
this.key = key;
this.values = values;
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/ShuffleBatchReader.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/ShuffleBatchReader.java
index b78b796bb35..1b13403e1ef 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/ShuffleBatchReader.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/ShuffleBatchReader.java
@@ -27,7 +27,7 @@ import org.checkerframework.checker.nullness.qual.Nullable;
*/
public interface ShuffleBatchReader {
/** The result returned by #read. */
- public static class Batch {
+ class Batch {
public final List<ShuffleEntry> entries;
public final @Nullable ShufflePosition nextStartPosition;
@@ -49,6 +49,6 @@ public interface ShuffleBatchReader {
* key is greater than or equal to startPosition).
* @return the first {@link Batch} of entries
*/
- public Batch read(@Nullable ShufflePosition startPosition, @Nullable ShufflePosition endPosition)
+ Batch read(@Nullable ShufflePosition startPosition, @Nullable ShufflePosition endPosition)
throws IOException;
}
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/ShuffleEntry.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/ShuffleEntry.java
index f9d98adb9da..89778cd3a3e 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/ShuffleEntry.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/ShuffleEntry.java
@@ -18,6 +18,8 @@
package org.apache.beam.runners.dataflow.worker.util.common.worker;
import java.util.Arrays;
+import java.util.Objects;
+import org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.ByteString;
import org.checkerframework.checker.nullness.qual.Nullable;
/** Entry written to/read from a shuffle dataset. */
@@ -26,18 +28,19 @@ import org.checkerframework.checker.nullness.qual.Nullable;
})
public class ShuffleEntry {
final ShufflePosition position;
- final byte[] key;
- final byte[] secondaryKey;
- final byte[] value;
+ final ByteString key;
+ final ByteString secondaryKey;
+ final ByteString value;
- public ShuffleEntry(byte[] key, byte[] secondaryKey, byte[] value) {
+ public ShuffleEntry(ByteString key, ByteString secondaryKey, ByteString value) {
this.position = null;
this.key = key;
this.secondaryKey = secondaryKey;
this.value = value;
}
- public ShuffleEntry(ShufflePosition position, byte[] key, byte[] secondaryKey, byte[] value) {
+ public ShuffleEntry(
+ ShufflePosition position, ByteString key, ByteString secondaryKey, ByteString value) {
this.position = position;
this.key = key;
this.secondaryKey = secondaryKey;
@@ -48,23 +51,23 @@ public class ShuffleEntry {
return position;
}
- public byte[] getKey() {
+ public ByteString getKey() {
return key;
}
- public byte[] getSecondaryKey() {
+ public ByteString getSecondaryKey() {
return secondaryKey;
}
- public byte[] getValue() {
+ public ByteString getValue() {
return value;
}
/** Returns the size of this entry in bytes, excluding {@code position}. */
public int length() {
- return (key == null ? 0 : key.length)
- + (secondaryKey == null ? 0 : secondaryKey.length)
- + (value == null ? 0 : value.length);
+ return (key == null ? 0 : key.size())
+ + (secondaryKey == null ? 0 : secondaryKey.size())
+ + (value == null ? 0 : value.size());
}
@Override
@@ -72,11 +75,11 @@ public class ShuffleEntry {
return "ShuffleEntry("
+ position.toString()
+ ","
- + byteArrayToString(key)
+ + byteArrayToString(key.toByteArray())
+ ","
- + byteArrayToString(secondaryKey)
+ + byteArrayToString(secondaryKey.toByteArray())
+ ","
- + byteArrayToString(value)
+ + byteArrayToString(value.toByteArray())
+ ")";
}
@@ -93,12 +96,10 @@ public class ShuffleEntry {
}
if (o instanceof ShuffleEntry) {
ShuffleEntry that = (ShuffleEntry) o;
- return (this.position == null ? that.position == null : this.position.equals(that.position))
- && (this.key == null ? that.key == null : Arrays.equals(this.key, that.key))
- && (this.secondaryKey == null
- ? that.secondaryKey == null
- : Arrays.equals(this.secondaryKey, that.secondaryKey))
- && (this.value == null ? that.value == null : Arrays.equals(this.value, that.value));
+ return (Objects.equals(this.position, that.position))
+ && (Objects.equals(this.key, that.key))
+ && (Objects.equals(this.secondaryKey, that.secondaryKey))
+ && (Objects.equals(this.value, that.value));
}
return false;
}
@@ -107,8 +108,8 @@ public class ShuffleEntry {
public int hashCode() {
return getClass().hashCode()
+ (position == null ? 0 : position.hashCode())
- + (key == null ? 0 : Arrays.hashCode(key))
- + (secondaryKey == null ? 0 : Arrays.hashCode(secondaryKey))
- + (value == null ? 0 : Arrays.hashCode(value));
+ + (key == null ? 0 : key.hashCode())
+ + (secondaryKey == null ? 0 : secondaryKey.hashCode())
+ + (value == null ? 0 : value.hashCode());
}
}
diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/ShuffleEntryReader.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/ShuffleEntryReader.java
index a1725dc2c83..def545c188f 100644
--- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/ShuffleEntryReader.java
+++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/util/common/worker/ShuffleEntryReader.java
@@ -39,6 +39,6 @@ public interface ShuffleEntryReader extends Closeable {
* key is greater than or equal to startPosition).
* @return a {@link Reiterator} over the requested range of entries.
*/
- public Reiterator<ShuffleEntry> read(
+ Reiterator<ShuffleEntry> read(
@Nullable ShufflePosition startPosition, @Nullable ShufflePosition endPosition);
}
diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/GroupingShuffleReaderTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/GroupingShuffleReaderTest.java
index 3793aa3bb75..5eaf2b626ad 100644
--- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/GroupingShuffleReaderTest.java
+++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/GroupingShuffleReaderTest.java
@@ -62,6 +62,7 @@ import org.apache.beam.runners.dataflow.worker.util.common.worker.ByteArrayShuff
import org.apache.beam.runners.dataflow.worker.util.common.worker.ExecutorTestUtils;
import org.apache.beam.runners.dataflow.worker.util.common.worker.NativeReader;
import org.apache.beam.runners.dataflow.worker.util.common.worker.ShuffleEntry;
+import org.apache.beam.runners.dataflow.worker.util.common.worker.ShufflePosition;
import org.apache.beam.runners.dataflow.worker.util.common.worker.ShuffleReadCounter;
import org.apache.beam.runners.dataflow.worker.util.common.worker.ShuffleReadCounterFactory;
import org.apache.beam.runners.dataflow.worker.util.common.worker.Sink;
@@ -79,6 +80,7 @@ import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.util.common.Reiterable;
import org.apache.beam.sdk.util.common.Reiterator;
import org.apache.beam.sdk.values.KV;
+import org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.ByteString;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.joda.time.Duration;
@@ -589,6 +591,15 @@ public class GroupingShuffleReaderTest {
return fabricatePosition(shard, key == null ? null : Arrays.hashCode(key));
}
+ static ShuffleEntry newShuffleEntry(
+ ShufflePosition position, byte[] key, byte[] secondaryKey, byte[] value) {
+ return new ShuffleEntry(
+ position,
+ key == null ? null : ByteString.copyFrom(key),
+ secondaryKey == null ? null : ByteString.copyFrom(secondaryKey),
+ value == null ? null : ByteString.copyFrom(value));
+ }
+
static ByteArrayShufflePosition fabricatePosition(int shard, @Nullable Integer keyHash)
throws Exception {
ByteArrayOutputStream os = new ByteArrayOutputStream();
@@ -611,13 +622,14 @@ public class GroupingShuffleReaderTest {
for (int i = 0; i < kNumRecords; ++i) {
byte[] key = CoderUtils.encodeToByteArray(BigEndianIntegerCoder.of(), i);
shuffleReader.addEntry(
- new ShuffleEntry(fabricatePosition(kFirstShard, key), key, EMPTY_BYTE_ARRAY, key));
+ newShuffleEntry(fabricatePosition(kFirstShard, key), key, EMPTY_BYTE_ARRAY, key));
}
// Note that TestShuffleReader start/end positions are in the
// space of keys not the positions (TODO: should probably always
// use positions instead).
- String stop = encodeBase64URLSafeString(fabricatePosition(kNumRecords).getPosition());
+ String stop =
+ encodeBase64URLSafeString(fabricatePosition(kNumRecords).getPosition().toByteArray());
TestOperationContext operationContext = TestOperationContext.create();
GroupingShuffleReader<Integer, Integer> groupingShuffleReader =
new GroupingShuffleReader<>(
@@ -678,7 +690,7 @@ public class GroupingShuffleReaderTest {
for (int i = 0; i < kNumRecords; ++i) {
byte[] key = CoderUtils.encodeToByteArray(BigEndianIntegerCoder.of(), i);
ShuffleEntry entry =
- new ShuffleEntry(fabricatePosition(kFirstShard, i), key, EMPTY_BYTE_ARRAY, key);
+ newShuffleEntry(fabricatePosition(kFirstShard, i), key, EMPTY_BYTE_ARRAY, key);
shuffleReader.addEntry(entry);
}
@@ -742,11 +754,12 @@ public class GroupingShuffleReaderTest {
private Position makeShufflePosition(int shard, byte[] key) throws Exception {
return new Position()
- .setShufflePosition(encodeBase64URLSafeString(fabricatePosition(shard, key).getPosition()));
+ .setShufflePosition(
+ encodeBase64URLSafeString(fabricatePosition(shard, key).getPosition().toByteArray()));
}
- private Position makeShufflePosition(byte[] position) throws Exception {
- return new Position().setShufflePosition(encodeBase64URLSafeString(position));
+ private Position makeShufflePosition(ByteString position) throws Exception {
+ return new Position().setShufflePosition(encodeBase64URLSafeString(position.toByteArray()));
}
@Test
@@ -784,7 +797,7 @@ public class GroupingShuffleReaderTest {
for (int i = 0; i < kNumRecords; ++i) {
byte[] keyByte = CoderUtils.encodeToByteArray(BigEndianIntegerCoder.of(), i);
ShuffleEntry entry =
- new ShuffleEntry(
+ newShuffleEntry(
fabricatePosition(kFirstShard, keyByte), keyByte, EMPTY_BYTE_ARRAY, keyByte);
shuffleReader.addEntry(entry);
}
@@ -793,7 +806,7 @@ public class GroupingShuffleReaderTest {
byte[] keyByte = CoderUtils.encodeToByteArray(BigEndianIntegerCoder.of(), i);
ShuffleEntry entry =
- new ShuffleEntry(
+ newShuffleEntry(
fabricatePosition(kSecondShard, keyByte), keyByte, EMPTY_BYTE_ARRAY, keyByte);
shuffleReader.addEntry(entry);
}
@@ -813,7 +826,7 @@ public class GroupingShuffleReaderTest {
iter.requestDynamicSplit(splitRequestAtPosition(makeShufflePosition(kSecondShard, null)));
assertNotNull(dynamicSplitResult);
assertEquals(
- encodeBase64URLSafeString(fabricatePosition(kSecondShard).getPosition()),
+ encodeBase64URLSafeString(fabricatePosition(kSecondShard).getPosition().toByteArray()),
positionFromSplitResult(dynamicSplitResult).getShufflePosition());
for (; iter.advance(); ++i) {
@@ -882,7 +895,7 @@ public class GroupingShuffleReaderTest {
ByteArrayShufflePosition position = fabricatePosition(i);
byte[] keyByte = CoderUtils.encodeToByteArray(BigEndianIntegerCoder.of(), i);
positionsList.add(position);
- ShuffleEntry entry = new ShuffleEntry(position, keyByte, EMPTY_BYTE_ARRAY, keyByte);
+ ShuffleEntry entry = newShuffleEntry(position, keyByte, EMPTY_BYTE_ARRAY, keyByte);
shuffleReader.addEntry(entry);
}
@@ -906,7 +919,7 @@ public class GroupingShuffleReaderTest {
// Cannot split since all input was consumed.
Position proposedSplitPosition = new Position();
- String stop = encodeBase64URLSafeString(fabricatePosition(0).getPosition());
+ String stop = encodeBase64URLSafeString(fabricatePosition(0).getPosition().toByteArray());
proposedSplitPosition.setShufflePosition(stop);
assertNull(
iter.requestDynamicSplit(
@@ -929,14 +942,15 @@ public class GroupingShuffleReaderTest {
for (int i = 0; i < kNumRecords; ++i) {
byte[] key = CoderUtils.encodeToByteArray(BigEndianIntegerCoder.of(), i);
shuffleReader.addEntry(
- new ShuffleEntry(fabricatePosition(kFirstShard, key), key, EMPTY_BYTE_ARRAY, key));
+ newShuffleEntry(fabricatePosition(kFirstShard, key), key, EMPTY_BYTE_ARRAY, key));
}
TestShuffleReadCounterFactory shuffleReadCounterFactory = new TestShuffleReadCounterFactory();
// Note that TestShuffleReader start/end positions are in the
// space of keys not the positions (TODO: should probably always
// use positions instead).
- String stop = encodeBase64URLSafeString(fabricatePosition(kNumRecords).getPosition());
+ String stop =
+ encodeBase64URLSafeString(fabricatePosition(kNumRecords).getPosition().toByteArray());
TestOperationContext operationContext = TestOperationContext.create();
GroupingShuffleReader<Integer, Integer> groupingShuffleReader =
new GroupingShuffleReader<>(
diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/ShuffleSinkTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/ShuffleSinkTest.java
index 140449628e9..df14e60ad1e 100644
--- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/ShuffleSinkTest.java
+++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/ShuffleSinkTest.java
@@ -17,7 +17,6 @@
*/
package org.apache.beam.runners.dataflow.worker;
-import java.io.ByteArrayInputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
@@ -39,6 +38,7 @@ import org.apache.beam.sdk.transforms.windowing.PaneInfo;
import org.apache.beam.sdk.util.CoderUtils;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.KV;
+import org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.ByteString;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
import org.joda.time.Duration;
import org.joda.time.Instant;
@@ -113,8 +113,9 @@ public class ShuffleSinkTest {
List<Integer> actual = new ArrayList<>();
for (ShuffleEntry record : records) {
// Ignore the key.
- byte[] valueBytes = record.getValue();
- WindowedValue<Integer> value = CoderUtils.decodeFromByteArray(windowedValueCoder, valueBytes);
+ ByteString valueBytes = record.getValue();
+ WindowedValue<Integer> value =
+ CoderUtils.decodeFromByteString(windowedValueCoder, valueBytes);
Assert.assertEquals(Lists.newArrayList(GlobalWindow.INSTANCE), value.getWindows());
actual.add(value.getValue());
}
@@ -156,13 +157,13 @@ public class ShuffleSinkTest {
List<KV<Integer, String>> actual = new ArrayList<>();
for (ShuffleEntry record : records) {
- byte[] keyBytes = record.getKey();
- byte[] valueBytes = record.getValue();
+ ByteString keyBytes = record.getKey();
+ ByteString valueBytes = record.getValue();
Assert.assertEquals(
- timestamp, CoderUtils.decodeFromByteArray(InstantCoder.of(), record.getSecondaryKey()));
+ timestamp, CoderUtils.decodeFromByteString(InstantCoder.of(), record.getSecondaryKey()));
- Integer key = CoderUtils.decodeFromByteArray(BigEndianIntegerCoder.of(), keyBytes);
- String valueElem = CoderUtils.decodeFromByteArray(StringUtf8Coder.of(), valueBytes);
+ Integer key = CoderUtils.decodeFromByteString(BigEndianIntegerCoder.of(), keyBytes);
+ String valueElem = CoderUtils.decodeFromByteString(StringUtf8Coder.of(), valueBytes);
actual.add(KV.of(key, valueElem));
}
@@ -201,14 +202,14 @@ public class ShuffleSinkTest {
List<KV<Integer, KV<String, Integer>>> actual = new ArrayList<>();
for (ShuffleEntry record : records) {
- byte[] keyBytes = record.getKey();
- byte[] valueBytes = record.getValue();
- byte[] sortKeyBytes = record.getSecondaryKey();
-
- Integer key = CoderUtils.decodeFromByteArray(BigEndianIntegerCoder.of(), keyBytes);
- ByteArrayInputStream bais = new ByteArrayInputStream(sortKeyBytes);
- String sortKey = StringUtf8Coder.of().decode(bais);
- Integer sortValue = CoderUtils.decodeFromByteArray(BigEndianIntegerCoder.of(), valueBytes);
+ ByteString keyBytes = record.getKey();
+ ByteString valueBytes = record.getValue();
+ ByteString sortKeyBytes = record.getSecondaryKey();
+
+ Integer key = CoderUtils.decodeFromByteString(BigEndianIntegerCoder.of(), keyBytes);
+ String sortKey =
+ CoderUtils.decodeFromByteString(StringUtf8Coder.of(), sortKeyBytes, Coder.Context.NESTED);
+ Integer sortValue = CoderUtils.decodeFromByteString(BigEndianIntegerCoder.of(), valueBytes);
actual.add(KV.of(key, KV.of(sortKey, sortValue)));
}
diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/TestShuffleReader.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/TestShuffleReader.java
index 11808e98936..6c0a6cafc4a 100644
--- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/TestShuffleReader.java
+++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/TestShuffleReader.java
@@ -17,7 +17,6 @@
*/
package org.apache.beam.runners.dataflow.worker;
-import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.Iterator;
@@ -30,13 +29,13 @@ import org.apache.beam.runners.dataflow.worker.util.common.worker.ShuffleEntry;
import org.apache.beam.runners.dataflow.worker.util.common.worker.ShuffleEntryReader;
import org.apache.beam.runners.dataflow.worker.util.common.worker.ShufflePosition;
import org.apache.beam.sdk.util.common.Reiterator;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.primitives.UnsignedBytes;
+import org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.ByteString;
import org.checkerframework.checker.nullness.qual.Nullable;
/** A fake implementation of a ShuffleEntryReader, for testing. */
public class TestShuffleReader implements ShuffleEntryReader {
// Sorts by secondary key where an empty secondary key sorts before all other secondary keys.
- static final Comparator<byte[]> SHUFFLE_KEY_COMPARATOR =
+ static final Comparator<ByteString> SHUFFLE_KEY_COMPARATOR =
(o1, o2) -> {
if (o1 == o2) {
return 0;
@@ -47,10 +46,10 @@ public class TestShuffleReader implements ShuffleEntryReader {
if (o2 == null) {
return 1;
}
- return UnsignedBytes.lexicographicalComparator().compare(o1, o2);
+ return ByteString.unsignedLexicographicalComparator().compare(o1, o2);
};
- final TreeMap<byte[], TreeMap<byte[], List<ShuffleEntry>>> records =
+ final TreeMap<ByteString, TreeMap<ByteString, List<ShuffleEntry>>> records =
new TreeMap<>(SHUFFLE_KEY_COMPARATOR);
boolean closed = false;
@@ -59,13 +58,13 @@ public class TestShuffleReader implements ShuffleEntryReader {
public void addEntry(String key, String secondaryKey, String value) {
addEntry(
new ShuffleEntry(
- key.getBytes(StandardCharsets.UTF_8),
- secondaryKey.getBytes(StandardCharsets.UTF_8),
- value.getBytes(StandardCharsets.UTF_8)));
+ ByteString.copyFromUtf8(key),
+ ByteString.copyFromUtf8(secondaryKey),
+ ByteString.copyFromUtf8(value)));
}
public void addEntry(ShuffleEntry entry) {
- TreeMap<byte[], List<ShuffleEntry>> valuesBySecondaryKey = records.get(entry.getKey());
+ TreeMap<ByteString, List<ShuffleEntry>> valuesBySecondaryKey = records.get(entry.getKey());
if (valuesBySecondaryKey == null) {
valuesBySecondaryKey = new TreeMap<>(SHUFFLE_KEY_COMPARATOR);
records.put(entry.getKey(), valuesBySecondaryKey);
@@ -79,7 +78,15 @@ public class TestShuffleReader implements ShuffleEntryReader {
}
public Iterator<ShuffleEntry> read() {
- return read((byte[]) null, (byte[]) null);
+ return read((ByteString) null, null);
+ }
+
+ private ByteString toByteString(ShufflePosition pos) {
+ byte[] posBytes = ByteArrayShufflePosition.getPosition(pos);
+ if (posBytes == null) {
+ return null;
+ }
+ return ByteString.copyFrom(posBytes);
}
@Override
@@ -88,24 +95,22 @@ public class TestShuffleReader implements ShuffleEntryReader {
if (closed) {
throw new RuntimeException("Cannot read from a closed reader.");
}
- return read(
- ByteArrayShufflePosition.getPosition(startPosition),
- ByteArrayShufflePosition.getPosition(endPosition));
+ return read(toByteString(startPosition), toByteString(endPosition));
}
public Reiterator<ShuffleEntry> read(@Nullable String startKey, @Nullable String endKey) {
return read(
- startKey == null ? null : startKey.getBytes(StandardCharsets.UTF_8),
- endKey == null ? null : endKey.getBytes(StandardCharsets.UTF_8));
+ startKey == null ? null : ByteString.copyFromUtf8(startKey),
+ endKey == null ? null : ByteString.copyFromUtf8(endKey));
}
- public Reiterator<ShuffleEntry> read(byte @Nullable [] startKey, byte @Nullable [] endKey) {
+ public Reiterator<ShuffleEntry> read(@Nullable ByteString startKey, @Nullable ByteString endKey) {
List<ShuffleEntry> res = new ArrayList<>();
- for (byte[] key : records.keySet()) {
+ for (ByteString key : records.keySet()) {
if ((startKey == null || SHUFFLE_KEY_COMPARATOR.compare(startKey, key) <= 0)
&& (endKey == null || SHUFFLE_KEY_COMPARATOR.compare(key, endKey) < 0)) {
- TreeMap<byte[], List<ShuffleEntry>> entriesBySecondaryKey = records.get(key);
- for (Map.Entry<byte[], List<ShuffleEntry>> entries : entriesBySecondaryKey.entrySet()) {
+ TreeMap<ByteString, List<ShuffleEntry>> entriesBySecondaryKey = records.get(key);
+ for (Map.Entry<ByteString, List<ShuffleEntry>> entries : entriesBySecondaryKey.entrySet()) {
res.addAll(entries.getValue());
}
}
diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/TestShuffleReaderTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/TestShuffleReaderTest.java
index ace0eb0597a..88530afa35a 100644
--- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/TestShuffleReaderTest.java
+++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/TestShuffleReaderTest.java
@@ -20,7 +20,6 @@ package org.apache.beam.runners.dataflow.worker;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
-import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
@@ -116,10 +115,8 @@ public class TestShuffleReaderTest {
ShuffleEntry entry = iter.next();
actual.add(
KV.of(
- new String(entry.getKey(), StandardCharsets.UTF_8),
- KV.of(
- new String(entry.getSecondaryKey(), StandardCharsets.UTF_8),
- new String(entry.getValue(), StandardCharsets.UTF_8))));
+ (entry.getKey().toStringUtf8()),
+ KV.of((entry.getSecondaryKey().toStringUtf8()), (entry.getValue().toStringUtf8()))));
}
return actual;
}
diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/TestShuffleWriter.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/TestShuffleWriter.java
index d22bce211e7..d9507cf38ea 100644
--- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/TestShuffleWriter.java
+++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/TestShuffleWriter.java
@@ -23,6 +23,7 @@ import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import org.apache.beam.runners.dataflow.worker.util.common.worker.ShuffleEntry;
+import org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.ByteString;
/** A fake implementation of a ShuffleEntryWriter, for testing. */
public class TestShuffleWriter implements ShuffleWriter {
@@ -46,7 +47,9 @@ public class TestShuffleWriter implements ShuffleWriter {
byte[] value = new byte[dais.readInt()];
dais.readFully(value);
- ShuffleEntry entry = new ShuffleEntry(key, sortKey, value);
+ ShuffleEntry entry =
+ new ShuffleEntry(
+ ByteString.copyFrom(key), ByteString.copyFrom(sortKey), ByteString.copyFrom(value));
records.add(entry);
long size = entry.length();
diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/common/worker/BatchingShuffleEntryReaderTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/common/worker/BatchingShuffleEntryReaderTest.java
index 798acf39bbb..6763068c221 100644
--- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/common/worker/BatchingShuffleEntryReaderTest.java
+++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/common/worker/BatchingShuffleEntryReaderTest.java
@@ -30,6 +30,7 @@ import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.apache.beam.sdk.util.common.Reiterator;
+import org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.ByteString;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
@@ -52,6 +53,11 @@ public final class BatchingShuffleEntryReaderTest {
private static final ShufflePosition SECOND_NEXT_START_POSITION =
ByteArrayShufflePosition.of("next-second".getBytes(StandardCharsets.UTF_8));
+ static ShuffleEntry newShuffleEntry(byte[] key, byte[] secondaryKey, byte[] value) {
+ return new ShuffleEntry(
+ ByteString.copyFrom(key), ByteString.copyFrom(secondaryKey), ByteString.copyFrom(value));
+ }
+
@Mock private ShuffleBatchReader batchReader;
private ShuffleEntryReader reader;
@@ -63,8 +69,8 @@ public final class BatchingShuffleEntryReaderTest {
@Test
public void readerCanRead() throws Exception {
- ShuffleEntry e1 = new ShuffleEntry(KEY, SKEY, VALUE);
- ShuffleEntry e2 = new ShuffleEntry(KEY, SKEY, VALUE);
+ ShuffleEntry e1 = newShuffleEntry(KEY, SKEY, VALUE);
+ ShuffleEntry e2 = newShuffleEntry(KEY, SKEY, VALUE);
ArrayList<ShuffleEntry> entries = new ArrayList<>();
entries.add(e1);
entries.add(e2);
@@ -76,8 +82,8 @@ public final class BatchingShuffleEntryReaderTest {
@Test
public void readerIteratorCanBeCopied() throws Exception {
- ShuffleEntry e1 = new ShuffleEntry(KEY, SKEY, VALUE);
- ShuffleEntry e2 = new ShuffleEntry(KEY, SKEY, VALUE);
+ ShuffleEntry e1 = newShuffleEntry(KEY, SKEY, VALUE);
+ ShuffleEntry e2 = newShuffleEntry(KEY, SKEY, VALUE);
ArrayList<ShuffleEntry> entries = new ArrayList<>();
entries.add(e1);
entries.add(e2);
@@ -97,9 +103,9 @@ public final class BatchingShuffleEntryReaderTest {
@Test
public void readerShouldMergeMultipleBatchResults() throws Exception {
- ShuffleEntry e1 = new ShuffleEntry(KEY, SKEY, VALUE);
+ ShuffleEntry e1 = newShuffleEntry(KEY, SKEY, VALUE);
List<ShuffleEntry> e1s = Collections.singletonList(e1);
- ShuffleEntry e2 = new ShuffleEntry(KEY, SKEY, VALUE);
+ ShuffleEntry e2 = newShuffleEntry(KEY, SKEY, VALUE);
List<ShuffleEntry> e2s = Collections.singletonList(e2);
when(batchReader.read(START_POSITION, END_POSITION))
.thenReturn(new ShuffleBatchReader.Batch(e1s, NEXT_START_POSITION));
@@ -117,7 +123,7 @@ public final class BatchingShuffleEntryReaderTest {
public void readerShouldMergeMultipleBatchResultsIncludingEmptyShards() throws Exception {
List<ShuffleEntry> e1s = new ArrayList<>();
List<ShuffleEntry> e2s = new ArrayList<>();
- ShuffleEntry e3 = new ShuffleEntry(KEY, SKEY, VALUE);
+ ShuffleEntry e3 = newShuffleEntry(KEY, SKEY, VALUE);
List<ShuffleEntry> e3s = Collections.singletonList(e3);
when(batchReader.read(START_POSITION, END_POSITION))
.thenReturn(new ShuffleBatchReader.Batch(e1s, NEXT_START_POSITION));
diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/common/worker/GroupingShuffleEntryIteratorTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/common/worker/GroupingShuffleEntryIteratorTest.java
index 446189cb49c..d6f7f565cd7 100644
--- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/common/worker/GroupingShuffleEntryIteratorTest.java
+++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/common/worker/GroupingShuffleEntryIteratorTest.java
@@ -41,6 +41,7 @@ import org.apache.beam.runners.dataflow.worker.counters.NameContext;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.util.common.Reiterator;
+import org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.ByteString;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Charsets;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
import org.checkerframework.checker.nullness.qual.Nullable;
@@ -130,9 +131,9 @@ public class GroupingShuffleEntryIteratorTest {
return new ShuffleEntry(
/* use key itself as position */
ByteArrayShufflePosition.of(key.getBytes(Charsets.UTF_8)),
- key.getBytes(Charsets.UTF_8),
- new byte[0],
- value.getBytes(Charsets.UTF_8));
+ ByteString.copyFrom(key.getBytes(Charsets.UTF_8)),
+ ByteString.copyFrom(new byte[0]),
+ ByteString.copyFrom(value.getBytes(Charsets.UTF_8)));
}
@Test
diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/common/worker/ShuffleEntryTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/common/worker/ShuffleEntryTest.java
index 38cb642b2d3..262687fc9af 100644
--- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/common/worker/ShuffleEntryTest.java
+++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/common/worker/ShuffleEntryTest.java
@@ -23,6 +23,7 @@ import static org.hamcrest.Matchers.not;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
+import org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.ByteString;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -30,9 +31,9 @@ import org.junit.runners.JUnit4;
/** Unit tests for {@link ShuffleEntry}. */
@RunWith(JUnit4.class)
public class ShuffleEntryTest {
- private static final byte[] KEY = {0xA};
- private static final byte[] SKEY = {0xB};
- private static final byte[] VALUE = {0xC};
+ private static final ByteString KEY = ByteString.copyFrom(new byte[] {0xA});
+ private static final ByteString SKEY = ByteString.copyFrom(new byte[] {0xB});
+ private static final ByteString VALUE = ByteString.copyFrom(new byte[] {0xC});
@Test
public void accessors() {
@@ -52,7 +53,7 @@ public class ShuffleEntryTest {
@Test
public void equalsForEqualEntries() {
ShuffleEntry entry0 = new ShuffleEntry(KEY, SKEY, VALUE);
- ShuffleEntry entry1 = new ShuffleEntry(KEY.clone(), SKEY.clone(), VALUE.clone());
+ ShuffleEntry entry1 = new ShuffleEntry(KEY.concat(ByteString.EMPTY), SKEY, VALUE);
assertEquals(entry0, entry1);
assertEquals(entry1, entry0);
@@ -71,7 +72,7 @@ public class ShuffleEntryTest {
@Test
public void notEqualsWhenKeysDiffer() {
- final byte[] otherKey = {0x1};
+ final ByteString otherKey = ByteString.copyFrom(new byte[] {0x1});
ShuffleEntry entry0 = new ShuffleEntry(KEY, SKEY, VALUE);
ShuffleEntry entry1 = new ShuffleEntry(otherKey, SKEY, VALUE);
@@ -92,7 +93,7 @@ public class ShuffleEntryTest {
@Test
public void notEqualsWhenSecondaryKeysDiffer() {
- final byte[] otherSKey = {0x2};
+ final ByteString otherSKey = ByteString.copyFrom(new byte[] {0x2});
ShuffleEntry entry0 = new ShuffleEntry(KEY, SKEY, VALUE);
ShuffleEntry entry1 = new ShuffleEntry(KEY, otherSKey, VALUE);
@@ -113,7 +114,7 @@ public class ShuffleEntryTest {
@Test
public void notEqualsWhenValuesDiffer() {
- final byte[] otherValue = {0x2};
+ final ByteString otherValue = ByteString.copyFrom(new byte[] {0x2});
ShuffleEntry entry0 = new ShuffleEntry(KEY, SKEY, VALUE);
ShuffleEntry entry1 = new ShuffleEntry(KEY, SKEY, otherValue);
@@ -134,7 +135,7 @@ public class ShuffleEntryTest {
@Test
public void emptyNotTheSameAsNull() {
- final byte[] empty = {};
+ final ByteString empty = ByteString.EMPTY;
ShuffleEntry entry0 = new ShuffleEntry(null, null, null);
ShuffleEntry entry1 = new ShuffleEntry(empty, empty, empty);
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/CoderUtils.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/CoderUtils.java
index a650efcf1ce..53edd1062f7 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/CoderUtils.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/CoderUtils.java
@@ -27,6 +27,7 @@ import java.lang.reflect.ParameterizedType;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.CoderException;
import org.apache.beam.sdk.values.TypeDescriptor;
+import org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.ByteString;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Throwables;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.BaseEncoding;
@@ -107,6 +108,29 @@ public final class CoderUtils {
}
}
+ /**
+ * Decodes a value from the given ByteString, validating that no bytes are remaining once decoded.
+ */
+ public static <T> T decodeFromByteString(Coder<T> coder, ByteString encodedValue)
+ throws IOException {
+ return decodeFromByteString(coder, encodedValue, Coder.Context.OUTER);
+ }
+
+ /**
+ * Decodes a value from the given ByteString using a given context, validating that no bytes are
+ * remaining once decoded.
+ */
+ public static <T> T decodeFromByteString(
+ Coder<T> coder, ByteString encodedValue, Coder.Context context) throws IOException {
+ InputStream stream = encodedValue.newInput();
+ T result = coder.decode(stream, context);
+ if (stream.available() != 0) {
+ throw new CoderException(
+ stream.available() + " unexpected extra bytes after decoding " + result);
+ }
+ return result;
+ }
+
/**
* Decodes a value from the given {@code stream}, which should be a stream that never throws
* {@code IOException}, such as {@code ByteArrayInputStream} or {@link
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/util/CoderUtilsTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/CoderUtilsTest.java
index a681fbe8e7f..10681618698 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/util/CoderUtilsTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/CoderUtilsTest.java
@@ -17,6 +17,8 @@
*/
package org.apache.beam.sdk.util;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertThrows;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.doThrow;
@@ -28,7 +30,9 @@ import org.apache.beam.sdk.coders.AtomicCoder;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.Coder.Context;
import org.apache.beam.sdk.coders.CoderException;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.testing.CoderPropertiesTest.ClosingCoder;
+import org.apache.beam.vendor.grpc.v1p43p2.com.google.protobuf.ByteString;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
@@ -117,4 +121,25 @@ public class CoderUtilsTest {
expectedException.expectMessage("Caller does not own the underlying");
CoderUtils.encodeToByteArray(new ClosingCoder(), "test-value", Context.NESTED);
}
+
+ @Test
+ public void testDecodeFromByteString() throws Exception {
+ String expected = "test string";
+ byte[] data = CoderUtils.encodeToByteArray(StringUtf8Coder.of(), expected);
+ ByteString byteString = ByteString.copyFrom(data);
+ String result = CoderUtils.decodeFromByteString(StringUtf8Coder.of(), byteString);
+ assertEquals(expected, result);
+ }
+
+ @Test
+ public void testDecodeFromByteStringWithExtraDataThrows() throws Exception {
+ String expected = "test string";
+ byte[] data = CoderUtils.encodeToByteArray(StringUtf8Coder.of(), expected, Context.NESTED);
+ ByteString byteString = ByteString.copyFrom(data).concat(ByteString.copyFromUtf8("more text"));
+
+ assertThrows(
+ "9 unexpected extra bytes after decoding test string",
+ CoderException.class,
+ () -> CoderUtils.decodeFromByteString(StringUtf8Coder.of(), byteString, Context.NESTED));
+ }
}