You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by mm...@apache.org on 2022/04/07 16:28:59 UTC
[beam] branch master updated: [BEAM-14104] Support shard aware aggregation in Kinesis writer.
This is an automated email from the ASF dual-hosted git repository.
mmack pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 416fc9ba752 [BEAM-14104] Support shard aware aggregation in Kinesis writer.
new ad4561ea5d8 Merge pull request #17113 from mosche/BEAM-14104-ShardAwareAggregation
416fc9ba752 is described below
commit 416fc9ba75251f17543f8556071885a026ef5745
Author: Moritz Mack <mm...@talend.com>
AuthorDate: Thu Mar 17 18:00:38 2022 +0100
[BEAM-14104] Support shard aware aggregation in Kinesis writer.
---
.../apache/beam/sdk/io/aws2/common/ClientPool.java | 123 ----------
.../apache/beam/sdk/io/aws2/common/ObjectPool.java | 151 ++++++++++++
.../sdk/io/aws2/common/RetryConfiguration.java | 2 +-
.../apache/beam/sdk/io/aws2/kinesis/KinesisIO.java | 270 +++++++++++++++++----
.../sdk/io/aws2/kinesis/KinesisPartitioner.java | 28 ++-
.../sdk/io/aws2/kinesis/RecordsAggregator.java | 4 -
.../{ClientPoolTest.java => ObjectPoolTest.java} | 99 ++++----
.../sdk/io/aws2/kinesis/KinesisIOWriteTest.java | 243 ++++++++++++++++---
.../sdk/io/aws2/kinesis/PutRecordsHelpers.java | 8 +
.../sdk/io/aws2/kinesis/testing/KinesisIOIT.java | 5 +-
10 files changed, 677 insertions(+), 256 deletions(-)
diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/common/ClientPool.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/common/ClientPool.java
deleted file mode 100644
index 1a7cd29ec98..00000000000
--- a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/common/ClientPool.java
+++ /dev/null
@@ -1,123 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.sdk.io.aws2.common;
-
-import java.util.function.BiFunction;
-import org.apache.beam.sdk.io.aws2.options.AwsOptions;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.BiMap;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.HashBiMap;
-import org.apache.commons.lang3.tuple.Pair;
-import org.checkerframework.checker.nullness.qual.Nullable;
-import software.amazon.awssdk.awscore.client.builder.AwsClientBuilder;
-
-/**
- * Reference counting pool to easily share AWS clients or similar by individual client provider and
- * configuration (optional).
- *
- * <p>NOTE: This relies heavily on the implementation of {@link #equals(Object)} for {@link
- * ProviderT} and {@link ConfigT}. If not implemented properly, clients can't be shared between
- * instances of {@link org.apache.beam.sdk.transforms.DoFn}.
- *
- * @param <ProviderT> Client provider
- * @param <ConfigT> Optional, nullable configuration
- * @param <ClientT> Client
- */
-public class ClientPool<ProviderT, ConfigT, ClientT extends AutoCloseable> {
- private final BiMap<Pair<ProviderT, ConfigT>, RefCounted> pool = HashBiMap.create(2);
- private final BiFunction<ProviderT, ConfigT, ClientT> builder;
-
- public static <
- ClientT extends AutoCloseable, BuilderT extends AwsClientBuilder<BuilderT, ClientT>>
- ClientPool<AwsOptions, ClientConfiguration, ClientT> pooledClientFactory(BuilderT builder) {
- return new ClientPool<>((opts, conf) -> ClientBuilderFactory.buildClient(opts, builder, conf));
- }
-
- public ClientPool(BiFunction<ProviderT, ConfigT, ClientT> builder) {
- this.builder = builder;
- }
-
- /** Retain a reference to a shared client instance. If not available, an instance is created. */
- public ClientT retain(ProviderT provider, @Nullable ConfigT config) {
- @SuppressWarnings("nullness")
- Pair<ProviderT, ConfigT> key = Pair.of(provider, config);
- synchronized (pool) {
- RefCounted ref = pool.computeIfAbsent(key, RefCounted::new);
- ref.count++;
- return ref.client;
- }
- }
-
- /**
- * Release a reference to a shared client instance using {@link ProviderT} and {@link ConfigT} .
- * If that instance is not used anymore, it will be removed and destroyed.
- */
- public void release(ProviderT provider, @Nullable ConfigT config) throws Exception {
- @SuppressWarnings("nullness")
- Pair<ProviderT, ConfigT> key = Pair.of(provider, config);
- RefCounted ref;
- synchronized (pool) {
- ref = pool.get(key);
- if (ref == null || --ref.count > 0) {
- return;
- }
- pool.remove(key);
- }
- ref.client.close();
- }
-
- /**
- * Release a reference to a shared client instance. If that instance is not used anymore, it will
- * be removed and destroyed.
- */
- public void release(ClientT client) throws Exception {
- Pair<ProviderT, ConfigT> pair = pool.inverse().get(new RefCounted(client));
- if (pair != null) {
- release(pair.getLeft(), pair.getRight());
- }
- }
-
- private class RefCounted {
- private int count = 0;
- private final ClientT client;
-
- RefCounted(ClientT client) {
- this.client = client;
- }
-
- RefCounted(Pair<ProviderT, ConfigT> key) {
- this(builder.apply(key.getLeft(), key.getRight()));
- }
-
- @Override
- public boolean equals(@Nullable Object o) {
- if (this == o) {
- return true;
- }
- if (o == null || getClass() != o.getClass()) {
- return false;
- }
- // only identity of ref counted client matters
- return client == ((RefCounted) o).client;
- }
-
- @Override
- public int hashCode() {
- return client.hashCode();
- }
- }
-}
diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/common/ObjectPool.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/common/ObjectPool.java
new file mode 100644
index 00000000000..a17c6b56e5f
--- /dev/null
+++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/common/ObjectPool.java
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.aws2.common;
+
+import static org.apache.beam.sdk.io.aws2.common.ClientBuilderFactory.buildClient;
+
+import java.util.function.Function;
+import org.apache.beam.sdk.annotations.Experimental;
+import org.apache.beam.sdk.annotations.Internal;
+import org.apache.beam.sdk.function.ThrowingConsumer;
+import org.apache.beam.sdk.io.aws2.options.AwsOptions;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.BiMap;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.HashBiMap;
+import org.apache.commons.lang3.tuple.Pair;
+import org.checkerframework.checker.nullness.qual.NonNull;
+import org.checkerframework.checker.nullness.qual.Nullable;
+import org.slf4j.LoggerFactory;
+import software.amazon.awssdk.awscore.client.builder.AwsClientBuilder;
+import software.amazon.awssdk.core.SdkClient;
+
+/**
+ * Reference counting object pool to easily share & destroy objects.
+ *
+ * <p>Internal only, subject to incompatible changes or removal at any time!
+ *
+ * <p>NOTE: This relies heavily on the implementation of {@link #equals(Object)} for {@link KeyT}.
+ * If not implemented properly, clients can't be shared between instances of {@link
+ * org.apache.beam.sdk.transforms.DoFn}.
+ *
+ * @param <KeyT>> Key to share objects by
+ * @param <ObjectT>> Shared object
+ */
+@Internal
+@Experimental
+public class ObjectPool<KeyT extends @NonNull Object, ObjectT extends @NonNull Object> {
+ private final BiMap<KeyT, RefCounted> pool = HashBiMap.create(2);
+ private final Function<KeyT, ObjectT> builder;
+ private final @Nullable ThrowingConsumer<Exception, ObjectT> finalizer;
+
+ public ObjectPool(Function<KeyT, ObjectT> builder) {
+ this(builder, null);
+ }
+
+ public ObjectPool(
+ Function<KeyT, ObjectT> builder, @Nullable ThrowingConsumer<Exception, ObjectT> finalizer) {
+ this.builder = builder;
+ this.finalizer = finalizer;
+ }
+
+ /** Retain a reference to a shared client instance. If not available, an instance is created. */
+ public ObjectT retain(KeyT key) {
+ synchronized (pool) {
+ RefCounted ref = pool.computeIfAbsent(key, k -> new RefCounted(builder.apply(k)));
+ ref.count++;
+ return ref.shared;
+ }
+ }
+
+ /**
+ * Release a reference to a shared object instance using {@link KeyT}. If that instance is not
+ * used anymore, it will be removed and destroyed.
+ */
+ public void releaseByKey(KeyT key) {
+ RefCounted ref;
+ synchronized (pool) {
+ ref = pool.get(key);
+ if (ref == null || --ref.count > 0) {
+ return;
+ }
+ pool.remove(key);
+ }
+ if (finalizer != null) {
+ try {
+ finalizer.accept(ref.shared);
+ } catch (Exception e) {
+ LoggerFactory.getLogger(ObjectPool.class).warn("Exception destroying pooled object.", e);
+ }
+ }
+ }
+
+ /**
+ * Release a reference to a shared client instance. If that instance is not used anymore, it will
+ * be removed and destroyed.
+ */
+ public void release(ObjectT object) {
+ KeyT key = pool.inverse().get(new RefCounted(object));
+ if (key != null) {
+ releaseByKey(key);
+ }
+ }
+
+ public static <ClientT extends SdkClient, BuilderT extends AwsClientBuilder<BuilderT, ClientT>>
+ ClientPool<ClientT> pooledClientFactory(BuilderT builder) {
+ return new ClientPool<>(c -> buildClient(c.getLeft(), builder, c.getRight()));
+ }
+
+ /** Client pool to easily share AWS clients per configuration. */
+ public static class ClientPool<ClientT extends SdkClient>
+ extends ObjectPool<Pair<AwsOptions, ClientConfiguration>, ClientT> {
+
+ private ClientPool(Function<Pair<AwsOptions, ClientConfiguration>, ClientT> builder) {
+ super(builder, c -> c.close());
+ }
+
+ /** Retain a reference to a shared client instance. If not available, an instance is created. */
+ public ClientT retain(AwsOptions provider, ClientConfiguration config) {
+ return retain(Pair.of(provider, config));
+ }
+ }
+
+ private class RefCounted {
+ private int count = 0;
+ private final ObjectT shared;
+
+ RefCounted(ObjectT client) {
+ this.shared = client;
+ }
+
+ @Override
+ public boolean equals(@Nullable Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ // only identity of ref counted shared object matters
+ return shared == ((RefCounted) o).shared;
+ }
+
+ @Override
+ public int hashCode() {
+ return shared.hashCode();
+ }
+ }
+}
diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/common/RetryConfiguration.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/common/RetryConfiguration.java
index 6ca342927a8..4a816cb3d91 100644
--- a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/common/RetryConfiguration.java
+++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/common/RetryConfiguration.java
@@ -68,7 +68,7 @@ public abstract class RetryConfiguration implements Serializable {
public abstract RetryConfiguration.Builder toBuilder();
public static Builder builder() {
- return Builder.builder();
+ return Builder.builder().numRetries(3);
}
@AutoValue.Builder
diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/kinesis/KinesisIO.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/kinesis/KinesisIO.java
index 69f4f28427c..5042a574592 100644
--- a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/kinesis/KinesisIO.java
+++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/kinesis/KinesisIO.java
@@ -22,6 +22,7 @@ import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Prec
import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState;
import static org.apache.commons.lang3.ArrayUtils.EMPTY_BYTE_ARRAY;
import static org.apache.commons.lang3.StringUtils.isEmpty;
+import static software.amazon.awssdk.services.kinesis.model.ShardFilterType.AT_LATEST;
import com.google.auto.value.AutoValue;
import java.io.Serializable;
@@ -33,17 +34,24 @@ import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
+import java.util.NavigableSet;
+import java.util.TreeSet;
+import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
+import javax.annotation.concurrent.NotThreadSafe;
+import javax.annotation.concurrent.ThreadSafe;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.annotations.Experimental;
import org.apache.beam.sdk.annotations.Experimental.Kind;
import org.apache.beam.sdk.io.Read.Unbounded;
import org.apache.beam.sdk.io.aws2.common.ClientBuilderFactory;
import org.apache.beam.sdk.io.aws2.common.ClientConfiguration;
-import org.apache.beam.sdk.io.aws2.common.ClientPool;
+import org.apache.beam.sdk.io.aws2.common.ObjectPool;
+import org.apache.beam.sdk.io.aws2.common.ObjectPool.ClientPool;
import org.apache.beam.sdk.io.aws2.common.RetryConfiguration;
+import org.apache.beam.sdk.io.aws2.kinesis.KinesisPartitioner.ExplicitPartitioner;
import org.apache.beam.sdk.io.aws2.options.AwsOptions;
import org.apache.beam.sdk.metrics.Counter;
import org.apache.beam.sdk.metrics.Distribution;
@@ -65,7 +73,9 @@ import org.apache.beam.sdk.values.PInput;
import org.apache.beam.sdk.values.POutput;
import org.apache.beam.sdk.values.PValue;
import org.apache.beam.sdk.values.TupleTag;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSortedSet;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.checkerframework.dataflow.qual.Pure;
import org.joda.time.DateTimeUtils;
@@ -79,7 +89,9 @@ import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.kinesis.KinesisAsyncClient;
+import software.amazon.awssdk.services.kinesis.model.ListShardsRequest;
import software.amazon.awssdk.services.kinesis.model.PutRecordsRequestEntry;
+import software.amazon.awssdk.services.kinesis.model.Shard;
import software.amazon.kinesis.common.InitialPositionInStream;
/**
@@ -173,7 +185,8 @@ import software.amazon.kinesis.common.InitialPositionInStream;
* utilized at all.
*
* <p>If you require finer control over the distribution of records, override {@link
- * KinesisPartitioner#getExplicitHashKey(Object)} according to your needs.
+ * KinesisPartitioner#getExplicitHashKey(Object)} according to your needs. However, this might
+ * impact record aggregation.
*
* <h4>Aggregation of records</h4>
*
@@ -182,10 +195,29 @@ import software.amazon.kinesis.common.InitialPositionInStream;
* href="https://docs.aws.amazon.com/streams/latest/dev/kinesis-kpl-concepts.html#kinesis-kpl-concepts-aggretation">aggregated
* KPL record</a>.
*
- * <p>However, only records with the same effective hash key are aggregated, in which the effective
- * hash key is either the explicit hash key if defined, or otherwise the hashed partition key.
+ * <p>Records of the same effective hash key get aggregated. The effective hash key is:
*
- * <p>Record aggregation can be explicitly disabled using {@link
+ * <ol>
+ * <li>the explicit hash key, if provided.
+ * <li>the lower bound of the hash key range of the target shard according to the given partition
+ * key, if available.
+ * <li>or otherwise the hashed partition key
+ * </ol>
+ *
+ * <p>To provide shard aware aggregation in 2., hash key ranges of shards are loaded and refreshed
+ * periodically. This allows to aggregate records into a number of aggregates that matches the
+ * number of shards in the stream to max out Kinesis API limits the best possible way.
+ *
+ * <p><b>Note:</b>There's an important downside to consider when using shard aware aggregation:
+ * records get assigned to a shard (via an explicit hash key) on the client side, but respective
+ * client side state can't be guaranteed to always be up-to-date. If a shard gets split, all
+ * aggregates are mapped to the lower child shard until state is refreshed. Timing, however, will
+ * diverge between the different workers.
+ *
+ * <p>If using an {@link ExplicitPartitioner} or disabling shard refresh via {@link
+ * RecordAggregation}, no shard details will be loaded (and used).
+ *
+ * <p>Record aggregation can be entirely disabled using {@link
* Write#withRecordAggregationDisabled()}.
*
* <h3>Configuration of AWS clients</h3>
@@ -536,11 +568,30 @@ public final class KinesisIO {
abstract double maxBufferedTimeJitter();
+ abstract Duration shardRefreshInterval();
+
+ abstract double shardRefreshIntervalJitter();
+
+ Instant nextBufferTimeout() {
+ return nextInstant(maxBufferedTime(), maxBufferedTimeJitter());
+ }
+
+ Instant nextShardRefresh() {
+ return nextInstant(shardRefreshInterval(), shardRefreshIntervalJitter());
+ }
+
+ private Instant nextInstant(Duration duration, double jitter) {
+ double millis = (1 - jitter + jitter * Math.random()) * duration.getMillis();
+ return Instant.ofEpochMilli(DateTimeUtils.currentTimeMillis() + (long) millis);
+ }
+
public static Builder builder() {
return new AutoValue_KinesisIO_RecordAggregation.Builder()
.maxBytes(Write.MAX_BYTES_PER_RECORD)
.maxBufferedTimeJitter(0.7) // 70% jitter
- .maxBufferedTime(Duration.standardSeconds(1));
+ .maxBufferedTime(Duration.millis(500))
+ .shardRefreshIntervalJitter(0.5) // 50% jitter
+ .shardRefreshInterval(Duration.standardMinutes(2));
}
@AutoValue.Builder
@@ -556,8 +607,19 @@ public final class KinesisIO {
*/
public abstract Builder maxBufferedTime(Duration interval);
+ /**
+ * Refresh interval for shards.
+ *
+ * <p>This is used for shard aware record aggregation to assign all records hashed to a
+ * particular shard to the same explicit hash key. Set to {@link Duration#ZERO} to disable
+ * loading shards.
+ */
+ public abstract Builder shardRefreshInterval(Duration interval);
+
abstract Builder maxBufferedTimeJitter(double jitter);
+ abstract Builder shardRefreshIntervalJitter(double jitter);
+
abstract RecordAggregation autoBuild();
public RecordAggregation build() {
@@ -670,9 +732,6 @@ public final class KinesisIO {
* Enable record aggregation that is compatible with the KPL / KCL.
*
* <p>https://docs.aws.amazon.com/streams/latest/dev/kinesis-kpl-concepts.html#kinesis-kpl-concepts-aggretation
- *
- * <p>Note: The aggregation is a lot simpler than the one offered by KPL. It only aggregates
- * records with the same partition key as it's not aware of explicit hash key ranges per shard.
*/
public Write<T> withRecordAggregation(RecordAggregation aggregation) {
return builder().recordAggregation(aggregation).build();
@@ -682,9 +741,6 @@ public final class KinesisIO {
* Enable record aggregation that is compatible with the KPL / KCL.
*
* <p>https://docs.aws.amazon.com/streams/latest/dev/kinesis-kpl-concepts.html#kinesis-kpl-concepts-aggretation
- *
- * <p>Note: The aggregation is a lot simpler than the one offered by KPL. It only aggregates
- * records with the same partition key as it's not aware of explicit hash key ranges per shard.
*/
public Write<T> withRecordAggregation(Consumer<RecordAggregation.Builder> aggregation) {
RecordAggregation.Builder builder = RecordAggregation.builder();
@@ -803,14 +859,14 @@ public final class KinesisIO {
private static final int PARTIAL_RETRIES = 10; // Retries for partial success (throttling)
- private static final ClientPool<AwsOptions, ClientConfiguration, KinesisAsyncClient> CLIENTS =
- ClientPool.pooledClientFactory(KinesisAsyncClient.builder());
+ private static final ClientPool<KinesisAsyncClient> CLIENTS =
+ ObjectPool.pooledClientFactory(KinesisAsyncClient.builder());
protected final Write<T> spec;
protected final Stats stats;
protected final AsyncPutRecordsHandler handler;
+ protected final KinesisAsyncClient kinesis;
- private final KinesisAsyncClient kinesis;
private List<PutRecordsRequestEntry> requestEntries;
private int requestBytes = 0;
@@ -947,25 +1003,37 @@ public final class KinesisIO {
* with KCL to correctly implement the binary protocol, specifically {@link
* software.amazon.kinesis.retrieval.kpl.Messages.AggregatedRecord}.
*
- * <p>Note: The aggregation is a lot simpler than the one offered by KPL. While the KPL is aware
- * of effective hash key ranges assigned to each shard, we're not and don't want to be to keep
- * complexity manageable and avoid the risk of silently loosing records in the KCL:
+ * <p>To aggregate records the best possible way, records are assigned an explicit hash key that
+ * corresponds to the lower bound of the hash key range of the target shard. In case a record
+ * has already an explicit hash key assigned, it is kept unchanged.
*
- * <p>{@link software.amazon.kinesis.retrieval.AggregatorUtil#deaggregate(List, BigInteger,
- * BigInteger)} drops records not matching the expected hash key range.
+ * <p>Hash key ranges of shards are expected to be only slowly changing and get refreshed
+ * infrequently. If using an {@link ExplicitPartitioner} or disabling shard refresh via {@link
+ * RecordAggregation}, no shard details will be pulled.
*/
static class AggregatedWriter<T> extends Writer<T> {
private static final Logger LOG = LoggerFactory.getLogger(AggregatedWriter.class);
+ private static final ObjectPool<String, ShardRanges> SHARD_RANGES_BY_STREAM =
+ new ObjectPool<>(ShardRanges::of);
private final RecordAggregation aggSpec;
private final Map<BigInteger, RecordsAggregator> aggregators;
- private final MessageDigest md5Digest;
+ private final PartitionKeyHasher pkHasher;
+
+ private final ShardRanges shardRanges;
AggregatedWriter(PipelineOptions options, Write<T> spec, RecordAggregation aggSpec) {
super(options, spec);
this.aggSpec = aggSpec;
- this.aggregators = new LinkedHashMap<>();
- this.md5Digest = md5Digest();
+ aggregators = new LinkedHashMap<>();
+ pkHasher = new PartitionKeyHasher();
+ if (aggSpec.shardRefreshInterval().isLongerThan(Duration.ZERO)
+ && !(spec.partitioner() instanceof ExplicitPartitioner)) {
+ shardRanges = SHARD_RANGES_BY_STREAM.retain(spec.streamName());
+ shardRanges.refreshPeriodically(kinesis, aggSpec::nextShardRefresh);
+ } else {
+ shardRanges = ShardRanges.EMPTY;
+ }
}
@Override
@@ -977,20 +1045,36 @@ public final class KinesisIO {
@Override
protected void write(String partitionKey, @Nullable String explicitHashKey, byte[] data)
throws Throwable {
- BigInteger hashKey = effectiveHashKey(partitionKey, explicitHashKey);
- RecordsAggregator agg = aggregators.computeIfAbsent(hashKey, k -> newRecordsAggregator());
+ shardRanges.refreshPeriodically(kinesis, aggSpec::nextShardRefresh);
+
+ // calculate the effective hash key used for aggregation
+ BigInteger aggKey;
+ if (explicitHashKey != null) {
+ aggKey = new BigInteger(explicitHashKey);
+ } else {
+ BigInteger hashedPartitionKey = pkHasher.hashKey(partitionKey);
+ aggKey = shardRanges.shardAwareHashKey(hashedPartitionKey);
+ if (aggKey != null) {
+ // use the shard aware aggregation key as explicit hash key for optimal aggregation
+ explicitHashKey = aggKey.toString();
+ } else {
+ aggKey = hashedPartitionKey;
+ }
+ }
+
+ RecordsAggregator agg = aggregators.computeIfAbsent(aggKey, k -> newRecordsAggregator());
if (!agg.addRecord(partitionKey, explicitHashKey, data)) {
// aggregated record too full, add a request entry and reset aggregator
- addRequestEntry(agg.getAndReset(aggregationTimeoutWithJitter()));
- aggregators.remove(hashKey);
+ addRequestEntry(agg.getAndReset(aggSpec.nextBufferTimeout()));
+ aggregators.remove(aggKey);
if (agg.addRecord(partitionKey, explicitHashKey, data)) {
- aggregators.put(hashKey, agg); // new aggregation started
+ aggregators.put(aggKey, agg); // new aggregation started
} else {
super.write(partitionKey, explicitHashKey, data); // skip aggregation
}
} else if (!agg.hasCapacity()) {
addRequestEntry(agg.get());
- aggregators.remove(hashKey);
+ aggregators.remove(aggKey);
}
// only check timeouts sporadically if concurrency is already maxed out
@@ -1001,14 +1085,7 @@ public final class KinesisIO {
private RecordsAggregator newRecordsAggregator() {
return new RecordsAggregator(
- Math.min(aggSpec.maxBytes(), spec.batchMaxBytes()), aggregationTimeoutWithJitter());
- }
-
- private Instant aggregationTimeoutWithJitter() {
- double millis =
- (1 - aggSpec.maxBufferedTimeJitter() + aggSpec.maxBufferedTimeJitter() * Math.random())
- * aggSpec.maxBufferedTime().getMillis();
- return Instant.ofEpochMilli(DateTimeUtils.currentTimeMillis() + (long) millis);
+ Math.min(aggSpec.maxBytes(), spec.batchMaxBytes()), aggSpec.nextBufferTimeout());
}
private void checkAggregationTimeouts() throws Throwable {
@@ -1021,9 +1098,8 @@ public final class KinesisIO {
if (agg.timeout().isAfter(now)) {
break;
}
- LOG.debug(
- "Adding aggregated entry after timeout [delay = {} ms]",
- now.getMillis() - agg.timeout().getMillis());
+ long delayMillis = now.getMillis() - agg.timeout().getMillis();
+ LOG.debug("Adding aggregated entry after timeout [delay = {} ms]", delayMillis);
addRequestEntry(agg.get());
removals.add(e.getKey());
}
@@ -1041,16 +1117,23 @@ public final class KinesisIO {
super.finishBundle();
}
- private BigInteger effectiveHashKey(String partitionKey, @Nullable String explicitHashKey) {
- return explicitHashKey == null
- ? new BigInteger(1, md5(partitionKey.getBytes(UTF_8)))
- : new BigInteger(explicitHashKey);
+ @Override
+ public void close() throws Exception {
+ super.close();
+ SHARD_RANGES_BY_STREAM.release(shardRanges);
}
+ }
- private byte[] md5(byte[] data) {
- byte[] hash = md5Digest.digest(data);
+ @VisibleForTesting
+ @NotThreadSafe
+ static class PartitionKeyHasher {
+ private final MessageDigest md5Digest = md5Digest();
+
+ /** Hash partition key to 128 bit integer. */
+ BigInteger hashKey(String partitionKey) {
+ byte[] hashedBytes = md5Digest.digest(partitionKey.getBytes(UTF_8));
md5Digest.reset();
- return hash;
+ return new BigInteger(1, hashedBytes);
}
private static MessageDigest md5Digest() {
@@ -1062,6 +1145,99 @@ public final class KinesisIO {
}
}
+ /** Shard hash ranges per stream to generate shard aware hash keys for record aggregation. */
+ @VisibleForTesting
+ @ThreadSafe
+ interface ShardRanges {
+ ShardRanges EMPTY = new ShardRanges() {};
+
+ static ShardRanges of(String stream) {
+ return new ShardRangesImpl(stream);
+ }
+
+ /**
+ * Align partition key hash to lower bound of key range of the target shard. If unavailable
+ * {@code null} is returned.
+ */
+ default @Nullable BigInteger shardAwareHashKey(BigInteger hashedPartitionKey) {
+ return null;
+ }
+
+ /** Check for and trigger periodic refresh if needed. */
+ default void refreshPeriodically(
+ KinesisAsyncClient kinesis, Supplier<Instant> nextRefreshFn) {}
+
+ class ShardRangesImpl implements ShardRanges {
+ private static final Logger LOG = LoggerFactory.getLogger(ShardRanges.class);
+
+ private final String streamName;
+
+ private final AtomicBoolean running = new AtomicBoolean(false);
+ private NavigableSet<BigInteger> shardBounds = ImmutableSortedSet.of();
+ private Instant nextRefresh = Instant.EPOCH;
+
+ private ShardRangesImpl(String streamName) {
+ this.streamName = streamName;
+ }
+
+ @Override
+ public @Nullable BigInteger shardAwareHashKey(BigInteger hashedPartitionKey) {
+ BigInteger lowerBound = shardBounds.floor(hashedPartitionKey);
+ if (!shardBounds.isEmpty() && lowerBound == null) {
+ LOG.warn("No shard found for {} [shards={}]", hashedPartitionKey, shardBounds.size());
+ }
+ return lowerBound;
+ }
+
+ @Override
+ public void refreshPeriodically(
+ KinesisAsyncClient client, Supplier<Instant> nextRefreshFn) {
+ if (nextRefresh.isBeforeNow() && running.compareAndSet(false, true)) {
+ refresh(client, nextRefreshFn, new TreeSet<>(), null);
+ }
+ }
+
+ @SuppressWarnings("FutureReturnValueIgnored") // safe to ignore
+ private void refresh(
+ KinesisAsyncClient client,
+ Supplier<Instant> nextRefreshFn,
+ TreeSet<BigInteger> bounds,
+ @Nullable String nextToken) {
+ ListShardsRequest.Builder reqBuilder =
+ ListShardsRequest.builder().shardFilter(f -> f.type(AT_LATEST));
+ if (nextToken != null) {
+ reqBuilder.nextToken(nextToken);
+ } else {
+ reqBuilder.streamName(streamName);
+ }
+ client
+ .listShards(reqBuilder.build())
+ .whenComplete(
+ (resp, exc) -> {
+ if (exc != null) {
+ LOG.warn("Failed to refresh shards.", exc);
+ nextRefresh = nextRefreshFn.get(); // retry later
+ running.set(false);
+ return;
+ }
+ resp.shards().forEach(shard -> bounds.add(lowerHashKey(shard)));
+ if (resp.nextToken() != null) {
+ refresh(client, nextRefreshFn, bounds, resp.nextToken());
+ return;
+ }
+ LOG.debug("Done refreshing {} shards.", bounds.size());
+ nextRefresh = nextRefreshFn.get();
+ running.set(false);
+ shardBounds = bounds; // swap key ranges
+ });
+ }
+
+ private BigInteger lowerHashKey(Shard shard) {
+ return new BigInteger(shard.hashKeyRange().startingHashKey());
+ }
+ }
+ }
+
private static class Stats implements AsyncPutRecordsHandler.Stats {
private static final Logger LOG = LoggerFactory.getLogger(Stats.class);
private static final Duration LOG_STATS_PERIOD = Duration.standardSeconds(10);
diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/kinesis/KinesisPartitioner.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/kinesis/KinesisPartitioner.java
index 99d5f915446..eaf9b0b7607 100644
--- a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/kinesis/KinesisPartitioner.java
+++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/kinesis/KinesisPartitioner.java
@@ -47,6 +47,26 @@ public interface KinesisPartitioner<T> extends Serializable {
return null;
}
+ /**
+ * An explicit partitioner that always returns a {@code Nonnull} explicit hash key. The partition
+ * key is irrelevant in this case, though it cannot be {@code null}.
+ */
+ interface ExplicitPartitioner<T> extends KinesisPartitioner<T> {
+ @Override
+ default @Nonnull String getPartitionKey(T record) {
+ return "a"; // will be ignored, but can't be null or empty
+ }
+
+ /**
+ * Required hash value (128-bit integer) to determine explicitly the shard a record is assigned
+ * to based on the hash key range of each shard. The explicit hash key overrides the partition
+ * key hash.
+ */
+ @Override
+ @Nonnull
+ String getExplicitHashKey(T record);
+ }
+
/**
* Explicit hash key partitioner that randomly returns one of x precalculated hash keys. Hash keys
* are derived by equally dividing the 128-bit hash universe, assuming that hash ranges of shards
@@ -70,15 +90,9 @@ public interface KinesisPartitioner<T> extends Serializable {
hashKey = hashKey.add(distance);
}
- return new KinesisPartitioner<T>() {
+ return new ExplicitPartitioner<T>() {
@Nonnull
@Override
- public String getPartitionKey(T record) {
- return "a"; // ignored, but can't be null
- }
-
- @Nullable
- @Override
public String getExplicitHashKey(T record) {
return hashKeys[new Random().nextInt(shards)];
}
diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/kinesis/RecordsAggregator.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/kinesis/RecordsAggregator.java
index 5694ec7f14f..bcf546e792e 100644
--- a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/kinesis/RecordsAggregator.java
+++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/kinesis/RecordsAggregator.java
@@ -41,10 +41,6 @@ import software.amazon.kinesis.retrieval.kpl.Messages.AggregatedRecord;
* Record aggregator compatible with the record (de)aggregation of the Kinesis Producer Library
* (KPL) and Kinesis Client Library (KCL).
*
- * <p>However, only records with the same effective hash key should be aggregated to keep complexity
- * manageable. Otherwise, the aggregator would have to be aware of the most up-to-date explicit hash
- * key ranges per shard.
- *
* <p>https://docs.aws.amazon.com/streams/latest/dev/kinesis-kpl-concepts.html#kinesis-kpl-concepts-aggretation
*/
@NotThreadSafe
diff --git a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/common/ClientPoolTest.java b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/common/ObjectPoolTest.java
similarity index 63%
rename from sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/common/ClientPoolTest.java
rename to sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/common/ObjectPoolTest.java
index 228e214f8b3..154957ee3c0 100644
--- a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/common/ClientPoolTest.java
+++ b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/common/ObjectPoolTest.java
@@ -20,10 +20,10 @@ package org.apache.beam.sdk.io.aws2.common;
import static java.util.concurrent.ForkJoinPool.commonPool;
import static java.util.stream.Collectors.toList;
import static org.assertj.core.api.Assertions.assertThat;
-import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoInteractions;
@@ -34,30 +34,24 @@ import java.util.concurrent.Callable;
import java.util.concurrent.ForkJoinTask;
import java.util.function.Function;
import java.util.stream.Stream;
-import org.junit.Before;
+import org.apache.beam.sdk.testing.ExpectedLogs;
+import org.junit.Rule;
import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.mockito.Spy;
-import org.mockito.junit.MockitoJUnitRunner;
-
-@RunWith(MockitoJUnitRunner.class)
-public class ClientPoolTest {
- @Spy ClientProvider provider = new ClientProvider();
- ClientPool<Function<String, AutoCloseable>, String, AutoCloseable> pool;
-
- @Before
- public void init() {
- pool = new ClientPool<>((p, c) -> p.apply(c));
- }
+
+public class ObjectPoolTest {
+ Function<String, AutoCloseable> provider = spy(new Provider());
+ ObjectPool<String, AutoCloseable> pool = new ObjectPool<>(provider, obj -> obj.close());
+
+ @Rule public ExpectedLogs logs = ExpectedLogs.none(ObjectPool.class);
class ResourceTask implements Callable<AutoCloseable> {
@Override
- public AutoCloseable call() throws Exception {
- AutoCloseable client = pool.retain(provider, "config");
- pool.retain(provider, "config");
- pool.release(provider, "config");
+ public AutoCloseable call() {
+ AutoCloseable client = pool.retain("config");
+ pool.retain("config");
+ pool.release(client);
verifyNoInteractions(client);
- pool.release(provider, "config");
+ pool.release(client);
return client;
}
}
@@ -83,8 +77,8 @@ public class ClientPoolTest {
String config1 = "config1";
String config2 = "config2";
- assertThat(pool.retain(provider, config1)).isSameAs(pool.retain(provider, config1));
- assertThat(pool.retain(provider, config1)).isNotSameAs(pool.retain(provider, config2));
+ assertThat(pool.retain(config1)).isSameAs(pool.retain(config1));
+ assertThat(pool.retain(config1)).isNotSameAs(pool.retain(config2));
verify(provider, times(2)).apply(anyString());
verify(provider, times(1)).apply(config1);
verify(provider, times(1)).apply(config2);
@@ -97,47 +91,70 @@ public class ClientPoolTest {
AutoCloseable client = null;
for (int i = 0; i < sharedInstances; i++) {
- client = pool.retain(provider, config);
+ client = pool.retain(config);
}
- for (int i = 1; i < sharedInstances; i++) {
- pool.release(provider, config);
+ for (int i = 0; i < sharedInstances - 1; i++) {
+ pool.release(client);
}
verifyNoInteractions(client);
// verify close on last release
- pool.release(provider, config);
+ pool.release(client);
verify(client).close();
// verify further attempts to release have no effect
- pool.release(provider, config);
+ pool.release(client);
verifyNoMoreInteractions(client);
}
@Test
- public void recreateClientOnceReleased() throws Exception {
+ public void closeClientsOnceReleasedByKey() throws Exception {
String config = "config";
- AutoCloseable client1 = pool.retain(provider, config);
- pool.release(provider, config);
- AutoCloseable client2 = pool.retain(provider, config);
+ int sharedInstances = 10;
- verify(provider, times(2)).apply(config);
+ AutoCloseable client = null;
+ for (int i = 0; i < sharedInstances; i++) {
+ client = pool.retain(config);
+ }
+
+ for (int i = 0; i < sharedInstances - 1; i++) {
+ pool.releaseByKey(config);
+ }
+ verifyNoInteractions(client);
+ // verify close on last release
+ pool.releaseByKey(config);
+ verify(client).close();
+ // verify further attempts to release have no effect
+ pool.releaseByKey(config);
+ verifyNoMoreInteractions(client);
+ }
+
+ @Test
+ public void recreateClientOnceReleased() throws Exception {
+ String config = "config";
+ AutoCloseable client1 = pool.retain(config);
+ pool.release(client1);
verify(client1).close();
+
+ AutoCloseable client2 = pool.retain(config);
verifyNoInteractions(client2);
+
+ verify(provider, times(2)).apply(config);
+ assertThat(client1).isNotSameAs(client2);
}
@Test
public void releaseWithError() throws Exception {
- String config = "config";
- AutoCloseable client1 = pool.retain(provider, config);
- doThrow(new Exception("error on close")).when(client1).close();
- assertThatThrownBy(() -> pool.release(provider, config)).hasMessage("error on close");
+ Exception onClose = new Exception("error on close");
- AutoCloseable client2 = pool.retain(provider, config);
- verify(provider, times(2)).apply(config);
- verify(client1).close();
- verifyNoInteractions(client2);
+ AutoCloseable client = pool.retain("config");
+ doThrow(onClose).when(client).close();
+ pool.release(client);
+
+ verify(client).close();
+ logs.verifyWarn("Exception destroying pooled object.", onClose);
}
- static class ClientProvider implements Function<String, AutoCloseable> {
+ static class Provider implements Function<String, AutoCloseable> {
@Override
public AutoCloseable apply(String configName) {
return mock(AutoCloseable.class, configName);
diff --git a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/kinesis/KinesisIOWriteTest.java b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/kinesis/KinesisIOWriteTest.java
index 1db39161306..e5418f4ff14 100644
--- a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/kinesis/KinesisIOWriteTest.java
+++ b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/kinesis/KinesisIOWriteTest.java
@@ -17,18 +17,26 @@
*/
package org.apache.beam.sdk.io.aws2.kinesis;
+import static java.math.BigInteger.ONE;
+import static java.util.Arrays.stream;
import static java.util.concurrent.CompletableFuture.completedFuture;
import static java.util.concurrent.CompletableFuture.supplyAsync;
import static java.util.function.Function.identity;
+import static java.util.stream.Collectors.toList;
import static org.apache.beam.sdk.io.aws2.kinesis.KinesisIO.Write.MAX_BYTES_PER_RECORD;
import static org.apache.beam.sdk.io.aws2.kinesis.KinesisIO.Write.MAX_BYTES_PER_REQUEST;
import static org.apache.beam.sdk.io.aws2.kinesis.KinesisIO.Write.MAX_RECORDS_PER_REQUEST;
+import static org.apache.beam.sdk.io.aws2.kinesis.KinesisPartitioner.MAX_HASH_KEY;
+import static org.apache.beam.sdk.io.aws2.kinesis.KinesisPartitioner.MIN_HASH_KEY;
+import static org.apache.beam.sdk.io.aws2.kinesis.KinesisPartitioner.explicitRandomPartitioner;
import static org.apache.beam.sdk.io.common.TestRow.getExpectedValues;
import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkNotNull;
import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables.concat;
import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists.transform;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
+import static org.joda.time.Duration.ZERO;
+import static org.joda.time.Duration.millis;
import static org.joda.time.Duration.standardSeconds;
import static org.mockito.AdditionalMatchers.and;
import static org.mockito.ArgumentMatchers.any;
@@ -40,14 +48,20 @@ import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
+import java.math.BigInteger;
import java.util.List;
+import java.util.concurrent.CompletableFuture;
import java.util.function.Function;
import java.util.function.Supplier;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.io.GenerateSequence;
import org.apache.beam.sdk.io.aws2.MockClientBuilderFactory;
+import org.apache.beam.sdk.io.aws2.common.ClientConfiguration;
+import org.apache.beam.sdk.io.aws2.common.RetryConfiguration;
import org.apache.beam.sdk.io.aws2.kinesis.KinesisIO.Write;
import org.apache.beam.sdk.io.aws2.kinesis.KinesisIO.Write.AggregatedWriter;
+import org.apache.beam.sdk.io.aws2.kinesis.KinesisIO.Write.PartitionKeyHasher;
+import org.apache.beam.sdk.io.aws2.kinesis.KinesisIO.Write.ShardRanges;
import org.apache.beam.sdk.io.common.TestRow;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Create;
@@ -55,22 +69,27 @@ import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Objects;
import org.assertj.core.api.ThrowableAssert;
import org.joda.time.DateTimeUtils;
-import org.joda.time.Duration;
+import org.joda.time.Instant;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
+import org.mockito.ArgumentMatcher;
import org.mockito.InOrder;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
import software.amazon.awssdk.services.kinesis.KinesisAsyncClient;
import software.amazon.awssdk.services.kinesis.KinesisAsyncClientBuilder;
+import software.amazon.awssdk.services.kinesis.model.HashKeyRange;
+import software.amazon.awssdk.services.kinesis.model.ListShardsRequest;
+import software.amazon.awssdk.services.kinesis.model.ListShardsResponse;
import software.amazon.awssdk.services.kinesis.model.PutRecordsRequest;
-import software.amazon.awssdk.services.kinesis.model.PutRecordsResponse;
+import software.amazon.awssdk.services.kinesis.model.Shard;
/** Tests for {@link KinesisIO#write()}. */
@RunWith(MockitoJUnitRunner.StrictStubs.class)
@@ -82,8 +101,12 @@ public class KinesisIOWriteTest extends PutRecordsHelpers {
@Mock public KinesisAsyncClient client;
@Before
- public void configureClientBuilderFactory() {
+ public void configure() {
MockClientBuilderFactory.set(pipeline, KinesisAsyncClientBuilder.class, client);
+
+ CompletableFuture<ListShardsResponse> errorResp = new CompletableFuture<>();
+ errorResp.completeExceptionally(new RuntimeException("Unavailable, retried later"));
+ when(client.listShards(any(ListShardsRequest.class))).thenReturn(errorResp);
}
@After
@@ -158,11 +181,11 @@ public class KinesisIOWriteTest extends PutRecordsHelpers {
@Test
public void testWriteFailure() {
- when(client.putRecords(any(PutRecordsRequest.class)))
+ when(client.putRecords(anyRequest()))
.thenReturn(
- completedFuture(PutRecordsResponse.builder().build()),
+ completedFuture(successResponse),
supplyAsync(() -> checkNotNull(null, "putRecords failed")),
- completedFuture(PutRecordsResponse.builder().build()));
+ completedFuture(successResponse));
pipeline
.apply(GenerateSequence.from(0).to(100))
@@ -178,17 +201,20 @@ public class KinesisIOWriteTest extends PutRecordsHelpers {
@Test
public void testWriteWithPartialSuccess() {
- when(client.putRecords(any(PutRecordsRequest.class)))
+ when(client.putRecords(anyRequest()))
.thenReturn(completedFuture(partialSuccessResponse(70, 30)))
.thenReturn(completedFuture(partialSuccessResponse(10, 20)))
- .thenReturn(completedFuture(PutRecordsResponse.builder().build()));
+ .thenReturn(completedFuture(successResponse));
+
+ // minimize delay due to retries
+ RetryConfiguration retry = RetryConfiguration.builder().maxBackoff(millis(1)).build();
pipeline
.apply(Create.of(100))
.apply(ParDo.of(new GenerateTestRows()))
.apply(
kinesisWrite()
- // .withRetryConfiguration(RetryConfiguration.fixed(10, Duration.millis(1)))
+ .withClientConfiguration(ClientConfiguration.builder().retry(retry).build())
.withRecordAggregationDisabled());
pipeline.run().waitUntilFinish();
@@ -201,9 +227,8 @@ public class KinesisIOWriteTest extends PutRecordsHelpers {
}
@Test
- public void testWriteAggregated() {
- when(client.putRecords(any(PutRecordsRequest.class)))
- .thenReturn(completedFuture(PutRecordsResponse.builder().build()));
+ public void testWriteAggregatedByDefault() {
+ when(client.putRecords(anyRequest())).thenReturn(completedFuture(successResponse));
pipeline
.apply(Create.of(100))
@@ -215,10 +240,82 @@ public class KinesisIOWriteTest extends PutRecordsHelpers {
verify(client).close();
}
+ @Test
+ public void testWriteAggregatedShardAware() {
+ mockShardRanges(MIN_HASH_KEY, MAX_HASH_KEY.shiftRight(1)); // 2 shards
+ when(client.putRecords(anyRequest())).thenReturn(completedFuture(successResponse));
+
+ pipeline
+ .apply(Create.of(100))
+ .apply(ParDo.of(new GenerateTestRows()))
+ .apply(kinesisWrite().withPartitioner(row -> row.id().toString()));
+
+ pipeline.run().waitUntilFinish();
+ verify(client).putRecords(argThat(hasSize(2))); // 1 aggregated record per shard
+ verify(client).listShards(any(ListShardsRequest.class));
+ verify(client).close();
+ }
+
+ @Test
+ public void testWriteAggregatedShardRefreshPending() {
+ CompletableFuture<ListShardsResponse> resp = new CompletableFuture<>();
+ when(client.listShards(any(ListShardsRequest.class))).thenReturn(resp);
+
+ when(client.putRecords(anyRequest())).thenReturn(completedFuture(successResponse));
+
+ pipeline
+ .apply(Create.of(100))
+ .apply(ParDo.of(new GenerateTestRows()))
+ .apply(kinesisWrite().withPartitioner(row -> row.id().toString()));
+
+ pipeline.run().waitUntilFinish();
+ resp.complete(ListShardsResponse.builder().build()); // complete list shards after pipeline
+
+ // while shards are unknown, each row is aggregated into an individual aggregated record
+ verify(client).putRecords(argThat(hasSize(100)));
+ verify(client).listShards(any(ListShardsRequest.class));
+ verify(client).close();
+ }
+
+ @Test
+ public void testWriteAggregatedShardRefreshDisabled() {
+ when(client.putRecords(anyRequest())).thenReturn(completedFuture(successResponse));
+
+ pipeline
+ .apply(Create.of(100))
+ .apply(ParDo.of(new GenerateTestRows()))
+ .apply(
+ kinesisWrite()
+ .withRecordAggregation(b -> b.shardRefreshInterval(ZERO)) // disable refresh
+ .withPartitioner(row -> row.id().toString()));
+
+ pipeline.run().waitUntilFinish();
+
+ // each row is aggregated into an individual aggregated record
+ verify(client).putRecords(argThat(hasSize(100)));
+ verify(client, times(0)).listShards(any(ListShardsRequest.class)); // disabled
+ verify(client).close();
+ }
+
+ @Test
+ public void testWriteAggregatedUsingExplicitPartitioner() {
+ when(client.putRecords(anyRequest())).thenReturn(completedFuture(successResponse));
+
+ pipeline
+ .apply(Create.of(100))
+ .apply(ParDo.of(new GenerateTestRows()))
+ .apply(kinesisWrite().withPartitioner(explicitRandomPartitioner(2)));
+
+ pipeline.run().waitUntilFinish();
+ verify(client).putRecords(argThat(hasSize(2))); // configuration of partitioner
+ verify(client, times(0))
+ .listShards(any(ListShardsRequest.class)); // disabled for explicit partitioner
+ verify(client).close();
+ }
+
@Test
public void testWriteAggregatedWithMaxBytes() {
- when(client.putRecords(any(PutRecordsRequest.class)))
- .thenReturn(completedFuture(PutRecordsResponse.builder().build()));
+ when(client.putRecords(anyRequest())).thenReturn(completedFuture(successResponse));
// overhead protocol + key overhead + 500 records, each 4 bytes data + overhead
final int expectedBytes = 20 + 3 + 500 * 10;
@@ -241,8 +338,7 @@ public class KinesisIOWriteTest extends PutRecordsHelpers {
@Test
public void testWriteAggregatedWithMaxBytesAndBatchMaxBytes() {
- when(client.putRecords(any(PutRecordsRequest.class)))
- .thenReturn(completedFuture(PutRecordsResponse.builder().build()));
+ when(client.putRecords(anyRequest())).thenReturn(completedFuture(successResponse));
// overhead protocol + key overhead + 500 records, each 4 bytes data + overhead
final int expectedBytes = 20 + 3 + 500 * 10;
@@ -267,8 +363,7 @@ public class KinesisIOWriteTest extends PutRecordsHelpers {
@Test
public void testWriteAggregatedWithMaxBytesAndBatchMaxRecords() {
- when(client.putRecords(any(PutRecordsRequest.class)))
- .thenReturn(completedFuture(PutRecordsResponse.builder().build()));
+ when(client.putRecords(anyRequest())).thenReturn(completedFuture(successResponse));
// overhead protocol + key overhead + 500 records, each 4 bytes data + overhead
final int expectedBytes = 20 + 3 + 500 * 10;
@@ -293,21 +388,19 @@ public class KinesisIOWriteTest extends PutRecordsHelpers {
@Test
public void testWriteAggregatedWithMaxBufferTime() throws Throwable {
- when(client.putRecords(any(PutRecordsRequest.class)))
- .thenReturn(completedFuture(PutRecordsResponse.builder().build()));
+ when(client.putRecords(anyRequest())).thenReturn(completedFuture(successResponse));
Write<TestRow> write =
kinesisWrite()
.withPartitioner(r -> r.id().toString())
- .withRecordAggregation(
- b -> b.maxBufferedTime(Duration.millis(100)).maxBufferedTimeJitter(0.2));
+ .withRecordAggregation(b -> b.maxBufferedTime(millis(100)).maxBufferedTimeJitter(0.2));
+ DateTimeUtils.setCurrentMillisFixed(0);
AggregatedWriter<TestRow> writer =
new AggregatedWriter<>(pipeline.getOptions(), write, write.recordAggregation());
writer.startBundle();
- DateTimeUtils.setCurrentMillisFixed(0);
for (int i = 1; i <= 3; i++) {
writer.write(TestRow.fromSeed(i));
}
@@ -328,15 +421,82 @@ public class KinesisIOWriteTest extends PutRecordsHelpers {
writer.close();
InOrder ordered = inOrder(client);
- ordered
- .verify(client)
- .putRecords(and(argThat(hasSize(3)), argThat(hasPartitions("1", "2", "3"))));
- ordered.verify(client).putRecords(and(argThat(hasSize(2)), argThat(hasPartitions("4", "5"))));
- ordered.verify(client).putRecords(and(argThat(hasSize(1)), argThat(hasPartitions("6"))));
+ ordered.verify(client).putRecords(argThat(hasPartitions("1", "2", "3")));
+ ordered.verify(client).putRecords(argThat(hasPartitions("4", "5")));
+ ordered.verify(client).putRecords(argThat(hasPartitions("6")));
+ ordered.verify(client).close();
+ verifyNoMoreInteractions(client);
+ }
+
+ @Test
+ public void testWriteAggregatedWithShardsRefresh() throws Throwable {
+ when(client.putRecords(anyRequest())).thenReturn(completedFuture(successResponse));
+
+ Write<TestRow> write =
+ kinesisWrite()
+ .withPartitioner(r -> r.id().toString())
+ .withRecordAggregation(b -> b.shardRefreshInterval(millis(1000)));
+
+ DateTimeUtils.setCurrentMillisFixed(1);
+ AggregatedWriter<TestRow> writer =
+ new AggregatedWriter<>(pipeline.getOptions(), write, write.recordAggregation());
+
+ // initially, no shards known
+ for (int i = 1; i <= 3; i++) {
+ writer.write(TestRow.fromSeed(i));
+ }
+
+ // forward clock, trigger timeouts and refresh shards
+ DateTimeUtils.setCurrentMillisFixed(1500);
+ mockShardRanges(MIN_HASH_KEY);
+
+ for (int i = 1; i <= 10; i++) {
+ writer.write(TestRow.fromSeed(i)); // all aggregated into one record
+ }
+
+ writer.finishBundle();
+ writer.close();
+
+ InOrder ordered = inOrder(client);
+ ordered.verify(client).putRecords(argThat(hasPartitions("1", "2", "3")));
+ ordered.verify(client).putRecords(argThat(hasExplicitPartitions(MIN_HASH_KEY.toString())));
ordered.verify(client).close();
+ verify(client, times(2)).listShards(any(ListShardsRequest.class));
verifyNoMoreInteractions(client);
}
+ @Test
+ public void testShardRangesRefresh() {
+ BigInteger shard1 = MIN_HASH_KEY;
+ BigInteger shard2 = MAX_HASH_KEY.shiftRight(2);
+ BigInteger shard3 = MAX_HASH_KEY.shiftRight(1);
+
+ when(client.listShards(argThat(isRequest(STREAM, null))))
+ .thenReturn(completedFuture(listShardsResponse("a", shard(shard1))));
+ when(client.listShards(argThat(isRequest(null, "a"))))
+ .thenReturn(completedFuture(listShardsResponse("b", shard(shard2))));
+ when(client.listShards(argThat(isRequest(null, "b"))))
+ .thenReturn(completedFuture(listShardsResponse(null, shard(shard3))));
+
+ PartitionKeyHasher pkHasher = new PartitionKeyHasher();
+ ShardRanges shardRanges = ShardRanges.of(STREAM);
+ shardRanges.refreshPeriodically(client, Instant::now);
+
+ verify(client, times(3)).listShards(any(ListShardsRequest.class));
+
+ BigInteger hashKeyA = pkHasher.hashKey("a");
+ assertThat(shardRanges.shardAwareHashKey(hashKeyA)).isEqualTo(shard1);
+ assertThat(hashKeyA).isBetween(shard1, shard2.subtract(ONE));
+
+ BigInteger hashKeyB = pkHasher.hashKey("b");
+ assertThat(shardRanges.shardAwareHashKey(hashKeyB)).isEqualTo(shard3);
+ assertThat(hashKeyB).isBetween(shard3, MAX_HASH_KEY);
+
+ BigInteger hashKeyC = pkHasher.hashKey("c");
+ assertThat(shardRanges.shardAwareHashKey(hashKeyC)).isEqualTo(shard2);
+ assertThat(hashKeyC).isBetween(shard2, shard3.subtract(ONE));
+ }
+
@Test
public void validateMissingStreamName() {
assertThrown(identity())
@@ -407,6 +567,30 @@ public class KinesisIOWriteTest extends PutRecordsHelpers {
.hasMessage("maxBytes must be positive and <= " + MAX_BYTES_PER_RECORD);
}
+ private Shard shard(BigInteger lowerRange) {
+ return Shard.builder()
+ .hashKeyRange(HashKeyRange.builder().startingHashKey(lowerRange.toString()).build())
+ .build();
+ }
+
+ private ListShardsResponse listShardsResponse(String nextToken, Shard... shards) {
+ return ListShardsResponse.builder().shards(shards).nextToken(nextToken).build();
+ }
+
+ protected ArgumentMatcher<ListShardsRequest> isRequest(String stream, String nextToken) {
+ return req ->
+ req != null
+ && Objects.equal(stream, req.streamName())
+ && Objects.equal(nextToken, req.nextToken());
+ }
+
+ private void mockShardRanges(BigInteger... lowerBounds) {
+ List<Shard> shards = stream(lowerBounds).map(lower -> shard(lower)).collect(toList());
+
+ when(client.listShards(any(ListShardsRequest.class)))
+ .thenReturn(completedFuture(ListShardsResponse.builder().shards(shards).build()));
+ }
+
private ThrowableAssert assertThrown(Function<Write<TestRow>, Write<TestRow>> writeConfig) {
pipeline.enableAbandonedNodeEnforcement(false);
PCollection<TestRow> input = mock(PCollection.class);
@@ -422,8 +606,7 @@ public class KinesisIOWriteTest extends PutRecordsHelpers {
private Supplier<List<List<TestRow>>> captureBatchRecords(KinesisAsyncClient mock) {
ArgumentCaptor<PutRecordsRequest> cap = ArgumentCaptor.forClass(PutRecordsRequest.class);
- when(mock.putRecords(cap.capture()))
- .thenReturn(completedFuture(PutRecordsResponse.builder().build()));
+ when(mock.putRecords(cap.capture())).thenReturn(completedFuture(successResponse));
return () -> transform(cap.getAllValues(), req -> transform(req.records(), this::toTestRow));
}
diff --git a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/kinesis/PutRecordsHelpers.java b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/kinesis/PutRecordsHelpers.java
index 0a4976caa16..ae178118d51 100644
--- a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/kinesis/PutRecordsHelpers.java
+++ b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/kinesis/PutRecordsHelpers.java
@@ -40,6 +40,8 @@ import software.amazon.awssdk.services.kinesis.model.PutRecordsResultEntry;
public abstract class PutRecordsHelpers {
protected static final String ERROR_CODE = "ProvisionedThroughputExceededException";
+ PutRecordsResponse successResponse = PutRecordsResponse.builder().build();
+
protected PutRecordsRequest anyRequest() {
return any();
}
@@ -62,6 +64,12 @@ public abstract class PutRecordsHelpers {
&& transform(req.records(), r -> r.partitionKey()).containsAll(asList(partitions));
}
+ protected ArgumentMatcher<PutRecordsRequest> hasExplicitPartitions(String... partitions) {
+ return req ->
+ hasSize(partitions.length).matches(req)
+ && transform(req.records(), r -> r.explicitHashKey()).containsAll(asList(partitions));
+ }
+
protected PutRecordsResponse partialSuccessResponse(int successes, int errors) {
PutRecordsResultEntry e = PutRecordsResultEntry.builder().errorCode(ERROR_CODE).build();
PutRecordsResultEntry s = PutRecordsResultEntry.builder().build();
diff --git a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/kinesis/testing/KinesisIOIT.java b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/kinesis/testing/KinesisIOIT.java
index e049c437663..c100ddca7f7 100644
--- a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/kinesis/testing/KinesisIOIT.java
+++ b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/kinesis/testing/KinesisIOIT.java
@@ -18,7 +18,6 @@
package org.apache.beam.sdk.io.aws2.kinesis.testing;
import static java.nio.charset.StandardCharsets.UTF_8;
-import static org.apache.beam.sdk.io.aws2.kinesis.KinesisPartitioner.explicitRandomPartitioner;
import static org.testcontainers.containers.localstack.LocalStackContainer.Service.KINESIS;
import java.io.Serializable;
@@ -74,7 +73,7 @@ public class KinesisIOIT implements Serializable {
void setKinesisStream(String value);
@Description("Number of shards of stream")
- @Default.Integer(2)
+ @Default.Integer(8)
Integer getKinesisShards();
void setKinesisShards(Integer count);
@@ -120,7 +119,7 @@ public class KinesisIOIT implements Serializable {
KinesisIO.Write<TestRow> write =
KinesisIO.<TestRow>write()
.withStreamName(env.options().getKinesisStream())
- .withPartitioner(explicitRandomPartitioner(env.options().getKinesisShards()))
+ .withPartitioner(row -> row.name())
.withSerializer(testRowToBytes);
if (!options.getUseRecordAggregation()) {
write = write.withRecordAggregationDisabled();