You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by lc...@apache.org on 2020/03/17 15:50:55 UTC
[beam] branch master updated: [BEAM-8374] Add alternate SnsIO
PublishResult coders
This is an automated email from the ASF dual-hosted git repository.
lcwik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 791d6f0 [BEAM-8374] Add alternate SnsIO PublishResult coders
new 2ff0bc1 Merge pull request #9758 from jfarr/sns-io-fix
791d6f0 is described below
commit 791d6f06dc2842f2afebe686d6638f3d6339f82c
Author: jfarr <jf...@godaddy.com>
AuthorDate: Wed Oct 9 23:29:40 2019 -0700
[BEAM-8374] Add alternate SnsIO PublishResult coders
---
sdks/java/io/amazon-web-services/build.gradle | 1 +
.../apache/beam/sdk/io/aws/coders/AwsCoders.java | 138 +++++++++++++++++++++
.../beam/sdk/io/aws/coders/package-info.java | 19 +++
.../beam/sdk/io/aws/sns/PublishResultCoder.java | 62 ---------
.../beam/sdk/io/aws/sns/PublishResultCoders.java | 121 ++++++++++++++++++
.../sdk/io/aws/sns/SnsCoderProviderRegistrar.java | 3 +-
.../java/org/apache/beam/sdk/io/aws/sns/SnsIO.java | 45 ++++++-
.../beam/sdk/io/aws/coders/AwsCodersTest.java | 68 ++++++++++
.../sdk/io/aws/sns/PublishResultCodersTest.java | 91 ++++++++++++++
.../org/apache/beam/sdk/io/aws/sns/SnsIOTest.java | 54 ++++++++
10 files changed, 535 insertions(+), 67 deletions(-)
diff --git a/sdks/java/io/amazon-web-services/build.gradle b/sdks/java/io/amazon-web-services/build.gradle
index 6948a58..c288a19 100644
--- a/sdks/java/io/amazon-web-services/build.gradle
+++ b/sdks/java/io/amazon-web-services/build.gradle
@@ -47,6 +47,7 @@ dependencies {
testCompile library.java.hamcrest_library
testCompile library.java.mockito_core
testCompile library.java.junit
+ testCompile "org.assertj:assertj-core:3.11.1"
testCompile 'org.elasticmq:elasticmq-rest-sqs_2.12:0.14.1'
testCompile 'org.testcontainers:localstack:1.11.2'
testRuntimeOnly library.java.slf4j_jdk14
diff --git a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/coders/AwsCoders.java b/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/coders/AwsCoders.java
new file mode 100644
index 0000000..01b54ce
--- /dev/null
+++ b/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/coders/AwsCoders.java
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.aws.coders;
+
+import com.amazonaws.ResponseMetadata;
+import com.amazonaws.http.HttpResponse;
+import com.amazonaws.http.SdkHttpMetadata;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.util.Map;
+import java.util.Optional;
+import org.apache.beam.sdk.coders.AtomicCoder;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.CoderException;
+import org.apache.beam.sdk.coders.CustomCoder;
+import org.apache.beam.sdk.coders.MapCoder;
+import org.apache.beam.sdk.coders.NullableCoder;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.coders.VarIntCoder;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
+
+/** {@link Coder}s for common AWS SDK objects. */
+public final class AwsCoders {
+
+ private AwsCoders() {}
+
+ /**
+ * Returns a new coder for ResponseMetadata.
+ *
+ * @return the ResponseMetadata coder
+ */
+ public static Coder<ResponseMetadata> responseMetadata() {
+ return ResponseMetadataCoder.of();
+ }
+
+ /**
+ * Returns a new coder for SdkHttpMetadata.
+ *
+ * @return the SdkHttpMetadata coder
+ */
+ public static Coder<SdkHttpMetadata> sdkHttpMetadata() {
+ return new SdkHttpMetadataCoder(true);
+ }
+
+ /**
+ * Returns a new coder for SdkHttpMetadata that does not serialize the response headers.
+ *
+ * @return the SdkHttpMetadata coder
+ */
+ public static Coder<SdkHttpMetadata> sdkHttpMetadataWithoutHeaders() {
+ return new SdkHttpMetadataCoder(false);
+ }
+
+ private static class ResponseMetadataCoder extends AtomicCoder<ResponseMetadata> {
+
+ private static final Coder<Map<String, String>> METADATA_ENCODER =
+ NullableCoder.of(MapCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of()));
+ private static final ResponseMetadataCoder INSTANCE = new ResponseMetadataCoder();
+
+ private ResponseMetadataCoder() {}
+
+ public static ResponseMetadataCoder of() {
+ return INSTANCE;
+ }
+
+ @Override
+ public void encode(ResponseMetadata value, OutputStream outStream)
+ throws CoderException, IOException {
+ METADATA_ENCODER.encode(
+ ImmutableMap.of(ResponseMetadata.AWS_REQUEST_ID, value.getRequestId()), outStream);
+ }
+
+ @Override
+ public ResponseMetadata decode(InputStream inStream) throws CoderException, IOException {
+ return new ResponseMetadata(METADATA_ENCODER.decode(inStream));
+ }
+ }
+
+ private static class SdkHttpMetadataCoder extends CustomCoder<SdkHttpMetadata> {
+
+ private static final Coder<Integer> STATUS_CODE_CODER = VarIntCoder.of();
+ private static final Coder<Map<String, String>> HEADERS_ENCODER =
+ NullableCoder.of(MapCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of()));
+
+ private final boolean includeHeaders;
+
+ protected SdkHttpMetadataCoder(boolean includeHeaders) {
+ this.includeHeaders = includeHeaders;
+ }
+
+ @Override
+ public void encode(SdkHttpMetadata value, OutputStream outStream)
+ throws CoderException, IOException {
+ STATUS_CODE_CODER.encode(value.getHttpStatusCode(), outStream);
+ if (includeHeaders) {
+ HEADERS_ENCODER.encode(value.getHttpHeaders(), outStream);
+ }
+ }
+
+ @Override
+ public SdkHttpMetadata decode(InputStream inStream) throws CoderException, IOException {
+ final int httpStatusCode = STATUS_CODE_CODER.decode(inStream);
+ HttpResponse httpResponse = new HttpResponse(null, null);
+ httpResponse.setStatusCode(httpStatusCode);
+ if (includeHeaders) {
+ Optional.ofNullable(HEADERS_ENCODER.decode(inStream))
+ .ifPresent(
+ headers ->
+ headers.keySet().forEach(k -> httpResponse.addHeader(k, headers.get(k))));
+ }
+ return SdkHttpMetadata.from(httpResponse);
+ }
+
+ @Override
+ public void verifyDeterministic() throws NonDeterministicException {
+ STATUS_CODE_CODER.verifyDeterministic();
+ if (includeHeaders) {
+ HEADERS_ENCODER.verifyDeterministic();
+ }
+ }
+ }
+}
diff --git a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/coders/package-info.java b/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/coders/package-info.java
new file mode 100644
index 0000000..1b76a71
--- /dev/null
+++ b/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/coders/package-info.java
@@ -0,0 +1,19 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+/** Defines common coders for Amazon Web Services. */
+package org.apache.beam.sdk.io.aws.coders;
diff --git a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sns/PublishResultCoder.java b/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sns/PublishResultCoder.java
deleted file mode 100644
index 8c2a29a..0000000
--- a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sns/PublishResultCoder.java
+++ /dev/null
@@ -1,62 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.sdk.io.aws.sns;
-
-import com.amazonaws.services.sns.model.PublishResult;
-import java.io.IOException;
-import java.io.InputStream;
-import java.io.OutputStream;
-import java.io.Serializable;
-import java.util.List;
-import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.coders.CoderException;
-import org.apache.beam.sdk.coders.StringUtf8Coder;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
-
-/** Custom Coder for handling publish result. */
-public class PublishResultCoder extends Coder<PublishResult> implements Serializable {
- private static final PublishResultCoder INSTANCE = new PublishResultCoder();
-
- private PublishResultCoder() {}
-
- static PublishResultCoder of() {
- return INSTANCE;
- }
-
- @Override
- public void encode(PublishResult value, OutputStream outStream)
- throws CoderException, IOException {
- StringUtf8Coder.of().encode(value.getMessageId(), outStream);
- }
-
- @Override
- public PublishResult decode(InputStream inStream) throws CoderException, IOException {
- final String messageId = StringUtf8Coder.of().decode(inStream);
- return new PublishResult().withMessageId(messageId);
- }
-
- @Override
- public List<? extends Coder<?>> getCoderArguments() {
- return ImmutableList.of();
- }
-
- @Override
- public void verifyDeterministic() throws NonDeterministicException {
- StringUtf8Coder.of().verifyDeterministic();
- }
-}
diff --git a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sns/PublishResultCoders.java b/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sns/PublishResultCoders.java
new file mode 100644
index 0000000..297c1c8
--- /dev/null
+++ b/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sns/PublishResultCoders.java
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.aws.sns;
+
+import com.amazonaws.ResponseMetadata;
+import com.amazonaws.http.SdkHttpMetadata;
+import com.amazonaws.services.sns.model.PublishResult;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.CoderException;
+import org.apache.beam.sdk.coders.CustomCoder;
+import org.apache.beam.sdk.coders.NullableCoder;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.io.aws.coders.AwsCoders;
+
+/** Coders for SNS {@link PublishResult}. */
+public final class PublishResultCoders {
+
+ private static final Coder<String> MESSAGE_ID_CODER = StringUtf8Coder.of();
+ private static final Coder<ResponseMetadata> RESPONSE_METADATA_CODER =
+ NullableCoder.of(AwsCoders.responseMetadata());
+
+ private PublishResultCoders() {}
+
+ /**
+ * Returns a new PublishResult coder which by default serializes only the messageId.
+ *
+ * @return the PublishResult coder
+ */
+ public static Coder<PublishResult> defaultPublishResult() {
+ return new PublishResultCoder(null, null);
+ }
+
+ /**
+ * Returns a new PublishResult coder which serializes the sdkResponseMetadata and sdkHttpMetadata,
+ * including the HTTP response headers.
+ *
+ * @return the PublishResult coder
+ */
+ public static Coder<PublishResult> fullPublishResult() {
+ return new PublishResultCoder(
+ RESPONSE_METADATA_CODER, NullableCoder.of(AwsCoders.sdkHttpMetadata()));
+ }
+
+ /**
+ * Returns a new PublishResult coder which serializes the sdkResponseMetadata and sdkHttpMetadata,
+ * but does not include the HTTP response headers.
+ *
+ * @return the PublishResult coder
+ */
+ public static Coder<PublishResult> fullPublishResultWithoutHeaders() {
+ return new PublishResultCoder(
+ RESPONSE_METADATA_CODER, NullableCoder.of(AwsCoders.sdkHttpMetadataWithoutHeaders()));
+ }
+
+ static class PublishResultCoder extends CustomCoder<PublishResult> {
+
+ private final Coder<ResponseMetadata> responseMetadataEncoder;
+ private final Coder<SdkHttpMetadata> sdkHttpMetadataCoder;
+
+ private PublishResultCoder(
+ Coder<ResponseMetadata> responseMetadataEncoder,
+ Coder<SdkHttpMetadata> sdkHttpMetadataCoder) {
+ this.responseMetadataEncoder = responseMetadataEncoder;
+ this.sdkHttpMetadataCoder = sdkHttpMetadataCoder;
+ }
+
+ @Override
+ public void encode(PublishResult value, OutputStream outStream)
+ throws CoderException, IOException {
+ MESSAGE_ID_CODER.encode(value.getMessageId(), outStream);
+ if (responseMetadataEncoder != null) {
+ responseMetadataEncoder.encode(value.getSdkResponseMetadata(), outStream);
+ }
+ if (sdkHttpMetadataCoder != null) {
+ sdkHttpMetadataCoder.encode(value.getSdkHttpMetadata(), outStream);
+ }
+ }
+
+ @Override
+ public PublishResult decode(InputStream inStream) throws CoderException, IOException {
+ String messageId = MESSAGE_ID_CODER.decode(inStream);
+ PublishResult publishResult = new PublishResult().withMessageId(messageId);
+ if (responseMetadataEncoder != null) {
+ publishResult.setSdkResponseMetadata(responseMetadataEncoder.decode(inStream));
+ }
+ if (sdkHttpMetadataCoder != null) {
+ publishResult.setSdkHttpMetadata(sdkHttpMetadataCoder.decode(inStream));
+ }
+ return publishResult;
+ }
+
+ @Override
+ public void verifyDeterministic() throws NonDeterministicException {
+ MESSAGE_ID_CODER.verifyDeterministic();
+ if (responseMetadataEncoder != null) {
+ responseMetadataEncoder.verifyDeterministic();
+ }
+ if (sdkHttpMetadataCoder != null) {
+ sdkHttpMetadataCoder.verifyDeterministic();
+ }
+ }
+ }
+}
diff --git a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sns/SnsCoderProviderRegistrar.java b/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sns/SnsCoderProviderRegistrar.java
index 3a17cdd..9893a02 100644
--- a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sns/SnsCoderProviderRegistrar.java
+++ b/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sns/SnsCoderProviderRegistrar.java
@@ -32,6 +32,7 @@ public class SnsCoderProviderRegistrar implements CoderProviderRegistrar {
@Override
public List<CoderProvider> getCoderProviders() {
return ImmutableList.of(
- CoderProviders.forCoder(TypeDescriptor.of(PublishResult.class), PublishResultCoder.of()));
+ CoderProviders.forCoder(
+ TypeDescriptor.of(PublishResult.class), PublishResultCoders.defaultPublishResult()));
}
}
diff --git a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sns/SnsIO.java b/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sns/SnsIO.java
index 233431d..35bf2d5 100644
--- a/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sns/SnsIO.java
+++ b/sdks/java/io/amazon-web-services/src/main/java/org/apache/beam/sdk/io/aws/sns/SnsIO.java
@@ -32,6 +32,7 @@ import java.util.function.Predicate;
import javax.annotation.Nullable;
import org.apache.beam.sdk.annotations.Experimental;
import org.apache.beam.sdk.annotations.Experimental.Kind;
+import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.metrics.Counter;
import org.apache.beam.sdk.metrics.Metrics;
import org.apache.beam.sdk.transforms.DoFn;
@@ -79,6 +80,11 @@ import org.slf4j.LoggerFactory;
* <li>need to specify AwsClientsProvider. You can pass on the default one BasicSnsProvider
* <li>an output tag where you can get results. Example in SnsIOTest
* </ul>
+ *
+ * <p>By default, the output PublishResult contains only the messageId, all other fields are null.
+ * If you need the full ResponseMetadata and SdkHttpMetadata you can call {@link
+ * Write#withFullPublishResult}. If you need the HTTP status code but not the response headers you
+ * can call {@link Write#withFullPublishResultWithoutHeaders}.
*/
@Experimental(Kind.SOURCE_SINK)
public final class SnsIO {
@@ -136,7 +142,7 @@ public final class SnsIO {
/**
* An interface used to control if we retry the SNS Publish call when a {@link Throwable}
* occurs. If {@link RetryPredicate#test(Object)} returns true, {@link Write} tries to resend
- * the requests to the Solr server if the {@link RetryConfiguration} permits it.
+ * the requests to SNS if the {@link RetryConfiguration} permits it.
*/
@FunctionalInterface
interface RetryPredicate extends Predicate<Throwable>, Serializable {}
@@ -171,6 +177,9 @@ public final class SnsIO {
@Nullable
abstract TupleTag<PublishResult> getResultOutputTag();
+ @Nullable
+ abstract Coder getCoder();
+
abstract Builder builder();
@AutoValue.Builder
@@ -184,6 +193,8 @@ public final class SnsIO {
abstract Builder setResultOutputTag(TupleTag<PublishResult> results);
+ abstract Builder setCoder(Coder coder);
+
abstract Write build();
}
@@ -257,12 +268,38 @@ public final class SnsIO {
return builder().setResultOutputTag(results).build();
}
+ /**
+ * Encode the full {@code PublishResult} object, including sdkResponseMetadata and
+ * sdkHttpMetadata with the HTTP response headers.
+ */
+ public Write withFullPublishResult() {
+ return withCoder(PublishResultCoders.fullPublishResult());
+ }
+
+ /**
+ * Encode the full {@code PublishResult} object, including sdkResponseMetadata and
+ * sdkHttpMetadata but excluding the HTTP response headers.
+ */
+ public Write withFullPublishResultWithoutHeaders() {
+ return withCoder(PublishResultCoders.fullPublishResultWithoutHeaders());
+ }
+
+ /** Encode the {@code PublishResult} with the given coder. */
+ public Write withCoder(Coder<PublishResult> coder) {
+ return builder().setCoder(coder).build();
+ }
+
@Override
public PCollectionTuple expand(PCollection<PublishRequest> input) {
checkArgument(getTopicName() != null, "withTopicName() is required");
- return input.apply(
- ParDo.of(new SnsWriterFn(this))
- .withOutputTags(getResultOutputTag(), TupleTagList.empty()));
+ PCollectionTuple result =
+ input.apply(
+ ParDo.of(new SnsWriterFn(this))
+ .withOutputTags(getResultOutputTag(), TupleTagList.empty()));
+ if (getCoder() != null) {
+ result.get(getResultOutputTag()).setCoder(getCoder());
+ }
+ return result;
}
static class SnsWriterFn extends DoFn<PublishRequest, PublishResult> {
diff --git a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/coders/AwsCodersTest.java b/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/coders/AwsCodersTest.java
new file mode 100644
index 0000000..2baaa21
--- /dev/null
+++ b/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/coders/AwsCodersTest.java
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.aws.coders;
+
+import static org.hamcrest.Matchers.equalTo;
+import static org.junit.Assert.assertThat;
+
+import com.amazonaws.ResponseMetadata;
+import com.amazonaws.http.HttpResponse;
+import com.amazonaws.http.SdkHttpMetadata;
+import java.util.UUID;
+import org.apache.beam.sdk.util.CoderUtils;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
+import org.junit.Test;
+
+/** Tests for AWS coders. */
+public class AwsCodersTest {
+
+ @Test
+ public void testResponseMetadataDecodeEncodeEquals() throws Exception {
+ ResponseMetadata value = buildResponseMetadata();
+ ResponseMetadata clone = CoderUtils.clone(AwsCoders.responseMetadata(), value);
+ assertThat(clone.getRequestId(), equalTo(value.getRequestId()));
+ }
+
+ @Test
+ public void testSdkHttpMetadataDecodeEncodeEquals() throws Exception {
+ SdkHttpMetadata value = buildSdkHttpMetadata();
+ SdkHttpMetadata clone = CoderUtils.clone(AwsCoders.sdkHttpMetadata(), value);
+ assertThat(clone.getHttpStatusCode(), equalTo(value.getHttpStatusCode()));
+ assertThat(clone.getHttpHeaders(), equalTo(value.getHttpHeaders()));
+ }
+
+ @Test
+ public void testSdkHttpMetadataWithoutHeadersDecodeEncodeEquals() throws Exception {
+ SdkHttpMetadata value = buildSdkHttpMetadata();
+ SdkHttpMetadata clone = CoderUtils.clone(AwsCoders.sdkHttpMetadataWithoutHeaders(), value);
+ assertThat(clone.getHttpStatusCode(), equalTo(value.getHttpStatusCode()));
+ assertThat(clone.getHttpHeaders().isEmpty(), equalTo(true));
+ }
+
+ private ResponseMetadata buildResponseMetadata() {
+ return new ResponseMetadata(
+ ImmutableMap.of(ResponseMetadata.AWS_REQUEST_ID, UUID.randomUUID().toString()));
+ }
+
+ private SdkHttpMetadata buildSdkHttpMetadata() {
+ HttpResponse httpResponse = new HttpResponse(null, null);
+ httpResponse.setStatusCode(200);
+ httpResponse.addHeader("Content-Type", "application/json");
+ return SdkHttpMetadata.from(httpResponse);
+ }
+}
diff --git a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/sns/PublishResultCodersTest.java b/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/sns/PublishResultCodersTest.java
new file mode 100644
index 0000000..e21456d
--- /dev/null
+++ b/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/sns/PublishResultCodersTest.java
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.aws.sns;
+
+import static org.hamcrest.Matchers.equalTo;
+import static org.junit.Assert.assertThat;
+
+import com.amazonaws.ResponseMetadata;
+import com.amazonaws.http.HttpResponse;
+import com.amazonaws.http.SdkHttpMetadata;
+import com.amazonaws.services.sns.model.PublishResult;
+import java.util.UUID;
+import org.apache.beam.sdk.testing.CoderProperties;
+import org.apache.beam.sdk.util.CoderUtils;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
+import org.junit.Test;
+
+/** Tests for PublishResult coders. */
+public class PublishResultCodersTest {
+
+ @Test
+ public void testDefaultPublishResultDecodeEncodeEquals() throws Exception {
+ CoderProperties.coderDecodeEncodeEqual(
+ PublishResultCoders.defaultPublishResult(),
+ new PublishResult().withMessageId(UUID.randomUUID().toString()));
+ }
+
+ @Test
+ public void testFullPublishResultWithoutHeadersDecodeEncodeEquals() throws Exception {
+ CoderProperties.coderDecodeEncodeEqual(
+ PublishResultCoders.fullPublishResultWithoutHeaders(),
+ new PublishResult().withMessageId(UUID.randomUUID().toString()));
+
+ PublishResult value = buildFullPublishResult();
+ PublishResult clone =
+ CoderUtils.clone(PublishResultCoders.fullPublishResultWithoutHeaders(), value);
+ assertThat(
+ clone.getSdkResponseMetadata().getRequestId(),
+ equalTo(value.getSdkResponseMetadata().getRequestId()));
+ assertThat(
+ clone.getSdkHttpMetadata().getHttpStatusCode(),
+ equalTo(value.getSdkHttpMetadata().getHttpStatusCode()));
+ assertThat(clone.getSdkHttpMetadata().getHttpHeaders().isEmpty(), equalTo(true));
+ }
+
+ @Test
+ public void testFullPublishResultIncludingHeadersDecodeEncodeEquals() throws Exception {
+ CoderProperties.coderDecodeEncodeEqual(
+ PublishResultCoders.fullPublishResult(),
+ new PublishResult().withMessageId(UUID.randomUUID().toString()));
+
+ PublishResult value = buildFullPublishResult();
+ PublishResult clone = CoderUtils.clone(PublishResultCoders.fullPublishResult(), value);
+ assertThat(
+ clone.getSdkResponseMetadata().getRequestId(),
+ equalTo(value.getSdkResponseMetadata().getRequestId()));
+ assertThat(
+ clone.getSdkHttpMetadata().getHttpStatusCode(),
+ equalTo(value.getSdkHttpMetadata().getHttpStatusCode()));
+ assertThat(
+ clone.getSdkHttpMetadata().getHttpHeaders(),
+ equalTo(value.getSdkHttpMetadata().getHttpHeaders()));
+ }
+
+ private PublishResult buildFullPublishResult() {
+ PublishResult publishResult = new PublishResult().withMessageId(UUID.randomUUID().toString());
+ publishResult.setSdkResponseMetadata(
+ new ResponseMetadata(
+ ImmutableMap.of(ResponseMetadata.AWS_REQUEST_ID, UUID.randomUUID().toString())));
+ HttpResponse httpResponse = new HttpResponse(null, null);
+ httpResponse.setStatusCode(200);
+ httpResponse.addHeader("Content-Type", "application/json");
+ publishResult.setSdkHttpMetadata(SdkHttpMetadata.from(httpResponse));
+ return publishResult;
+ }
+}
diff --git a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/sns/SnsIOTest.java b/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/sns/SnsIOTest.java
index 4a39d31..db28f3d 100644
--- a/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/sns/SnsIOTest.java
+++ b/sdks/java/io/amazon-web-services/src/test/java/org/apache/beam/sdk/io/aws/sns/SnsIOTest.java
@@ -17,6 +17,7 @@
*/
package org.apache.beam.sdk.io.aws.sns;
+import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.Assert.fail;
import com.amazonaws.http.SdkHttpMetadata;
@@ -26,18 +27,25 @@ import com.amazonaws.services.sns.model.GetTopicAttributesResult;
import com.amazonaws.services.sns.model.InternalErrorException;
import com.amazonaws.services.sns.model.PublishRequest;
import com.amazonaws.services.sns.model.PublishResult;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
import java.io.Serializable;
import java.util.HashMap;
import java.util.UUID;
import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.coders.AtomicCoder;
+import org.apache.beam.sdk.coders.CoderException;
import org.apache.beam.sdk.testing.ExpectedLogs;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Count;
import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.MapElements;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.TupleTag;
+import org.apache.beam.sdk.values.TypeDescriptors;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
import org.junit.Rule;
import org.junit.Test;
@@ -131,6 +139,52 @@ public class SnsIOTest implements Serializable {
fail("Pipeline is expected to fail because we were unable to write to SNS.");
}
+ @Test
+ public void testCustomCoder() throws Exception {
+ final PublishRequest request1 = createSampleMessage("my_first_message");
+
+ final TupleTag<PublishResult> results = new TupleTag<>();
+ final AmazonSNS amazonSnsSuccess = getAmazonSnsMockSuccess();
+ final MockCoder mockCoder = new MockCoder();
+
+ final PCollectionTuple snsWrites =
+ p.apply(Create.of(request1))
+ .apply(
+ SnsIO.write()
+ .withTopicName(topicName)
+ .withAWSClientsProvider(new Provider(amazonSnsSuccess))
+ .withResultOutputTag(results)
+ .withCoder(mockCoder));
+
+ final PCollection<Long> publishedResultsSize =
+ snsWrites
+ .get(results)
+ .apply(MapElements.into(TypeDescriptors.strings()).via(result -> result.getMessageId()))
+ .apply(Count.globally());
+ PAssert.that(publishedResultsSize).containsInAnyOrder(ImmutableList.of(1L));
+ p.run().waitUntilFinish();
+ assertThat(mockCoder.captured).isNotNull();
+ }
+
+ // Hand-code mock because Mockito mocks cause NotSerializableException even with
+ // withSettings().serializable().
+ private static class MockCoder extends AtomicCoder<PublishResult> {
+
+ private PublishResult captured;
+
+ @Override
+ public void encode(PublishResult value, OutputStream outStream)
+ throws CoderException, IOException {
+ this.captured = value;
+ PublishResultCoders.defaultPublishResult().encode(value, outStream);
+ }
+
+ @Override
+ public PublishResult decode(InputStream inStream) throws CoderException, IOException {
+ return PublishResultCoders.defaultPublishResult().decode(inStream);
+ }
+ };
+
private static AmazonSNS getAmazonSnsMockSuccess() {
final AmazonSNS amazonSNS = Mockito.mock(AmazonSNS.class);
configureAmazonSnsMock(amazonSNS);