You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by da...@apache.org on 2017/05/04 07:17:09 UTC

[03/50] [abbrv] beam git commit: Removed coder and formatFn from PubsubIO.Write

Removed coder and formatFn from PubsubIO.Write


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/429c6131
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/429c6131
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/429c6131

Branch: refs/heads/DSL_SQL
Commit: 429c61310f65ff44b8bb3b96799f63dd7b1377f3
Parents: 25dc94b
Author: Eugene Kirpichov <ki...@google.com>
Authored: Thu Apr 20 19:27:28 2017 -0700
Committer: Eugene Kirpichov <ki...@google.com>
Committed: Tue May 2 23:08:29 2017 -0700

----------------------------------------------------------------------
 .../beam/runners/dataflow/DataflowRunner.java   |  55 +++++-----
 .../apache/beam/sdk/io/gcp/pubsub/PubsubIO.java |  72 +++++++------
 .../sdk/io/gcp/pubsub/PubsubUnboundedSink.java  | 102 ++++++-------------
 .../io/gcp/pubsub/PubsubUnboundedSinkTest.java  |  50 +++++----
 4 files changed, 126 insertions(+), 153 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/429c6131/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
----------------------------------------------------------------------
diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
index 6f29797..6aaa11b 100644
--- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
+++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
@@ -960,26 +960,27 @@ public class DataflowRunner extends PipelineRunner<DataflowPipelineJob> {
   }
 
   /**
-   * Suppress application of {@link PubsubUnboundedSink#expand} in streaming mode so that we
-   * can instead defer to Windmill's implementation.
+   * Suppress application of {@link PubsubUnboundedSink#expand} in streaming mode so that we can
+   * instead defer to Windmill's implementation.
    */
