You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by mm...@apache.org on 2023/08/07 14:49:34 UTC

[beam] branch master updated: [AWS] Adjust interface of SqsIO.writeBatches to make entryMapper optional. (#27800)

This is an automated email from the ASF dual-hosted git repository.

mmack pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 4bae68c08c1 [AWS] Adjust interface of SqsIO.writeBatches to make entryMapper optional. (#27800)
4bae68c08c1 is described below

commit 4bae68c08c113f51f68307aa9e52ba83fcaef85a
Author: Moritz Mack <mm...@talend.com>
AuthorDate: Mon Aug 7 16:49:25 2023 +0200

    [AWS] Adjust interface of SqsIO.writeBatches to make entryMapper optional. (#27800)
---
 .../org/apache/beam/sdk/io/aws2/sqs/SqsIO.java     | 187 ++++++++++++++++-----
 .../sdk/io/aws2/sqs/SqsIOWriteBatchesTest.java     | 141 ++++++++++++++--
 .../beam/sdk/io/aws2/sqs/SqsIOWriteTest.java       |  68 --------
 .../beam/sdk/io/aws2/sqs/testing/SqsIOIT.java      |   4 +-
 4 files changed, 270 insertions(+), 130 deletions(-)

diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/SqsIO.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/SqsIO.java
index ac31738154a..7ad84c20330 100644
--- a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/SqsIO.java
+++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/SqsIO.java
@@ -26,6 +26,7 @@ import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Prec
 import com.google.auto.value.AutoValue;
 import java.io.Serializable;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
@@ -41,27 +42,35 @@ import org.apache.beam.sdk.io.aws2.common.ClientBuilderFactory;
 import org.apache.beam.sdk.io.aws2.common.ClientConfiguration;
 import org.apache.beam.sdk.io.aws2.options.AwsOptions;
 import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.schemas.NoSuchSchemaException;
+import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.schemas.Schema.Field;
+import org.apache.beam.sdk.schemas.SchemaRegistry;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.SerializableFunction;
 import org.apache.beam.sdk.values.PBegin;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PDone;
 import org.apache.beam.sdk.values.PInput;
 import org.apache.beam.sdk.values.POutput;
 import org.apache.beam.sdk.values.PValue;
+import org.apache.beam.sdk.values.Row;
 import org.apache.beam.sdk.values.TupleTag;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
-import org.checkerframework.checker.nullness.qual.MonotonicNonNull;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
 import org.checkerframework.checker.nullness.qual.NonNull;
 import org.checkerframework.checker.nullness.qual.Nullable;
 import org.checkerframework.dataflow.qual.Pure;
 import org.joda.time.Duration;
 import org.joda.time.Instant;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
 import software.amazon.awssdk.regions.Region;
 import software.amazon.awssdk.services.sqs.SqsAsyncClient;
-import software.amazon.awssdk.services.sqs.SqsClient;
 import software.amazon.awssdk.services.sqs.model.BatchResultErrorEntry;
 import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequest;
 import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequestEntry;
@@ -129,19 +138,20 @@ public class SqsIO {
         .build();
   }
 
+  /** @deprecated Use {@link #writeBatches()} for more configuration options. */
+  @Deprecated
   public static Write write() {
     return new AutoValue_SqsIO_Write.Builder()
         .setClientConfiguration(ClientConfiguration.EMPTY)
         .build();
   }
 
-  public static <T> WriteBatches<T> writeBatches(WriteBatches.EntryBuilder<T> entryBuilder) {
+  public static <T> WriteBatches<T> writeBatches() {
     return new AutoValue_SqsIO_WriteBatches.Builder<T>()
         .clientConfiguration(ClientConfiguration.EMPTY)
         .concurrentRequests(WriteBatches.DEFAULT_CONCURRENCY)
         .batchSize(WriteBatches.MAX_BATCH_SIZE)
         .batchTimeout(WriteBatches.DEFAULT_BATCH_TIMEOUT)
-        .entryBuilder(entryBuilder)
         .build();
   }
 
@@ -225,11 +235,15 @@ public class SqsIO {
       return input.getPipeline().apply(transform);
     }
   }
+
   /**
    * A {@link PTransform} to send messages to SQS. See {@link SqsIO} for more information on usage
    * and configuration.
+   *
+   * @deprecated superseded by {@link WriteBatches}
    */
   @AutoValue
+  @Deprecated
   public abstract static class Write extends PTransform<PCollection<SendMessageRequest>, PDone> {
 
     abstract @Pure ClientConfiguration getClientConfiguration();
@@ -251,39 +265,14 @@ public class SqsIO {
 
     @Override
     public PDone expand(PCollection<SendMessageRequest> input) {
-      AwsOptions awsOptions = input.getPipeline().getOptions().as(AwsOptions.class);
-      ClientBuilderFactory.validate(awsOptions, getClientConfiguration());
-
-      input.apply(ParDo.of(new SqsWriteFn(this)));
+      input.apply(
+          SqsIO.<SendMessageRequest>writeBatches()
+              .withBatchSize(1)
+              .to(SendMessageRequest::queueUrl));
       return PDone.in(input.getPipeline());
     }
   }
 
-  private static class SqsWriteFn extends DoFn<SendMessageRequest, Void> {
-    private final Write spec;
-    private transient @MonotonicNonNull SqsClient sqs = null;
-
-    SqsWriteFn(Write write) {
-      this.spec = write;
-    }
-
-    @Setup
-    public void setup(PipelineOptions options) throws Exception {
-      AwsOptions awsOpts = options.as(AwsOptions.class);
-      sqs =
-          ClientBuilderFactory.buildClient(
-              awsOpts, SqsClient.builder(), spec.getClientConfiguration());
-    }
-
-    @ProcessElement
-    public void processElement(ProcessContext processContext) throws Exception {
-      if (sqs == null) {
-        throw new IllegalStateException("No SQS client");
-      }
-      sqs.sendMessage(processContext.element());
-    }
-  }
-
   /**
    * A {@link PTransform} to send messages to SQS. See {@link SqsIO} for more information on usage
    * and configuration.
@@ -291,6 +280,7 @@ public class SqsIO {
   @AutoValue
   public abstract static class WriteBatches<T>
       extends PTransform<PCollection<T>, WriteBatches.Result> {
+    private static final Logger LOG = LoggerFactory.getLogger(WriteBatches.class);
     private static final int DEFAULT_CONCURRENCY = 5;
     private static final int MAX_BATCH_SIZE = 10;
     private static final Duration DEFAULT_BATCH_TIMEOUT = Duration.standardSeconds(3);
@@ -303,7 +293,7 @@ public class SqsIO {
 
     abstract @Pure ClientConfiguration clientConfiguration();
 
-    abstract @Pure EntryBuilder<T> entryBuilder();
+    abstract @Pure @Nullable EntryMapperFn<T> entryMapper();
 
     abstract @Pure @Nullable DynamicDestination<T> dynamicDestination();
 
@@ -325,7 +315,7 @@ public class SqsIO {
 
       abstract Builder<T> clientConfiguration(ClientConfiguration config);
 
-      abstract Builder<T> entryBuilder(EntryBuilder<T> entryBuilder);
+      abstract Builder<T> entryMapper(@Nullable EntryMapperFn<T> entryMapper);
 
       abstract Builder<T> dynamicDestination(@Nullable DynamicDestination<T> destination);
 
@@ -346,6 +336,22 @@ public class SqsIO {
       return builder().concurrentRequests(concurrentRequests).build();
     }
 
+    /**
+     * Optional mapper to create a batch entry from a unique entry id and the input {@code T},
+     * otherwise inferred from the schema.
+     */
+    public WriteBatches<T> withEntryMapper(EntryMapperFn<T> mapper) {
+      return builder().entryMapper(mapper).build();
+    }
+
+    /**
+     * Optional mapper to create a batch entry from the input {@code T} using a builder, otherwise
+     * inferred from the schema.
+     */
+    public WriteBatches<T> withEntryMapper(EntryMapperFn.Builder<T> mapper) {
+      return builder().entryMapper(mapper).build();
+    }
+
     /** The batch size to use, default (and AWS limit) is {@code 10}. */
     public WriteBatches<T> withBatchSize(int batchSize) {
       checkArgument(
@@ -375,11 +381,25 @@ public class SqsIO {
       return builder().dynamicDestination(null).queueUrl(queueUrl).build();
     }
 
+    private EntryMapperFn<T> schemaEntryMapper(PCollection<T> input) {
+      checkState(input.hasSchema(), "withEntryMapper is required if schema is not available");
+      SchemaRegistry registry = input.getPipeline().getSchemaRegistry();
+      try {
+        return new SchemaEntryMapper<>(
+            input.getSchema(),
+            registry.getSchema(SendMessageBatchRequestEntry.class),
+            input.getToRowFunction(),
+            registry.getFromRowFunction(SendMessageBatchRequestEntry.class));
+      } catch (NoSuchSchemaException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
     @Override
     public Result expand(PCollection<T> input) {
       AwsOptions awsOptions = input.getPipeline().getOptions().as(AwsOptions.class);
       ClientBuilderFactory.validate(awsOptions, clientConfiguration());
-
+      EntryMapperFn<T> mapper = entryMapper() != null ? entryMapper() : schemaEntryMapper(input);
       input.apply(
           ParDo.of(
               new DoFn<T, Void>() {
@@ -387,7 +407,8 @@ public class SqsIO {
 
                 @Setup
                 public void setup(PipelineOptions options) {
-                  handler = new BatchHandler<>(WriteBatches.this, options.as(AwsOptions.class));
+                  handler =
+                      new BatchHandler<>(WriteBatches.this, mapper, options.as(AwsOptions.class));
                 }
 
                 @StartBundle
@@ -420,9 +441,86 @@ public class SqsIO {
       return new Result(input.getPipeline());
     }
 
-    /** Batch entry builder. */
-    public interface EntryBuilder<T>
-        extends BiConsumer<SendMessageBatchRequestEntry.Builder, T>, Serializable {}
+    /**
+     * Mapper to create a {@link SendMessageBatchRequestEntry} from a unique batch entry id and the
+     * input {@code T}.
+     */
+    public interface EntryMapperFn<T>
+        extends BiFunction<String, T, SendMessageBatchRequestEntry>, Serializable {
+
+      /** A more convenient {@link EntryMapperFn} variant that already sets the entry id. */
+      interface Builder<T>
+          extends BiConsumer<SendMessageBatchRequestEntry.Builder, T>, EntryMapperFn<T> {
+        @Override
+        default SendMessageBatchRequestEntry apply(String entryId, T msg) {
+          SendMessageBatchRequestEntry.Builder builder = SendMessageBatchRequestEntry.builder();
+          accept(builder, msg);
+          return builder.id(entryId).build();
+        }
+      }
+    }
+
+    @VisibleForTesting
+    static class SchemaEntryMapper<T> implements EntryMapperFn<T> {
+      private final SerializableFunction<T, Row> toRow;
+      private final SerializableFunction<Row, SendMessageBatchRequestEntry> fromRow;
+      private final Schema schema;
+      private final int[] fieldMapping;
+
+      SchemaEntryMapper(
+          Schema sourceSchema,
+          Schema targetSchema,
+          SerializableFunction<T, Row> toRow,
+          SerializableFunction<Row, SendMessageBatchRequestEntry> fromRow) {
+        this.toRow = toRow;
+        this.fromRow = fromRow;
+        this.schema = targetSchema;
+        this.fieldMapping = new int[targetSchema.getFieldCount()];
+
+        Arrays.fill(fieldMapping, -1);
+
+        List<String> ignored = Lists.newLinkedList();
+        List<String> invalid = Lists.newLinkedList();
+
+        for (int i = 0; i < sourceSchema.getFieldCount(); i++) {
+          Field sourceField = sourceSchema.getField(i);
+          if (targetSchema.hasField(sourceField.getName())) {
+            int targetIdx = targetSchema.indexOf(sourceField.getName());
+            // make sure field types match
+            if (!sourceField.typesEqual(targetSchema.getField(targetIdx))) {
+              invalid.add(sourceField.getName());
+            }
+            fieldMapping[targetIdx] = i;
+          } else {
+            ignored.add(sourceField.getName());
+          }
+        }
+        checkState(
+            ignored.size() < sourceSchema.getFieldCount(),
+            "No fields matched, expected %s but got %s",
+            schema.getFieldNames(),
+            ignored);
+
+        checkState(invalid.isEmpty(), "Detected incompatible types for input fields: {}", invalid);
+
+        if (!ignored.isEmpty()) {
+          LOG.warn("Ignoring unmatched input fields: {}", ignored);
+        }
+      }
+
+      @Override
+      public SendMessageBatchRequestEntry apply(String entryId, T input) {
+        Row row = toRow.apply(input);
+        Object[] values = new Object[fieldMapping.length];
+        values[0] = entryId;
+        for (int i = 0; i < values.length; i++) {
+          if (fieldMapping[i] >= 0) {
+            values[i] = row.getValue(fieldMapping[i]);
+          }
+        }
+        return fromRow.apply(Row.withSchema(schema).attachValues(values));
+      }
+    }
 
     /** Result of {@link #writeBatches}. */
     public static class Result implements POutput {
@@ -451,12 +549,14 @@ public class SqsIO {
       private final WriteBatches<T> spec;
       private final SqsAsyncClient sqs;
       private final Batches batches;
+      private final EntryMapperFn<T> entryMapper;
       private final AsyncBatchWriteHandler<SendMessageBatchRequestEntry, BatchResultErrorEntry>
           handler;
 
-      BatchHandler(WriteBatches<T> spec, AwsOptions options) {
+      BatchHandler(WriteBatches<T> spec, EntryMapperFn<T> entryMapper, AwsOptions options) {
         this.spec = spec;
         this.sqs = buildClient(options, SqsAsyncClient.builder(), spec.clientConfiguration());
+        this.entryMapper = entryMapper;
         this.handler =
             AsyncBatchWriteHandler.byId(
                 spec.concurrentRequests(),
@@ -488,10 +588,7 @@ public class SqsIO {
       }
 
       public void process(T msg) {
-        SendMessageBatchRequestEntry.Builder builder = SendMessageBatchRequestEntry.builder();
-        spec.entryBuilder().accept(builder, msg);
-        SendMessageBatchRequestEntry entry = builder.id(batches.nextId()).build();
-
+        SendMessageBatchRequestEntry entry = entryMapper.apply(batches.nextId(), msg);
         Batch batch = batches.getLocked(msg);
         batch.add(entry);
         if (batch.size() >= spec.batchSize() || batch.isExpired()) {
diff --git a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sqs/SqsIOWriteBatchesTest.java b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sqs/SqsIOWriteBatchesTest.java
index aeb9122df9f..dff4b4e72c8 100644
--- a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sqs/SqsIOWriteBatchesTest.java
+++ b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sqs/SqsIOWriteBatchesTest.java
@@ -17,19 +17,26 @@
  */
 package org.apache.beam.sdk.io.aws2.sqs;
 
+import static java.nio.charset.StandardCharsets.UTF_8;
 import static java.util.concurrent.CompletableFuture.completedFuture;
 import static java.util.concurrent.CompletableFuture.supplyAsync;
 import static java.util.stream.Collectors.toList;
 import static java.util.stream.IntStream.range;
 import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkNotNull;
+import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
 import static org.joda.time.Duration.millis;
+import static org.junit.Assert.assertTrue;
 import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.verifyNoMoreInteractions;
 import static org.mockito.Mockito.when;
 
 import java.util.Arrays;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
 import java.util.stream.IntStream;
 import java.util.stream.Stream;
 import org.apache.beam.sdk.Pipeline;
@@ -38,31 +45,38 @@ import org.apache.beam.sdk.io.aws2.common.AsyncBatchWriteHandler;
 import org.apache.beam.sdk.io.aws2.common.ClientConfiguration;
 import org.apache.beam.sdk.io.aws2.common.RetryConfiguration;
 import org.apache.beam.sdk.io.aws2.sqs.SqsIO.WriteBatches;
-import org.apache.beam.sdk.io.aws2.sqs.SqsIO.WriteBatches.EntryBuilder;
+import org.apache.beam.sdk.io.aws2.sqs.SqsIO.WriteBatches.EntryMapperFn;
+import org.apache.beam.sdk.schemas.SchemaRegistry;
 import org.apache.beam.sdk.testing.ExpectedLogs;
 import org.apache.beam.sdk.testing.TestPipeline;
 import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Streams;
 import org.joda.time.Duration;
 import org.junit.Before;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.runner.RunWith;
+import org.mockito.ArgumentCaptor;
 import org.mockito.Mock;
 import org.mockito.junit.MockitoJUnitRunner;
+import software.amazon.awssdk.core.SdkBytes;
 import software.amazon.awssdk.services.sqs.SqsAsyncClient;
 import software.amazon.awssdk.services.sqs.SqsAsyncClientBuilder;
 import software.amazon.awssdk.services.sqs.model.BatchResultErrorEntry;
+import software.amazon.awssdk.services.sqs.model.MessageAttributeValue;
+import software.amazon.awssdk.services.sqs.model.MessageSystemAttributeValue;
 import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequest;
 import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequestEntry;
 import software.amazon.awssdk.services.sqs.model.SendMessageBatchResponse;
+import software.amazon.awssdk.services.sqs.model.SendMessageRequest;
 
 /** Tests for {@link WriteBatches}. */
 @RunWith(MockitoJUnitRunner.class)
 public class SqsIOWriteBatchesTest {
-  private static final EntryBuilder<String> SET_MESSAGE_BODY =
+  private static final EntryMapperFn.Builder<String> SET_MESSAGE_BODY =
       SendMessageBatchRequestEntry.Builder::messageBody;
   private static final SendMessageBatchResponse SUCCESS =
       SendMessageBatchResponse.builder().build();
@@ -76,13 +90,80 @@ public class SqsIOWriteBatchesTest {
     MockClientBuilderFactory.set(p, SqsAsyncClientBuilder.class, sqs);
   }
 
+  @Test
+  public void testSchemaEntryMapper() throws Exception {
+    SchemaRegistry registry = p.getSchemaRegistry();
+
+    Map<String, MessageAttributeValue> attributes =
+        ImmutableMap.of("key", MessageAttributeValue.builder().stringValue("value").build());
+    Map<String, MessageSystemAttributeValue> systemAttributes =
+        ImmutableMap.of(
+            "key",
+            MessageSystemAttributeValue.builder()
+                .binaryValue(SdkBytes.fromString("bytes", UTF_8))
+                .build());
+
+    SendMessageRequest input =
+        SendMessageRequest.builder()
+            .messageBody("body")
+            .delaySeconds(3)
+            .messageAttributes(attributes)
+            .messageSystemAttributesWithStrings(systemAttributes)
+            .build();
+
+    SqsIO.WriteBatches.EntryMapperFn<SendMessageRequest> mapper =
+        new SqsIO.WriteBatches.SchemaEntryMapper<>(
+            registry.getSchema(SendMessageRequest.class),
+            registry.getSchema(SendMessageBatchRequestEntry.class),
+            registry.getToRowFunction(SendMessageRequest.class),
+            registry.getFromRowFunction(SendMessageBatchRequestEntry.class));
+
+    assertThat(mapper.apply("1", input))
+        .isEqualTo(
+            SendMessageBatchRequestEntry.builder()
+                .id("1")
+                .messageBody("body")
+                .delaySeconds(3)
+                .messageAttributes(attributes)
+                .messageSystemAttributesWithStrings(systemAttributes)
+                .build());
+  }
+
+  @Test
+  public void testWrite() {
+    // write uses writeBatches with batch size 1
+    when(sqs.sendMessageBatch(anyRequest())).thenReturn(completedFuture(SUCCESS));
+
+    SendMessageRequest.Builder msgBuilder = SendMessageRequest.builder().queueUrl("queue");
+    Set<SendMessageRequest> messages =
+        range(0, 100)
+            .mapToObj(i -> msgBuilder.messageBody("test" + i).build())
+            .collect(Collectors.toSet());
+
+    p.apply(Create.of(messages)).apply(SqsIO.write());
+    p.run().waitUntilFinish();
+
+    ArgumentCaptor<SendMessageBatchRequest> captor =
+        ArgumentCaptor.forClass(SendMessageBatchRequest.class);
+    verify(sqs, times(100)).sendMessageBatch(captor.capture());
+
+    for (SendMessageBatchRequest req : captor.getAllValues()) {
+      assertThat(req.queueUrl()).isEqualTo("queue");
+      assertThat(req.entries()).hasSize(1);
+      for (SendMessageBatchRequestEntry entry : req.entries()) {
+        assertTrue(messages.remove(msgBuilder.messageBody(entry.messageBody()).build()));
+      }
+    }
+    assertTrue(messages.isEmpty());
+  }
+
   @Test
   public void testWriteBatches() {
     when(sqs.sendMessageBatch(anyRequest())).thenReturn(completedFuture(SUCCESS));
 
     p.apply(Create.of(23))
         .apply(ParDo.of(new CreateMessages()))
-        .apply(SqsIO.writeBatches(SET_MESSAGE_BODY).to("queue"));
+        .apply(SqsIO.<String>writeBatches().withEntryMapper(SET_MESSAGE_BODY).to("queue"));
 
     p.run().waitUntilFinish();
 
@@ -104,7 +185,7 @@ public class SqsIOWriteBatchesTest {
 
     p.apply(Create.of(23))
         .apply(ParDo.of(new CreateMessages()))
-        .apply(SqsIO.writeBatches(SET_MESSAGE_BODY).to("queue"));
+        .apply(SqsIO.<String>writeBatches().withEntryMapper(SET_MESSAGE_BODY).to("queue"));
 
     assertThatThrownBy(() -> p.run().waitUntilFinish())
         .isInstanceOf(Pipeline.PipelineExecutionException.class)
@@ -122,7 +203,7 @@ public class SqsIOWriteBatchesTest {
 
     p.apply(Create.of(23))
         .apply(ParDo.of(new CreateMessages()))
-        .apply(SqsIO.writeBatches(SET_MESSAGE_BODY).to("queue"));
+        .apply(SqsIO.<String>writeBatches().withEntryMapper(SET_MESSAGE_BODY).to("queue"));
 
     p.run().waitUntilFinish();
 
@@ -145,7 +226,11 @@ public class SqsIOWriteBatchesTest {
 
     p.apply(Create.of(8))
         .apply(ParDo.of(new CreateMessages()))
-        .apply(SqsIO.writeBatches(SET_MESSAGE_BODY).withBatchSize(3).to("queue"));
+        .apply(
+            SqsIO.<String>writeBatches()
+                .withEntryMapper(SET_MESSAGE_BODY)
+                .withBatchSize(3)
+                .to("queue"));
 
     p.run().waitUntilFinish();
 
@@ -157,6 +242,27 @@ public class SqsIOWriteBatchesTest {
     verifyNoMoreInteractions(sqs);
   }
 
+  @Test
+  public void testWriteBatchesWithTimeout() {
+    when(sqs.sendMessageBatch(anyRequest())).thenReturn(completedFuture(SUCCESS));
+
+    p.apply(Create.of(5))
+        .apply(ParDo.of(new CreateMessages()))
+        .apply(
+            // simulate delay between messages > batch timeout
+            SqsIO.<String>writeBatches()
+                .withEntryMapper(withDelay(millis(100), SET_MESSAGE_BODY))
+                .withBatchTimeout(millis(150))
+                .to("queue"));
+
+    p.run().waitUntilFinish();
+
+    SendMessageBatchRequestEntry[] entries = entries(range(0, 5));
+    // due to added delay, batches are timed out on arrival of every 3rd msg
+    verify(sqs).sendMessageBatch(request("queue", entries[0], entries[1], entries[2]));
+    verify(sqs).sendMessageBatch(request("queue", entries[3], entries[4]));
+  }
+
   @Test
   public void testWriteBatchesToDynamic() {
     when(sqs.sendMessageBatch(anyRequest())).thenReturn(completedFuture(SUCCESS));
@@ -167,7 +273,8 @@ public class SqsIOWriteBatchesTest {
     p.apply(Create.of(10))
         .apply(ParDo.of(new CreateMessages()))
         .apply(
-            SqsIO.writeBatches(SET_MESSAGE_BODY)
+            SqsIO.<String>writeBatches()
+                .withEntryMapper(SET_MESSAGE_BODY)
                 .withClientConfiguration(ClientConfiguration.builder().retry(retry).build())
                 .withBatchSize(3)
                 .to(msg -> Integer.valueOf(msg) % 2 == 0 ? "even" : "uneven"));
@@ -187,24 +294,25 @@ public class SqsIOWriteBatchesTest {
   }
 
   @Test
-  public void testWriteBatchesWithTimeout() {
+  public void testWriteBatchesToDynamicWithTimeout() {
     when(sqs.sendMessageBatch(anyRequest())).thenReturn(completedFuture(SUCCESS));
 
     p.apply(Create.of(5))
         .apply(ParDo.of(new CreateMessages()))
         .apply(
             // simulate delay between messages > batch timeout
-            SqsIO.writeBatches(withDelay(millis(200), SET_MESSAGE_BODY))
-                .withBatchTimeout(millis(100))
-                .to("queue"));
+            SqsIO.<String>writeBatches()
+                .withEntryMapper(withDelay(millis(100), SET_MESSAGE_BODY))
+                .withBatchTimeout(millis(150))
+                .to(msg -> Integer.valueOf(msg) % 2 == 0 ? "even" : "uneven"));
 
     p.run().waitUntilFinish();
 
     SendMessageBatchRequestEntry[] entries = entries(range(0, 5));
-    // due to added delay, batches are timed out on arrival of every 2nd msg
-    verify(sqs).sendMessageBatch(request("queue", entries[0], entries[1]));
-    verify(sqs).sendMessageBatch(request("queue", entries[2], entries[3]));
-    verify(sqs).sendMessageBatch(request("queue", entries[4]));
+    // due to added delay, dynamic batches are timed out on arrival of every 2nd msg (per batch)
+    verify(sqs).sendMessageBatch(request("even", entries[0], entries[2]));
+    verify(sqs).sendMessageBatch(request("uneven", entries[1], entries[3]));
+    verify(sqs).sendMessageBatch(request("even", entries[4]));
   }
 
   private SendMessageBatchRequest anyRequest() {
@@ -252,7 +360,8 @@ public class SqsIOWriteBatchesTest {
     }
   }
 
-  private static <T> EntryBuilder<T> withDelay(Duration delay, EntryBuilder<T> builder) {
+  private static <T> EntryMapperFn.Builder<T> withDelay(
+      Duration delay, EntryMapperFn.Builder<T> builder) {
     return (t1, t2) -> {
       builder.accept(t1, t2);
       try {
diff --git a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sqs/SqsIOWriteTest.java b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sqs/SqsIOWriteTest.java
deleted file mode 100644
index 738cf282adf..00000000000
--- a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sqs/SqsIOWriteTest.java
+++ /dev/null
@@ -1,68 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.sdk.io.aws2.sqs;
-
-import static java.util.stream.IntStream.range;
-import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.when;
-
-import java.util.List;
-import java.util.stream.Collectors;
-import org.apache.beam.sdk.io.aws2.MockClientBuilderFactory;
-import org.apache.beam.sdk.testing.TestPipeline;
-import org.apache.beam.sdk.transforms.Create;
-import org.junit.Before;
-import org.junit.Rule;
-import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.mockito.Mock;
-import org.mockito.junit.MockitoJUnitRunner;
-import software.amazon.awssdk.services.sqs.SqsClient;
-import software.amazon.awssdk.services.sqs.SqsClientBuilder;
-import software.amazon.awssdk.services.sqs.model.SendMessageRequest;
-import software.amazon.awssdk.services.sqs.model.SendMessageResponse;
-
-/** Tests for {@link SqsIO.Write}. */
-@RunWith(MockitoJUnitRunner.class)
-public class SqsIOWriteTest {
-  @Rule public TestPipeline p = TestPipeline.create();
-  @Mock public SqsClient sqs;
-
-  @Before
-  public void configureClientBuilderFactory() {
-    MockClientBuilderFactory.set(p, SqsClientBuilder.class, sqs);
-  }
-
-  @Test
-  public void testWrite() {
-    when(sqs.sendMessage(any(SendMessageRequest.class)))
-        .thenReturn(SendMessageResponse.builder().build());
-
-    SendMessageRequest.Builder builder = SendMessageRequest.builder().queueUrl("url");
-    List<SendMessageRequest> messages =
-        range(0, 100)
-            .mapToObj(i -> builder.messageBody("test" + i).build())
-            .collect(Collectors.toList());
-
-    p.apply(Create.of(messages)).apply(SqsIO.write());
-    p.run().waitUntilFinish();
-
-    messages.forEach(msg -> verify(sqs).sendMessage(msg));
-  }
-}
diff --git a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sqs/testing/SqsIOIT.java b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sqs/testing/SqsIOIT.java
index 2f10f3d08f1..9e78f7a15ef 100644
--- a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sqs/testing/SqsIOIT.java
+++ b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sqs/testing/SqsIOIT.java
@@ -114,7 +114,9 @@ public class SqsIOIT {
         .apply("Prepare TestRows", ParDo.of(new DeterministicallyConstructTestRowFn()))
         .apply(
             "Write to SQS",
-            SqsIO.<TestRow>writeBatches((b, row) -> b.messageBody(row.name())).to(sqsQueue.url));
+            SqsIO.<TestRow>writeBatches()
+                .withEntryMapper((b, row) -> b.messageBody(row.name()))
+                .to(sqsQueue.url));
 
     // Read test dataset from SQS.
     PCollection<String> output =