You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by da...@apache.org on 2016/06/07 16:59:51 UTC

[06/11] incubator-beam git commit: squash PR 271

squash PR 271


Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/cf21fdb4
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/cf21fdb4
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/cf21fdb4

Branch: refs/heads/release-0.1.0-incubating
Commit: cf21fdb40fd600b6f9f5b8ee4b124e0d21c09260
Parents: eb529d4
Author: Raghu Angadi <ra...@google.com>
Authored: Fri Jun 3 17:46:38 2016 -0700
Committer: Davor Bonaci <da...@google.com>
Committed: Mon Jun 6 18:08:22 2016 -0700

----------------------------------------------------------------------
 .../org/apache/beam/sdk/io/kafka/KafkaIO.java   | 441 ++++++++++++++++++-
 .../apache/beam/sdk/io/kafka/KafkaIOTest.java   | 300 ++++++++++++-
 2 files changed, 705 insertions(+), 36 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/cf21fdb4/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java
----------------------------------------------------------------------
diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java
index 7fff641..9645d7c 100644
--- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java
+++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java
@@ -23,7 +23,10 @@ import static com.google.common.base.Preconditions.checkState;
 
 import org.apache.beam.sdk.coders.ByteArrayCoder;
 import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.CoderException;
+import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.coders.SerializableCoder;
+import org.apache.beam.sdk.coders.VoidCoder;
 import org.apache.beam.sdk.io.Read.Unbounded;
 import org.apache.beam.sdk.io.UnboundedSource;
 import org.apache.beam.sdk.io.UnboundedSource.CheckpointMark;
@@ -34,10 +37,12 @@ import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.SerializableFunction;
+import org.apache.beam.sdk.util.CoderUtils;
 import org.apache.beam.sdk.util.ExposedByteArrayInputStream;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PBegin;
 import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PDone;
 import org.apache.beam.sdk.values.PInput;
 
 import com.google.common.annotations.VisibleForTesting;
@@ -56,10 +61,17 @@ import org.apache.kafka.clients.consumer.ConsumerConfig;
 import org.apache.kafka.clients.consumer.ConsumerRecord;
 import org.apache.kafka.clients.consumer.ConsumerRecords;
 import org.apache.kafka.clients.consumer.KafkaConsumer;
+import org.apache.kafka.clients.producer.Callback;
+import org.apache.kafka.clients.producer.KafkaProducer;
+import org.apache.kafka.clients.producer.Producer;
+import org.apache.kafka.clients.producer.ProducerConfig;
+import org.apache.kafka.clients.producer.ProducerRecord;
+import org.apache.kafka.clients.producer.RecordMetadata;
 import org.apache.kafka.common.PartitionInfo;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.errors.WakeupException;
 import org.apache.kafka.common.serialization.ByteArrayDeserializer;
+import org.apache.kafka.common.serialization.Serializer;
 import org.joda.time.Duration;
 import org.joda.time.Instant;
 import org.slf4j.Logger;
