You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by mm...@apache.org on 2022/04/07 16:28:59 UTC

[beam] branch master updated: [BEAM-14104] Support shard aware aggregation in Kinesis writer.

This is an automated email from the ASF dual-hosted git repository.

mmack pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 416fc9ba752 [BEAM-14104] Support shard aware aggregation in Kinesis writer.
     new ad4561ea5d8 Merge pull request #17113 from mosche/BEAM-14104-ShardAwareAggregation
416fc9ba752 is described below

commit 416fc9ba75251f17543f8556071885a026ef5745
Author: Moritz Mack <mm...@talend.com>
AuthorDate: Thu Mar 17 18:00:38 2022 +0100

    [BEAM-14104] Support shard aware aggregation in Kinesis writer.
---
 .../apache/beam/sdk/io/aws2/common/ClientPool.java | 123 ----------
 .../apache/beam/sdk/io/aws2/common/ObjectPool.java | 151 ++++++++++++
 .../sdk/io/aws2/common/RetryConfiguration.java     |   2 +-
 .../apache/beam/sdk/io/aws2/kinesis/KinesisIO.java | 270 +++++++++++++++++----
 .../sdk/io/aws2/kinesis/KinesisPartitioner.java    |  28 ++-
 .../sdk/io/aws2/kinesis/RecordsAggregator.java     |   4 -
 .../{ClientPoolTest.java => ObjectPoolTest.java}   |  99 ++++----
 .../sdk/io/aws2/kinesis/KinesisIOWriteTest.java    | 243 ++++++++++++++++---
 .../sdk/io/aws2/kinesis/PutRecordsHelpers.java     |   8 +
 .../sdk/io/aws2/kinesis/testing/KinesisIOIT.java   |   5 +-
 10 files changed, 677 insertions(+), 256 deletions(-)

diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/common/ClientPool.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/common/ClientPool.java
deleted file mode 100644
index 1a7cd29ec98..00000000000
--- a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/common/ClientPool.java
+++ /dev/null
@@ -1,123 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.sdk.io.aws2.common;
-
-import java.util.function.BiFunction;
-import org.apache.beam.sdk.io.aws2.options.AwsOptions;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.BiMap;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.HashBiMap;
-import org.apache.commons.lang3.tuple.Pair;
-import org.checkerframework.checker.nullness.qual.Nullable;
-import software.amazon.awssdk.awscore.client.builder.AwsClientBuilder;
-
-/**
- * Reference counting pool to easily share AWS clients or similar by individual client provider and
- * configuration (optional).
- *
- * <p>NOTE: This relies heavily on the implementation of {@link #equals(Object)} for {@link
- * ProviderT} and {@link ConfigT}. If not implemented properly, clients can't be shared between
- * instances of {@link org.apache.beam.sdk.transforms.DoFn}.
- *
- * @param <ProviderT> Client provider
- * @param <ConfigT> Optional, nullable configuration
- * @param <ClientT> Client
- */
-public class ClientPool<ProviderT, ConfigT, ClientT extends AutoCloseable> {
-  private final BiMap<Pair<ProviderT, ConfigT>, RefCounted> pool = HashBiMap.create(2);
-  private final BiFunction<ProviderT, ConfigT, ClientT> builder;
-
-  public static <
-          ClientT extends AutoCloseable, BuilderT extends AwsClientBuilder<BuilderT, ClientT>>
-      ClientPool<AwsOptions, ClientConfiguration, ClientT> pooledClientFactory(BuilderT builder) {
-    return new ClientPool<>((opts, conf) -> ClientBuilderFactory.buildClient(opts, builder, conf));
-  }
-
-  public ClientPool(BiFunction<ProviderT, ConfigT, ClientT> builder) {
-    this.builder = builder;
-  }
-
-  /** Retain a reference to a shared client instance. If not available, an instance is created. */
-  public ClientT retain(ProviderT provider, @Nullable ConfigT config) {
-    @SuppressWarnings("nullness")
-    Pair<ProviderT, ConfigT> key = Pair.of(provider, config);
-    synchronized (pool) {
-      RefCounted ref = pool.computeIfAbsent(key, RefCounted::new);
-      ref.count++;
-      return ref.client;
-    }
-  }
-
-  /**
-   * Release a reference to a shared client instance using {@link ProviderT} and {@link ConfigT} .
-   * If that instance is not used anymore, it will be removed and destroyed.
-   */
-  public void release(ProviderT provider, @Nullable ConfigT config) throws Exception {
-    @SuppressWarnings("nullness")
-    Pair<ProviderT, ConfigT> key = Pair.of(provider, config);
-    RefCounted ref;
-    synchronized (pool) {
-      ref = pool.get(key);
-      if (ref == null || --ref.count > 0) {
-        return;
-      }
-      pool.remove(key);
-    }
-    ref.client.close();
-  }
-
-  /**
-   * Release a reference to a shared client instance. If that instance is not used anymore, it will
-   * be removed and destroyed.
-   */
-  public void release(ClientT client) throws Exception {
-    Pair<ProviderT, ConfigT> pair = pool.inverse().get(new RefCounted(client));
-    if (pair != null) {
-      release(pair.getLeft(), pair.getRight());
-    }
-  }
-
-  private class RefCounted {
-    private int count = 0;
-    private final ClientT client;
-
-    RefCounted(ClientT client) {
-      this.client = client;
-    }
-
-    RefCounted(Pair<ProviderT, ConfigT> key) {
-      this(builder.apply(key.getLeft(), key.getRight()));
-    }
-
-    @Override
-    public boolean equals(@Nullable Object o) {
-      if (this == o) {
-        return true;
-      }
-      if (o == null || getClass() != o.getClass()) {
-        return false;
-      }
-      // only identity of ref counted client matters
-      return client == ((RefCounted) o).client;
-    }
-
-    @Override
-    public int hashCode() {
-      return client.hashCode();
-    }
-  }
-}
diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/common/ObjectPool.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/common/ObjectPool.java
new file mode 100644
index 00000000000..a17c6b56e5f
--- /dev/null
+++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/common/ObjectPool.java
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.aws2.common;
+
+import static org.apache.beam.sdk.io.aws2.common.ClientBuilderFactory.buildClient;
+
+import java.util.function.Function;
+import org.apache.beam.sdk.annotations.Experimental;
+import org.apache.beam.sdk.annotations.Internal;
+import org.apache.beam.sdk.function.ThrowingConsumer;
+import org.apache.beam.sdk.io.aws2.options.AwsOptions;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.BiMap;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.HashBiMap;
+import org.apache.commons.lang3.tuple.Pair;
+import org.checkerframework.checker.nullness.qual.NonNull;
+import org.checkerframework.checker.nullness.qual.Nullable;
+import org.slf4j.LoggerFactory;
+import software.amazon.awssdk.awscore.client.builder.AwsClientBuilder;
+import software.amazon.awssdk.core.SdkClient;
+
+/**
+ * Reference counting object pool to easily share & destroy objects.
+ *
+ * <p>Internal only, subject to incompatible changes or removal at any time!
+ *
+ * <p>NOTE: This relies heavily on the implementation of {@link #equals(Object)} for {@link KeyT}.
+ * If not implemented properly, clients can't be shared between instances of {@link
+ * org.apache.beam.sdk.transforms.DoFn}.
+ *
+ * @param <KeyT>> Key to share objects by
+ * @param <ObjectT>> Shared object
+ */
+@Internal
+@Experimental
+public class ObjectPool<KeyT extends @NonNull Object, ObjectT extends @NonNull Object> {
+  private final BiMap<KeyT, RefCounted> pool = HashBiMap.create(2);
+  private final Function<KeyT, ObjectT> builder;
+  private final @Nullable ThrowingConsumer<Exception, ObjectT> finalizer;
+
+  public ObjectPool(Function<KeyT, ObjectT> builder) {
+    this(builder, null);
+  }
+
+  public ObjectPool(
+      Function<KeyT, ObjectT> builder, @Nullable ThrowingConsumer<Exception, ObjectT> finalizer) {
+    this.builder = builder;
+    this.finalizer = finalizer;
+  }
+
+  /** Retain a reference to a shared client instance. If not available, an instance is created. */
+  public ObjectT retain(KeyT key) {
+    synchronized (pool) {
+      RefCounted ref = pool.computeIfAbsent(key, k -> new RefCounted(builder.apply(k)));
+      ref.count++;
+      return ref.shared;
+    }
+  }
+
+  /**
+   * Release a reference to a shared object instance using {@link KeyT}. If that instance is not
+   * used anymore, it will be removed and destroyed.
+   */
+  public void releaseByKey(KeyT key) {
+    RefCounted ref;
+    synchronized (pool) {
+      ref = pool.get(key);
+      if (ref == null || --ref.count > 0) {
+        return;
+      }
+      pool.remove(key);
+    }
+    if (finalizer != null) {
+      try {
+        finalizer.accept(ref.shared);
+      } catch (Exception e) {
+        LoggerFactory.getLogger(ObjectPool.class).warn("Exception destroying pooled object.", e);
+      }
+    }
+  }
+
+  /**
+   * Release a reference to a shared client instance. If that instance is not used anymore, it will
+   * be removed and destroyed.
+   */
+  public void release(ObjectT object) {
+    KeyT key = pool.inverse().get(new RefCounted(object));
+    if (key != null) {
+      releaseByKey(key);
+    }
+  }
+
+  public static <ClientT extends SdkClient, BuilderT extends AwsClientBuilder<BuilderT, ClientT>>
+      ClientPool<ClientT> pooledClientFactory(BuilderT builder) {
+    return new ClientPool<>(c -> buildClient(c.getLeft(), builder, c.getRight()));
+  }
+
+  /** Client pool to easily share AWS clients per configuration. */
+  public static class ClientPool<ClientT extends SdkClient>
+      extends ObjectPool<Pair<AwsOptions, ClientConfiguration>, ClientT> {
+
+    private ClientPool(Function<Pair<AwsOptions, ClientConfiguration>, ClientT> builder) {
+      super(builder, c -> c.close());
+    }
+
+    /** Retain a reference to a shared client instance. If not available, an instance is created. */
+    public ClientT retain(AwsOptions provider, ClientConfiguration config) {
+      return retain(Pair.of(provider, config));
+    }
+  }
+
+  private class RefCounted {
+    private int count = 0;
+    private final ObjectT shared;
+
+    RefCounted(ObjectT client) {
+      this.shared = client;
+    }
+
+    @Override
+    public boolean equals(@Nullable Object o) {
+      if (this == o) {
+        return true;
+      }
+      if (o == null || getClass() != o.getClass()) {
+        return false;
+      }
+      // only identity of ref counted shared object matters
+      return shared == ((RefCounted) o).shared;
+    }
+
+    @Override
+    public int hashCode() {
+      return shared.hashCode();
+    }
+  }
+}
diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/common/RetryConfiguration.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/common/RetryConfiguration.java
index 6ca342927a8..4a816cb3d91 100644
--- a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/common/RetryConfiguration.java
+++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/common/RetryConfiguration.java
@@ -68,7 +68,7 @@ public abstract class RetryConfiguration implements Serializable {
   public abstract RetryConfiguration.Builder toBuilder();
 
   public static Builder builder() {
-    return Builder.builder();
+    return Builder.builder().numRetries(3);
   }
 
   @AutoValue.Builder
diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/kinesis/KinesisIO.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/kinesis/KinesisIO.java
index 69f4f28427c..5042a574592 100644
--- a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/kinesis/KinesisIO.java
+++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/kinesis/KinesisIO.java
@@ -22,6 +22,7 @@ import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Prec
 import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState;
 import static org.apache.commons.lang3.ArrayUtils.EMPTY_BYTE_ARRAY;
 import static org.apache.commons.lang3.StringUtils.isEmpty;
+import static software.amazon.awssdk.services.kinesis.model.ShardFilterType.AT_LATEST;
 
 import com.google.auto.value.AutoValue;
 import java.io.Serializable;
@@ -33,17 +34,24 @@ import java.util.ArrayList;
 import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.NavigableSet;
+import java.util.TreeSet;
+import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.function.Consumer;
 import java.util.function.Function;
 import java.util.function.Supplier;
+import javax.annotation.concurrent.NotThreadSafe;
+import javax.annotation.concurrent.ThreadSafe;
 import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.annotations.Experimental;
 import org.apache.beam.sdk.annotations.Experimental.Kind;
 import org.apache.beam.sdk.io.Read.Unbounded;
 import org.apache.beam.sdk.io.aws2.common.ClientBuilderFactory;
 import org.apache.beam.sdk.io.aws2.common.ClientConfiguration;
-import org.apache.beam.sdk.io.aws2.common.ClientPool;
+import org.apache.beam.sdk.io.aws2.common.ObjectPool;
+import org.apache.beam.sdk.io.aws2.common.ObjectPool.ClientPool;
 import org.apache.beam.sdk.io.aws2.common.RetryConfiguration;
+import org.apache.beam.sdk.io.aws2.kinesis.KinesisPartitioner.ExplicitPartitioner;
 import org.apache.beam.sdk.io.aws2.options.AwsOptions;
 import org.apache.beam.sdk.metrics.Counter;
 import org.apache.beam.sdk.metrics.Distribution;
@@ -65,7 +73,9 @@ import org.apache.beam.sdk.values.PInput;
 import org.apache.beam.sdk.values.POutput;
 import org.apache.beam.sdk.values.PValue;
 import org.apache.beam.sdk.values.TupleTag;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSortedSet;
 import org.checkerframework.checker.nullness.qual.Nullable;
 import org.checkerframework.dataflow.qual.Pure;
 import org.joda.time.DateTimeUtils;
@@ -79,7 +89,9 @@ import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
 import software.amazon.awssdk.core.SdkBytes;
 import software.amazon.awssdk.regions.Region;
 import software.amazon.awssdk.services.kinesis.KinesisAsyncClient;
+import software.amazon.awssdk.services.kinesis.model.ListShardsRequest;
 import software.amazon.awssdk.services.kinesis.model.PutRecordsRequestEntry;
+import software.amazon.awssdk.services.kinesis.model.Shard;
 import software.amazon.kinesis.common.InitialPositionInStream;
 
 /**
@@ -173,7 +185,8 @@ import software.amazon.kinesis.common.InitialPositionInStream;
  * utilized at all.
  *
  * <p>If you require finer control over the distribution of records, override {@link
- * KinesisPartitioner#getExplicitHashKey(Object)} according to your needs.
+ * KinesisPartitioner#getExplicitHashKey(Object)} according to your needs. However, this might
+ * impact record aggregation.
  *
  * <h4>Aggregation of records</h4>
  *
@@ -182,10 +195,29 @@ import software.amazon.kinesis.common.InitialPositionInStream;
  * href="https://docs.aws.amazon.com/streams/latest/dev/kinesis-kpl-concepts.html#kinesis-kpl-concepts-aggretation">aggregated
  * KPL record</a>.
  *
- * <p>However, only records with the same effective hash key are aggregated, in which the effective
- * hash key is either the explicit hash key if defined, or otherwise the hashed partition key.
+ * <p>Records of the same effective hash key get aggregated. The effective hash key is:
  *
- * <p>Record aggregation can be explicitly disabled using {@link
+ * <ol>
+ *   <li>the explicit hash key, if provided.
+ *   <li>the lower bound of the hash key range of the target shard according to the given partition
+ *       key, if available.
+ *   <li>or otherwise the hashed partition key
+ * </ol>
+ *
+ * <p>To provide shard aware aggregation in 2., hash key ranges of shards are loaded and refreshed
+ * periodically. This allows to aggregate records into a number of aggregates that matches the
+ * number of shards in the stream to max out Kinesis API limits the best possible way.
+ *
+ * <p><b>Note:</b>There's an important downside to consider when using shard aware aggregation:
+ * records get assigned to a shard (via an explicit hash key) on the client side, but respective
+ * client side state can't be guaranteed to always be up-to-date. If a shard gets split, all
+ * aggregates are mapped to the lower child shard until state is refreshed. Timing, however, will
+ * diverge between the different workers.
+ *
+ * <p>If using an {@link ExplicitPartitioner} or disabling shard refresh via {@link
+ * RecordAggregation}, no shard details will be loaded (and used).
+ *
+ * <p>Record aggregation can be entirely disabled using {@link
  * Write#withRecordAggregationDisabled()}.
  *
  * <h3>Configuration of AWS clients</h3>
@@ -536,11 +568,30 @@ public final class KinesisIO {
 
     abstract double maxBufferedTimeJitter();
 
+    abstract Duration shardRefreshInterval();
+
+    abstract double shardRefreshIntervalJitter();
+
+    Instant nextBufferTimeout() {
+      return nextInstant(maxBufferedTime(), maxBufferedTimeJitter());
+    }
+
+    Instant nextShardRefresh() {
+      return nextInstant(shardRefreshInterval(), shardRefreshIntervalJitter());
+    }
+
+    private Instant nextInstant(Duration duration, double jitter) {
+      double millis = (1 - jitter + jitter * Math.random()) * duration.getMillis();
+      return Instant.ofEpochMilli(DateTimeUtils.currentTimeMillis() + (long) millis);
+    }
+
     public static Builder builder() {
       return new AutoValue_KinesisIO_RecordAggregation.Builder()
           .maxBytes(Write.MAX_BYTES_PER_RECORD)
           .maxBufferedTimeJitter(0.7) // 70% jitter
-          .maxBufferedTime(Duration.standardSeconds(1));
+          .maxBufferedTime(Duration.millis(500))
+          .shardRefreshIntervalJitter(0.5) // 50% jitter
+          .shardRefreshInterval(Duration.standardMinutes(2));
     }
 
     @AutoValue.Builder
@@ -556,8 +607,19 @@ public final class KinesisIO {
        */
       public abstract Builder maxBufferedTime(Duration interval);
 
+      /**
+       * Refresh interval for shards.
+       *
+       * <p>This is used for shard aware record aggregation to assign all records hashed to a
+       * particular shard to the same explicit hash key. Set to {@link Duration#ZERO} to disable
+       * loading shards.
+       */
+      public abstract Builder shardRefreshInterval(Duration interval);
+
       abstract Builder maxBufferedTimeJitter(double jitter);
 
+      abstract Builder shardRefreshIntervalJitter(double jitter);
+
       abstract RecordAggregation autoBuild();
 
       public RecordAggregation build() {
@@ -670,9 +732,6 @@ public final class KinesisIO {
      * Enable record aggregation that is compatible with the KPL / KCL.
      *
      * <p>https://docs.aws.amazon.com/streams/latest/dev/kinesis-kpl-concepts.html#kinesis-kpl-concepts-aggretation
-     *
-     * <p>Note: The aggregation is a lot simpler than the one offered by KPL. It only aggregates
-     * records with the same partition key as it's not aware of explicit hash key ranges per shard.
      */
     public Write<T> withRecordAggregation(RecordAggregation aggregation) {
       return builder().recordAggregation(aggregation).build();
@@ -682,9 +741,6 @@ public final class KinesisIO {
      * Enable record aggregation that is compatible with the KPL / KCL.
      *
      * <p>https://docs.aws.amazon.com/streams/latest/dev/kinesis-kpl-concepts.html#kinesis-kpl-concepts-aggretation
-     *
-     * <p>Note: The aggregation is a lot simpler than the one offered by KPL. It only aggregates
-     * records with the same partition key as it's not aware of explicit hash key ranges per shard.
      */
     public Write<T> withRecordAggregation(Consumer<RecordAggregation.Builder> aggregation) {
       RecordAggregation.Builder builder = RecordAggregation.builder();
@@ -803,14 +859,14 @@ public final class KinesisIO {
 
       private static final int PARTIAL_RETRIES = 10; // Retries for partial success (throttling)
 
-      private static final ClientPool<AwsOptions, ClientConfiguration, KinesisAsyncClient> CLIENTS =
-          ClientPool.pooledClientFactory(KinesisAsyncClient.builder());
+      private static final ClientPool<KinesisAsyncClient> CLIENTS =
+          ObjectPool.pooledClientFactory(KinesisAsyncClient.builder());
 
       protected final Write<T> spec;
       protected final Stats stats;
       protected final AsyncPutRecordsHandler handler;
+      protected final KinesisAsyncClient kinesis;
 
-      private final KinesisAsyncClient kinesis;
       private List<PutRecordsRequestEntry> requestEntries;
       private int requestBytes = 0;
 
@@ -947,25 +1003,37 @@ public final class KinesisIO {
      * with KCL to correctly implement the binary protocol, specifically {@link
      * software.amazon.kinesis.retrieval.kpl.Messages.AggregatedRecord}.
      *
-     * <p>Note: The aggregation is a lot simpler than the one offered by KPL. While the KPL is aware
-     * of effective hash key ranges assigned to each shard, we're not and don't want to be to keep
-     * complexity manageable and avoid the risk of silently loosing records in the KCL:
+     * <p>To aggregate records the best possible way, records are assigned an explicit hash key that
+     * corresponds to the lower bound of the hash key range of the target shard. In case a record
+     * has already an explicit hash key assigned, it is kept unchanged.
      *
-     * <p>{@link software.amazon.kinesis.retrieval.AggregatorUtil#deaggregate(List, BigInteger,
-     * BigInteger)} drops records not matching the expected hash key range.
+     * <p>Hash key ranges of shards are expected to be only slowly changing and get refreshed
+     * infrequently. If using an {@link ExplicitPartitioner} or disabling shard refresh via {@link
+     * RecordAggregation}, no shard details will be pulled.
      */
     static class AggregatedWriter<T> extends Writer<T> {
       private static final Logger LOG = LoggerFactory.getLogger(AggregatedWriter.class);
+      private static final ObjectPool<String, ShardRanges> SHARD_RANGES_BY_STREAM =
+          new ObjectPool<>(ShardRanges::of);
 
       private final RecordAggregation aggSpec;
       private final Map<BigInteger, RecordsAggregator> aggregators;
-      private final MessageDigest md5Digest;
+      private final PartitionKeyHasher pkHasher;
+
+      private final ShardRanges shardRanges;
 
       AggregatedWriter(PipelineOptions options, Write<T> spec, RecordAggregation aggSpec) {
         super(options, spec);
         this.aggSpec = aggSpec;
-        this.aggregators = new LinkedHashMap<>();
-        this.md5Digest = md5Digest();
+        aggregators = new LinkedHashMap<>();
+        pkHasher = new PartitionKeyHasher();
+        if (aggSpec.shardRefreshInterval().isLongerThan(Duration.ZERO)
+            && !(spec.partitioner() instanceof ExplicitPartitioner)) {
+          shardRanges = SHARD_RANGES_BY_STREAM.retain(spec.streamName());
+          shardRanges.refreshPeriodically(kinesis, aggSpec::nextShardRefresh);
+        } else {
+          shardRanges = ShardRanges.EMPTY;
+        }
       }
 
       @Override
@@ -977,20 +1045,36 @@ public final class KinesisIO {
       @Override
       protected void write(String partitionKey, @Nullable String explicitHashKey, byte[] data)
           throws Throwable {
-        BigInteger hashKey = effectiveHashKey(partitionKey, explicitHashKey);
-        RecordsAggregator agg = aggregators.computeIfAbsent(hashKey, k -> newRecordsAggregator());
+        shardRanges.refreshPeriodically(kinesis, aggSpec::nextShardRefresh);
+
+        // calculate the effective hash key used for aggregation
+        BigInteger aggKey;
+        if (explicitHashKey != null) {
+          aggKey = new BigInteger(explicitHashKey);
+        } else {
+          BigInteger hashedPartitionKey = pkHasher.hashKey(partitionKey);
+          aggKey = shardRanges.shardAwareHashKey(hashedPartitionKey);
+          if (aggKey != null) {
+            // use the shard aware aggregation key as explicit hash key for optimal aggregation
+            explicitHashKey = aggKey.toString();
+          } else {
+            aggKey = hashedPartitionKey;
+          }
+        }
+
+        RecordsAggregator agg = aggregators.computeIfAbsent(aggKey, k -> newRecordsAggregator());
         if (!agg.addRecord(partitionKey, explicitHashKey, data)) {
           // aggregated record too full, add a request entry and reset aggregator
-          addRequestEntry(agg.getAndReset(aggregationTimeoutWithJitter()));
-          aggregators.remove(hashKey);
+          addRequestEntry(agg.getAndReset(aggSpec.nextBufferTimeout()));
+          aggregators.remove(aggKey);
           if (agg.addRecord(partitionKey, explicitHashKey, data)) {
-            aggregators.put(hashKey, agg); // new aggregation started
+            aggregators.put(aggKey, agg); // new aggregation started
           } else {
             super.write(partitionKey, explicitHashKey, data); // skip aggregation
           }
         } else if (!agg.hasCapacity()) {
           addRequestEntry(agg.get());
-          aggregators.remove(hashKey);
+          aggregators.remove(aggKey);
         }
 
         // only check timeouts sporadically if concurrency is already maxed out
@@ -1001,14 +1085,7 @@ public final class KinesisIO {
 
       private RecordsAggregator newRecordsAggregator() {
         return new RecordsAggregator(
-            Math.min(aggSpec.maxBytes(), spec.batchMaxBytes()), aggregationTimeoutWithJitter());
-      }
-
-      private Instant aggregationTimeoutWithJitter() {
-        double millis =
-            (1 - aggSpec.maxBufferedTimeJitter() + aggSpec.maxBufferedTimeJitter() * Math.random())
-                * aggSpec.maxBufferedTime().getMillis();
-        return Instant.ofEpochMilli(DateTimeUtils.currentTimeMillis() + (long) millis);
+            Math.min(aggSpec.maxBytes(), spec.batchMaxBytes()), aggSpec.nextBufferTimeout());
       }
 
       private void checkAggregationTimeouts() throws Throwable {
@@ -1021,9 +1098,8 @@ public final class KinesisIO {
           if (agg.timeout().isAfter(now)) {
             break;
           }
-          LOG.debug(
-              "Adding aggregated entry after timeout [delay = {} ms]",
-              now.getMillis() - agg.timeout().getMillis());
+          long delayMillis = now.getMillis() - agg.timeout().getMillis();
+          LOG.debug("Adding aggregated entry after timeout [delay = {} ms]", delayMillis);
           addRequestEntry(agg.get());
           removals.add(e.getKey());
         }
@@ -1041,16 +1117,23 @@ public final class KinesisIO {
         super.finishBundle();
       }
 
-      private BigInteger effectiveHashKey(String partitionKey, @Nullable String explicitHashKey) {
-        return explicitHashKey == null
-            ? new BigInteger(1, md5(partitionKey.getBytes(UTF_8)))
-            : new BigInteger(explicitHashKey);
+      @Override
+      public void close() throws Exception {
+        super.close();
+        SHARD_RANGES_BY_STREAM.release(shardRanges);
       }
+    }
 
-      private byte[] md5(byte[] data) {
-        byte[] hash = md5Digest.digest(data);
+    @VisibleForTesting
+    @NotThreadSafe
+    static class PartitionKeyHasher {
+      private final MessageDigest md5Digest = md5Digest();
+
+      /** Hash partition key to 128 bit integer. */
+      BigInteger hashKey(String partitionKey) {
+        byte[] hashedBytes = md5Digest.digest(partitionKey.getBytes(UTF_8));
         md5Digest.reset();
-        return hash;
+        return new BigInteger(1, hashedBytes);
       }
 
       private static MessageDigest md5Digest() {
@@ -1062,6 +1145,99 @@ public final class KinesisIO {
       }
     }
 
+    /** Shard hash ranges per stream to generate shard aware hash keys for record aggregation. */
+    @VisibleForTesting
+    @ThreadSafe
+    interface ShardRanges {
+      ShardRanges EMPTY = new ShardRanges() {};
+
+      static ShardRanges of(String stream) {
+        return new ShardRangesImpl(stream);
+      }
+
+      /**
+       * Align partition key hash to lower bound of key range of the target shard. If unavailable
+       * {@code null} is returned.
+       */
+      default @Nullable BigInteger shardAwareHashKey(BigInteger hashedPartitionKey) {
+        return null;
+      }
+
+      /** Check for and trigger periodic refresh if needed. */
+      default void refreshPeriodically(
+          KinesisAsyncClient kinesis, Supplier<Instant> nextRefreshFn) {}
+
+      class ShardRangesImpl implements ShardRanges {
+        private static final Logger LOG = LoggerFactory.getLogger(ShardRanges.class);
+
+        private final String streamName;
+
+        private final AtomicBoolean running = new AtomicBoolean(false);
+        private NavigableSet<BigInteger> shardBounds = ImmutableSortedSet.of();
+        private Instant nextRefresh = Instant.EPOCH;
+
+        private ShardRangesImpl(String streamName) {
+          this.streamName = streamName;
+        }
+
+        @Override
+        public @Nullable BigInteger shardAwareHashKey(BigInteger hashedPartitionKey) {
+          BigInteger lowerBound = shardBounds.floor(hashedPartitionKey);
+          if (!shardBounds.isEmpty() && lowerBound == null) {
+            LOG.warn("No shard found for {} [shards={}]", hashedPartitionKey, shardBounds.size());
+          }
+          return lowerBound;
+        }
+
+        @Override
+        public void refreshPeriodically(
+            KinesisAsyncClient client, Supplier<Instant> nextRefreshFn) {
+          if (nextRefresh.isBeforeNow() && running.compareAndSet(false, true)) {
+            refresh(client, nextRefreshFn, new TreeSet<>(), null);
+          }
+        }
+
+        @SuppressWarnings("FutureReturnValueIgnored") // safe to ignore
+        private void refresh(
+            KinesisAsyncClient client,
+            Supplier<Instant> nextRefreshFn,
+            TreeSet<BigInteger> bounds,
+            @Nullable String nextToken) {
+          ListShardsRequest.Builder reqBuilder =
+              ListShardsRequest.builder().shardFilter(f -> f.type(AT_LATEST));
+          if (nextToken != null) {
+            reqBuilder.nextToken(nextToken);
+          } else {
+            reqBuilder.streamName(streamName);
+          }
+          client
+              .listShards(reqBuilder.build())
+              .whenComplete(
+                  (resp, exc) -> {
+                    if (exc != null) {
+                      LOG.warn("Failed to refresh shards.", exc);
+                      nextRefresh = nextRefreshFn.get(); // retry later
+                      running.set(false);
+                      return;
+                    }
+                    resp.shards().forEach(shard -> bounds.add(lowerHashKey(shard)));
+                    if (resp.nextToken() != null) {
+                      refresh(client, nextRefreshFn, bounds, resp.nextToken());
+                      return;
+                    }
+                    LOG.debug("Done refreshing {} shards.", bounds.size());
+                    nextRefresh = nextRefreshFn.get();
+                    running.set(false);
+                    shardBounds = bounds; // swap key ranges
+                  });
+        }
+
+        private BigInteger lowerHashKey(Shard shard) {
+          return new BigInteger(shard.hashKeyRange().startingHashKey());
+        }
+      }
+    }
+
     private static class Stats implements AsyncPutRecordsHandler.Stats {
       private static final Logger LOG = LoggerFactory.getLogger(Stats.class);
       private static final Duration LOG_STATS_PERIOD = Duration.standardSeconds(10);
diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/kinesis/KinesisPartitioner.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/kinesis/KinesisPartitioner.java
index 99d5f915446..eaf9b0b7607 100644
--- a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/kinesis/KinesisPartitioner.java
+++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/kinesis/KinesisPartitioner.java
@@ -47,6 +47,26 @@ public interface KinesisPartitioner<T> extends Serializable {
     return null;
   }
 
+  /**
+   * An explicit partitioner that always returns a {@code Nonnull} explicit hash key. The partition
+   * key is irrelevant in this case, though it cannot be {@code null}.
+   */
+  interface ExplicitPartitioner<T> extends KinesisPartitioner<T> {
+    @Override
+    default @Nonnull String getPartitionKey(T record) {
+      return "a"; // will be ignored, but can't be null or empty
+    }
+
+    /**
+     * Required hash value (128-bit integer) to determine explicitly the shard a record is assigned
+     * to based on the hash key range of each shard. The explicit hash key overrides the partition
+     * key hash.
+     */
+    @Override
+    @Nonnull
+    String getExplicitHashKey(T record);
+  }
+
   /**
    * Explicit hash key partitioner that randomly returns one of x precalculated hash keys. Hash keys
    * are derived by equally dividing the 128-bit hash universe, assuming that hash ranges of shards
@@ -70,15 +90,9 @@ public interface KinesisPartitioner<T> extends Serializable {
       hashKey = hashKey.add(distance);
     }
 
-    return new KinesisPartitioner<T>() {
+    return new ExplicitPartitioner<T>() {
       @Nonnull
       @Override
-      public String getPartitionKey(T record) {
-        return "a"; // ignored, but can't be null
-      }
-
-      @Nullable
-      @Override
       public String getExplicitHashKey(T record) {
         return hashKeys[new Random().nextInt(shards)];
       }
diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/kinesis/RecordsAggregator.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/kinesis/RecordsAggregator.java
index 5694ec7f14f..bcf546e792e 100644
--- a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/kinesis/RecordsAggregator.java
+++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/kinesis/RecordsAggregator.java
@@ -41,10 +41,6 @@ import software.amazon.kinesis.retrieval.kpl.Messages.AggregatedRecord;
  * Record aggregator compatible with the record (de)aggregation of the Kinesis Producer Library
  * (KPL) and Kinesis Client Library (KCL).
  *
- * <p>However, only records with the same effective hash key should be aggregated to keep complexity
- * manageable. Otherwise, the aggregator would have to be aware of the most up-to-date explicit hash
- * key ranges per shard.
- *
  * <p>https://docs.aws.amazon.com/streams/latest/dev/kinesis-kpl-concepts.html#kinesis-kpl-concepts-aggretation
  */
 @NotThreadSafe
diff --git a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/common/ClientPoolTest.java b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/common/ObjectPoolTest.java
similarity index 63%
rename from sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/common/ClientPoolTest.java
rename to sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/common/ObjectPoolTest.java
index 228e214f8b3..154957ee3c0 100644
--- a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/common/ClientPoolTest.java
+++ b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/common/ObjectPoolTest.java
@@ -20,10 +20,10 @@ package org.apache.beam.sdk.io.aws2.common;
 import static java.util.concurrent.ForkJoinPool.commonPool;
 import static java.util.stream.Collectors.toList;
 import static org.assertj.core.api.Assertions.assertThat;
-import static org.assertj.core.api.Assertions.assertThatThrownBy;
 import static org.mockito.ArgumentMatchers.anyString;
 import static org.mockito.Mockito.doThrow;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.verifyNoInteractions;
@@ -34,30 +34,24 @@ import java.util.concurrent.Callable;
 import java.util.concurrent.ForkJoinTask;
 import java.util.function.Function;
 import java.util.stream.Stream;
-import org.junit.Before;
+import org.apache.beam.sdk.testing.ExpectedLogs;
+import org.junit.Rule;
 import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.mockito.Spy;
-import org.mockito.junit.MockitoJUnitRunner;
-
-@RunWith(MockitoJUnitRunner.class)
-public class ClientPoolTest {
-  @Spy ClientProvider provider = new ClientProvider();
-  ClientPool<Function<String, AutoCloseable>, String, AutoCloseable> pool;
-
-  @Before
-  public void init() {
-    pool = new ClientPool<>((p, c) -> p.apply(c));
-  }
+
+public class ObjectPoolTest {
+  Function<String, AutoCloseable> provider = spy(new Provider());
+  ObjectPool<String, AutoCloseable> pool = new ObjectPool<>(provider, obj -> obj.close());
+
+  @Rule public ExpectedLogs logs = ExpectedLogs.none(ObjectPool.class);
 
   class ResourceTask implements Callable<AutoCloseable> {
     @Override
-    public AutoCloseable call() throws Exception {
-      AutoCloseable client = pool.retain(provider, "config");
-      pool.retain(provider, "config");
-      pool.release(provider, "config");
+    public AutoCloseable call() {
+      AutoCloseable client = pool.retain("config");
+      pool.retain("config");
+      pool.release(client);
       verifyNoInteractions(client);
-      pool.release(provider, "config");
+      pool.release(client);
       return client;
     }
   }
@@ -83,8 +77,8 @@ public class ClientPoolTest {
     String config1 = "config1";
     String config2 = "config2";
 
-    assertThat(pool.retain(provider, config1)).isSameAs(pool.retain(provider, config1));
-    assertThat(pool.retain(provider, config1)).isNotSameAs(pool.retain(provider, config2));
+    assertThat(pool.retain(config1)).isSameAs(pool.retain(config1));
+    assertThat(pool.retain(config1)).isNotSameAs(pool.retain(config2));
     verify(provider, times(2)).apply(anyString());
     verify(provider, times(1)).apply(config1);
     verify(provider, times(1)).apply(config2);
@@ -97,47 +91,70 @@ public class ClientPoolTest {
 
     AutoCloseable client = null;
     for (int i = 0; i < sharedInstances; i++) {
-      client = pool.retain(provider, config);
+      client = pool.retain(config);
     }
 
-    for (int i = 1; i < sharedInstances; i++) {
-      pool.release(provider, config);
+    for (int i = 0; i < sharedInstances - 1; i++) {
+      pool.release(client);
     }
     verifyNoInteractions(client);
     // verify close on last release
-    pool.release(provider, config);
+    pool.release(client);
     verify(client).close();
     // verify further attempts to release have no effect
-    pool.release(provider, config);
+    pool.release(client);
     verifyNoMoreInteractions(client);
   }
 
   @Test
-  public void recreateClientOnceReleased() throws Exception {
+  public void closeClientsOnceReleasedByKey() throws Exception {
     String config = "config";
-    AutoCloseable client1 = pool.retain(provider, config);
-    pool.release(provider, config);
-    AutoCloseable client2 = pool.retain(provider, config);
+    int sharedInstances = 10;
 
-    verify(provider, times(2)).apply(config);
+    AutoCloseable client = null;
+    for (int i = 0; i < sharedInstances; i++) {
+      client = pool.retain(config);
+    }
+
+    for (int i = 0; i < sharedInstances - 1; i++) {
+      pool.releaseByKey(config);
+    }
+    verifyNoInteractions(client);
+    // verify close on last release
+    pool.releaseByKey(config);
+    verify(client).close();
+    // verify further attempts to release have no effect
+    pool.releaseByKey(config);
+    verifyNoMoreInteractions(client);
+  }
+
+  @Test
+  public void recreateClientOnceReleased() throws Exception {
+    String config = "config";
+    AutoCloseable client1 = pool.retain(config);
+    pool.release(client1);
     verify(client1).close();
+
+    AutoCloseable client2 = pool.retain(config);
     verifyNoInteractions(client2);
+
+    verify(provider, times(2)).apply(config);
+    assertThat(client1).isNotSameAs(client2);
   }
 
   @Test
   public void releaseWithError() throws Exception {
-    String config = "config";
-    AutoCloseable client1 = pool.retain(provider, config);
-    doThrow(new Exception("error on close")).when(client1).close();
-    assertThatThrownBy(() -> pool.release(provider, config)).hasMessage("error on close");
+    Exception onClose = new Exception("error on close");
 
-    AutoCloseable client2 = pool.retain(provider, config);
-    verify(provider, times(2)).apply(config);
-    verify(client1).close();
-    verifyNoInteractions(client2);
+    AutoCloseable client = pool.retain("config");
+    doThrow(onClose).when(client).close();
+    pool.release(client);
+
+    verify(client).close();
+    logs.verifyWarn("Exception destroying pooled object.", onClose);
   }
 
-  static class ClientProvider implements Function<String, AutoCloseable> {
+  static class Provider implements Function<String, AutoCloseable> {
     @Override
     public AutoCloseable apply(String configName) {
       return mock(AutoCloseable.class, configName);
diff --git a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/kinesis/KinesisIOWriteTest.java b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/kinesis/KinesisIOWriteTest.java
index 1db39161306..e5418f4ff14 100644
--- a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/kinesis/KinesisIOWriteTest.java
+++ b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/kinesis/KinesisIOWriteTest.java
@@ -17,18 +17,26 @@
  */
 package org.apache.beam.sdk.io.aws2.kinesis;
 
+import static java.math.BigInteger.ONE;
+import static java.util.Arrays.stream;
 import static java.util.concurrent.CompletableFuture.completedFuture;
 import static java.util.concurrent.CompletableFuture.supplyAsync;
 import static java.util.function.Function.identity;
+import static java.util.stream.Collectors.toList;
 import static org.apache.beam.sdk.io.aws2.kinesis.KinesisIO.Write.MAX_BYTES_PER_RECORD;
 import static org.apache.beam.sdk.io.aws2.kinesis.KinesisIO.Write.MAX_BYTES_PER_REQUEST;
 import static org.apache.beam.sdk.io.aws2.kinesis.KinesisIO.Write.MAX_RECORDS_PER_REQUEST;
+import static org.apache.beam.sdk.io.aws2.kinesis.KinesisPartitioner.MAX_HASH_KEY;
+import static org.apache.beam.sdk.io.aws2.kinesis.KinesisPartitioner.MIN_HASH_KEY;
+import static org.apache.beam.sdk.io.aws2.kinesis.KinesisPartitioner.explicitRandomPartitioner;
 import static org.apache.beam.sdk.io.common.TestRow.getExpectedValues;
 import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkNotNull;
 import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables.concat;
 import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists.transform;
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
+import static org.joda.time.Duration.ZERO;
+import static org.joda.time.Duration.millis;
 import static org.joda.time.Duration.standardSeconds;
 import static org.mockito.AdditionalMatchers.and;
 import static org.mockito.ArgumentMatchers.any;
@@ -40,14 +48,20 @@ import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.verifyNoMoreInteractions;
 import static org.mockito.Mockito.when;
 
+import java.math.BigInteger;
 import java.util.List;
+import java.util.concurrent.CompletableFuture;
 import java.util.function.Function;
 import java.util.function.Supplier;
 import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.io.GenerateSequence;
 import org.apache.beam.sdk.io.aws2.MockClientBuilderFactory;
+import org.apache.beam.sdk.io.aws2.common.ClientConfiguration;
+import org.apache.beam.sdk.io.aws2.common.RetryConfiguration;
 import org.apache.beam.sdk.io.aws2.kinesis.KinesisIO.Write;
 import org.apache.beam.sdk.io.aws2.kinesis.KinesisIO.Write.AggregatedWriter;
+import org.apache.beam.sdk.io.aws2.kinesis.KinesisIO.Write.PartitionKeyHasher;
+import org.apache.beam.sdk.io.aws2.kinesis.KinesisIO.Write.ShardRanges;
 import org.apache.beam.sdk.io.common.TestRow;
 import org.apache.beam.sdk.testing.TestPipeline;
 import org.apache.beam.sdk.transforms.Create;
@@ -55,22 +69,27 @@ import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.SerializableFunction;
 import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Objects;
 import org.assertj.core.api.ThrowableAssert;
 import org.joda.time.DateTimeUtils;
-import org.joda.time.Duration;
+import org.joda.time.Instant;
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.mockito.ArgumentCaptor;
+import org.mockito.ArgumentMatcher;
 import org.mockito.InOrder;
 import org.mockito.Mock;
 import org.mockito.junit.MockitoJUnitRunner;
 import software.amazon.awssdk.services.kinesis.KinesisAsyncClient;
 import software.amazon.awssdk.services.kinesis.KinesisAsyncClientBuilder;
+import software.amazon.awssdk.services.kinesis.model.HashKeyRange;
+import software.amazon.awssdk.services.kinesis.model.ListShardsRequest;
+import software.amazon.awssdk.services.kinesis.model.ListShardsResponse;
 import software.amazon.awssdk.services.kinesis.model.PutRecordsRequest;
-import software.amazon.awssdk.services.kinesis.model.PutRecordsResponse;
+import software.amazon.awssdk.services.kinesis.model.Shard;
 
 /** Tests for {@link KinesisIO#write()}. */
 @RunWith(MockitoJUnitRunner.StrictStubs.class)
@@ -82,8 +101,12 @@ public class KinesisIOWriteTest extends PutRecordsHelpers {
   @Mock public KinesisAsyncClient client;
 
   @Before
-  public void configureClientBuilderFactory() {
+  public void configure() {
     MockClientBuilderFactory.set(pipeline, KinesisAsyncClientBuilder.class, client);
+
+    CompletableFuture<ListShardsResponse> errorResp = new CompletableFuture<>();
+    errorResp.completeExceptionally(new RuntimeException("Unavailable, retried later"));
+    when(client.listShards(any(ListShardsRequest.class))).thenReturn(errorResp);
   }
 
   @After
@@ -158,11 +181,11 @@ public class KinesisIOWriteTest extends PutRecordsHelpers {
 
   @Test
   public void testWriteFailure() {
-    when(client.putRecords(any(PutRecordsRequest.class)))
+    when(client.putRecords(anyRequest()))
         .thenReturn(
-            completedFuture(PutRecordsResponse.builder().build()),
+            completedFuture(successResponse),
             supplyAsync(() -> checkNotNull(null, "putRecords failed")),
-            completedFuture(PutRecordsResponse.builder().build()));
+            completedFuture(successResponse));
 
     pipeline
         .apply(GenerateSequence.from(0).to(100))
@@ -178,17 +201,20 @@ public class KinesisIOWriteTest extends PutRecordsHelpers {
 
   @Test
   public void testWriteWithPartialSuccess() {
-    when(client.putRecords(any(PutRecordsRequest.class)))
+    when(client.putRecords(anyRequest()))
         .thenReturn(completedFuture(partialSuccessResponse(70, 30)))
         .thenReturn(completedFuture(partialSuccessResponse(10, 20)))
-        .thenReturn(completedFuture(PutRecordsResponse.builder().build()));
+        .thenReturn(completedFuture(successResponse));
+
+    // minimize delay due to retries
+    RetryConfiguration retry = RetryConfiguration.builder().maxBackoff(millis(1)).build();
 
     pipeline
         .apply(Create.of(100))
         .apply(ParDo.of(new GenerateTestRows()))
         .apply(
             kinesisWrite()
-                // .withRetryConfiguration(RetryConfiguration.fixed(10, Duration.millis(1)))
+                .withClientConfiguration(ClientConfiguration.builder().retry(retry).build())
                 .withRecordAggregationDisabled());
 
     pipeline.run().waitUntilFinish();
@@ -201,9 +227,8 @@ public class KinesisIOWriteTest extends PutRecordsHelpers {
   }
 
   @Test
-  public void testWriteAggregated() {
-    when(client.putRecords(any(PutRecordsRequest.class)))
-        .thenReturn(completedFuture(PutRecordsResponse.builder().build()));
+  public void testWriteAggregatedByDefault() {
+    when(client.putRecords(anyRequest())).thenReturn(completedFuture(successResponse));
 
     pipeline
         .apply(Create.of(100))
@@ -215,10 +240,82 @@ public class KinesisIOWriteTest extends PutRecordsHelpers {
     verify(client).close();
   }
 
+  @Test
+  public void testWriteAggregatedShardAware() {
+    mockShardRanges(MIN_HASH_KEY, MAX_HASH_KEY.shiftRight(1)); // 2 shards
+    when(client.putRecords(anyRequest())).thenReturn(completedFuture(successResponse));
+
+    pipeline
+        .apply(Create.of(100))
+        .apply(ParDo.of(new GenerateTestRows()))
+        .apply(kinesisWrite().withPartitioner(row -> row.id().toString()));
+
+    pipeline.run().waitUntilFinish();
+    verify(client).putRecords(argThat(hasSize(2))); // 1 aggregated record per shard
+    verify(client).listShards(any(ListShardsRequest.class));
+    verify(client).close();
+  }
+
+  @Test
+  public void testWriteAggregatedShardRefreshPending() {
+    CompletableFuture<ListShardsResponse> resp = new CompletableFuture<>();
+    when(client.listShards(any(ListShardsRequest.class))).thenReturn(resp);
+
+    when(client.putRecords(anyRequest())).thenReturn(completedFuture(successResponse));
+
+    pipeline
+        .apply(Create.of(100))
+        .apply(ParDo.of(new GenerateTestRows()))
+        .apply(kinesisWrite().withPartitioner(row -> row.id().toString()));
+
+    pipeline.run().waitUntilFinish();
+    resp.complete(ListShardsResponse.builder().build()); // complete list shards after pipeline
+
+    // while shards are unknown, each row is aggregated into an individual aggregated record
+    verify(client).putRecords(argThat(hasSize(100)));
+    verify(client).listShards(any(ListShardsRequest.class));
+    verify(client).close();
+  }
+
+  @Test
+  public void testWriteAggregatedShardRefreshDisabled() {
+    when(client.putRecords(anyRequest())).thenReturn(completedFuture(successResponse));
+
+    pipeline
+        .apply(Create.of(100))
+        .apply(ParDo.of(new GenerateTestRows()))
+        .apply(
+            kinesisWrite()
+                .withRecordAggregation(b -> b.shardRefreshInterval(ZERO)) // disable refresh
+                .withPartitioner(row -> row.id().toString()));
+
+    pipeline.run().waitUntilFinish();
+
+    // each row is aggregated into an individual aggregated record
+    verify(client).putRecords(argThat(hasSize(100)));
+    verify(client, times(0)).listShards(any(ListShardsRequest.class)); // disabled
+    verify(client).close();
+  }
+
+  @Test
+  public void testWriteAggregatedUsingExplicitPartitioner() {
+    when(client.putRecords(anyRequest())).thenReturn(completedFuture(successResponse));
+
+    pipeline
+        .apply(Create.of(100))
+        .apply(ParDo.of(new GenerateTestRows()))
+        .apply(kinesisWrite().withPartitioner(explicitRandomPartitioner(2)));
+
+    pipeline.run().waitUntilFinish();
+    verify(client).putRecords(argThat(hasSize(2))); // configuration of partitioner
+    verify(client, times(0))
+        .listShards(any(ListShardsRequest.class)); // disabled for explicit partitioner
+    verify(client).close();
+  }
+
   @Test
   public void testWriteAggregatedWithMaxBytes() {
-    when(client.putRecords(any(PutRecordsRequest.class)))
-        .thenReturn(completedFuture(PutRecordsResponse.builder().build()));
+    when(client.putRecords(anyRequest())).thenReturn(completedFuture(successResponse));
 
     // overhead protocol + key overhead + 500 records, each 4 bytes data + overhead
     final int expectedBytes = 20 + 3 + 500 * 10;
@@ -241,8 +338,7 @@ public class KinesisIOWriteTest extends PutRecordsHelpers {
 
   @Test
   public void testWriteAggregatedWithMaxBytesAndBatchMaxBytes() {
-    when(client.putRecords(any(PutRecordsRequest.class)))
-        .thenReturn(completedFuture(PutRecordsResponse.builder().build()));
+    when(client.putRecords(anyRequest())).thenReturn(completedFuture(successResponse));
 
     // overhead protocol + key overhead + 500 records, each 4 bytes data + overhead
     final int expectedBytes = 20 + 3 + 500 * 10;
@@ -267,8 +363,7 @@ public class KinesisIOWriteTest extends PutRecordsHelpers {
 
   @Test
   public void testWriteAggregatedWithMaxBytesAndBatchMaxRecords() {
-    when(client.putRecords(any(PutRecordsRequest.class)))
-        .thenReturn(completedFuture(PutRecordsResponse.builder().build()));
+    when(client.putRecords(anyRequest())).thenReturn(completedFuture(successResponse));
 
     // overhead protocol + key overhead + 500 records, each 4 bytes data + overhead
     final int expectedBytes = 20 + 3 + 500 * 10;
@@ -293,21 +388,19 @@ public class KinesisIOWriteTest extends PutRecordsHelpers {
 
   @Test
   public void testWriteAggregatedWithMaxBufferTime() throws Throwable {
-    when(client.putRecords(any(PutRecordsRequest.class)))
-        .thenReturn(completedFuture(PutRecordsResponse.builder().build()));
+    when(client.putRecords(anyRequest())).thenReturn(completedFuture(successResponse));
 
     Write<TestRow> write =
         kinesisWrite()
             .withPartitioner(r -> r.id().toString())
-            .withRecordAggregation(
-                b -> b.maxBufferedTime(Duration.millis(100)).maxBufferedTimeJitter(0.2));
+            .withRecordAggregation(b -> b.maxBufferedTime(millis(100)).maxBufferedTimeJitter(0.2));
 
+    DateTimeUtils.setCurrentMillisFixed(0);
     AggregatedWriter<TestRow> writer =
         new AggregatedWriter<>(pipeline.getOptions(), write, write.recordAggregation());
 
     writer.startBundle();
 
-    DateTimeUtils.setCurrentMillisFixed(0);
     for (int i = 1; i <= 3; i++) {
       writer.write(TestRow.fromSeed(i));
     }
@@ -328,15 +421,82 @@ public class KinesisIOWriteTest extends PutRecordsHelpers {
     writer.close();
 
     InOrder ordered = inOrder(client);
-    ordered
-        .verify(client)
-        .putRecords(and(argThat(hasSize(3)), argThat(hasPartitions("1", "2", "3"))));
-    ordered.verify(client).putRecords(and(argThat(hasSize(2)), argThat(hasPartitions("4", "5"))));
-    ordered.verify(client).putRecords(and(argThat(hasSize(1)), argThat(hasPartitions("6"))));
+    ordered.verify(client).putRecords(argThat(hasPartitions("1", "2", "3")));
+    ordered.verify(client).putRecords(argThat(hasPartitions("4", "5")));
+    ordered.verify(client).putRecords(argThat(hasPartitions("6")));
+    ordered.verify(client).close();
+    verifyNoMoreInteractions(client);
+  }
+
+  @Test
+  public void testWriteAggregatedWithShardsRefresh() throws Throwable {
+    when(client.putRecords(anyRequest())).thenReturn(completedFuture(successResponse));
+
+    Write<TestRow> write =
+        kinesisWrite()
+            .withPartitioner(r -> r.id().toString())
+            .withRecordAggregation(b -> b.shardRefreshInterval(millis(1000)));
+
+    DateTimeUtils.setCurrentMillisFixed(1);
+    AggregatedWriter<TestRow> writer =
+        new AggregatedWriter<>(pipeline.getOptions(), write, write.recordAggregation());
+
+    // initially, no shards known
+    for (int i = 1; i <= 3; i++) {
+      writer.write(TestRow.fromSeed(i));
+    }
+
+    // forward clock, trigger timeouts and refresh shards
+    DateTimeUtils.setCurrentMillisFixed(1500);
+    mockShardRanges(MIN_HASH_KEY);
+
+    for (int i = 1; i <= 10; i++) {
+      writer.write(TestRow.fromSeed(i)); // all aggregated into one record
+    }
+
+    writer.finishBundle();
+    writer.close();
+
+    InOrder ordered = inOrder(client);
+    ordered.verify(client).putRecords(argThat(hasPartitions("1", "2", "3")));
+    ordered.verify(client).putRecords(argThat(hasExplicitPartitions(MIN_HASH_KEY.toString())));
     ordered.verify(client).close();
+    verify(client, times(2)).listShards(any(ListShardsRequest.class));
     verifyNoMoreInteractions(client);
   }
 
+  @Test
+  public void testShardRangesRefresh() {
+    BigInteger shard1 = MIN_HASH_KEY;
+    BigInteger shard2 = MAX_HASH_KEY.shiftRight(2);
+    BigInteger shard3 = MAX_HASH_KEY.shiftRight(1);
+
+    when(client.listShards(argThat(isRequest(STREAM, null))))
+        .thenReturn(completedFuture(listShardsResponse("a", shard(shard1))));
+    when(client.listShards(argThat(isRequest(null, "a"))))
+        .thenReturn(completedFuture(listShardsResponse("b", shard(shard2))));
+    when(client.listShards(argThat(isRequest(null, "b"))))
+        .thenReturn(completedFuture(listShardsResponse(null, shard(shard3))));
+
+    PartitionKeyHasher pkHasher = new PartitionKeyHasher();
+    ShardRanges shardRanges = ShardRanges.of(STREAM);
+    shardRanges.refreshPeriodically(client, Instant::now);
+
+    verify(client, times(3)).listShards(any(ListShardsRequest.class));
+
+    BigInteger hashKeyA = pkHasher.hashKey("a");
+    assertThat(shardRanges.shardAwareHashKey(hashKeyA)).isEqualTo(shard1);
+    assertThat(hashKeyA).isBetween(shard1, shard2.subtract(ONE));
+
+    BigInteger hashKeyB = pkHasher.hashKey("b");
+    assertThat(shardRanges.shardAwareHashKey(hashKeyB)).isEqualTo(shard3);
+    assertThat(hashKeyB).isBetween(shard3, MAX_HASH_KEY);
+
+    BigInteger hashKeyC = pkHasher.hashKey("c");
+    assertThat(shardRanges.shardAwareHashKey(hashKeyC)).isEqualTo(shard2);
+    assertThat(hashKeyC).isBetween(shard2, shard3.subtract(ONE));
+  }
+
   @Test
   public void validateMissingStreamName() {
     assertThrown(identity())
@@ -407,6 +567,30 @@ public class KinesisIOWriteTest extends PutRecordsHelpers {
         .hasMessage("maxBytes must be positive and <= " + MAX_BYTES_PER_RECORD);
   }
 
+  private Shard shard(BigInteger lowerRange) {
+    return Shard.builder()
+        .hashKeyRange(HashKeyRange.builder().startingHashKey(lowerRange.toString()).build())
+        .build();
+  }
+
+  private ListShardsResponse listShardsResponse(String nextToken, Shard... shards) {
+    return ListShardsResponse.builder().shards(shards).nextToken(nextToken).build();
+  }
+
+  protected ArgumentMatcher<ListShardsRequest> isRequest(String stream, String nextToken) {
+    return req ->
+        req != null
+            && Objects.equal(stream, req.streamName())
+            && Objects.equal(nextToken, req.nextToken());
+  }
+
+  private void mockShardRanges(BigInteger... lowerBounds) {
+    List<Shard> shards = stream(lowerBounds).map(lower -> shard(lower)).collect(toList());
+
+    when(client.listShards(any(ListShardsRequest.class)))
+        .thenReturn(completedFuture(ListShardsResponse.builder().shards(shards).build()));
+  }
+
   private ThrowableAssert assertThrown(Function<Write<TestRow>, Write<TestRow>> writeConfig) {
     pipeline.enableAbandonedNodeEnforcement(false);
     PCollection<TestRow> input = mock(PCollection.class);
@@ -422,8 +606,7 @@ public class KinesisIOWriteTest extends PutRecordsHelpers {
 
   private Supplier<List<List<TestRow>>> captureBatchRecords(KinesisAsyncClient mock) {
     ArgumentCaptor<PutRecordsRequest> cap = ArgumentCaptor.forClass(PutRecordsRequest.class);
-    when(mock.putRecords(cap.capture()))
-        .thenReturn(completedFuture(PutRecordsResponse.builder().build()));
+    when(mock.putRecords(cap.capture())).thenReturn(completedFuture(successResponse));
     return () -> transform(cap.getAllValues(), req -> transform(req.records(), this::toTestRow));
   }
 
diff --git a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/kinesis/PutRecordsHelpers.java b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/kinesis/PutRecordsHelpers.java
index 0a4976caa16..ae178118d51 100644
--- a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/kinesis/PutRecordsHelpers.java
+++ b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/kinesis/PutRecordsHelpers.java
@@ -40,6 +40,8 @@ import software.amazon.awssdk.services.kinesis.model.PutRecordsResultEntry;
 public abstract class PutRecordsHelpers {
   protected static final String ERROR_CODE = "ProvisionedThroughputExceededException";
 
+  PutRecordsResponse successResponse = PutRecordsResponse.builder().build();
+
   protected PutRecordsRequest anyRequest() {
     return any();
   }
@@ -62,6 +64,12 @@ public abstract class PutRecordsHelpers {
             && transform(req.records(), r -> r.partitionKey()).containsAll(asList(partitions));
   }
 
+  protected ArgumentMatcher<PutRecordsRequest> hasExplicitPartitions(String... partitions) {
+    return req ->
+        hasSize(partitions.length).matches(req)
+            && transform(req.records(), r -> r.explicitHashKey()).containsAll(asList(partitions));
+  }
+
   protected PutRecordsResponse partialSuccessResponse(int successes, int errors) {
     PutRecordsResultEntry e = PutRecordsResultEntry.builder().errorCode(ERROR_CODE).build();
     PutRecordsResultEntry s = PutRecordsResultEntry.builder().build();
diff --git a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/kinesis/testing/KinesisIOIT.java b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/kinesis/testing/KinesisIOIT.java
index e049c437663..c100ddca7f7 100644
--- a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/kinesis/testing/KinesisIOIT.java
+++ b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/kinesis/testing/KinesisIOIT.java
@@ -18,7 +18,6 @@
 package org.apache.beam.sdk.io.aws2.kinesis.testing;
 
 import static java.nio.charset.StandardCharsets.UTF_8;
-import static org.apache.beam.sdk.io.aws2.kinesis.KinesisPartitioner.explicitRandomPartitioner;
 import static org.testcontainers.containers.localstack.LocalStackContainer.Service.KINESIS;
 
 import java.io.Serializable;
@@ -74,7 +73,7 @@ public class KinesisIOIT implements Serializable {
     void setKinesisStream(String value);
 
     @Description("Number of shards of stream")
-    @Default.Integer(2)
+    @Default.Integer(8)
     Integer getKinesisShards();
 
     void setKinesisShards(Integer count);
@@ -120,7 +119,7 @@ public class KinesisIOIT implements Serializable {
     KinesisIO.Write<TestRow> write =
         KinesisIO.<TestRow>write()
             .withStreamName(env.options().getKinesisStream())
-            .withPartitioner(explicitRandomPartitioner(env.options().getKinesisShards()))
+            .withPartitioner(row -> row.name())
             .withSerializer(testRowToBytes);
     if (!options.getUseRecordAggregation()) {
       write = write.withRecordAggregationDisabled();