-  private static class StreamingPubsubIOWrite<T> extends PTransform<PCollection<T>, PDone> {
-    private final PubsubUnboundedSink<T> transform;
+  private static class StreamingPubsubIOWrite
+      extends PTransform<PCollection<PubsubIO.PubsubMessage>, PDone> {
+    private final PubsubUnboundedSink transform;
 
     /**
      * Builds an instance of this class from the overridden transform.
      */
     public StreamingPubsubIOWrite(
-        DataflowRunner runner, PubsubUnboundedSink<T> transform) {
+        DataflowRunner runner, PubsubUnboundedSink transform) {
       this.transform = transform;
     }
 
-    PubsubUnboundedSink<T> getOverriddenTransform() {
+    PubsubUnboundedSink getOverriddenTransform() {
       return transform;
     }
 
     @Override
-    public PDone expand(PCollection<T> input) {
+    public PDone expand(PCollection<PubsubIO.PubsubMessage> input) {
       return PDone.in(input.getPipeline());
     }
 
@@ -990,23 +991,23 @@ public class DataflowRunner extends PipelineRunner<DataflowPipelineJob> {
 
     static {
       DataflowPipelineTranslator.registerTransformTranslator(
-          StreamingPubsubIOWrite.class, new StreamingPubsubIOWriteTranslator<>());
+          StreamingPubsubIOWrite.class, new StreamingPubsubIOWriteTranslator());
     }
   }
 
   /**
    * Rewrite {@link StreamingPubsubIOWrite} to the appropriate internal node.
    */
-  private static class StreamingPubsubIOWriteTranslator<T> implements
-      TransformTranslator<StreamingPubsubIOWrite<T>> {
+  private static class StreamingPubsubIOWriteTranslator implements
+      TransformTranslator<StreamingPubsubIOWrite> {
 
     @Override
     public void translate(
-        StreamingPubsubIOWrite<T> transform,
+        StreamingPubsubIOWrite transform,
         TranslationContext context) {
       checkArgument(context.getPipelineOptions().isStreaming(),
                     "StreamingPubsubIOWrite is only for streaming pipelines.");
-      PubsubUnboundedSink<T> overriddenTransform = transform.getOverriddenTransform();
+      PubsubUnboundedSink overriddenTransform = transform.getOverriddenTransform();
       StepTranslationContext stepContext = context.addStep(transform, "ParallelWrite");
       stepContext.addInput(PropertyNames.FORMAT, "pubsub");
       if (overriddenTransform.getTopicProvider().isAccessible()) {
@@ -1025,19 +1026,10 @@ public class DataflowRunner extends PipelineRunner<DataflowPipelineJob> {
         stepContext.addInput(
             PropertyNames.PUBSUB_ID_ATTRIBUTE, overriddenTransform.getIdAttribute());
       }
-      if (overriddenTransform.getFormatFn() != null) {
-        stepContext.addInput(
-            PropertyNames.PUBSUB_SERIALIZED_ATTRIBUTES_FN,
-            byteArrayToJsonString(serializeToByteArray(overriddenTransform.getFormatFn())));
-        // No coder is needed in this case since the formatFn formats directly into a byte[],
-        // however the Dataflow backend require a coder to be set.
-        stepContext.addEncodingInput(WindowedValue.getValueOnlyCoder(VoidCoder.of()));
-      } else if (overriddenTransform.getElementCoder() != null) {
-        stepContext.addEncodingInput(WindowedValue.getValueOnlyCoder(
-            overriddenTransform.getElementCoder()));
-      }
-      PCollection<T> input = context.getInput(transform);
-      stepContext.addInput(PropertyNames.PARALLEL_INPUT, input);
+      // No coder is needed in this case since the collection being written is already of
+      // PubsubMessage, however the Dataflow backend require a coder to be set.
+      stepContext.addEncodingInput(WindowedValue.getValueOnlyCoder(VoidCoder.of()));
+      stepContext.addInput(PropertyNames.PARALLEL_INPUT, context.getInput(transform));
     }
   }
 
@@ -1331,8 +1323,9 @@ public class DataflowRunner extends PipelineRunner<DataflowPipelineJob> {
     }
   }
 
-  private class StreamingPubsubIOWriteOverrideFactory<T>
-      implements PTransformOverrideFactory<PCollection<T>, PDone, PubsubUnboundedSink<T>> {
+  private class StreamingPubsubIOWriteOverrideFactory
+      implements PTransformOverrideFactory<
+          PCollection<PubsubIO.PubsubMessage>, PDone, PubsubUnboundedSink> {
     private final DataflowRunner runner;
 
     private StreamingPubsubIOWriteOverrideFactory(DataflowRunner runner) {
@@ -1340,11 +1333,13 @@ public class DataflowRunner extends PipelineRunner<DataflowPipelineJob> {
     }
 
     @Override
-    public PTransformReplacement<PCollection<T>, PDone> getReplacementTransform(
-        AppliedPTransform<PCollection<T>, PDone, PubsubUnboundedSink<T>> transform) {
+    public PTransformReplacement<PCollection<PubsubIO.PubsubMessage>, PDone>
+        getReplacementTransform(
+            AppliedPTransform<PCollection<PubsubIO.PubsubMessage>, PDone, PubsubUnboundedSink>
+                transform) {
       return PTransformReplacement.of(
           PTransformReplacements.getSingletonMainInput(transform),
-          new StreamingPubsubIOWrite<>(runner, transform.getTransform()));
+          new StreamingPubsubIOWrite(runner, transform.getTransform()));
     }
 
     @Override

http://git-wip-us.apache.org/repos/asf/beam/blob/429c6131/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubIO.java
----------------------------------------------------------------------
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubIO.java
index af8b7d6..1c3de76 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubIO.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubIO.java
@@ -21,6 +21,7 @@ import static com.google.common.base.Preconditions.checkNotNull;
 import static com.google.common.base.Preconditions.checkState;
 
 import com.google.auto.value.AutoValue;
+import com.google.common.collect.ImmutableMap;
 import com.google.protobuf.Message;
 import java.io.IOException;
 import java.io.Serializable;
@@ -537,7 +538,7 @@ public class PubsubIO {
    * stream.
    */
   public static Write<String> writeStrings() {
-    return PubsubIO.<String>write().withCoder(StringUtf8Coder.of());
+    return PubsubIO.<String>write().withFormatFn(new FormatPayloadAsUtf8());
   }
 
   /**
@@ -545,7 +546,9 @@ public class PubsubIO {
    * to a Google Cloud Pub/Sub stream.
    */
   public static <T extends Message> Write<T> writeProtos(Class<T> messageClass) {
-    return PubsubIO.<T>write().withCoder(ProtoCoder.of(messageClass));
+    // TODO: Like in readProtos(), stop using ProtoCoder and instead format the payload directly.
+    return PubsubIO.<T>write()
+        .withFormatFn(new FormatPayloadUsingCoder<>(ProtoCoder.of(messageClass)));
   }
 
   /**
@@ -553,7 +556,8 @@ public class PubsubIO {
    * to a Google Cloud Pub/Sub stream.
    */
   public static <T extends Message> Write<T> writeAvros(Class<T> clazz) {
-    return PubsubIO.<T>write().withCoder(AvroCoder.of(clazz));
+    // TODO: Like in readAvros(), stop using AvroCoder and instead format the payload directly.
+    return PubsubIO.<T>write().withFormatFn(new FormatPayloadUsingCoder<>(AvroCoder.of(clazz)));
   }
 
   /** Implementation of {@link #read}. */
@@ -801,10 +805,6 @@ public class PubsubIO {
     @Nullable
     abstract String getIdAttribute();
 
-    /** The input type Coder. */
-    @Nullable
-    abstract Coder<T> getCoder();
-
     /** The format function for input PubsubMessage objects. */
     @Nullable
     abstract SimpleFunction<T, PubsubMessage> getFormatFn();
@@ -819,8 +819,6 @@ public class PubsubIO {
 
       abstract Builder<T> setIdAttribute(String idAttribute);
 
-      abstract Builder<T> setCoder(Coder<T> coder);
-
       abstract Builder<T> setFormatFn(SimpleFunction<T, PubsubMessage> formatFn);
 
       abstract Write<T> build();
@@ -872,14 +870,6 @@ public class PubsubIO {
     }
 
     /**
-     * Uses the given {@link Coder} to encode each of the elements of the input {@link PCollection}
-     * into an output record.
-     */
-    public Write<T> withCoder(Coder<T> coder) {
-      return toBuilder().setCoder(coder).build();
-    }
-
-    /**
      * Used to write a PubSub message together with PubSub attributes. The user-supplied format
      * function translates the input type T to a PubsubMessage object, which is used by the sink
      * to separately set the PubSub message's payload and attributes.
@@ -898,13 +888,11 @@ public class PubsubIO {
           input.apply(ParDo.of(new PubsubBoundedWriter()));
           return PDone.in(input.getPipeline());
         case UNBOUNDED:
-          return input.apply(new PubsubUnboundedSink<T>(
+          return input.apply(MapElements.via(getFormatFn())).apply(new PubsubUnboundedSink(
               FACTORY,
               NestedValueProvider.of(getTopicProvider(), new TopicPathTranslator()),
-              getCoder(),
               getTimestampAttribute(),
               getIdAttribute(),
-              getFormatFn(),
               100 /* numShards */));
       }
       throw new RuntimeException(); // cases are exhaustive.
@@ -944,19 +932,12 @@ public class PubsubIO {
 
       @ProcessElement
       public void processElement(ProcessContext c) throws IOException {
-        byte[] payload = null;
-        Map<String, String> attributes = null;
-        if (getFormatFn() != null) {
-          PubsubMessage message = getFormatFn().apply(c.element());
-          payload = message.getMessage();
-          attributes = message.getAttributeMap();
-        } else {
-          payload = CoderUtils.encodeToByteArray(getCoder(), c.element());
-        }
+        byte[] payload;
+        PubsubMessage message = getFormatFn().apply(c.element());
+        payload = message.getMessage();
+        Map<String, String> attributes = message.getAttributeMap();
         // NOTE: The record id is always null.
-        OutgoingMessage message =
-            new OutgoingMessage(payload, attributes, c.timestamp().getMillis(), null);
-        output.add(message);
+        output.add(new OutgoingMessage(payload, attributes, c.timestamp().getMillis(), null));
 
         if (output.size() >= MAX_PUBLISH_BATCH_SIZE) {
           publish();
@@ -1016,6 +997,33 @@ public class PubsubIO {
     }
   }
 
+  private static class FormatPayloadAsUtf8 extends SimpleFunction<String, PubsubMessage> {
+    @Override
+    public PubsubMessage apply(String input) {
+      return new PubsubMessage(
+          input.getBytes(StandardCharsets.UTF_8), ImmutableMap.<String, String>of());
+    }
+  }
+
+  private static class FormatPayloadUsingCoder<T extends Message>
+      extends SimpleFunction<T, PubsubMessage> {
+    private Coder<T> coder;
+
+    public FormatPayloadUsingCoder(Coder<T> coder) {
+      this.coder = coder;
+    }
+
+    @Override
+    public PubsubMessage apply(T input) {
+      try {
+        return new PubsubMessage(
+            CoderUtils.encodeToByteArray(coder, input), ImmutableMap.<String, String>of());
+      } catch (CoderException e) {
+        throw new RuntimeException("Could not decode Pubsub message", e);
+      }
+    }
+  }
+
   private static class IdentityMessageFn extends SimpleFunction<PubsubMessage, PubsubMessage> {
     @Override
     public PubsubMessage apply(PubsubMessage input) {

http://git-wip-us.apache.org/repos/asf/beam/blob/429c6131/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubUnboundedSink.java
----------------------------------------------------------------------
diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubUnboundedSink.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubUnboundedSink.java
index 8d273ba..67530ec 100644
--- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubUnboundedSink.java
+++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubUnboundedSink.java
@@ -21,7 +21,6 @@ package org.apache.beam.sdk.io.gcp.pubsub;
 import static com.google.common.base.Preconditions.checkState;
 
 import com.google.common.annotations.VisibleForTesting;
-import com.google.common.collect.ImmutableMap;
 import com.google.common.hash.Hashing;
 import java.io.IOException;
 import java.io.InputStream;
@@ -53,7 +52,6 @@ import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.GroupByKey;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.ParDo;
-import org.apache.beam.sdk.transforms.SimpleFunction;
 import org.apache.beam.sdk.transforms.display.DisplayData;
 import org.apache.beam.sdk.transforms.display.DisplayData.Builder;
 import org.apache.beam.sdk.transforms.windowing.AfterFirst;
@@ -62,7 +60,6 @@ import org.apache.beam.sdk.transforms.windowing.AfterProcessingTime;
 import org.apache.beam.sdk.transforms.windowing.GlobalWindows;
 import org.apache.beam.sdk.transforms.windowing.Repeatedly;
 import org.apache.beam.sdk.transforms.windowing.Window;
-import org.apache.beam.sdk.util.CoderUtils;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PDone;
@@ -84,7 +81,7 @@ import org.joda.time.Duration;
  * to dedup messages.
  * </ul>
  */
-public class PubsubUnboundedSink<T> extends PTransform<PCollection<T>, PDone> {
+public class PubsubUnboundedSink extends PTransform<PCollection<PubsubIO.PubsubMessage>, PDone> {
   /**
    * Default maximum number of messages per publish.
    */
@@ -157,33 +154,22 @@ public class PubsubUnboundedSink<T> extends PTransform<PCollection<T>, PDone> {
   /**
    * Convert elements to messages and shard them.
    */
-  private static class ShardFn<T> extends DoFn<T, KV<Integer, OutgoingMessage>> {
+  private static class ShardFn extends DoFn<PubsubIO.PubsubMessage, KV<Integer, OutgoingMessage>> {
     private final Counter elementCounter = Metrics.counter(ShardFn.class, "elements");
-    private final Coder<T> elementCoder;
     private final int numShards;
     private final RecordIdMethod recordIdMethod;
-    private final SimpleFunction<T, PubsubIO.PubsubMessage> formatFn;
 
-    ShardFn(Coder<T> elementCoder, int numShards,
-            SimpleFunction<T, PubsubIO.PubsubMessage> formatFn, RecordIdMethod recordIdMethod) {
-      this.elementCoder = elementCoder;
+    ShardFn(int numShards, RecordIdMethod recordIdMethod) {
       this.numShards = numShards;
-      this.formatFn = formatFn;
       this.recordIdMethod = recordIdMethod;
     }
 
     @ProcessElement
     public void processElement(ProcessContext c) throws Exception {
       elementCounter.inc();
-      byte[] elementBytes = null;
-      Map<String, String> attributes = ImmutableMap.<String, String>of();
-      if (formatFn != null) {
-        PubsubIO.PubsubMessage message = formatFn.apply(c.element());
-        elementBytes = message.getMessage();
-        attributes = message.getAttributeMap();
-      } else {
-        elementBytes = CoderUtils.encodeToByteArray(elementCoder, c.element());
-      }
+      PubsubIO.PubsubMessage message = c.element();
+      byte[] elementBytes = message.getMessage();
+      Map<String, String> attributes = message.getAttributeMap();
 
       long timestampMsSinceEpoch = c.timestamp().getMillis();
       @Nullable String recordId = null;
@@ -335,12 +321,6 @@ public class PubsubUnboundedSink<T> extends PTransform<PCollection<T>, PDone> {
   private final ValueProvider<TopicPath> topic;
 
   /**
-   * Coder for elements. It is the responsibility of the underlying Pubsub transport to
-   * re-encode element bytes if necessary, eg as Base64 strings.
-   */
-  private final Coder<T> elementCoder;
-
-  /**
    * Pubsub metadata field holding timestamp of each element, or {@literal null} if should use
    * Pubsub message publish timestamp instead.
    */
@@ -383,49 +363,37 @@ public class PubsubUnboundedSink<T> extends PTransform<PCollection<T>, PDone> {
    */
   private final RecordIdMethod recordIdMethod;
 
-  /**
-   * In order to publish attributes, a formatting function is used to format the output into
-   * a {@link PubsubIO.PubsubMessage}.
-   */
-  private final SimpleFunction<T, PubsubIO.PubsubMessage> formatFn;
-
   @VisibleForTesting
   PubsubUnboundedSink(
       PubsubClientFactory pubsubFactory,
       ValueProvider<TopicPath> topic,
-      Coder<T> elementCoder,
       String timestampAttribute,
       String idAttribute,
       int numShards,
       int publishBatchSize,
       int publishBatchBytes,
       Duration maxLatency,
-      SimpleFunction<T, PubsubIO.PubsubMessage> formatFn,
       RecordIdMethod recordIdMethod) {
     this.pubsubFactory = pubsubFactory;
     this.topic = topic;
-    this.elementCoder = elementCoder;
     this.timestampAttribute = timestampAttribute;
     this.idAttribute = idAttribute;
     this.numShards = numShards;
     this.publishBatchSize = publishBatchSize;
     this.publishBatchBytes = publishBatchBytes;
     this.maxLatency = maxLatency;
-    this.formatFn = formatFn;
     this.recordIdMethod = idAttribute == null ? RecordIdMethod.NONE : recordIdMethod;
   }
 
   public PubsubUnboundedSink(
       PubsubClientFactory pubsubFactory,
       ValueProvider<TopicPath> topic,
-      Coder<T> elementCoder,
       String timestampAttribute,
       String idAttribute,
-      SimpleFunction<T, PubsubIO.PubsubMessage> formatFn,
       int numShards) {
-    this(pubsubFactory, topic, elementCoder, timestampAttribute, idAttribute, numShards,
+    this(pubsubFactory, topic, timestampAttribute, idAttribute, numShards,
          DEFAULT_PUBLISH_BATCH_SIZE, DEFAULT_PUBLISH_BATCH_BYTES, DEFAULT_MAX_LATENCY,
-         formatFn, RecordIdMethod.RANDOM);
+         RecordIdMethod.RANDOM);
   }
 
   /**
@@ -458,37 +426,31 @@ public class PubsubUnboundedSink<T> extends PTransform<PCollection<T>, PDone> {
     return idAttribute;
   }
 
-  /**
-   * Get the format function used for PubSub attributes.
-   */
-  @Nullable
-  public SimpleFunction<T, PubsubIO.PubsubMessage> getFormatFn() {
-    return formatFn;
-  }
-
-  /**
-   * Get the Coder used to encode output elements.
-   */
-  public Coder<T> getElementCoder() {
-    return elementCoder;
-  }
-
   @Override
-  public PDone expand(PCollection<T> input) {
-    input.apply("PubsubUnboundedSink.Window", Window.<T>into(new GlobalWindows())
-        .triggering(
-            Repeatedly.forever(
-                AfterFirst.of(AfterPane.elementCountAtLeast(publishBatchSize),
-                    AfterProcessingTime.pastFirstElementInPane()
-                    .plusDelayOf(maxLatency))))
-            .discardingFiredPanes())
-         .apply("PubsubUnboundedSink.Shard",
-             ParDo.of(new ShardFn<T>(elementCoder, numShards, formatFn, recordIdMethod)))
-         .setCoder(KvCoder.of(VarIntCoder.of(), CODER))
-         .apply(GroupByKey.<Integer, OutgoingMessage>create())
-         .apply("PubsubUnboundedSink.Writer",
-             ParDo.of(new WriterFn(pubsubFactory, topic, timestampAttribute, idAttribute,
-                 publishBatchSize, publishBatchBytes)));
+  public PDone expand(PCollection<PubsubIO.PubsubMessage> input) {
+    input
+        .apply(
+            "PubsubUnboundedSink.Window",
+            Window.<PubsubIO.PubsubMessage>into(new GlobalWindows())
+                .triggering(
+                    Repeatedly.forever(
+                        AfterFirst.of(
+                            AfterPane.elementCountAtLeast(publishBatchSize),
+                            AfterProcessingTime.pastFirstElementInPane().plusDelayOf(maxLatency))))
+                .discardingFiredPanes())
+        .apply("PubsubUnboundedSink.Shard", ParDo.of(new ShardFn(numShards, recordIdMethod)))
+        .setCoder(KvCoder.of(VarIntCoder.of(), CODER))
+        .apply(GroupByKey.<Integer, OutgoingMessage>create())
+        .apply(
+            "PubsubUnboundedSink.Writer",
+            ParDo.of(
+                new WriterFn(
+                    pubsubFactory,
+                    topic,
+                    timestampAttribute,
+                    idAttribute,
+                    publishBatchSize,
+                    publishBatchBytes)));
     return PDone.in(input.getPipeline());
   }
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/429c6131/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubUnboundedSinkTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubUnboundedSinkTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubUnboundedSinkTest.java
index 580ada9..11e7d83 100644
--- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubUnboundedSinkTest.java
+++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/pubsub/PubsubUnboundedSinkTest.java
@@ -23,11 +23,10 @@ import com.google.common.collect.ImmutableMap;
 import com.google.common.hash.Hashing;
 import java.io.IOException;
 import java.io.Serializable;
+import java.nio.charset.StandardCharsets;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.Map;
-
-import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.io.gcp.pubsub.PubsubClient.OutgoingMessage;
 import org.apache.beam.sdk.io.gcp.pubsub.PubsubClient.TopicPath;
 import org.apache.beam.sdk.io.gcp.pubsub.PubsubTestClient.PubsubTestClientFactory;
@@ -39,7 +38,6 @@ import org.apache.beam.sdk.testing.TestPipeline;
 import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.ParDo;
-import org.apache.beam.sdk.transforms.SimpleFunction;
 import org.joda.time.Duration;
 import org.joda.time.Instant;
 import org.junit.Rule;
@@ -62,10 +60,23 @@ public class PubsubUnboundedSinkTest implements Serializable {
   private static final String ID_ATTRIBUTE = "id";
   private static final int NUM_SHARDS = 10;
 
-  private static class Stamp extends DoFn<String, String> {
+  private static class Stamp extends DoFn<String, PubsubIO.PubsubMessage> {
+    private final Map<String, String> attributes;
+
+    private Stamp() {
+      this(ImmutableMap.<String, String>of());
+    }
+
+    private Stamp(Map<String, String> attributes) {
+      this.attributes = attributes;
+    }
+
     @ProcessElement
     public void processElement(ProcessContext c) {
-      c.outputWithTimestamp(c.element(), new Instant(TIMESTAMP));
+      c.outputWithTimestamp(
+          new PubsubIO.PubsubMessage(
+              c.element().getBytes(StandardCharsets.UTF_8), attributes),
+          new Instant(TIMESTAMP));
     }
   }
 
@@ -97,19 +108,14 @@ public class PubsubUnboundedSinkTest implements Serializable {
     try (PubsubTestClientFactory factory =
              PubsubTestClient.createFactoryForPublish(TOPIC, outgoing,
                                                       ImmutableList.<OutgoingMessage>of())) {
-      PubsubUnboundedSink<String> sink =
-          new PubsubUnboundedSink<>(factory, StaticValueProvider.of(TOPIC), StringUtf8Coder.of(),
+      PubsubUnboundedSink sink =
+          new PubsubUnboundedSink(factory, StaticValueProvider.of(TOPIC),
               TIMESTAMP_ATTRIBUTE, ID_ATTRIBUTE, NUM_SHARDS, batchSize, batchBytes,
               Duration.standardSeconds(2),
-              new SimpleFunction<String, PubsubIO.PubsubMessage>() {
-                @Override
-                public PubsubIO.PubsubMessage apply(String input) {
-                  return new PubsubIO.PubsubMessage(input.getBytes(), ATTRIBUTES);
-                }
-              },
               RecordIdMethod.DETERMINISTIC);
       p.apply(Create.of(ImmutableList.of(DATA)))
-       .apply(ParDo.of(new Stamp()))
+       .apply(ParDo.of(new Stamp(ATTRIBUTES)))
+       .setCoder(PubsubMessageWithAttributesCoder.of())
        .apply(sink);
       p.run();
     }
@@ -133,12 +139,13 @@ public class PubsubUnboundedSinkTest implements Serializable {
     try (PubsubTestClientFactory factory =
              PubsubTestClient.createFactoryForPublish(TOPIC, outgoing,
                                                       ImmutableList.<OutgoingMessage>of())) {
-      PubsubUnboundedSink<String> sink =
-          new PubsubUnboundedSink<>(factory, StaticValueProvider.of(TOPIC), StringUtf8Coder.of(),
+      PubsubUnboundedSink sink =
+          new PubsubUnboundedSink(factory, StaticValueProvider.of(TOPIC),
               TIMESTAMP_ATTRIBUTE, ID_ATTRIBUTE, NUM_SHARDS, batchSize, batchBytes,
-              Duration.standardSeconds(2), null, RecordIdMethod.DETERMINISTIC);
+              Duration.standardSeconds(2), RecordIdMethod.DETERMINISTIC);
       p.apply(Create.of(data))
        .apply(ParDo.of(new Stamp()))
+       .setCoder(PubsubMessagePayloadOnlyCoder.of())
        .apply(sink);
       p.run();
     }
@@ -168,13 +175,14 @@ public class PubsubUnboundedSinkTest implements Serializable {
     try (PubsubTestClientFactory factory =
              PubsubTestClient.createFactoryForPublish(TOPIC, outgoing,
                                                       ImmutableList.<OutgoingMessage>of())) {
-      PubsubUnboundedSink<String> sink =
-          new PubsubUnboundedSink<>(factory, StaticValueProvider.of(TOPIC),
-              StringUtf8Coder.of(), TIMESTAMP_ATTRIBUTE, ID_ATTRIBUTE,
+      PubsubUnboundedSink sink =
+          new PubsubUnboundedSink(factory, StaticValueProvider.of(TOPIC),
+              TIMESTAMP_ATTRIBUTE, ID_ATTRIBUTE,
               NUM_SHARDS, batchSize, batchBytes, Duration.standardSeconds(2),
-              null, RecordIdMethod.DETERMINISTIC);
+              RecordIdMethod.DETERMINISTIC);
       p.apply(Create.of(data))
        .apply(ParDo.of(new Stamp()))
+       .setCoder(PubsubMessagePayloadOnlyCoder.of())
        .apply(sink);
       p.run();
     }