@@ -86,8 +98,8 @@ import java.util.concurrent.atomic.AtomicBoolean;
 import javax.annotation.Nullable;
 
 /**
- * An unbounded source for <a href="http://kafka.apache.org/">Kafka</a> topics. Kafka version 0.9
- * and above are supported.
+ * An unbounded source and a sink for <a href="http://kafka.apache.org/">Kafka</a> topics.
+ * Kafka version 0.9 and above are supported.
  *
  * <h3>Reading from Kafka topics</h3>
  *
@@ -146,25 +158,54 @@ import javax.annotation.Nullable;
  * beginning by setting appropriate appropriate properties in {@link ConsumerConfig}, through
  * {@link Read#updateConsumerProperties(Map)}.
  *
+ * <h3>Writing to Kafka</h3>
+ *
+ * KafkaIO sink supports writing key-value pairs to a Kafka topic. Users can also write
+ * just the values. To configure a Kafka sink, you must specify at the minimum Kafka
+ * <tt>bootstrapServers</tt> and the topic to write to. The following example illustrates various
+ * options for configuring the sink:
+ *
+ * <pre>{@code
+ *
+ *  pipeline
+ *    .apply(...) // returns PCollection<KV<Long, String>>
+ *    .apply(KafkaIO.write()
+ *       .withBootstrapServers("broker_1:9092,broker_2:9092")
+ *       .withTopic("results")
+ *
+ *       // set Coder for Key and Value
+ *       .withKeyCoder(BigEndianLongCoder.of())
+ *       .withValueCoder(StringUtf8Coder.of())
+
+ *       // you can further customize KafkaProducer used to write the records by adding more
+ *       // settings for ProducerConfig. e.g, to enable compression :
+ *       .updateProducerProperties(ImmutableMap.of("compression.type", "gzip"))
+ *    );
+ * }</pre>
+ *
+ * Often you might want to write just values without any keys to Kafka. Use {@code values()} to
+ * write records with default empty(null) key:
+ *
+ * <pre>{@code
+ *  PCollection<String> strings = ...;
+ *  strings.apply(KafkaIO.write()
+ *      .withBootstrapServers("broker_1:9092,broker_2:9092")
+ *      .withTopic("results")
+ *      .withValueCoder(StringUtf8Coder.of()) // just need coder for value
+ *      .values() // writes values to Kafka with default key
+ *    );
+ * }</pre>
+ *
  * <h3>Advanced Kafka Configuration</h3>
- * KafakIO allows setting most of the properties in {@link ConsumerConfig}. E.g. if you would like
- * to enable offset <em>auto commit</em> (for external monitoring or other purposes), you can set
+ * KafakIO allows setting most of the properties in {@link ConsumerConfig} for source or in
+ * {@link ProducerConfig} for sink. E.g. if you would like to enable offset
+ * <em>auto commit</em> (for external monitoring or other purposes), you can set
  * <tt>"group.id"</tt>, <tt>"enable.auto.commit"</tt>, etc.
  */
 public class KafkaIO {
 
-  private static final Logger LOG = LoggerFactory.getLogger(KafkaIO.class);
-
-  private static class NowTimestampFn<T> implements SerializableFunction<T, Instant> {
-    @Override
-    public Instant apply(T input) {
-      return Instant.now();
-    }
-  }
-
-
   /**
-   * Creates and uninitialized {@link Read} {@link PTransform}. Before use, basic Kafka
+   * Creates an uninitialized {@link Read} {@link PTransform}. Before use, basic Kafka
    * configuration should set with {@link Read#withBootstrapServers(String)} and
    * {@link Read#withTopics(List)}. Other optional settings include key and value coders,
    * custom timestamp and watermark functions.
@@ -182,6 +223,21 @@ public class KafkaIO {
   }
 
   /**
+   * Creates an uninitialized {@link Write} {@link PTransform}. Before use, Kafka configuration
+   * should be set with {@link Write#withBootstrapServers(String)} and {@link Write#withTopic}
+   * along with {@link Coder}s for (optional) key and values.
+   */
+  public static Write<byte[], byte[]> write() {
+    return new Write<byte[], byte[]>(
+        null,
+        ByteArrayCoder.of(),
+        ByteArrayCoder.of(),
+        TypedWrite.DEFAULT_PRODUCER_PROPERTIES);
+  }
+
+  ///////////////////////// Read Support \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\
+
+  /**
    * A {@link PTransform} to read from Kafka topics. See {@link KafkaIO} for more
    * information on usage and configuration.
    */
@@ -253,13 +309,9 @@ public class KafkaIO {
      * Update consumer configuration with new properties.
      */
     public Read<K, V> updateConsumerProperties(Map<String, Object> configUpdates) {
-      for (String key : configUpdates.keySet()) {
-        checkArgument(!IGNORED_CONSUMER_PROPERTIES.containsKey(key),
-            "No need to configure '%s'. %s", key, IGNORED_CONSUMER_PROPERTIES.get(key));
-      }
 
-      Map<String, Object> config = new HashMap<>(consumerConfig);
-      config.putAll(configUpdates);
+      Map<String, Object> config = updateKafkaProperties(consumerConfig,
+          IGNORED_CONSUMER_PROPERTIES, configUpdates);
 
       return new Read<K, V>(topics, topicPartitions, keyCoder, valueCoder,
           consumerFactoryFn, config, maxNumRecords, maxReadTime);
@@ -305,8 +357,8 @@ public class KafkaIO {
      * A set of properties that are not required or don't make sense for our consumer.
      */
     private static final Map<String, String> IGNORED_CONSUMER_PROPERTIES = ImmutableMap.of(
-        ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, "Set keyDecoderFn instead",
-        ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, "Set valueDecoderFn instead"
+        ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, "Set keyCoder instead",
+        ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, "Set valueCoder instead"
         // "group.id", "enable.auto.commit", "auto.commit.interval.ms" :
         //     lets allow these, applications can have better resume point for restarts.
         );
@@ -508,6 +560,37 @@ public class KafkaIO {
     }
   }
 
+  ////////////////////////////////////////////////////////////////////////////////////////////////
+
+  private static final Logger LOG = LoggerFactory.getLogger(KafkaIO.class);
+
+  /**
+   * Returns a new config map which is merge of current config and updates.
+   * Verifies the updates do not includes ignored properties.
+   */
+  private static Map<String, Object> updateKafkaProperties(
+      Map<String, Object> currentConfig,
+      Map<String, String> ignoredProperties,
+      Map<String, Object> updates) {
+
+    for (String key : updates.keySet()) {
+      checkArgument(!ignoredProperties.containsKey(key),
+          "No need to configure '%s'. %s", key, ignoredProperties.get(key));
+    }
+
+    Map<String, Object> config = new HashMap<>(currentConfig);
+    config.putAll(updates);
+
+    return config;
+  }
+
+  private static class NowTimestampFn<T> implements SerializableFunction<T, Instant> {
+    @Override
+    public Instant apply(T input) {
+      return Instant.now();
+    }
+  }
+
   /** Static class, prevent instantiation. */
   private KafkaIO() {}
 
@@ -719,7 +802,6 @@ public class KafkaIO {
       private double avgRecordSize = 0;
       private static final int movingAvgWindow = 1000; // very roughly avg of last 1000 elements
 
-
       PartitionState(TopicPartition partition, long offset) {
         this.topicPartition = partition;
         this.consumedOffset = offset;
@@ -1073,4 +1155,315 @@ public class KafkaIO {
       Closeables.close(consumer, true);
     }
   }
+
+  //////////////////////// Sink Support \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\
+
+  /**
+   * A {@link PTransform} to write to a Kafka topic. See {@link KafkaIO} for more
+   * information on usage and configuration.
+   */
+  public static class Write<K, V> extends TypedWrite<K, V> {
+
+    /**
+     * Returns a new {@link Write} transform with Kafka producer pointing to
+     * {@code bootstrapServers}.
+     */
+    public Write<K, V> withBootstrapServers(String bootstrapServers) {
+      return updateProducerProperties(
+          ImmutableMap.<String, Object>of(
+              ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, bootstrapServers));
+    }
+
+    /**
+     * Returns a new {@link Write} transform that write to given topic.
+     */
+    public Write<K, V> withTopic(String topic) {
+      return new Write<K, V>(topic, keyCoder, valueCoder, producerConfig);
+    }
+
+    /**
+     * Returns a new {@link Write} with {@link Coder} for serializing key (if any) to bytes.
+     * A key is optional while writing to Kafka. Note when a key is set, its hash is used to
+     * determine partition in Kafka (see {@link ProducerRecord} for more details).
+     */
+    public <KeyT> Write<KeyT, V> withKeyCoder(Coder<KeyT> keyCoder) {
+      return new Write<KeyT, V>(topic, keyCoder, valueCoder, producerConfig);
+    }
+
+    /**
+     * Returns a new {@link Write} with {@link Coder} for serializing value to bytes.
+     */
+    public <ValueT> Write<K, ValueT> withValueCoder(Coder<ValueT> valueCoder) {
+      return new Write<K, ValueT>(topic, keyCoder, valueCoder, producerConfig);
+    }
+
+    public Write<K, V> updateProducerProperties(Map<String, Object> configUpdates) {
+      Map<String, Object> config = updateKafkaProperties(producerConfig,
+          TypedWrite.IGNORED_PRODUCER_PROPERTIES, configUpdates);
+      return new Write<K, V>(topic, keyCoder, valueCoder, config);
+    }
+
+    private Write(
+        String topic,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        Map<String, Object> producerConfig) {
+      super(topic, keyCoder, valueCoder, producerConfig,
+          Optional.<SerializableFunction<Map<String, Object>, Producer<K, V>>>absent());
+    }
+  }
+
+  /**
+   * A {@link PTransform} to write to a Kafka topic. See {@link KafkaIO} for more
+   * information on usage and configuration.
+   */
+  public static class TypedWrite<K, V> extends PTransform<PCollection<KV<K, V>>, PDone> {
+
+    /**
+     * Returns a new {@link Write} with a custom function to create Kafka producer. Primarily used
+     * for tests. Default is {@link KafkaProducer}
+     */
+    public TypedWrite<K, V> withProducerFactoryFn(
+        SerializableFunction<Map<String, Object>, Producer<K, V>> producerFactoryFn) {
+      return new TypedWrite<K, V>(topic, keyCoder, valueCoder, producerConfig,
+          Optional.of(producerFactoryFn));
+    }
+
+    /**
+     * Returns a new transform that writes just the values to Kafka. This is useful for writing
+     * collections of values rather thank {@link KV}s.
+     */
+    @SuppressWarnings("unchecked")
+    public PTransform<PCollection<V>, PDone> values() {
+      return new KafkaValueWrite<V>((TypedWrite<Void, V>) this);
+      // Any way to avoid casting here to TypedWrite<Void, V>? We can't create
+      // new TypedWrite without casting producerFactoryFn.
+    }
+
+    @Override
+    public PDone apply(PCollection<KV<K, V>> input) {
+      input.apply(ParDo.of(new KafkaWriter<K, V>(
+          topic, keyCoder, valueCoder, producerConfig, producerFactoryFnOpt)));
+      return PDone.in(input.getPipeline());
+    }
+
+    @Override
+    public void validate(PCollection<KV<K, V>> input) {
+      checkNotNull(producerConfig.get(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG),
+          "Kafka bootstrap servers should be set");
+      checkNotNull(topic, "Kafka topic should be set");
+    }
+
+    //////////////////////////////////////////////////////////////////////////////////////////
+
+    protected final String topic;
+    protected final Coder<K> keyCoder;
+    protected final Coder<V> valueCoder;
+    protected final Optional<SerializableFunction<Map<String, Object>, Producer<K, V>>>
+        producerFactoryFnOpt;
+    protected final Map<String, Object> producerConfig;
+
+    protected TypedWrite(
+        String topic,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        Map<String, Object> producerConfig,
+        Optional<SerializableFunction<Map<String, Object>, Producer<K, V>>> producerFactoryFnOpt) {
+
+      this.topic = topic;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.producerConfig = producerConfig;
+      this.producerFactoryFnOpt = producerFactoryFnOpt;
+    }
+
+    // set config defaults
+    private static final Map<String, Object> DEFAULT_PRODUCER_PROPERTIES =
+        ImmutableMap.<String, Object>of(
+            ProducerConfig.RETRIES_CONFIG, 3,
+            ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, CoderBasedKafkaSerializer.class,
+            ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, CoderBasedKafkaSerializer.class);
+
+    /**
+     * A set of properties that are not required or don't make sense for our consumer.
+     */
+    private static final Map<String, String> IGNORED_PRODUCER_PROPERTIES = ImmutableMap.of(
+        ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, "Set keyCoder instead",
+        ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, "Set valueCoder instead",
+        configForKeySerializer(), "Reserved for internal serializer",
+        configForValueSerializer(), "Reserved for internal serializer"
+     );
+  }
+
+  /**
+   * Same as Write<K, V> without a Key. Null is used for key as it is the convention is Kafka
+   * when there is no key specified. Majority of Kafka writers don't specify a key.
+   */
+  private static class KafkaValueWrite<V> extends PTransform<PCollection<V>, PDone> {
+
+    private final TypedWrite<Void, V> kvWriteTransform;
+
+    private KafkaValueWrite(TypedWrite<Void, V> kvWriteTransform) {
+      this.kvWriteTransform = kvWriteTransform;
+    }
+
+    @Override
+    public PDone apply(PCollection<V> input) {
+      return input
+        .apply("Kafka values with default key",
+          ParDo.of(new DoFn<V, KV<Void, V>>() {
+            @Override
+            public void processElement(ProcessContext ctx) throws Exception {
+              ctx.output(KV.<Void, V>of(null, ctx.element()));
+            }
+          }))
+        .setCoder(KvCoder.of(VoidCoder.of(), kvWriteTransform.valueCoder))
+        .apply(kvWriteTransform);
+    }
+  }
+
+  private static class KafkaWriter<K, V> extends DoFn<KV<K, V>, Void> {
+
+    @Override
+    public void startBundle(Context c) throws Exception {
+      // Producer initialization is fairly costly. Move this to future initialization api to avoid
+      // creating a producer for each bundle.
+      if (producer == null) {
+        if (producerFactoryFnOpt.isPresent()) {
+           producer = producerFactoryFnOpt.get().apply(producerConfig);
+        } else {
+          producer = new KafkaProducer<K, V>(producerConfig);
+        }
+      }
+    }
+
+    @Override
+    public void processElement(ProcessContext ctx) throws Exception {
+      checkForFailures();
+
+      KV<K, V> kv = ctx.element();
+      producer.send(
+          new ProducerRecord<K, V>(topic, kv.getKey(), kv.getValue()),
+          new SendCallback());
+    }
+
+    @Override
+    public void finishBundle(Context c) throws Exception {
+      producer.flush();
+      producer.close();
+      producer = null;
+      checkForFailures();
+    }
+
+    ///////////////////////////////////////////////////////////////////////////////////
+
+    private final String topic;
+    private final Map<String, Object> producerConfig;
+    private final Optional<SerializableFunction<Map<String, Object>, Producer<K, V>>>
+                  producerFactoryFnOpt;
+
+    private transient Producer<K, V> producer = null;
+    //private transient Callback sendCallback = new SendCallback();
+    // first exception and number of failures since last invocation of checkForFailures():
+    private transient Exception sendException = null;
+    private transient long numSendFailures = 0;
+
+    KafkaWriter(String topic,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        Map<String, Object> producerConfig,
+        Optional<SerializableFunction<Map<String, Object>, Producer<K, V>>> producerFactoryFnOpt) {
+
+      this.topic = topic;
+      this.producerFactoryFnOpt = producerFactoryFnOpt;
+
+      // Set custom kafka serializers. We can not serialize user objects then pass the bytes to
+      // producer. The key and value objects are used in kafka Partitioner interface.
+      // This does not matter for default partitioner in Kafka as it uses just the serialized
+      // key bytes to pick a partition. But are making sure user's custom partitioner would work
+      // as expected.
+
+      this.producerConfig = new HashMap<>(producerConfig);
+      this.producerConfig.put(configForKeySerializer(), keyCoder);
+      this.producerConfig.put(configForValueSerializer(), valueCoder);
+    }
+
+    private synchronized void checkForFailures() throws IOException {
+      if (numSendFailures == 0) {
+        return;
+      }
+
+      String msg = String.format(
+          "KafkaWriter : failed to send %d records (since last report)", numSendFailures);
+
+      Exception e = sendException;
+      sendException = null;
+      numSendFailures = 0;
+
+      LOG.warn(msg);
+      throw new IOException(msg, e);
+    }
+
+    private class SendCallback implements Callback {
+      @Override
+      public void onCompletion(RecordMetadata metadata, Exception exception) {
+        if (exception == null) {
+          return;
+        }
+
+        synchronized (KafkaWriter.this) {
+          if (sendException == null) {
+            sendException = exception;
+          }
+          numSendFailures++;
+        }
+        // don't log exception stacktrace here, exception will be propagated up.
+        LOG.warn("KafkaWriter send failed : '{}'", exception.getMessage());
+      }
+    }
+  }
+
+  /**
+   * Implements Kafka's {@link Serializer} with a {@link Coder}. The coder is stored as serialized
+   * value in producer configuration map.
+   */
+  public static class CoderBasedKafkaSerializer<T> implements Serializer<T> {
+
+    @SuppressWarnings("unchecked")
+    @Override
+    public void configure(Map<String, ?> configs, boolean isKey) {
+      String configKey = isKey ? configForKeySerializer() : configForValueSerializer();
+      coder = (Coder<T>) configs.get(configKey);
+      checkNotNull(coder, "could not instantiate coder for Kafka serialization");
+    }
+
+    @Override
+    public byte[] serialize(String topic, @Nullable T data) {
+      if (data == null) {
+        return null; // common for keys to be null
+      }
+
+      try {
+        return CoderUtils.encodeToByteArray(coder, data);
+      } catch (CoderException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public void close() {
+    }
+
+    private Coder<T> coder = null;
+    private static final String CONFIG_FORMAT = "beam.coder.based.kafka.%s.serializer";
+  }
+
+
+  private static String configForKeySerializer() {
+    return String.format(CoderBasedKafkaSerializer.CONFIG_FORMAT, "key");
+  }
+
+  private static String configForValueSerializer() {
+    return String.format(CoderBasedKafkaSerializer.CONFIG_FORMAT, "value");
+  }
 }

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/cf21fdb4/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java
index 957271e..7d4337d 100644
--- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java
+++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaIOTest.java
@@ -18,8 +18,12 @@
 package org.apache.beam.sdk.io.kafka;
 
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
 
 import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.Pipeline.PipelineExecutionException;
+import org.apache.beam.sdk.coders.BigEndianIntegerCoder;
 import org.apache.beam.sdk.coders.BigEndianLongCoder;
 import org.apache.beam.sdk.io.Read;
 import org.apache.beam.sdk.io.UnboundedSource;
@@ -49,20 +53,31 @@ import org.apache.kafka.clients.consumer.Consumer;
 import org.apache.kafka.clients.consumer.ConsumerRecord;
 import org.apache.kafka.clients.consumer.MockConsumer;
 import org.apache.kafka.clients.consumer.OffsetResetStrategy;
+import org.apache.kafka.clients.producer.MockProducer;
+import org.apache.kafka.clients.producer.Producer;
+import org.apache.kafka.clients.producer.ProducerRecord;
 import org.apache.kafka.common.PartitionInfo;
 import org.apache.kafka.common.TopicPartition;
 import org.joda.time.Instant;
+import org.junit.Rule;
 import org.junit.Test;
 import org.junit.experimental.categories.Category;
+import org.junit.rules.ExpectedException;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
 
 import java.io.IOException;
 import java.nio.ByteBuffer;
 import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
 
 import javax.annotation.Nullable;
 
@@ -79,6 +94,9 @@ public class KafkaIOTest {
    *   - test KafkaRecordCoder
    */
 
+  @Rule
+  public ExpectedException thrown = ExpectedException.none();
+
   // Update mock consumer with records distributed among the given topics, each with given number
   // of partitions. Records are assigned in round-robin order among the partitions.
   private static MockConsumer<byte[], byte[]> mkMockConsumer(
@@ -113,8 +131,8 @@ public class KafkaIOTest {
               tp.topic(),
               tp.partition(),
               offsets[pIdx]++,
-              null, // key
-              ByteBuffer.wrap(new byte[8]).putLong(i).array())); // value is 8 byte record id.
+              ByteBuffer.wrap(new byte[8]).putInt(i).array(),    // key is 4 byte record id
+              ByteBuffer.wrap(new byte[8]).putLong(i).array())); // value is 8 byte record id
     }
 
     MockConsumer<byte[], byte[]> consumer =
@@ -161,16 +179,17 @@ public class KafkaIOTest {
    * Creates a consumer with two topics, with 5 partitions each.
    * numElements are (round-robin) assigned all the 10 partitions.
    */
-  private static KafkaIO.TypedRead<byte[], Long> mkKafkaReadTransform(
+  private static KafkaIO.TypedRead<Integer, Long> mkKafkaReadTransform(
       int numElements,
-      @Nullable SerializableFunction<KV<byte[], Long>, Instant> timestampFn) {
+      @Nullable SerializableFunction<KV<Integer, Long>, Instant> timestampFn) {
 
     List<String> topics = ImmutableList.of("topic_a", "topic_b");
 
-    KafkaIO.Read<byte[], Long> reader = KafkaIO.read()
+    KafkaIO.Read<Integer, Long> reader = KafkaIO.read()
         .withBootstrapServers("none")
         .withTopics(topics)
         .withConsumerFactoryFn(new ConsumerFactoryFn(topics, 10, numElements)) // 20 partitions
+        .withKeyCoder(BigEndianIntegerCoder.of())
         .withValueCoder(BigEndianLongCoder.of())
         .withMaxNumRecords(numElements);
 
@@ -305,9 +324,9 @@ public class KafkaIOTest {
     int numElements = 1000;
     int numSplits = 10;
 
-    UnboundedSource<KafkaRecord<byte[], Long>, ?> initial =
+    UnboundedSource<KafkaRecord<Integer, Long>, ?> initial =
         mkKafkaReadTransform(numElements, null).makeSource();
-    List<? extends UnboundedSource<KafkaRecord<byte[], Long>, ?>> splits =
+    List<? extends UnboundedSource<KafkaRecord<Integer, Long>, ?>> splits =
         initial.generateInitialSplits(numSplits, p.getOptions());
     assertEquals("Expected exact splitting", numSplits, splits.size());
 
@@ -317,7 +336,7 @@ public class KafkaIOTest {
     for (int i = 0; i < splits.size(); ++i) {
       pcollections = pcollections.and(
           p.apply("split" + i, Read.from(splits.get(i)).withMaxNumRecords(elementsPerSplit))
-           .apply("Remove Metadata " + i, ParDo.of(new RemoveKafkaMetadata<byte[], Long>()))
+           .apply("Remove Metadata " + i, ParDo.of(new RemoveKafkaMetadata<Integer, Long>()))
            .apply("collection " + i, Values.<Long>create()));
     }
     PCollection<Long> input = pcollections.apply(Flatten.<Long>pCollections());
@@ -330,9 +349,9 @@ public class KafkaIOTest {
    * A timestamp function that uses the given value as the timestamp.
    */
   private static class ValueAsTimestampFn
-                       implements SerializableFunction<KV<byte[], Long>, Instant> {
+                       implements SerializableFunction<KV<Integer, Long>, Instant> {
     @Override
-    public Instant apply(KV<byte[], Long> input) {
+    public Instant apply(KV<Integer, Long> input) {
       return new Instant(input.getValue());
     }
   }
@@ -352,13 +371,13 @@ public class KafkaIOTest {
     int numElements = 85; // 85 to make sure some partitions have more records than other.
 
     // create a single split:
-    UnboundedSource<KafkaRecord<byte[], Long>, KafkaCheckpointMark> source =
+    UnboundedSource<KafkaRecord<Integer, Long>, KafkaCheckpointMark> source =
         mkKafkaReadTransform(numElements, new ValueAsTimestampFn())
           .makeSource()
           .generateInitialSplits(1, PipelineOptionsFactory.fromArgs(new String[0]).create())
           .get(0);
 
-    UnboundedReader<KafkaRecord<byte[], Long>> reader = source.createReader(null, null);
+    UnboundedReader<KafkaRecord<Integer, Long>> reader = source.createReader(null, null);
     final int numToSkip = 3;
 
     // advance numToSkip elements
@@ -394,4 +413,261 @@ public class KafkaIOTest {
       }
     }
   }
+
+  @Test
+  public void testSink() throws Exception {
+    // Simply read from kafka source and write to kafka sink. Then verify the records
+    // are correctly published to mock kafka producer.
+
+    int numElements = 1000;
+
+    synchronized (MOCK_PRODUCER_LOCK) {
+
+      MOCK_PRODUCER.clear();
+
+      ProducerSendCompletionThread completionThread = new ProducerSendCompletionThread().start();
+
+      Pipeline pipeline = TestPipeline.create();
+      String topic = "test";
+
+      pipeline
+        .apply(mkKafkaReadTransform(numElements, new ValueAsTimestampFn())
+            .withoutMetadata())
+        .apply(KafkaIO.write()
+            .withBootstrapServers("none")
+            .withTopic(topic)
+            .withKeyCoder(BigEndianIntegerCoder.of())
+            .withValueCoder(BigEndianLongCoder.of())
+            .withProducerFactoryFn(new ProducerFactoryFn()));
+
+      pipeline.run();
+
+      completionThread.shutdown();
+
+      verifyProducerRecords(topic, numElements, false);
+     }
+  }
+
+  @Test
+  public void testValuesSink() throws Exception {
+    // similar to testSink(), but use values()' interface.
+
+    int numElements = 1000;
+
+    synchronized (MOCK_PRODUCER_LOCK) {
+
+      MOCK_PRODUCER.clear();
+
+      ProducerSendCompletionThread completionThread = new ProducerSendCompletionThread().start();
+
+      Pipeline pipeline = TestPipeline.create();
+      String topic = "test";
+
+      pipeline
+        .apply(mkKafkaReadTransform(numElements, new ValueAsTimestampFn())
+            .withoutMetadata())
+        .apply(Values.<Long>create()) // there are no keys
+        .apply(KafkaIO.write()
+            .withBootstrapServers("none")
+            .withTopic(topic)
+            .withKeyCoder(BigEndianIntegerCoder.of())
+            .withValueCoder(BigEndianLongCoder.of())
+            .withProducerFactoryFn(new ProducerFactoryFn())
+            .values());
+
+      pipeline.run();
+
+      completionThread.shutdown();
+
+      verifyProducerRecords(topic, numElements, true);
+    }
+  }
+
+  @Test
+  public void testSinkWithSendErrors() throws Throwable {
+    // similar to testSink(), except that up to 10 of the send calls to producer will fail
+    // asynchronously.
+
+    // TODO: Ideally we want the pipeline to run to completion by retrying bundles that fail.
+    // We limit the number of errors injected to 10 below. This would reflect a real streaming
+    // pipeline. But I am sure how to achieve that. For now expect an exception:
+
+    thrown.expect(InjectedErrorException.class);
+    thrown.expectMessage("Injected Error #1");
+
+    int numElements = 1000;
+
+    synchronized (MOCK_PRODUCER_LOCK) {
+
+      MOCK_PRODUCER.clear();
+
+      Pipeline pipeline = TestPipeline.create();
+      String topic = "test";
+
+      ProducerSendCompletionThread completionThreadWithErrors =
+          new ProducerSendCompletionThread(10, 100).start();
+
+      pipeline
+        .apply(mkKafkaReadTransform(numElements, new ValueAsTimestampFn())
+            .withoutMetadata())
+        .apply(KafkaIO.write()
+            .withBootstrapServers("none")
+            .withTopic(topic)
+            .withKeyCoder(BigEndianIntegerCoder.of())
+            .withValueCoder(BigEndianLongCoder.of())
+            .withProducerFactoryFn(new ProducerFactoryFn()));
+
+      try {
+        pipeline.run();
+      } catch (PipelineExecutionException e) {
+        // throwing inner exception helps assert that first exception is thrown from the Sink
+        throw e.getCause().getCause();
+      } finally {
+        completionThreadWithErrors.shutdown();
+      }
+    }
+  }
+
+  private static void verifyProducerRecords(String topic, int numElements, boolean keyIsAbsent) {
+
+    // verify that appropriate messages are written to kafka
+    List<ProducerRecord<Integer, Long>> sent = MOCK_PRODUCER.history();
+
+    // sort by values
+    Collections.sort(sent, new Comparator<ProducerRecord<Integer, Long>>() {
+      @Override
+      public int compare(ProducerRecord<Integer, Long> o1, ProducerRecord<Integer, Long> o2) {
+        return Long.compare(o1.value(), o2.value());
+      }
+    });
+
+    for (int i = 0; i < numElements; i++) {
+      ProducerRecord<Integer, Long> record = sent.get(i);
+      assertEquals(topic, record.topic());
+      if (keyIsAbsent) {
+        assertNull(record.key());
+      } else {
+        assertEquals(i, record.key().intValue());
+      }
+      assertEquals(i, record.value().longValue());
+    }
+  }
+
+  /**
+   * Singleton MockProudcer. Using a singleton here since we need access to the object to fetch
+   * the actual records published to the producer. This prohibits running the tests using
+   * the producer in parallel, but there are only one or two tests.
+   */
+  private static final MockProducer<Integer, Long> MOCK_PRODUCER =
+    new MockProducer<Integer, Long>(
+      false, // disable synchronous completion of send. see ProducerSendCompletionThread below.
+      new KafkaIO.CoderBasedKafkaSerializer<Integer>(),
+      new KafkaIO.CoderBasedKafkaSerializer<Long>()) {
+
+      // override flush() so that it does not complete all the waiting sends, giving a chance to
+      // ProducerCompletionThread to inject errors.
+
+      @Override
+      public void flush() {
+        while (completeNext()) {
+          // there are some uncompleted records. let the completion thread handle them.
+          try {
+            Thread.sleep(10);
+          } catch (InterruptedException e) {
+          }
+        }
+      }
+    };
+
+  // use a separate object serialize tests using MOCK_PRODUCER so that we don't interfere
+  // with Kafka MockProducer locking itself.
+  private static final Object MOCK_PRODUCER_LOCK = new Object();
+
+  private static class ProducerFactoryFn
+    implements SerializableFunction<Map<String, Object>, Producer<Integer, Long>> {
+
+    @Override
+    public Producer<Integer, Long> apply(Map<String, Object> config) {
+      return MOCK_PRODUCER;
+    }
+  }
+
+  private static class InjectedErrorException extends RuntimeException {
+    public InjectedErrorException(String message) {
+      super(message);
+    }
+  }
+
+  /**
+   * We start MockProducer with auto-completion disabled. That implies a record is not marked sent
+   * until #completeNext() is called on it. This class starts a thread to asynchronously 'complete'
+   * the the sends. During completion, we can also make those requests fail. This error injection
+   * is used in one of the tests.
+   */
+  private static class ProducerSendCompletionThread {
+
+    private final int maxErrors;
+    private final int errorFrequency;
+    private final AtomicBoolean done = new AtomicBoolean(false);
+    private final ExecutorService injectorThread;
+    private int numCompletions = 0;
+
+    ProducerSendCompletionThread() {
+      // complete everything successfully
+      this(0, 0);
+    }
+
+    ProducerSendCompletionThread(final int maxErrors, final int errorFrequency) {
+      this.maxErrors = maxErrors;
+      this.errorFrequency = errorFrequency;
+      injectorThread = Executors.newSingleThreadExecutor();
+    }
+
+    ProducerSendCompletionThread start() {
+      injectorThread.submit(new Runnable() {
+        @Override
+        public void run() {
+          int errorsInjected = 0;
+
+          while (!done.get()) {
+            boolean successful;
+
+            if (errorsInjected < maxErrors && ((numCompletions + 1) % errorFrequency) == 0) {
+              successful = MOCK_PRODUCER.errorNext(
+                  new InjectedErrorException("Injected Error #" + (errorsInjected + 1)));
+
+              if (successful) {
+                errorsInjected++;
+              }
+            } else {
+              successful = MOCK_PRODUCER.completeNext();
+            }
+
+            if (successful) {
+              numCompletions++;
+            } else {
+              // wait a bit since there are no unsent records
+              try {
+                Thread.sleep(1);
+              } catch (InterruptedException e) {
+              }
+            }
+          }
+        }
+      });
+
+      return this;
+    }
+
+    void shutdown() {
+      done.set(true);
+      injectorThread.shutdown();
+      try {
+        assertTrue(injectorThread.awaitTermination(10, TimeUnit.SECONDS));
+      } catch (InterruptedException e) {
+        Thread.currentThread().interrupt();
+        throw new RuntimeException(e);
+      }
+    }
+  }
 }