You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by mm...@apache.org on 2023/08/07 14:49:34 UTC
[beam] branch master updated: [AWS] Adjust interface of SqsIO.writeBatches to make entryMapper optional. (#27800)
This is an automated email from the ASF dual-hosted git repository.
mmack pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 4bae68c08c1 [AWS] Adjust interface of SqsIO.writeBatches to make entryMapper optional. (#27800)
4bae68c08c1 is described below
commit 4bae68c08c113f51f68307aa9e52ba83fcaef85a
Author: Moritz Mack <mm...@talend.com>
AuthorDate: Mon Aug 7 16:49:25 2023 +0200
[AWS] Adjust interface of SqsIO.writeBatches to make entryMapper optional. (#27800)
---
.../org/apache/beam/sdk/io/aws2/sqs/SqsIO.java | 187 ++++++++++++++++-----
.../sdk/io/aws2/sqs/SqsIOWriteBatchesTest.java | 141 ++++++++++++++--
.../beam/sdk/io/aws2/sqs/SqsIOWriteTest.java | 68 --------
.../beam/sdk/io/aws2/sqs/testing/SqsIOIT.java | 4 +-
4 files changed, 270 insertions(+), 130 deletions(-)
diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/SqsIO.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/SqsIO.java
index ac31738154a..7ad84c20330 100644
--- a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/SqsIO.java
+++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/SqsIO.java
@@ -26,6 +26,7 @@ import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Prec
import com.google.auto.value.AutoValue;
import java.io.Serializable;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -41,27 +42,35 @@ import org.apache.beam.sdk.io.aws2.common.ClientBuilderFactory;
import org.apache.beam.sdk.io.aws2.common.ClientConfiguration;
import org.apache.beam.sdk.io.aws2.options.AwsOptions;
import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.schemas.NoSuchSchemaException;
+import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.schemas.Schema.Field;
+import org.apache.beam.sdk.schemas.SchemaRegistry;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.values.PBegin;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PDone;
import org.apache.beam.sdk.values.PInput;
import org.apache.beam.sdk.values.POutput;
import org.apache.beam.sdk.values.PValue;
+import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.values.TupleTag;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
-import org.checkerframework.checker.nullness.qual.MonotonicNonNull;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
import org.checkerframework.checker.nullness.qual.NonNull;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.checkerframework.dataflow.qual.Pure;
import org.joda.time.Duration;
import org.joda.time.Instant;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.sqs.SqsAsyncClient;
-import software.amazon.awssdk.services.sqs.SqsClient;
import software.amazon.awssdk.services.sqs.model.BatchResultErrorEntry;
import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequest;
import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequestEntry;
@@ -129,19 +138,20 @@ public class SqsIO {
.build();
}
+ /** @deprecated Use {@link #writeBatches()} for more configuration options. */
+ @Deprecated
public static Write write() {
return new AutoValue_SqsIO_Write.Builder()
.setClientConfiguration(ClientConfiguration.EMPTY)
.build();
}
- public static <T> WriteBatches<T> writeBatches(WriteBatches.EntryBuilder<T> entryBuilder) {
+ public static <T> WriteBatches<T> writeBatches() {
return new AutoValue_SqsIO_WriteBatches.Builder<T>()
.clientConfiguration(ClientConfiguration.EMPTY)
.concurrentRequests(WriteBatches.DEFAULT_CONCURRENCY)
.batchSize(WriteBatches.MAX_BATCH_SIZE)
.batchTimeout(WriteBatches.DEFAULT_BATCH_TIMEOUT)
- .entryBuilder(entryBuilder)
.build();
}
@@ -225,11 +235,15 @@ public class SqsIO {
return input.getPipeline().apply(transform);
}
}
+
/**
* A {@link PTransform} to send messages to SQS. See {@link SqsIO} for more information on usage
* and configuration.
+ *
+ * @deprecated superseded by {@link WriteBatches}
*/
@AutoValue
+ @Deprecated
public abstract static class Write extends PTransform<PCollection<SendMessageRequest>, PDone> {
abstract @Pure ClientConfiguration getClientConfiguration();
@@ -251,39 +265,14 @@ public class SqsIO {
@Override
public PDone expand(PCollection<SendMessageRequest> input) {
- AwsOptions awsOptions = input.getPipeline().getOptions().as(AwsOptions.class);
- ClientBuilderFactory.validate(awsOptions, getClientConfiguration());
-
- input.apply(ParDo.of(new SqsWriteFn(this)));
+ input.apply(
+ SqsIO.<SendMessageRequest>writeBatches()
+ .withBatchSize(1)
+ .to(SendMessageRequest::queueUrl));
return PDone.in(input.getPipeline());
}
}
- private static class SqsWriteFn extends DoFn<SendMessageRequest, Void> {
- private final Write spec;
- private transient @MonotonicNonNull SqsClient sqs = null;
-
- SqsWriteFn(Write write) {
- this.spec = write;
- }
-
- @Setup
- public void setup(PipelineOptions options) throws Exception {
- AwsOptions awsOpts = options.as(AwsOptions.class);
- sqs =
- ClientBuilderFactory.buildClient(
- awsOpts, SqsClient.builder(), spec.getClientConfiguration());
- }
-
- @ProcessElement
- public void processElement(ProcessContext processContext) throws Exception {
- if (sqs == null) {
- throw new IllegalStateException("No SQS client");
- }
- sqs.sendMessage(processContext.element());
- }
- }
-
/**
* A {@link PTransform} to send messages to SQS. See {@link SqsIO} for more information on usage
* and configuration.
@@ -291,6 +280,7 @@ public class SqsIO {
@AutoValue
public abstract static class WriteBatches<T>
extends PTransform<PCollection<T>, WriteBatches.Result> {
+ private static final Logger LOG = LoggerFactory.getLogger(WriteBatches.class);
private static final int DEFAULT_CONCURRENCY = 5;
private static final int MAX_BATCH_SIZE = 10;
private static final Duration DEFAULT_BATCH_TIMEOUT = Duration.standardSeconds(3);
@@ -303,7 +293,7 @@ public class SqsIO {
abstract @Pure ClientConfiguration clientConfiguration();
- abstract @Pure EntryBuilder<T> entryBuilder();
+ abstract @Pure @Nullable EntryMapperFn<T> entryMapper();
abstract @Pure @Nullable DynamicDestination<T> dynamicDestination();
@@ -325,7 +315,7 @@ public class SqsIO {
abstract Builder<T> clientConfiguration(ClientConfiguration config);
- abstract Builder<T> entryBuilder(EntryBuilder<T> entryBuilder);
+ abstract Builder<T> entryMapper(@Nullable EntryMapperFn<T> entryMapper);
abstract Builder<T> dynamicDestination(@Nullable DynamicDestination<T> destination);
@@ -346,6 +336,22 @@ public class SqsIO {
return builder().concurrentRequests(concurrentRequests).build();
}
+ /**
+ * Optional mapper to create a batch entry from a unique entry id and the input {@code T},
+ * otherwise inferred from the schema.
+ */
+ public WriteBatches<T> withEntryMapper(EntryMapperFn<T> mapper) {
+ return builder().entryMapper(mapper).build();
+ }
+
+ /**
+ * Optional mapper to create a batch entry from the input {@code T} using a builder, otherwise
+ * inferred from the schema.
+ */
+ public WriteBatches<T> withEntryMapper(EntryMapperFn.Builder<T> mapper) {
+ return builder().entryMapper(mapper).build();
+ }
+
/** The batch size to use, default (and AWS limit) is {@code 10}. */
public WriteBatches<T> withBatchSize(int batchSize) {
checkArgument(
@@ -375,11 +381,25 @@ public class SqsIO {
return builder().dynamicDestination(null).queueUrl(queueUrl).build();
}
+ private EntryMapperFn<T> schemaEntryMapper(PCollection<T> input) {
+ checkState(input.hasSchema(), "withEntryMapper is required if schema is not available");
+ SchemaRegistry registry = input.getPipeline().getSchemaRegistry();
+ try {
+ return new SchemaEntryMapper<>(
+ input.getSchema(),
+ registry.getSchema(SendMessageBatchRequestEntry.class),
+ input.getToRowFunction(),
+ registry.getFromRowFunction(SendMessageBatchRequestEntry.class));
+ } catch (NoSuchSchemaException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
@Override
public Result expand(PCollection<T> input) {
AwsOptions awsOptions = input.getPipeline().getOptions().as(AwsOptions.class);
ClientBuilderFactory.validate(awsOptions, clientConfiguration());
-
+ EntryMapperFn<T> mapper = entryMapper() != null ? entryMapper() : schemaEntryMapper(input);
input.apply(
ParDo.of(
new DoFn<T, Void>() {
@@ -387,7 +407,8 @@ public class SqsIO {
@Setup
public void setup(PipelineOptions options) {
- handler = new BatchHandler<>(WriteBatches.this, options.as(AwsOptions.class));
+ handler =
+ new BatchHandler<>(WriteBatches.this, mapper, options.as(AwsOptions.class));
}
@StartBundle
@@ -420,9 +441,86 @@ public class SqsIO {
return new Result(input.getPipeline());
}
- /** Batch entry builder. */
- public interface EntryBuilder<T>
- extends BiConsumer<SendMessageBatchRequestEntry.Builder, T>, Serializable {}
+ /**
+ * Mapper to create a {@link SendMessageBatchRequestEntry} from a unique batch entry id and the
+ * input {@code T}.
+ */
+ public interface EntryMapperFn<T>
+ extends BiFunction<String, T, SendMessageBatchRequestEntry>, Serializable {
+
+ /** A more convenient {@link EntryMapperFn} variant that already sets the entry id. */
+ interface Builder<T>
+ extends BiConsumer<SendMessageBatchRequestEntry.Builder, T>, EntryMapperFn<T> {
+ @Override
+ default SendMessageBatchRequestEntry apply(String entryId, T msg) {
+ SendMessageBatchRequestEntry.Builder builder = SendMessageBatchRequestEntry.builder();
+ accept(builder, msg);
+ return builder.id(entryId).build();
+ }
+ }
+ }
+
+ @VisibleForTesting
+ static class SchemaEntryMapper<T> implements EntryMapperFn<T> {
+ private final SerializableFunction<T, Row> toRow;
+ private final SerializableFunction<Row, SendMessageBatchRequestEntry> fromRow;
+ private final Schema schema;
+ private final int[] fieldMapping;
+
+ SchemaEntryMapper(
+ Schema sourceSchema,
+ Schema targetSchema,
+ SerializableFunction<T, Row> toRow,
+ SerializableFunction<Row, SendMessageBatchRequestEntry> fromRow) {
+ this.toRow = toRow;
+ this.fromRow = fromRow;
+ this.schema = targetSchema;
+ this.fieldMapping = new int[targetSchema.getFieldCount()];
+
+ Arrays.fill(fieldMapping, -1);
+
+ List<String> ignored = Lists.newLinkedList();
+ List<String> invalid = Lists.newLinkedList();
+
+ for (int i = 0; i < sourceSchema.getFieldCount(); i++) {
+ Field sourceField = sourceSchema.getField(i);
+ if (targetSchema.hasField(sourceField.getName())) {
+ int targetIdx = targetSchema.indexOf(sourceField.getName());
+ // make sure field types match
+ if (!sourceField.typesEqual(targetSchema.getField(targetIdx))) {
+ invalid.add(sourceField.getName());
+ }
+ fieldMapping[targetIdx] = i;
+ } else {
+ ignored.add(sourceField.getName());
+ }
+ }
+ checkState(
+ ignored.size() < sourceSchema.getFieldCount(),
+ "No fields matched, expected %s but got %s",
+ schema.getFieldNames(),
+ ignored);
+
+ checkState(invalid.isEmpty(), "Detected incompatible types for input fields: {}", invalid);
+
+ if (!ignored.isEmpty()) {
+ LOG.warn("Ignoring unmatched input fields: {}", ignored);
+ }
+ }
+
+ @Override
+ public SendMessageBatchRequestEntry apply(String entryId, T input) {
+ Row row = toRow.apply(input);
+ Object[] values = new Object[fieldMapping.length];
+ values[0] = entryId;
+ for (int i = 0; i < values.length; i++) {
+ if (fieldMapping[i] >= 0) {
+ values[i] = row.getValue(fieldMapping[i]);
+ }
+ }
+ return fromRow.apply(Row.withSchema(schema).attachValues(values));
+ }
+ }
/** Result of {@link #writeBatches}. */
public static class Result implements POutput {
@@ -451,12 +549,14 @@ public class SqsIO {
private final WriteBatches<T> spec;
private final SqsAsyncClient sqs;
private final Batches batches;
+ private final EntryMapperFn<T> entryMapper;
private final AsyncBatchWriteHandler<SendMessageBatchRequestEntry, BatchResultErrorEntry>
handler;
- BatchHandler(WriteBatches<T> spec, AwsOptions options) {
+ BatchHandler(WriteBatches<T> spec, EntryMapperFn<T> entryMapper, AwsOptions options) {
this.spec = spec;
this.sqs = buildClient(options, SqsAsyncClient.builder(), spec.clientConfiguration());
+ this.entryMapper = entryMapper;
this.handler =
AsyncBatchWriteHandler.byId(
spec.concurrentRequests(),
@@ -488,10 +588,7 @@ public class SqsIO {
}
public void process(T msg) {
- SendMessageBatchRequestEntry.Builder builder = SendMessageBatchRequestEntry.builder();
- spec.entryBuilder().accept(builder, msg);
- SendMessageBatchRequestEntry entry = builder.id(batches.nextId()).build();
-
+ SendMessageBatchRequestEntry entry = entryMapper.apply(batches.nextId(), msg);
Batch batch = batches.getLocked(msg);
batch.add(entry);
if (batch.size() >= spec.batchSize() || batch.isExpired()) {
diff --git a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sqs/SqsIOWriteBatchesTest.java b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sqs/SqsIOWriteBatchesTest.java
index aeb9122df9f..dff4b4e72c8 100644
--- a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sqs/SqsIOWriteBatchesTest.java
+++ b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sqs/SqsIOWriteBatchesTest.java
@@ -17,19 +17,26 @@
*/
package org.apache.beam.sdk.io.aws2.sqs;
+import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.concurrent.CompletableFuture.completedFuture;
import static java.util.concurrent.CompletableFuture.supplyAsync;
import static java.util.stream.Collectors.toList;
import static java.util.stream.IntStream.range;
import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkNotNull;
+import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.joda.time.Duration.millis;
+import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import java.util.Arrays;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.apache.beam.sdk.Pipeline;
@@ -38,31 +45,38 @@ import org.apache.beam.sdk.io.aws2.common.AsyncBatchWriteHandler;
import org.apache.beam.sdk.io.aws2.common.ClientConfiguration;
import org.apache.beam.sdk.io.aws2.common.RetryConfiguration;
import org.apache.beam.sdk.io.aws2.sqs.SqsIO.WriteBatches;
-import org.apache.beam.sdk.io.aws2.sqs.SqsIO.WriteBatches.EntryBuilder;
+import org.apache.beam.sdk.io.aws2.sqs.SqsIO.WriteBatches.EntryMapperFn;
+import org.apache.beam.sdk.schemas.SchemaRegistry;
import org.apache.beam.sdk.testing.ExpectedLogs;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Streams;
import org.joda.time.Duration;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
+import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
+import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.services.sqs.SqsAsyncClient;
import software.amazon.awssdk.services.sqs.SqsAsyncClientBuilder;
import software.amazon.awssdk.services.sqs.model.BatchResultErrorEntry;
+import software.amazon.awssdk.services.sqs.model.MessageAttributeValue;
+import software.amazon.awssdk.services.sqs.model.MessageSystemAttributeValue;
import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequest;
import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequestEntry;
import software.amazon.awssdk.services.sqs.model.SendMessageBatchResponse;
+import software.amazon.awssdk.services.sqs.model.SendMessageRequest;
/** Tests for {@link WriteBatches}. */
@RunWith(MockitoJUnitRunner.class)
public class SqsIOWriteBatchesTest {
- private static final EntryBuilder<String> SET_MESSAGE_BODY =
+ private static final EntryMapperFn.Builder<String> SET_MESSAGE_BODY =
SendMessageBatchRequestEntry.Builder::messageBody;
private static final SendMessageBatchResponse SUCCESS =
SendMessageBatchResponse.builder().build();
@@ -76,13 +90,80 @@ public class SqsIOWriteBatchesTest {
MockClientBuilderFactory.set(p, SqsAsyncClientBuilder.class, sqs);
}
+ @Test
+ public void testSchemaEntryMapper() throws Exception {
+ SchemaRegistry registry = p.getSchemaRegistry();
+
+ Map<String, MessageAttributeValue> attributes =
+ ImmutableMap.of("key", MessageAttributeValue.builder().stringValue("value").build());
+ Map<String, MessageSystemAttributeValue> systemAttributes =
+ ImmutableMap.of(
+ "key",
+ MessageSystemAttributeValue.builder()
+ .binaryValue(SdkBytes.fromString("bytes", UTF_8))
+ .build());
+
+ SendMessageRequest input =
+ SendMessageRequest.builder()
+ .messageBody("body")
+ .delaySeconds(3)
+ .messageAttributes(attributes)
+ .messageSystemAttributesWithStrings(systemAttributes)
+ .build();
+
+ SqsIO.WriteBatches.EntryMapperFn<SendMessageRequest> mapper =
+ new SqsIO.WriteBatches.SchemaEntryMapper<>(
+ registry.getSchema(SendMessageRequest.class),
+ registry.getSchema(SendMessageBatchRequestEntry.class),
+ registry.getToRowFunction(SendMessageRequest.class),
+ registry.getFromRowFunction(SendMessageBatchRequestEntry.class));
+
+ assertThat(mapper.apply("1", input))
+ .isEqualTo(
+ SendMessageBatchRequestEntry.builder()
+ .id("1")
+ .messageBody("body")
+ .delaySeconds(3)
+ .messageAttributes(attributes)
+ .messageSystemAttributesWithStrings(systemAttributes)
+ .build());
+ }
+
+ @Test
+ public void testWrite() {
+ // write uses writeBatches with batch size 1
+ when(sqs.sendMessageBatch(anyRequest())).thenReturn(completedFuture(SUCCESS));
+
+ SendMessageRequest.Builder msgBuilder = SendMessageRequest.builder().queueUrl("queue");
+ Set<SendMessageRequest> messages =
+ range(0, 100)
+ .mapToObj(i -> msgBuilder.messageBody("test" + i).build())
+ .collect(Collectors.toSet());
+
+ p.apply(Create.of(messages)).apply(SqsIO.write());
+ p.run().waitUntilFinish();
+
+ ArgumentCaptor<SendMessageBatchRequest> captor =
+ ArgumentCaptor.forClass(SendMessageBatchRequest.class);
+ verify(sqs, times(100)).sendMessageBatch(captor.capture());
+
+ for (SendMessageBatchRequest req : captor.getAllValues()) {
+ assertThat(req.queueUrl()).isEqualTo("queue");
+ assertThat(req.entries()).hasSize(1);
+ for (SendMessageBatchRequestEntry entry : req.entries()) {
+ assertTrue(messages.remove(msgBuilder.messageBody(entry.messageBody()).build()));
+ }
+ }
+ assertTrue(messages.isEmpty());
+ }
+
@Test
public void testWriteBatches() {
when(sqs.sendMessageBatch(anyRequest())).thenReturn(completedFuture(SUCCESS));
p.apply(Create.of(23))
.apply(ParDo.of(new CreateMessages()))
- .apply(SqsIO.writeBatches(SET_MESSAGE_BODY).to("queue"));
+ .apply(SqsIO.<String>writeBatches().withEntryMapper(SET_MESSAGE_BODY).to("queue"));
p.run().waitUntilFinish();
@@ -104,7 +185,7 @@ public class SqsIOWriteBatchesTest {
p.apply(Create.of(23))
.apply(ParDo.of(new CreateMessages()))
- .apply(SqsIO.writeBatches(SET_MESSAGE_BODY).to("queue"));
+ .apply(SqsIO.<String>writeBatches().withEntryMapper(SET_MESSAGE_BODY).to("queue"));
assertThatThrownBy(() -> p.run().waitUntilFinish())
.isInstanceOf(Pipeline.PipelineExecutionException.class)
@@ -122,7 +203,7 @@ public class SqsIOWriteBatchesTest {
p.apply(Create.of(23))
.apply(ParDo.of(new CreateMessages()))
- .apply(SqsIO.writeBatches(SET_MESSAGE_BODY).to("queue"));
+ .apply(SqsIO.<String>writeBatches().withEntryMapper(SET_MESSAGE_BODY).to("queue"));
p.run().waitUntilFinish();
@@ -145,7 +226,11 @@ public class SqsIOWriteBatchesTest {
p.apply(Create.of(8))
.apply(ParDo.of(new CreateMessages()))
- .apply(SqsIO.writeBatches(SET_MESSAGE_BODY).withBatchSize(3).to("queue"));
+ .apply(
+ SqsIO.<String>writeBatches()
+ .withEntryMapper(SET_MESSAGE_BODY)
+ .withBatchSize(3)
+ .to("queue"));
p.run().waitUntilFinish();
@@ -157,6 +242,27 @@ public class SqsIOWriteBatchesTest {
verifyNoMoreInteractions(sqs);
}
+ @Test
+ public void testWriteBatchesWithTimeout() {
+ when(sqs.sendMessageBatch(anyRequest())).thenReturn(completedFuture(SUCCESS));
+
+ p.apply(Create.of(5))
+ .apply(ParDo.of(new CreateMessages()))
+ .apply(
+ // simulate delay between messages > batch timeout
+ SqsIO.<String>writeBatches()
+ .withEntryMapper(withDelay(millis(100), SET_MESSAGE_BODY))
+ .withBatchTimeout(millis(150))
+ .to("queue"));
+
+ p.run().waitUntilFinish();
+
+ SendMessageBatchRequestEntry[] entries = entries(range(0, 5));
+ // due to added delay, batches are timed out on arrival of every 3rd msg
+ verify(sqs).sendMessageBatch(request("queue", entries[0], entries[1], entries[2]));
+ verify(sqs).sendMessageBatch(request("queue", entries[3], entries[4]));
+ }
+
@Test
public void testWriteBatchesToDynamic() {
when(sqs.sendMessageBatch(anyRequest())).thenReturn(completedFuture(SUCCESS));
@@ -167,7 +273,8 @@ public class SqsIOWriteBatchesTest {
p.apply(Create.of(10))
.apply(ParDo.of(new CreateMessages()))
.apply(
- SqsIO.writeBatches(SET_MESSAGE_BODY)
+ SqsIO.<String>writeBatches()
+ .withEntryMapper(SET_MESSAGE_BODY)
.withClientConfiguration(ClientConfiguration.builder().retry(retry).build())
.withBatchSize(3)
.to(msg -> Integer.valueOf(msg) % 2 == 0 ? "even" : "uneven"));
@@ -187,24 +294,25 @@ public class SqsIOWriteBatchesTest {
}
@Test
- public void testWriteBatchesWithTimeout() {
+ public void testWriteBatchesToDynamicWithTimeout() {
when(sqs.sendMessageBatch(anyRequest())).thenReturn(completedFuture(SUCCESS));
p.apply(Create.of(5))
.apply(ParDo.of(new CreateMessages()))
.apply(
// simulate delay between messages > batch timeout
- SqsIO.writeBatches(withDelay(millis(200), SET_MESSAGE_BODY))
- .withBatchTimeout(millis(100))
- .to("queue"));
+ SqsIO.<String>writeBatches()
+ .withEntryMapper(withDelay(millis(100), SET_MESSAGE_BODY))
+ .withBatchTimeout(millis(150))
+ .to(msg -> Integer.valueOf(msg) % 2 == 0 ? "even" : "uneven"));
p.run().waitUntilFinish();
SendMessageBatchRequestEntry[] entries = entries(range(0, 5));
- // due to added delay, batches are timed out on arrival of every 2nd msg
- verify(sqs).sendMessageBatch(request("queue", entries[0], entries[1]));
- verify(sqs).sendMessageBatch(request("queue", entries[2], entries[3]));
- verify(sqs).sendMessageBatch(request("queue", entries[4]));
+ // due to added delay, dynamic batches are timed out on arrival of every 2nd msg (per batch)
+ verify(sqs).sendMessageBatch(request("even", entries[0], entries[2]));
+ verify(sqs).sendMessageBatch(request("uneven", entries[1], entries[3]));
+ verify(sqs).sendMessageBatch(request("even", entries[4]));
}
private SendMessageBatchRequest anyRequest() {
@@ -252,7 +360,8 @@ public class SqsIOWriteBatchesTest {
}
}
- private static <T> EntryBuilder<T> withDelay(Duration delay, EntryBuilder<T> builder) {
+ private static <T> EntryMapperFn.Builder<T> withDelay(
+ Duration delay, EntryMapperFn.Builder<T> builder) {
return (t1, t2) -> {
builder.accept(t1, t2);
try {
diff --git a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sqs/SqsIOWriteTest.java b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sqs/SqsIOWriteTest.java
deleted file mode 100644
index 738cf282adf..00000000000
--- a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sqs/SqsIOWriteTest.java
+++ /dev/null
@@ -1,68 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.sdk.io.aws2.sqs;
-
-import static java.util.stream.IntStream.range;
-import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.when;
-
-import java.util.List;
-import java.util.stream.Collectors;
-import org.apache.beam.sdk.io.aws2.MockClientBuilderFactory;
-import org.apache.beam.sdk.testing.TestPipeline;
-import org.apache.beam.sdk.transforms.Create;
-import org.junit.Before;
-import org.junit.Rule;
-import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.mockito.Mock;
-import org.mockito.junit.MockitoJUnitRunner;
-import software.amazon.awssdk.services.sqs.SqsClient;
-import software.amazon.awssdk.services.sqs.SqsClientBuilder;
-import software.amazon.awssdk.services.sqs.model.SendMessageRequest;
-import software.amazon.awssdk.services.sqs.model.SendMessageResponse;
-
-/** Tests for {@link SqsIO.Write}. */
-@RunWith(MockitoJUnitRunner.class)
-public class SqsIOWriteTest {
- @Rule public TestPipeline p = TestPipeline.create();
- @Mock public SqsClient sqs;
-
- @Before
- public void configureClientBuilderFactory() {
- MockClientBuilderFactory.set(p, SqsClientBuilder.class, sqs);
- }
-
- @Test
- public void testWrite() {
- when(sqs.sendMessage(any(SendMessageRequest.class)))
- .thenReturn(SendMessageResponse.builder().build());
-
- SendMessageRequest.Builder builder = SendMessageRequest.builder().queueUrl("url");
- List<SendMessageRequest> messages =
- range(0, 100)
- .mapToObj(i -> builder.messageBody("test" + i).build())
- .collect(Collectors.toList());
-
- p.apply(Create.of(messages)).apply(SqsIO.write());
- p.run().waitUntilFinish();
-
- messages.forEach(msg -> verify(sqs).sendMessage(msg));
- }
-}
diff --git a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sqs/testing/SqsIOIT.java b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sqs/testing/SqsIOIT.java
index 2f10f3d08f1..9e78f7a15ef 100644
--- a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sqs/testing/SqsIOIT.java
+++ b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sqs/testing/SqsIOIT.java
@@ -114,7 +114,9 @@ public class SqsIOIT {
.apply("Prepare TestRows", ParDo.of(new DeterministicallyConstructTestRowFn()))
.apply(
"Write to SQS",
- SqsIO.<TestRow>writeBatches((b, row) -> b.messageBody(row.name())).to(sqsQueue.url));
+ SqsIO.<TestRow>writeBatches()
+ .withEntryMapper((b, row) -> b.messageBody(row.name()))
+ .to(sqsQueue.url));
// Read test dataset from SQS.
PCollection<String> output =