You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ec...@apache.org on 2018/12/11 14:22:17 UTC

[beam] 04/04: start source instanciation

This is an automated email from the ASF dual-hosted git repository.

echauchot pushed a commit to branch spark-runner_structured-streaming
in repository https://gitbox.apache.org/repos/asf/beam.git

commit e86247f88aafd9d32abef9eb4b897d393317d264
Author: Etienne Chauchot <ec...@apache.org>
AuthorDate: Mon Dec 10 15:27:49 2018 +0100

    start source instanciation
---
 .../batch/ReadSourceTranslatorBatch.java           | 27 ++++++++++++++++++----
 .../translation/io/DatasetSource.java              | 10 ++++----
 2 files changed, 28 insertions(+), 9 deletions(-)

diff --git a/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ReadSourceTranslatorBatch.java b/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ReadSourceTranslatorBatch.java
index 63f2fdf..a75730a 100644
--- a/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ReadSourceTranslatorBatch.java
+++ b/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ReadSourceTranslatorBatch.java
@@ -22,18 +22,25 @@ import org.apache.beam.runners.core.construction.ReadTranslation;
 import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator;
 import org.apache.beam.runners.spark.structuredstreaming.translation.TranslationContext;
 import org.apache.beam.runners.spark.structuredstreaming.translation.io.DatasetSource;
+import org.apache.beam.sdk.coders.SerializableCoder;
 import org.apache.beam.sdk.io.BoundedSource;
 import org.apache.beam.sdk.runners.AppliedPTransform;
 import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.PBegin;
 import org.apache.beam.sdk.values.PCollection;
+import org.apache.spark.api.java.function.MapFunction;
 import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoders;
 import org.apache.spark.sql.Row;
 import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.streaming.DataStreamReader;
 
 class ReadSourceTranslatorBatch<T>
     implements TransformTranslator<PTransform<PBegin, PCollection<T>>> {
 
+  private String SOURCE_PROVIDER_CLASS = DatasetSource.class.getCanonicalName();
+
   @SuppressWarnings("unchecked")
   @Override
   public void translateTransform(
@@ -41,18 +48,28 @@ class ReadSourceTranslatorBatch<T>
     AppliedPTransform<PBegin, PCollection<T>, PTransform<PBegin, PCollection<T>>> rootTransform =
         (AppliedPTransform<PBegin, PCollection<T>, PTransform<PBegin, PCollection<T>>>)
             context.getCurrentTransform();
-    BoundedSource<T> source;
+
+        String providerClassName = SOURCE_PROVIDER_CLASS.substring(0, SOURCE_PROVIDER_CLASS.indexOf("$"));
+        BoundedSource<T> source;
     try {
       source = ReadTranslation.boundedSourceFromTransform(rootTransform);
     } catch (IOException e) {
       throw new RuntimeException(e);
     }
-    PCollection<T> output = (PCollection<T>) context.getOutput();
-
     SparkSession sparkSession = context.getSparkSession();
-    DatasetSource datasetSource = new DatasetSource(context, source);
-    Dataset<Row> dataset = sparkSession.readStream().format("DatasetSource").load();
+    Dataset<Row> rowDataset = sparkSession.readStream().format(providerClassName).load();
+    //TODO initialize source : how, to get a reference to the DatasetSource instance that spark
+    // instantiates to be able to call DatasetSource.initialize()
+    MapFunction<Row, WindowedValue<T>> func = new MapFunction<Row, WindowedValue<T>>() {
+      @Override public WindowedValue<T> call(Row value) throws Exception {
+        //TODO fix row content extraction: I guess cast is not enough
+        return (WindowedValue<T>) value.get(0);
+      }
+    };
+    //TODO fix encoder
+    Dataset<WindowedValue<T>> dataset = rowDataset.map(func, null);
 
+    PCollection<T> output = (PCollection<T>) context.getOutput();
     context.putDataset(output, dataset);
   }
 }
diff --git a/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/io/DatasetSource.java b/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/io/DatasetSource.java
index f230a70..75cdd5d 100644
--- a/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/io/DatasetSource.java
+++ b/runners/spark-structured-streaming/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/io/DatasetSource.java
@@ -30,6 +30,7 @@ import org.apache.beam.sdk.io.BoundedSource;
 import org.apache.beam.sdk.io.BoundedSource.BoundedReader;
 import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.sources.DataSourceRegister;
 import org.apache.spark.sql.sources.v2.ContinuousReadSupport;
 import org.apache.spark.sql.sources.v2.DataSourceOptions;
 import org.apache.spark.sql.sources.v2.DataSourceV2;
@@ -45,14 +46,15 @@ import org.apache.spark.sql.types.StructType;
  * is tagged experimental in spark, this class does no implement {@link ContinuousReadSupport}. This
  * class is just a mix-in.
  */
-public class DatasetSource<T> implements DataSourceV2, MicroBatchReadSupport {
+public class DatasetSource<T> implements DataSourceV2, MicroBatchReadSupport{
 
-  private final int numPartitions;
-  private final Long bundleSize;
+  private int numPartitions;
+  private Long bundleSize;
   private TranslationContext context;
   private BoundedSource<T> source;
 
-  public DatasetSource(TranslationContext context, BoundedSource<T> source) {
+
+  public void initialize(TranslationContext context, BoundedSource<T> source){
     this.context = context;
     this.source = source;
     this.numPartitions = context.getSparkSession().sparkContext().defaultParallelism();