You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ec...@apache.org on 2022/09/22 10:30:15 UTC
[beam] branch master updated: Improved pipeline translation in SparkStructuredStreamingRunner (#22446)
This is an automated email from the ASF dual-hosted git repository.
echauchot pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 762edd7f3a6 Improved pipeline translation in SparkStructuredStreamingRunner (#22446)
762edd7f3a6 is described below
commit 762edd7f3a64f076dbee156fa48b8a7e5e6a512f
Author: Moritz Mack <mm...@talend.com>
AuthorDate: Thu Sep 22 12:30:03 2022 +0200
Improved pipeline translation in SparkStructuredStreamingRunner (#22446)
* Closes #22445: Improved pipeline translation in SparkStructuredStreamingRunner (also closes #22382)
---
.../translation/helpers/EncoderFactory.java | 12 +-
.../translation/utils/ScalaInterop.java} | 27 +-
.../spark/structuredstreaming/Constants.java | 25 -
.../SparkStructuredStreamingRunner.java | 51 +-
.../io/BoundedDatasetFactory.java | 324 +++++++++++
.../streaming => io}/package-info.java | 4 +-
.../metrics/WithMetricsSupport.java | 2 +-
.../translation/AbstractTranslationContext.java | 235 --------
.../translation/PipelineTranslator.java | 57 +-
.../translation/TransformTranslator.java | 198 ++++++-
.../translation/TranslationContext.java | 124 ++++-
.../translation/batch/AggregatorCombiner.java | 270 ----------
.../translation/batch/Aggregators.java | 591 +++++++++++++++++++++
.../batch/CombineGloballyTranslatorBatch.java | 121 +++++
.../batch/CombinePerKeyTranslatorBatch.java | 181 ++++---
.../CreatePCollectionViewTranslatorBatch.java | 24 +-
.../translation/batch/DatasetSourceBatch.java | 240 ---------
.../translation/batch/DoFnFunction.java | 164 ------
.../batch/DoFnMapPartitionsFactory.java | 224 ++++++++
.../translation/batch/FlattenTranslatorBatch.java | 60 +--
.../translation/batch/GroupByKeyHelpers.java | 106 ++++
.../batch/GroupByKeyTranslatorBatch.java | 298 +++++++++--
.../translation/batch/ImpulseTranslatorBatch.java | 26 +-
.../translation/batch/ParDoTranslatorBatch.java | 315 ++++++-----
.../translation/batch/PipelineTranslatorBatch.java | 28 +-
.../translation/batch/ProcessContext.java | 138 -----
.../batch/ReadSourceTranslatorBatch.java | 76 +--
.../batch/ReshuffleTranslatorBatch.java | 30 --
.../batch/WindowAssignTranslatorBatch.java | 90 +++-
.../translation/helpers/CoderHelpers.java | 10 +-
.../translation/helpers/EncoderFactory.java | 71 ++-
.../translation/helpers/EncoderHelpers.java | 546 ++++++++++++++++++-
.../translation/helpers/MultiOutputCoder.java | 84 ---
.../translation/helpers/RowHelpers.java | 75 ---
.../translation/helpers/SchemaHelpers.java | 39 --
.../translation/helpers/WindowingHelpers.java | 82 ---
.../streaming/DatasetSourceStreaming.java | 25 -
.../streaming/PipelineTranslatorStreaming.java | 93 ----
.../streaming/ReadSourceTranslatorStreaming.java | 87 ---
.../translation/utils/ScalaInterop.java | 114 ++++
.../aggregators/metrics/sink/InMemoryMetrics.java | 2 +-
.../translation/batch/AggregatorsTest.java | 370 +++++++++++++
.../{CombineTest.java => CombineGloballyTest.java} | 129 ++---
.../{CombineTest.java => CombinePerKeyTest.java} | 92 ++--
.../translation/batch/ComplexSourceTest.java | 15 +-
.../translation/batch/FlattenTest.java | 12 +-
.../translation/batch/GroupByKeyTest.java | 152 ++++--
.../translation/batch/ParDoTest.java | 54 +-
.../translation/batch/SimpleSourceTest.java | 12 +-
.../translation/batch/WindowAssignTest.java | 12 +-
.../translation/helpers/EncoderHelpersTest.java | 210 +++++++-
.../runners/spark/SparkCommonPipelineOptions.java | 6 +
.../beam/runners/spark/SparkPipelineOptions.java | 6 -
53 files changed, 3978 insertions(+), 2361 deletions(-)
diff --git a/runners/spark/2/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderFactory.java b/runners/spark/2/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderFactory.java
index 54b400f08d0..2b86ec839c9 100644
--- a/runners/spark/2/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderFactory.java
+++ b/runners/spark/2/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderFactory.java
@@ -17,26 +17,24 @@
*/
package org.apache.beam.runners.spark.structuredstreaming.translation.helpers;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.listOf;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.seqOf;
+
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;
import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke;
import org.apache.spark.sql.types.DataType;
-import scala.collection.immutable.List;
-import scala.collection.immutable.Nil$;
-import scala.collection.mutable.WrappedArray;
import scala.reflect.ClassTag$;
public class EncoderFactory {
static <T> Encoder<T> create(
Expression serializer, Expression deserializer, Class<? super T> clazz) {
- // TODO Isolate usage of Scala APIs in utility https://github.com/apache/beam/issues/22382
- List<Expression> serializers = Nil$.MODULE$.$colon$colon(serializer);
return new ExpressionEncoder<>(
SchemaHelpers.binarySchema(),
false,
- serializers,
+ listOf(serializer),
deserializer,
ClassTag$.MODULE$.apply(clazz));
}
@@ -46,6 +44,6 @@ public class EncoderFactory {
* input arg is {@code null}.
*/
static Expression invokeIfNotNull(Class<?> cls, String fun, DataType type, Expression... args) {
- return new StaticInvoke(cls, type, fun, new WrappedArray.ofRef<>(args), true, true);
+ return new StaticInvoke(cls, type, fun, seqOf(args), true, true);
}
}
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/KVHelpers.java b/runners/spark/2/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/utils/ScalaInterop.java
similarity index 61%
rename from runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/KVHelpers.java
rename to runners/spark/2/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/utils/ScalaInterop.java
index 2406c0f49ab..c5bc71af602 100644
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/KVHelpers.java
+++ b/runners/spark/2/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/utils/ScalaInterop.java
@@ -15,17 +15,26 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.beam.runners.spark.structuredstreaming.translation.helpers;
+package org.apache.beam.runners.spark.structuredstreaming.translation.utils;
-import org.apache.beam.sdk.util.WindowedValue;
-import org.apache.beam.sdk.values.KV;
-import org.apache.spark.api.java.function.MapFunction;
+import scala.collection.Seq;
+import scala.collection.immutable.List;
+import scala.collection.immutable.Nil$;
+import scala.collection.mutable.WrappedArray;
-/** Helper functions for working with {@link org.apache.beam.sdk.values.KV}. */
-public final class KVHelpers {
+/** Utilities for easier interoperability with the Spark Scala API. */
+public class ScalaInterop {
+ private ScalaInterop() {}
- /** A Spark {@link MapFunction} for extracting the key out of a {@link KV} for GBK for example. */
- public static <K, V> MapFunction<WindowedValue<KV<K, V>>, K> extractKey() {
- return wv -> wv.getValue().getKey();
+ public static <T> Seq<T> seqOf(T... t) {
+ return new WrappedArray.ofRef<>(t);
+ }
+
+ public static <T> Seq<T> listOf(T t) {
+ return emptyList().$colon$colon(t);
+ }
+
+ public static <T> List<T> emptyList() {
+ return (List<T>) Nil$.MODULE$;
}
}
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/Constants.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/Constants.java
deleted file mode 100644
index 08c187ce6c6..00000000000
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/Constants.java
+++ /dev/null
@@ -1,25 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.runners.spark.structuredstreaming;
-
-public class Constants {
-
- public static final String BEAM_SOURCE_OPTION = "beam-source";
- public static final String DEFAULT_PARALLELISM = "default-parallelism";
- public static final String PIPELINE_OPTIONS = "pipeline-options";
-}
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingRunner.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingRunner.java
index b1de9e941e4..68f54ac93bf 100644
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingRunner.java
+++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingRunner.java
@@ -29,10 +29,9 @@ import org.apache.beam.runners.spark.structuredstreaming.metrics.AggregatorMetri
import org.apache.beam.runners.spark.structuredstreaming.metrics.CompositeSource;
import org.apache.beam.runners.spark.structuredstreaming.metrics.MetricsAccumulator;
import org.apache.beam.runners.spark.structuredstreaming.metrics.SparkBeamMetricSource;
-import org.apache.beam.runners.spark.structuredstreaming.translation.AbstractTranslationContext;
import org.apache.beam.runners.spark.structuredstreaming.translation.PipelineTranslator;
+import org.apache.beam.runners.spark.structuredstreaming.translation.TranslationContext;
import org.apache.beam.runners.spark.structuredstreaming.translation.batch.PipelineTranslatorBatch;
-import org.apache.beam.runners.spark.structuredstreaming.translation.streaming.PipelineTranslatorStreaming;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.PipelineRunner;
import org.apache.beam.sdk.metrics.MetricsEnvironment;
@@ -41,6 +40,7 @@ import org.apache.beam.sdk.options.ExperimentalOptions;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.options.PipelineOptionsValidator;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions;
import org.apache.spark.SparkEnv$;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.metrics.MetricsSystem;
@@ -48,24 +48,34 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
- * SparkStructuredStreamingRunner is based on spark structured streaming framework and is no more
- * based on RDD/DStream API. See
- * https://spark.apache.org/docs/latest/structured-streaming-programming-guide.html It is still
- * experimental, its coverage of the Beam model is partial. The SparkStructuredStreamingRunner
- * translate operations defined on a pipeline to a representation executable by Spark, and then
- * submitting the job to Spark to be executed. If we wanted to run a Beam pipeline with the default
- * options of a single threaded spark instance in local mode, we would do the following:
+ * A Spark runner build on top of Spark's SQL Engine (<a
+ * href="https://spark.apache.org/docs/latest/structured-streaming-programming-guide.html">Structured
+ * Streaming framework</a>).
*
- * <p>{@code Pipeline p = [logic for pipeline creation] SparkStructuredStreamingPipelineResult
- * result = (SparkStructuredStreamingPipelineResult) p.run(); }
+ * <p><b>This runner is experimental, its coverage of the Beam model is still partial. Due to
+ * limitations of the Structured Streaming framework (e.g. lack of support for multiple stateful
+ * operators), streaming mode is not yet supported by this runner. </b>
+ *
+ * <p>The runner translates transforms defined on a Beam pipeline to Spark `Dataset` transformations
+ * (leveraging the high level Dataset API) and then submits these to Spark to be executed.
+ *
+ * <p>To run a Beam pipeline with the default options using Spark's local mode, we would do the
+ * following:
+ *
+ * <pre>{@code
+ * Pipeline p = [logic for pipeline creation]
+ * PipelineResult result = p.run();
+ * }</pre>
*
* <p>To create a pipeline runner to run against a different spark cluster, with a custom master url
* we would do the following:
*
- * <p>{@code Pipeline p = [logic for pipeline creation] SparkStructuredStreamingPipelineOptions
- * options = SparkPipelineOptionsFactory.create(); options.setSparkMaster("spark://host:port");
- * SparkStructuredStreamingPipelineResult result = (SparkStructuredStreamingPipelineResult) p.run();
- * }
+ * <pre>{@code
+ * Pipeline p = [logic for pipeline creation]
+ * SparkCommonPipelineOptions options = p.getOptions.as(SparkCommonPipelineOptions.class);
+ * options.setSparkMaster("spark://host:port");
+ * PipelineResult result = p.run();
+ * }</pre>
*/
@SuppressWarnings({
"nullness" // TODO(https://github.com/apache/beam/issues/20497)
@@ -135,7 +145,7 @@ public final class SparkStructuredStreamingRunner
AggregatorsAccumulator.clear();
MetricsAccumulator.clear();
- final AbstractTranslationContext translationContext = translatePipeline(pipeline);
+ final TranslationContext translationContext = translatePipeline(pipeline);
final ExecutorService executorService = Executors.newSingleThreadExecutor();
final Future<?> submissionFuture =
@@ -169,8 +179,10 @@ public final class SparkStructuredStreamingRunner
return result;
}
- private AbstractTranslationContext translatePipeline(Pipeline pipeline) {
+ private TranslationContext translatePipeline(Pipeline pipeline) {
PipelineTranslator.detectTranslationMode(pipeline, options);
+ Preconditions.checkArgument(
+ !options.isStreaming(), "%s does not support streaming pipelines.", getClass().getName());
// Default to using the primitive versions of Read.Bounded and Read.Unbounded for non-portable
// execution.
@@ -182,10 +194,7 @@ public final class SparkStructuredStreamingRunner
PipelineTranslator.replaceTransforms(pipeline, options);
prepareFilesToStage(options);
- PipelineTranslator pipelineTranslator =
- options.isStreaming()
- ? new PipelineTranslatorStreaming(options)
- : new PipelineTranslatorBatch(options);
+ PipelineTranslator pipelineTranslator = new PipelineTranslatorBatch(options);
final JavaSparkContext jsc =
JavaSparkContext.fromSparkContext(
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/io/BoundedDatasetFactory.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/io/BoundedDatasetFactory.java
new file mode 100644
index 00000000000..83dc98f3c10
--- /dev/null
+++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/io/BoundedDatasetFactory.java
@@ -0,0 +1,324 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.spark.structuredstreaming.io;
+
+import static java.util.stream.Collectors.toList;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.emptyList;
+import static org.apache.beam.sdk.util.WindowedValue.timestampedValueInGlobalWindow;
+import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;
+import static scala.collection.JavaConverters.asScalaIterator;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.List;
+import java.util.Set;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.IntSupplier;
+import javax.annotation.Nullable;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
+import org.apache.beam.sdk.io.BoundedSource;
+import org.apache.beam.sdk.io.BoundedSource.BoundedReader;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.AbstractIterator;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSet;
+import org.apache.spark.InterruptibleIterator;
+import org.apache.spark.Partition;
+import org.apache.spark.SparkContext;
+import org.apache.spark.TaskContext;
+import org.apache.spark.rdd.RDD;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.Serializer;
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
+import org.apache.spark.sql.connector.catalog.SupportsRead;
+import org.apache.spark.sql.connector.catalog.Table;
+import org.apache.spark.sql.connector.catalog.TableCapability;
+import org.apache.spark.sql.connector.read.Batch;
+import org.apache.spark.sql.connector.read.InputPartition;
+import org.apache.spark.sql.connector.read.PartitionReader;
+import org.apache.spark.sql.connector.read.PartitionReaderFactory;
+import org.apache.spark.sql.connector.read.Scan;
+import org.apache.spark.sql.connector.read.ScanBuilder;
+import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation;
+import org.apache.spark.sql.types.StructType;
+import org.apache.spark.sql.util.CaseInsensitiveStringMap;
+import scala.Option;
+import scala.collection.Iterator;
+import scala.reflect.ClassTag;
+
+public class BoundedDatasetFactory {
+ private BoundedDatasetFactory() {}
+
+ /**
+ * Create a {@link Dataset} for a {@link BoundedSource} via a Spark {@link Table}.
+ *
+ * <p>Unfortunately tables are expected to return an {@link InternalRow}, requiring serialization.
+ * This makes this approach at the time being significantly less performant than creating a
+ * dataset from an RDD.
+ */
+ public static <T> Dataset<WindowedValue<T>> createDatasetFromRows(
+ SparkSession session,
+ BoundedSource<T> source,
+ SerializablePipelineOptions options,
+ Encoder<WindowedValue<T>> encoder) {
+ Params<T> params = new Params<>(encoder, options, session.sparkContext().defaultParallelism());
+ BeamTable<T> table = new BeamTable<>(source, params);
+ LogicalPlan logicalPlan = DataSourceV2Relation.create(table, Option.empty(), Option.empty());
+ return Dataset.ofRows(session, logicalPlan).as(encoder);
+ }
+
+ /**
+ * Create a {@link Dataset} for a {@link BoundedSource} via a Spark {@link RDD}.
+ *
+ * <p>This is currently the most efficient approach as it avoid any serialization overhead.
+ */
+ public static <T> Dataset<WindowedValue<T>> createDatasetFromRDD(
+ SparkSession session,
+ BoundedSource<T> source,
+ SerializablePipelineOptions options,
+ Encoder<WindowedValue<T>> encoder) {
+ Params<T> params = new Params<>(encoder, options, session.sparkContext().defaultParallelism());
+ RDD<WindowedValue<T>> rdd = new BoundedRDD<>(session.sparkContext(), source, params);
+ return session.createDataset(rdd, encoder);
+ }
+
+ /** An {@link RDD} for a bounded Beam source. */
+ private static class BoundedRDD<T> extends RDD<WindowedValue<T>> {
+ final BoundedSource<T> source;
+ final Params<T> params;
+
+ public BoundedRDD(SparkContext sc, BoundedSource<T> source, Params<T> params) {
+ super(sc, emptyList(), ClassTag.apply(WindowedValue.class));
+ this.source = source;
+ this.params = params;
+ }
+
+ @Override
+ public Iterator<WindowedValue<T>> compute(Partition split, TaskContext context) {
+ return new InterruptibleIterator<>(
+ context,
+ asScalaIterator(new SourcePartitionIterator<>((SourcePartition<T>) split, params)));
+ }
+
+ @Override
+ public Partition[] getPartitions() {
+ return SourcePartition.partitionsOf(source, params).toArray(new Partition[0]);
+ }
+ }
+
+ /** A Spark {@link Table} for a bounded Beam source supporting batch reads only. */
+ private static class BeamTable<T> implements Table, SupportsRead {
+ final BoundedSource<T> source;
+ final Params<T> params;
+
+ BeamTable(BoundedSource<T> source, Params<T> params) {
+ this.source = source;
+ this.params = params;
+ }
+
+ public Encoder<WindowedValue<T>> getEncoder() {
+ return params.encoder;
+ }
+
+ @Override
+ public ScanBuilder newScanBuilder(CaseInsensitiveStringMap ignored) {
+ return () ->
+ new Scan() {
+ @Override
+ public StructType readSchema() {
+ return params.encoder.schema();
+ }
+
+ @Override
+ public Batch toBatch() {
+ return new BeamBatch<>(source, params);
+ }
+ };
+ }
+
+ @Override
+ public String name() {
+ return "BeamSource<" + source.getClass().getName() + ">";
+ }
+
+ @Override
+ public StructType schema() {
+ return params.encoder.schema();
+ }
+
+ @Override
+ public Set<TableCapability> capabilities() {
+ return ImmutableSet.of(TableCapability.BATCH_READ);
+ }
+
+ private static class BeamBatch<T> implements Batch, Serializable {
+ final BoundedSource<T> source;
+ final Params<T> params;
+
+ private BeamBatch(BoundedSource<T> source, Params<T> params) {
+ this.source = source;
+ this.params = params;
+ }
+
+ @Override
+ public InputPartition[] planInputPartitions() {
+ return SourcePartition.partitionsOf(source, params).toArray(new InputPartition[0]);
+ }
+
+ @Override
+ public PartitionReaderFactory createReaderFactory() {
+ return p -> new BeamPartitionReader<>(((SourcePartition<T>) p), params);
+ }
+ }
+
+ private static class BeamPartitionReader<T> implements PartitionReader<InternalRow> {
+ final SourcePartitionIterator<T> iterator;
+ final Serializer<WindowedValue<T>> serializer;
+ transient @Nullable InternalRow next;
+
+ BeamPartitionReader(SourcePartition<T> partition, Params<T> params) {
+ iterator = new SourcePartitionIterator<>(partition, params);
+ serializer = ((ExpressionEncoder<WindowedValue<T>>) params.encoder).createSerializer();
+ }
+
+ @Override
+ public boolean next() throws IOException {
+ if (iterator.hasNext()) {
+ next = serializer.apply(iterator.next());
+ return true;
+ }
+ return false;
+ }
+
+ @Override
+ public InternalRow get() {
+ if (next == null) {
+ throw new IllegalStateException("Next not available");
+ }
+ return next;
+ }
+
+ @Override
+ public void close() throws IOException {
+ next = null;
+ iterator.close();
+ }
+ }
+ }
+
+ /** A Spark partition wrapping the partitioned Beam {@link BoundedSource}. */
+ private static class SourcePartition<T> implements Partition, InputPartition {
+ final BoundedSource<T> source;
+ final int index;
+
+ SourcePartition(BoundedSource<T> source, IntSupplier idxSupplier) {
+ this.source = source;
+ this.index = idxSupplier.getAsInt();
+ }
+
+ static <T> List<SourcePartition<T>> partitionsOf(BoundedSource<T> source, Params<T> params) {
+ try {
+ PipelineOptions options = params.options.get();
+ long desiredSize = source.getEstimatedSizeBytes(options) / params.numPartitions;
+ List<BoundedSource<T>> split = (List<BoundedSource<T>>) source.split(desiredSize, options);
+ IntSupplier idxSupplier = new AtomicInteger(0)::getAndIncrement;
+ return split.stream().map(s -> new SourcePartition<>(s, idxSupplier)).collect(toList());
+ } catch (Exception e) {
+ throw new RuntimeException(
+ "Error splitting BoundedSource " + source.getClass().getCanonicalName(), e);
+ }
+ }
+
+ @Override
+ public int index() {
+ return index;
+ }
+
+ @Override
+ public int hashCode() {
+ return index;
+ }
+ }
+
+ /** A partition iterator on a partitioned Beam {@link BoundedSource}. */
+ private static class SourcePartitionIterator<T> extends AbstractIterator<WindowedValue<T>>
+ implements Closeable {
+ BoundedReader<T> reader;
+ boolean started = false;
+
+ public SourcePartitionIterator(SourcePartition<T> partition, Params<T> params) {
+ try {
+ reader = partition.source.createReader(params.options.get());
+ } catch (IOException e) {
+ throw new RuntimeException("Failed to create reader from a BoundedSource.", e);
+ }
+ }
+
+ @Override
+ @SuppressWarnings("nullness") // ok, reader not used any longer
+ public void close() throws IOException {
+ if (reader != null) {
+ endOfData();
+ try {
+ reader.close();
+ } finally {
+ reader = null;
+ }
+ }
+ }
+
+ @Override
+ protected WindowedValue<T> computeNext() {
+ try {
+ if (started ? reader.advance() : start()) {
+ return timestampedValueInGlobalWindow(reader.getCurrent(), reader.getCurrentTimestamp());
+ } else {
+ close();
+ return endOfData();
+ }
+ } catch (IOException e) {
+ throw new RuntimeException("Failed to start or advance reader.", e);
+ }
+ }
+
+ private boolean start() throws IOException {
+ started = true;
+ return reader.start();
+ }
+ }
+
+ /** Shared parameters. */
+ private static class Params<T> implements Serializable {
+ final Encoder<WindowedValue<T>> encoder;
+ final SerializablePipelineOptions options;
+ final int numPartitions;
+
+ Params(
+ Encoder<WindowedValue<T>> encoder, SerializablePipelineOptions options, int numPartitions) {
+ checkArgument(numPartitions > 0, "Number of partitions must be greater than zero.");
+ this.encoder = encoder;
+ this.options = options;
+ this.numPartitions = numPartitions;
+ }
+ }
+}
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/streaming/package-info.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/io/package-info.java
similarity index 83%
rename from runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/streaming/package-info.java
rename to runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/io/package-info.java
index 67f3613e056..23de70c705b 100644
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/streaming/package-info.java
+++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/io/package-info.java
@@ -16,5 +16,5 @@
* limitations under the License.
*/
-/** Internal utilities to translate Beam pipelines to Spark streaming. */
-package org.apache.beam.runners.spark.structuredstreaming.translation.streaming;
+/** Spark-specific transforms for I/O. */
+package org.apache.beam.runners.spark.structuredstreaming.io;
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/WithMetricsSupport.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/WithMetricsSupport.java
index d48a229996f..c9233a128c1 100644
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/WithMetricsSupport.java
+++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/WithMetricsSupport.java
@@ -36,7 +36,6 @@ import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Ordering
* <p>{@link MetricRegistry} is not an interface, so this is not a by-the-book decorator. That said,
* it delegates all metric related getters to the "decorated" instance.
*/
-@SuppressWarnings({"rawtypes"}) // required by interface
public class WithMetricsSupport extends MetricRegistry {
private final MetricRegistry internalMetricRegistry;
@@ -70,6 +69,7 @@ public class WithMetricsSupport extends MetricRegistry {
}
@Override
+ @SuppressWarnings({"rawtypes"}) // required by interface
public SortedMap<String, Gauge> getGauges(final MetricFilter filter) {
ImmutableSortedMap.Builder<String, Gauge> builder =
new ImmutableSortedMap.Builder<>(Ordering.from(String.CASE_INSENSITIVE_ORDER));
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/AbstractTranslationContext.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/AbstractTranslationContext.java
deleted file mode 100644
index aed287ba6d5..00000000000
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/AbstractTranslationContext.java
+++ /dev/null
@@ -1,235 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.runners.spark.structuredstreaming.translation;
-
-import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Map;
-import java.util.Set;
-import java.util.stream.Collectors;
-import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
-import org.apache.beam.runners.core.construction.TransformInputs;
-import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingPipelineOptions;
-import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers;
-import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.coders.VoidCoder;
-import org.apache.beam.sdk.runners.AppliedPTransform;
-import org.apache.beam.sdk.transforms.PTransform;
-import org.apache.beam.sdk.util.WindowedValue;
-import org.apache.beam.sdk.values.PCollection;
-import org.apache.beam.sdk.values.PCollectionView;
-import org.apache.beam.sdk.values.PValue;
-import org.apache.beam.sdk.values.TupleTag;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
-import org.apache.spark.api.java.function.ForeachFunction;
-import org.apache.spark.sql.Dataset;
-import org.apache.spark.sql.ForeachWriter;
-import org.apache.spark.sql.SparkSession;
-import org.apache.spark.sql.streaming.DataStreamWriter;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-/**
- * Base class that gives a context for {@link PTransform} translation: keeping track of the
- * datasets, the {@link SparkSession}, the current transform being translated.
- */
-@SuppressWarnings({
- "rawtypes", // TODO(https://github.com/apache/beam/issues/20447)
- "nullness" // TODO(https://github.com/apache/beam/issues/20497)
-})
-public abstract class AbstractTranslationContext {
-
- private static final Logger LOG = LoggerFactory.getLogger(AbstractTranslationContext.class);
-
- /** All the datasets of the DAG. */
- private final Map<PValue, Dataset<?>> datasets;
- /** datasets that are not used as input to other datasets (leaves of the DAG). */
- private final Set<Dataset<?>> leaves;
-
- private final SerializablePipelineOptions serializablePipelineOptions;
-
- @SuppressFBWarnings("URF_UNREAD_FIELD") // make spotbugs happy
- private AppliedPTransform<?, ?, ?> currentTransform;
-
- @SuppressFBWarnings("URF_UNREAD_FIELD") // make spotbugs happy
- private final SparkSession sparkSession;
-
- private final Map<PCollectionView<?>, Dataset<?>> broadcastDataSets;
-
- public AbstractTranslationContext(SparkStructuredStreamingPipelineOptions options) {
- this.sparkSession = SparkSessionFactory.getOrCreateSession(options);
- this.serializablePipelineOptions = new SerializablePipelineOptions(options);
- this.datasets = new HashMap<>();
- this.leaves = new HashSet<>();
- this.broadcastDataSets = new HashMap<>();
- }
-
- public SparkSession getSparkSession() {
- return sparkSession;
- }
-
- public SerializablePipelineOptions getSerializableOptions() {
- return serializablePipelineOptions;
- }
-
- // --------------------------------------------------------------------------------------------
- // Transforms methods
- // --------------------------------------------------------------------------------------------
- public void setCurrentTransform(AppliedPTransform<?, ?, ?> currentTransform) {
- this.currentTransform = currentTransform;
- }
-
- public AppliedPTransform<?, ?, ?> getCurrentTransform() {
- return currentTransform;
- }
-
- // --------------------------------------------------------------------------------------------
- // Datasets methods
- // --------------------------------------------------------------------------------------------
- @SuppressWarnings("unchecked")
- public <T> Dataset<T> emptyDataset() {
- return (Dataset<T>) sparkSession.emptyDataset(EncoderHelpers.fromBeamCoder(VoidCoder.of()));
- }
-
- @SuppressWarnings("unchecked")
- public <T> Dataset<WindowedValue<T>> getDataset(PValue value) {
- Dataset<?> dataset = datasets.get(value);
- // assume that the Dataset is used as an input if retrieved here. So it is not a leaf anymore
- leaves.remove(dataset);
- return (Dataset<WindowedValue<T>>) dataset;
- }
-
- /**
- * TODO: All these 3 methods (putDataset*) are temporary and they are used only for generics type
- * checking. We should unify them in the future.
- */
- public void putDatasetWildcard(PValue value, Dataset<WindowedValue<?>> dataset) {
- if (!datasets.containsKey(value)) {
- datasets.put(value, dataset);
- leaves.add(dataset);
- }
- }
-
- public <T> void putDataset(PValue value, Dataset<WindowedValue<T>> dataset) {
- if (!datasets.containsKey(value)) {
- datasets.put(value, dataset);
- leaves.add(dataset);
- }
- }
-
- public <ViewT, ElemT> void setSideInputDataset(
- PCollectionView<ViewT> value, Dataset<WindowedValue<ElemT>> set) {
- if (!broadcastDataSets.containsKey(value)) {
- broadcastDataSets.put(value, set);
- }
- }
-
- @SuppressWarnings("unchecked")
- public <T> Dataset<T> getSideInputDataSet(PCollectionView<?> value) {
- return (Dataset<T>) broadcastDataSets.get(value);
- }
-
- // --------------------------------------------------------------------------------------------
- // PCollections methods
- // --------------------------------------------------------------------------------------------
- public PValue getInput() {
- return Iterables.getOnlyElement(TransformInputs.nonAdditionalInputs(currentTransform));
- }
-
- public Map<TupleTag<?>, PCollection<?>> getInputs() {
- return currentTransform.getInputs();
- }
-
- public PValue getOutput() {
- return Iterables.getOnlyElement(currentTransform.getOutputs().values());
- }
-
- public Map<TupleTag<?>, PCollection<?>> getOutputs() {
- return currentTransform.getOutputs();
- }
-
- @SuppressWarnings("unchecked")
- public Map<TupleTag<?>, Coder<?>> getOutputCoders() {
- return currentTransform.getOutputs().entrySet().stream()
- .filter(e -> e.getValue() instanceof PCollection)
- .collect(Collectors.toMap(Map.Entry::getKey, e -> ((PCollection) e.getValue()).getCoder()));
- }
-
- // --------------------------------------------------------------------------------------------
- // Pipeline methods
- // --------------------------------------------------------------------------------------------
-
- /** Starts the pipeline. */
- public void startPipeline() {
- SparkStructuredStreamingPipelineOptions options =
- serializablePipelineOptions.get().as(SparkStructuredStreamingPipelineOptions.class);
- int datasetIndex = 0;
- for (Dataset<?> dataset : leaves) {
- if (options.isStreaming()) {
- // TODO: deal with Beam Discarding, Accumulating and Accumulating & Retracting outputmodes
- // with DatastreamWriter.outputMode
- DataStreamWriter<?> dataStreamWriter = dataset.writeStream();
- // spark sets a default checkpoint dir if not set.
- if (options.getCheckpointDir() != null) {
- dataStreamWriter =
- dataStreamWriter.option("checkpointLocation", options.getCheckpointDir());
- }
- launchStreaming(dataStreamWriter.foreach(new NoOpForeachWriter<>()));
- } else {
- if (options.getTestMode()) {
- LOG.debug("**** dataset {} catalyst execution plans ****", ++datasetIndex);
- dataset.explain(true);
- }
- // apply a dummy fn just to apply foreach action that will trigger the pipeline run in
- // spark
- dataset.foreach((ForeachFunction) t -> {});
- }
- }
- }
-
- public abstract void launchStreaming(DataStreamWriter<?> dataStreamWriter);
-
- public static void printDatasetContent(Dataset<WindowedValue> dataset) {
- // cannot use dataset.show because dataset schema is binary so it will print binary
- // code.
- List<WindowedValue> windowedValues = dataset.collectAsList();
- for (WindowedValue windowedValue : windowedValues) {
- LOG.debug("**** dataset content {} ****", windowedValue.toString());
- }
- }
-
- private static class NoOpForeachWriter<T> extends ForeachWriter<T> {
-
- @Override
- public boolean open(long partitionId, long epochId) {
- return false;
- }
-
- @Override
- public void process(T value) {
- // do nothing
- }
-
- @Override
- public void close(Throwable errorOrNull) {
- // do nothing
- }
- }
-}
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/PipelineTranslator.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/PipelineTranslator.java
index 0f851d9588d..0fa48fc3d3e 100644
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/PipelineTranslator.java
+++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/PipelineTranslator.java
@@ -17,24 +17,25 @@
*/
package org.apache.beam.runners.spark.structuredstreaming.translation;
+import java.io.IOException;
import org.apache.beam.runners.core.construction.PTransformTranslation;
-import org.apache.beam.runners.spark.structuredstreaming.translation.batch.PipelineTranslatorBatch;
-import org.apache.beam.runners.spark.structuredstreaming.translation.streaming.PipelineTranslatorStreaming;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.options.StreamingOptions;
+import org.apache.beam.sdk.runners.AppliedPTransform;
import org.apache.beam.sdk.runners.TransformHierarchy;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PInput;
+import org.apache.beam.sdk.values.POutput;
import org.apache.beam.sdk.values.PValue;
+import org.checkerframework.checker.nullness.qual.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* {@link Pipeline.PipelineVisitor} that translates the Beam operators to their Spark counterparts.
* It also does the pipeline preparation: mode detection, transforms replacement, classpath
- * preparation. If we have a streaming job, it is instantiated as a {@link
- * PipelineTranslatorStreaming}. If we have a batch job, it is instantiated as a {@link
- * PipelineTranslatorBatch}.
+ * preparation.
*/
@SuppressWarnings({
"nullness" // TODO(https://github.com/apache/beam/issues/20497)
@@ -42,7 +43,7 @@ import org.slf4j.LoggerFactory;
public abstract class PipelineTranslator extends Pipeline.PipelineVisitor.Defaults {
private int depth = 0;
private static final Logger LOG = LoggerFactory.getLogger(PipelineTranslator.class);
- protected AbstractTranslationContext translationContext;
+ protected TranslationContext translationContext;
// --------------------------------------------------------------------------------------------
// Pipeline preparation methods
@@ -123,22 +124,25 @@ public abstract class PipelineTranslator extends Pipeline.PipelineVisitor.Defaul
}
/** Get a {@link TransformTranslator} for the given {@link TransformHierarchy.Node}. */
- protected abstract TransformTranslator<?> getTransformTranslator(TransformHierarchy.Node node);
+ protected abstract @Nullable <
+ InT extends PInput, OutT extends POutput, TransformT extends PTransform<InT, OutT>>
+ TransformTranslator<InT, OutT, TransformT> getTransformTranslator(
+ @Nullable TransformT transform);
/** Apply the given TransformTranslator to the given node. */
- private <T extends PTransform<?, ?>> void applyTransformTranslator(
- TransformHierarchy.Node node, TransformTranslator<?> transformTranslator) {
+ private <InT extends PInput, OutT extends POutput, TransformT extends PTransform<InT, OutT>>
+ void applyTransformTranslator(
+ TransformHierarchy.Node node,
+ TransformT transform,
+ TransformTranslator<InT, OutT, TransformT> transformTranslator) {
// create the applied PTransform on the translationContext
- translationContext.setCurrentTransform(node.toAppliedPTransform(getPipeline()));
-
- // avoid type capture
- @SuppressWarnings("unchecked")
- T typedTransform = (T) node.getTransform();
- @SuppressWarnings("unchecked")
- TransformTranslator<T> typedTransformTranslator = (TransformTranslator<T>) transformTranslator;
-
- // apply the transformTranslator
- typedTransformTranslator.translateTransform(typedTransform, translationContext);
+ AppliedPTransform<InT, OutT, PTransform<InT, OutT>> appliedTransform =
+ (AppliedPTransform) node.toAppliedPTransform(getPipeline());
+ try {
+ transformTranslator.translate(transform, appliedTransform, translationContext);
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
}
// --------------------------------------------------------------------------------------------
@@ -164,10 +168,12 @@ public abstract class PipelineTranslator extends Pipeline.PipelineVisitor.Defaul
LOG.debug("{} enterCompositeTransform- {}", genSpaces(depth), node.getFullName());
depth++;
- TransformTranslator<?> transformTranslator = getTransformTranslator(node);
+ PTransform<PInput, POutput> transform = (PTransform<PInput, POutput>) node.getTransform();
+ TransformTranslator<PInput, POutput, PTransform<PInput, POutput>> transformTranslator =
+ getTransformTranslator(transform);
if (transformTranslator != null) {
- applyTransformTranslator(node, transformTranslator);
+ applyTransformTranslator(node, transform, transformTranslator);
LOG.debug("{} translated- {}", genSpaces(depth), node.getFullName());
return CompositeBehavior.DO_NOT_ENTER_TRANSFORM;
} else {
@@ -187,16 +193,19 @@ public abstract class PipelineTranslator extends Pipeline.PipelineVisitor.Defaul
// get the transformation corresponding to the node we are
// currently visiting and translate it into its Spark alternative.
- TransformTranslator<?> transformTranslator = getTransformTranslator(node);
+ PTransform<PInput, POutput> transform = (PTransform<PInput, POutput>) node.getTransform();
+ TransformTranslator<PInput, POutput, PTransform<PInput, POutput>> transformTranslator =
+ getTransformTranslator(transform);
+
if (transformTranslator == null) {
String transformUrn = PTransformTranslation.urnForTransform(node.getTransform());
throw new UnsupportedOperationException(
"The transform " + transformUrn + " is currently not supported.");
}
- applyTransformTranslator(node, transformTranslator);
+ applyTransformTranslator(node, transform, transformTranslator);
}
- public AbstractTranslationContext getTranslationContext() {
+ public TranslationContext getTranslationContext() {
return translationContext;
}
}
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/TransformTranslator.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/TransformTranslator.java
index 61580aed219..d991a0d9148 100644
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/TransformTranslator.java
+++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/TransformTranslator.java
@@ -17,15 +17,197 @@
*/
package org.apache.beam.runners.spark.structuredstreaming.translation;
-import java.io.Serializable;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.encoderFor;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.kvEncoder;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.windowedValueEncoder;
+import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState;
+import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables.getOnlyElement;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
+import org.apache.beam.runners.core.construction.TransformInputs;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.runners.AppliedPTransform;
import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionView;
+import org.apache.beam.sdk.values.PInput;
+import org.apache.beam.sdk.values.POutput;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.spark.broadcast.Broadcast;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.Encoders;
+import org.apache.spark.sql.SparkSession;
+import org.checkerframework.checker.nullness.qual.MonotonicNonNull;
+import scala.Tuple2;
+import scala.reflect.ClassTag;
+
+/**
+ * Supports translation between a Beam transform, and Spark's operations on Datasets.
+ *
+ * <p>WARNING: Do not make this class serializable! It could easily hide situations where
+ * unnecessary references leak into Spark closures.
+ */
+public abstract class TransformTranslator<
+ InT extends PInput, OutT extends POutput, TransformT extends PTransform<? extends InT, OutT>> {
+
+ protected abstract void translate(TransformT transform, Context cxt) throws IOException;
+
+ public final void translate(
+ TransformT transform,
+ AppliedPTransform<InT, OutT, PTransform<InT, OutT>> appliedTransform,
+ TranslationContext cxt)
+ throws IOException {
+ translate(transform, new Context(appliedTransform, cxt));
+ }
+
+ protected class Context {
+ private final AppliedPTransform<InT, OutT, PTransform<InT, OutT>> transform;
+ private final TranslationContext cxt;
+ private @MonotonicNonNull InT pIn = null;
+ private @MonotonicNonNull OutT pOut = null;
+
+ protected Context(
+ AppliedPTransform<InT, OutT, PTransform<InT, OutT>> transform, TranslationContext cxt) {
+ this.transform = transform;
+ this.cxt = cxt;
+ }
+
+ public InT getInput() {
+ if (pIn == null) {
+ pIn = (InT) getOnlyElement(TransformInputs.nonAdditionalInputs(transform));
+ }
+ return pIn;
+ }
+
+ public Map<TupleTag<?>, PCollection<?>> getInputs() {
+ return transform.getInputs();
+ }
+
+ public OutT getOutput() {
+ if (pOut == null) {
+ pOut = (OutT) getOnlyElement(transform.getOutputs().values());
+ }
+ return pOut;
+ }
+
+ public <T> PCollection<T> getOutput(TupleTag<T> tag) {
+ PCollection<T> pc = (PCollection<T>) transform.getOutputs().get(tag);
+ if (pc == null) {
+ throw new IllegalStateException("No output for tag " + tag);
+ }
+ return pc;
+ }
+
+ public Map<TupleTag<?>, PCollection<?>> getOutputs() {
+ return transform.getOutputs();
+ }
+
+ public AppliedPTransform<InT, OutT, PTransform<InT, OutT>> getCurrentTransform() {
+ return transform;
+ }
+
+ public <T> Dataset<WindowedValue<T>> getDataset(PCollection<T> pCollection) {
+ return cxt.getDataset(pCollection);
+ }
+
+ public <T> void putDataset(PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset) {
+ cxt.putDataset(pCollection, dataset);
+ }
+
+ public SerializablePipelineOptions getSerializableOptions() {
+ return cxt.getSerializableOptions();
+ }
+
+ public PipelineOptions getOptions() {
+ return cxt.getSerializableOptions().get();
+ }
+
+ // FIXME Types don't guarantee anything!
+ public <ViewT, ElemT> void setSideInputDataset(
+ PCollectionView<ViewT> value, Dataset<WindowedValue<ElemT>> set) {
+ cxt.setSideInputDataset(value, set);
+ }
+
+ public <T> Dataset<T> getSideInputDataset(PCollectionView<?> sideInput) {
+ return cxt.getSideInputDataSet(sideInput);
+ }
+
+ public <T> Dataset<WindowedValue<T>> createDataset(
+ List<WindowedValue<T>> data, Encoder<WindowedValue<T>> enc) {
+ return data.isEmpty()
+ ? cxt.getSparkSession().emptyDataset(enc)
+ : cxt.getSparkSession().createDataset(data, enc);
+ }
+
+ public <T> Broadcast<T> broadcast(T value) {
+ return cxt.getSparkSession().sparkContext().broadcast(value, (ClassTag) ClassTag.AnyRef());
+ }
+
+ public SparkSession getSparkSession() {
+ return cxt.getSparkSession();
+ }
+
+ public <T> Encoder<T> encoderOf(Coder<T> coder) {
+ return coder instanceof KvCoder ? kvEncoderOf((KvCoder) coder) : getOrCreateEncoder(coder);
+ }
+
+ public <K, V> Encoder<KV<K, V>> kvEncoderOf(KvCoder<K, V> coder) {
+ return cxt.encoderOf(coder, c -> kvEncoder(keyEncoderOf(coder), valueEncoderOf(coder)));
+ }
+
+ public <K, V> Encoder<K> keyEncoderOf(KvCoder<K, V> coder) {
+ return getOrCreateEncoder(coder.getKeyCoder());
+ }
+
+ public <K, V> Encoder<V> valueEncoderOf(KvCoder<K, V> coder) {
+ return getOrCreateEncoder(coder.getValueCoder());
+ }
+
+ public <T> Encoder<WindowedValue<T>> windowedEncoder(Coder<T> coder) {
+ return windowedValueEncoder(encoderOf(coder), windowEncoder());
+ }
+
+ public <T> Encoder<WindowedValue<T>> windowedEncoder(Encoder<T> enc) {
+ return windowedValueEncoder(enc, windowEncoder());
+ }
+
+ public <T1, T2> Encoder<Tuple2<T1, T2>> tupleEncoder(Encoder<T1> e1, Encoder<T2> e2) {
+ return Encoders.tuple(e1, e2);
+ }
+
+ public <T, W extends BoundedWindow> Encoder<WindowedValue<T>> windowedEncoder(
+ Coder<T> coder, Coder<W> windowCoder) {
+ return windowedValueEncoder(encoderOf(coder), getOrCreateWindowCoder(windowCoder));
+ }
+
+ public Encoder<BoundedWindow> windowEncoder() {
+ checkState(
+ !transform.getInputs().isEmpty(), "Transform has no inputs, cannot get windowCoder!");
+ Coder<BoundedWindow> coder =
+ ((PCollection) getInput()).getWindowingStrategy().getWindowFn().windowCoder();
+ return cxt.encoderOf(coder, c -> encoderFor(c));
+ }
+
+ private <W extends BoundedWindow> Encoder<BoundedWindow> getOrCreateWindowCoder(
+ Coder<W> coder) {
+ return cxt.encoderOf((Coder<BoundedWindow>) coder, c -> encoderFor(c));
+ }
-/** Supports translation between a Beam transform, and Spark's operations on Datasets. */
-@SuppressWarnings({
- "rawtypes" // TODO(https://github.com/apache/beam/issues/20447)
-})
-public interface TransformTranslator<TransformT extends PTransform> extends Serializable {
+ private <T> Encoder<T> getOrCreateEncoder(Coder<T> coder) {
+ return cxt.encoderOf(coder, c -> encoderFor(c));
+ }
+ }
- /** Base class for translators of {@link PTransform}. */
- void translateTransform(TransformT transform, AbstractTranslationContext context);
+ protected <T> Coder<BoundedWindow> windowCoder(PCollection<T> pc) {
+ return (Coder) pc.getWindowingStrategy().getWindowFn().windowCoder();
+ }
}
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/TranslationContext.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/TranslationContext.java
index 12cb2d2fef0..617aa67c5fe 100644
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/TranslationContext.java
+++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/TranslationContext.java
@@ -17,27 +17,125 @@
*/
package org.apache.beam.runners.spark.structuredstreaming.translation;
-import java.util.concurrent.TimeoutException;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.function.Function;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingPipelineOptions;
-import org.apache.spark.sql.streaming.DataStreamWriter;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.util.Preconditions;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionView;
+import org.apache.beam.sdk.values.PValue;
+import org.apache.spark.api.java.function.ForeachFunction;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
/**
- * Subclass of {@link
- * org.apache.beam.runners.spark.structuredstreaming.translation.AbstractTranslationContext} that
- * address spark breaking changes.
+ * Base class that gives a context for {@link PTransform} translation: keeping track of the
+ * datasets, the {@link SparkSession}, the current transform being translated.
*/
-public class TranslationContext extends AbstractTranslationContext {
+public class TranslationContext {
+
+ private static final Logger LOG = LoggerFactory.getLogger(TranslationContext.class);
+
+ /** All the datasets of the DAG. */
+ private final Map<PValue, Dataset<?>> datasets;
+ /** datasets that are not used as input to other datasets (leaves of the DAG). */
+ private final Set<Dataset<?>> leaves;
+
+ private final SerializablePipelineOptions serializablePipelineOptions;
+
+ private final SparkSession sparkSession;
+
+ private final Map<PCollectionView<?>, Dataset<?>> broadcastDataSets;
+
+ private final Map<Coder<?>, ExpressionEncoder<?>> encoders;
public TranslationContext(SparkStructuredStreamingPipelineOptions options) {
- super(options);
+ this.sparkSession = SparkSessionFactory.getOrCreateSession(options);
+ this.serializablePipelineOptions = new SerializablePipelineOptions(options);
+ this.datasets = new HashMap<>();
+ this.leaves = new HashSet<>();
+ this.broadcastDataSets = new HashMap<>();
+ this.encoders = new HashMap<>();
+ }
+
+ public SparkSession getSparkSession() {
+ return sparkSession;
+ }
+
+ public SerializablePipelineOptions getSerializableOptions() {
+ return serializablePipelineOptions;
+ }
+
+ public <T> Encoder<T> encoderOf(Coder<T> coder, Function<Coder<T>, Encoder<T>> loadFn) {
+ return (Encoder<T>) encoders.computeIfAbsent(coder, (Function) loadFn);
+ }
+
+ @SuppressWarnings("unchecked") // can't be avoided
+ public <T> Dataset<WindowedValue<T>> getDataset(PCollection<T> pCollection) {
+ Dataset<?> dataset = Preconditions.checkStateNotNull(datasets.get(pCollection));
+ // assume that the Dataset is used as an input if retrieved here. So it is not a leaf anymore
+ leaves.remove(dataset);
+ return (Dataset<WindowedValue<T>>) dataset;
+ }
+
+ public <T> void putDataset(PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset) {
+ if (!datasets.containsKey(pCollection)) {
+ datasets.put(pCollection, dataset);
+ leaves.add(dataset);
+ }
+ }
+
+ public <ViewT, ElemT> void setSideInputDataset(
+ PCollectionView<ViewT> value, Dataset<WindowedValue<ElemT>> set) {
+ if (!broadcastDataSets.containsKey(value)) {
+ broadcastDataSets.put(value, set);
+ }
+ }
+
+ @SuppressWarnings("unchecked") // can't be avoided
+ public <T> Dataset<T> getSideInputDataSet(PCollectionView<?> value) {
+ return (Dataset<T>) Preconditions.checkStateNotNull(broadcastDataSets.get(value));
+ }
+
+ /**
+ * Starts the batch pipeline, streaming is not supported.
+ *
+ * @see org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingRunner
+ */
+ public void startPipeline() {
+ encoders.clear();
+
+ SparkStructuredStreamingPipelineOptions options =
+ serializablePipelineOptions.get().as(SparkStructuredStreamingPipelineOptions.class);
+ int datasetIndex = 0;
+ for (Dataset<?> dataset : leaves) {
+ if (options.getTestMode()) {
+ LOG.debug("**** dataset {} catalyst execution plans ****", ++datasetIndex);
+ dataset.explain(true);
+ }
+ // force evaluation using a dummy foreach action
+ dataset.foreach((ForeachFunction) t -> {});
+ }
}
- @Override
- public void launchStreaming(DataStreamWriter<?> dataStreamWriter) {
- try {
- dataStreamWriter.start();
- } catch (TimeoutException e) {
- throw new RuntimeException("A timeout occurred when running the streaming pipeline", e);
+ public static <T> void printDatasetContent(Dataset<WindowedValue<T>> dataset) {
+ // cannot use dataset.show because dataset schema is binary so it will print binary
+ // code.
+ List<WindowedValue<T>> windowedValues = dataset.collectAsList();
+ for (WindowedValue<?> windowedValue : windowedValues) {
+ System.out.println(windowedValue);
}
}
}
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/AggregatorCombiner.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/AggregatorCombiner.java
deleted file mode 100644
index d0f46ea807c..00000000000
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/AggregatorCombiner.java
+++ /dev/null
@@ -1,270 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
-
-import java.util.ArrayList;
-import java.util.Collection;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Map;
-import java.util.Set;
-import java.util.stream.Collectors;
-import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers;
-import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.coders.CoderException;
-import org.apache.beam.sdk.coders.IterableCoder;
-import org.apache.beam.sdk.transforms.Combine;
-import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
-import org.apache.beam.sdk.transforms.windowing.PaneInfo;
-import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
-import org.apache.beam.sdk.transforms.windowing.WindowFn;
-import org.apache.beam.sdk.util.CoderUtils;
-import org.apache.beam.sdk.util.WindowedValue;
-import org.apache.beam.sdk.values.KV;
-import org.apache.beam.sdk.values.WindowingStrategy;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
-import org.apache.spark.sql.Encoder;
-import org.apache.spark.sql.expressions.Aggregator;
-import org.joda.time.Instant;
-import scala.Tuple2;
-
-/** An {@link Aggregator} for the Spark Batch Runner.
- * The accumulator is a {@code Iterable<WindowedValue<AccumT>> because an {@code InputT} can be in multiple windows. So, when accumulating {@code InputT} values, we create one accumulator per input window.
- * */
-class AggregatorCombiner<K, InputT, AccumT, OutputT, W extends BoundedWindow>
- extends Aggregator<
- WindowedValue<KV<K, InputT>>,
- Iterable<WindowedValue<AccumT>>,
- Iterable<WindowedValue<OutputT>>> {
-
- private final Combine.CombineFn<InputT, AccumT, OutputT> combineFn;
- private WindowingStrategy<InputT, W> windowingStrategy;
- private TimestampCombiner timestampCombiner;
- private Coder<AccumT> accumulatorCoder;
- private IterableCoder<WindowedValue<AccumT>> bufferEncoder;
- private IterableCoder<WindowedValue<OutputT>> outputCoder;
-
- public AggregatorCombiner(
- Combine.CombineFn<InputT, AccumT, OutputT> combineFn,
- WindowingStrategy<?, ?> windowingStrategy,
- Coder<AccumT> accumulatorCoder,
- Coder<OutputT> outputCoder) {
- this.combineFn = combineFn;
- this.windowingStrategy = (WindowingStrategy<InputT, W>) windowingStrategy;
- this.timestampCombiner = windowingStrategy.getTimestampCombiner();
- this.accumulatorCoder = accumulatorCoder;
- this.bufferEncoder =
- IterableCoder.of(
- WindowedValue.FullWindowedValueCoder.of(
- accumulatorCoder, windowingStrategy.getWindowFn().windowCoder()));
- this.outputCoder =
- IterableCoder.of(
- WindowedValue.FullWindowedValueCoder.of(
- outputCoder, windowingStrategy.getWindowFn().windowCoder()));
- }
-
- @Override
- public Iterable<WindowedValue<AccumT>> zero() {
- return new ArrayList<>();
- }
-
- private Iterable<WindowedValue<AccumT>> createAccumulator(WindowedValue<KV<K, InputT>> inputWv) {
- // need to create an accumulator because combineFn can modify its input accumulator.
- AccumT accumulator = combineFn.createAccumulator();
- AccumT accumT = combineFn.addInput(accumulator, inputWv.getValue().getValue());
- return Lists.newArrayList(
- WindowedValue.of(accumT, inputWv.getTimestamp(), inputWv.getWindows(), inputWv.getPane()));
- }
-
- @Override
- public Iterable<WindowedValue<AccumT>> reduce(
- Iterable<WindowedValue<AccumT>> accumulators, WindowedValue<KV<K, InputT>> inputWv) {
- return merge(accumulators, createAccumulator(inputWv));
- }
-
- @Override
- public Iterable<WindowedValue<AccumT>> merge(
- Iterable<WindowedValue<AccumT>> accumulators1,
- Iterable<WindowedValue<AccumT>> accumulators2) {
-
- // merge the windows of all the accumulators
- Iterable<WindowedValue<AccumT>> accumulators = Iterables.concat(accumulators1, accumulators2);
- Set<W> accumulatorsWindows = collectAccumulatorsWindows(accumulators);
- Map<W, W> windowToMergeResult;
- try {
- windowToMergeResult = mergeWindows(windowingStrategy, accumulatorsWindows);
- } catch (Exception e) {
- throw new RuntimeException("Unable to merge accumulators windows", e);
- }
-
- // group accumulators by their merged window
- Map<W, List<Tuple2<AccumT, Instant>>> mergedWindowToAccumulators = new HashMap<>();
- for (WindowedValue<AccumT> accumulatorWv : accumulators) {
- // Encode a version of the accumulator if it is in multiple windows. The combineFn is able to
- // mutate the accumulator instance and this could lead to incorrect results if the same
- // instance is merged across multiple windows so we decode a new instance as needed. This
- // prevents issues during merging of accumulators.
- byte[] encodedAccumT = null;
- if (accumulatorWv.getWindows().size() > 1) {
- try {
- encodedAccumT = CoderUtils.encodeToByteArray(accumulatorCoder, accumulatorWv.getValue());
- } catch (CoderException e) {
- throw new RuntimeException(
- String.format(
- "Unable to encode accumulator %s with coder %s.",
- accumulatorWv.getValue(), accumulatorCoder),
- e);
- }
- }
- for (BoundedWindow accumulatorWindow : accumulatorWv.getWindows()) {
- W mergedWindowForAccumulator = windowToMergeResult.get(accumulatorWindow);
- mergedWindowForAccumulator =
- (mergedWindowForAccumulator == null)
- ? (W) accumulatorWindow
- : mergedWindowForAccumulator;
-
- // Decode a copy of the accumulator when necessary.
- AccumT accumT;
- if (encodedAccumT != null) {
- try {
- accumT = CoderUtils.decodeFromByteArray(accumulatorCoder, encodedAccumT);
- } catch (CoderException e) {
- throw new RuntimeException(
- String.format(
- "Unable to encode accumulator %s with coder %s.",
- accumulatorWv.getValue(), accumulatorCoder),
- e);
- }
- } else {
- accumT = accumulatorWv.getValue();
- }
-
- // we need only the timestamp and the AccumT, we create a tuple
- Tuple2<AccumT, Instant> accumAndInstant =
- new Tuple2<>(
- accumT,
- timestampCombiner.assign(mergedWindowForAccumulator, accumulatorWv.getTimestamp()));
- if (mergedWindowToAccumulators.get(mergedWindowForAccumulator) == null) {
- mergedWindowToAccumulators.put(
- mergedWindowForAccumulator, Lists.newArrayList(accumAndInstant));
- } else {
- mergedWindowToAccumulators.get(mergedWindowForAccumulator).add(accumAndInstant);
- }
- }
- }
- // merge the accumulators for each mergedWindow
- List<WindowedValue<AccumT>> result = new ArrayList<>();
- for (Map.Entry<W, List<Tuple2<AccumT, Instant>>> entry :
- mergedWindowToAccumulators.entrySet()) {
- W mergedWindow = entry.getKey();
- List<Tuple2<AccumT, Instant>> accumsAndInstantsForMergedWindow = entry.getValue();
-
- // we need to create the first accumulator because combineFn.mergerAccumulators can modify the
- // first accumulator
- AccumT first = combineFn.createAccumulator();
- Iterable<AccumT> accumulatorsToMerge =
- Iterables.concat(
- Collections.singleton(first),
- accumsAndInstantsForMergedWindow.stream()
- .map(x -> x._1())
- .collect(Collectors.toList()));
- result.add(
- WindowedValue.of(
- combineFn.mergeAccumulators(accumulatorsToMerge),
- timestampCombiner.combine(
- accumsAndInstantsForMergedWindow.stream()
- .map(x -> x._2())
- .collect(Collectors.toList())),
- mergedWindow,
- PaneInfo.NO_FIRING));
- }
- return result;
- }
-
- @Override
- public Iterable<WindowedValue<OutputT>> finish(Iterable<WindowedValue<AccumT>> reduction) {
- List<WindowedValue<OutputT>> result = new ArrayList<>();
- for (WindowedValue<AccumT> windowedValue : reduction) {
- result.add(windowedValue.withValue(combineFn.extractOutput(windowedValue.getValue())));
- }
- return result;
- }
-
- @Override
- public Encoder<Iterable<WindowedValue<AccumT>>> bufferEncoder() {
- return EncoderHelpers.fromBeamCoder(bufferEncoder);
- }
-
- @Override
- public Encoder<Iterable<WindowedValue<OutputT>>> outputEncoder() {
- return EncoderHelpers.fromBeamCoder(outputCoder);
- }
-
- private Set<W> collectAccumulatorsWindows(Iterable<WindowedValue<AccumT>> accumulators) {
- Set<W> windows = new HashSet<>();
- for (WindowedValue<?> accumulator : accumulators) {
- for (BoundedWindow untypedWindow : accumulator.getWindows()) {
- @SuppressWarnings("unchecked")
- W window = (W) untypedWindow;
- windows.add(window);
- }
- }
- return windows;
- }
-
- private Map<W, W> mergeWindows(WindowingStrategy<InputT, W> windowingStrategy, Set<W> windows)
- throws Exception {
- WindowFn<InputT, W> windowFn = windowingStrategy.getWindowFn();
-
- if (!windowingStrategy.needsMerge()) {
- // Return an empty map, indicating that every window is not merged.
- return Collections.emptyMap();
- }
-
- Map<W, W> windowToMergeResult = new HashMap<>();
- windowFn.mergeWindows(new MergeContextImpl(windowFn, windows, windowToMergeResult));
- return windowToMergeResult;
- }
-
- private class MergeContextImpl extends WindowFn<InputT, W>.MergeContext {
-
- private Set<W> windows;
- private Map<W, W> windowToMergeResult;
-
- MergeContextImpl(WindowFn<InputT, W> windowFn, Set<W> windows, Map<W, W> windowToMergeResult) {
- windowFn.super();
- this.windows = windows;
- this.windowToMergeResult = windowToMergeResult;
- }
-
- @Override
- public Collection<W> windows() {
- return windows;
- }
-
- @Override
- public void merge(Collection<W> toBeMerged, W mergeResult) throws Exception {
- for (W w : toBeMerged) {
- windowToMergeResult.put(w, mergeResult);
- }
- }
- }
-}
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/Aggregators.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/Aggregators.java
new file mode 100644
index 00000000000..45026f9d8bd
--- /dev/null
+++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/Aggregators.java
@@ -0,0 +1,591 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
+
+import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.collectionEncoder;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.encoderOf;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.mapEncoder;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.mutablePairEncoder;
+import static org.apache.beam.sdk.transforms.windowing.PaneInfo.NO_FIRING;
+import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;
+import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterators.peekingIterator;
+
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.Set;
+import java.util.TreeMap;
+import java.util.function.BiFunction;
+import java.util.function.BinaryOperator;
+import java.util.function.Function;
+import javax.annotation.Nullable;
+import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.Fun1;
+import org.apache.beam.sdk.transforms.Combine.CombineFn;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
+import org.apache.beam.sdk.transforms.windowing.Sessions;
+import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
+import org.apache.beam.sdk.transforms.windowing.WindowFn;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.values.WindowingStrategy;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Collections2;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSet;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.PeekingIterator;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Sets;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.expressions.Aggregator;
+import org.apache.spark.util.MutablePair;
+import org.checkerframework.checker.nullness.qual.MonotonicNonNull;
+import org.checkerframework.checker.nullness.qual.NonNull;
+import org.checkerframework.checker.nullness.qual.PolyNull;
+import org.joda.time.Instant;
+
+public class Aggregators {
+
+ /**
+ * Creates simple value {@link Aggregator} that is not window aware.
+ *
+ * @param <ValT> {@link CombineFn} input type
+ * @param <AccT> {@link CombineFn} accumulator type
+ * @param <ResT> {@link CombineFn} / {@link Aggregator} result type
+ * @param <InT> {@link Aggregator} input type
+ */
+ public static <ValT, AccT, ResT, InT> Aggregator<InT, ?, ResT> value(
+ CombineFn<ValT, AccT, ResT> fn,
+ Fun1<InT, ValT> valueFn,
+ Encoder<AccT> accEnc,
+ Encoder<ResT> outEnc) {
+ return new ValueAggregator<>(fn, valueFn, accEnc, outEnc);
+ }
+
+ /**
+ * Creates windowed Spark {@link Aggregator} depending on the provided Beam {@link WindowFn}s.
+ *
+ * <p>Specialised implementations are provided for:
+ * <li>{@link Sessions}
+ * <li>Non merging window functions
+ * <li>Merging window functions
+ *
+ * @param <ValT> {@link CombineFn} input type
+ * @param <AccT> {@link CombineFn} accumulator type
+ * @param <ResT> {@link CombineFn} / {@link Aggregator} result type
+ * @param <InT> {@link Aggregator} input type
+ */
+ public static <ValT, AccT, ResT, InT>
+ Aggregator<WindowedValue<InT>, ?, Collection<WindowedValue<ResT>>> windowedValue(
+ CombineFn<ValT, AccT, ResT> fn,
+ Fun1<WindowedValue<InT>, ValT> valueFn,
+ WindowingStrategy<?, ?> windowing,
+ Encoder<BoundedWindow> windowEnc,
+ Encoder<AccT> accEnc,
+ Encoder<WindowedValue<ResT>> outEnc) {
+ if (!windowing.needsMerge()) {
+ return new NonMergingWindowedAggregator<>(fn, valueFn, windowing, windowEnc, accEnc, outEnc);
+ } else if (windowing.getWindowFn().getClass().equals(Sessions.class)) {
+ return new SessionsAggregator<>(fn, valueFn, windowing, (Encoder) windowEnc, accEnc, outEnc);
+ }
+ return new MergingWindowedAggregator<>(fn, valueFn, windowing, windowEnc, accEnc, outEnc);
+ }
+
+ /**
+ * Simple value {@link Aggregator} that is not window aware.
+ *
+ * @param <ValT> {@link CombineFn} input type
+ * @param <AccT> {@link CombineFn} accumulator type
+ * @param <ResT> {@link CombineFn} / {@link Aggregator} result type
+ * @param <InT> {@link Aggregator} input type
+ */
+ private static class ValueAggregator<ValT, AccT, ResT, InT>
+ extends CombineFnAggregator<ValT, AccT, ResT, InT, AccT, ResT> {
+
+ public ValueAggregator(
+ CombineFn<ValT, AccT, ResT> fn,
+ Fun1<InT, ValT> valueFn,
+ Encoder<AccT> accEnc,
+ Encoder<ResT> outEnc) {
+ super(fn, valueFn, accEnc, outEnc);
+ }
+
+ @Override
+ public AccT zero() {
+ return emptyAcc();
+ }
+
+ @Override
+ public AccT reduce(AccT buff, InT in) {
+ return addToAcc(buff, value(in));
+ }
+
+ @Override
+ public AccT merge(AccT b1, AccT b2) {
+ return mergeAccs(b1, b2);
+ }
+
+ @Override
+ public ResT finish(AccT buff) {
+ return extract(buff);
+ }
+ }
+
+ /**
+ * Specialized windowed Spark {@link Aggregator} for Beam {@link WindowFn}s of type {@link
+ * Sessions}. The aggregator uses a {@link TreeMap} as buffer to maintain ordering of the {@link
+ * IntervalWindow}s and merge these more efficiently.
+ *
+ * <p>For efficiency, this aggregator re-implements {@link
+ * Sessions#mergeWindows(WindowFn.MergeContext)} to leverage the already sorted buffer.
+ *
+ * @param <ValT> {@link CombineFn} input type
+ * @param <AccT> {@link CombineFn} accumulator type
+ * @param <ResT> {@link CombineFn} / {@link Aggregator} result type
+ * @param <InT> {@link Aggregator} input type
+ */
+ private static class SessionsAggregator<ValT, AccT, ResT, InT>
+ extends WindowedAggregator<
+ ValT,
+ AccT,
+ ResT,
+ InT,
+ IntervalWindow,
+ TreeMap<IntervalWindow, MutablePair<Instant, AccT>>> {
+
+ SessionsAggregator(
+ CombineFn<ValT, AccT, ResT> combineFn,
+ Fun1<WindowedValue<InT>, ValT> valueFn,
+ WindowingStrategy<?, ?> windowing,
+ Encoder<IntervalWindow> windowEnc,
+ Encoder<AccT> accEnc,
+ Encoder<WindowedValue<ResT>> outEnc) {
+ super(combineFn, valueFn, windowing, windowEnc, accEnc, outEnc, (Class) TreeMap.class);
+ checkArgument(windowing.getWindowFn().getClass().equals(Sessions.class));
+ }
+
+ @Override
+ public final TreeMap<IntervalWindow, MutablePair<Instant, AccT>> zero() {
+ return new TreeMap<>();
+ }
+
+ @Override
+ @SuppressWarnings("keyfor")
+ public TreeMap<IntervalWindow, MutablePair<Instant, AccT>> reduce(
+ TreeMap<IntervalWindow, MutablePair<Instant, AccT>> buff, WindowedValue<InT> input) {
+ for (IntervalWindow window : (Collection<IntervalWindow>) input.getWindows()) {
+ @MonotonicNonNull MutablePair<Instant, AccT> acc = null;
+ @MonotonicNonNull IntervalWindow first = null, last = null;
+ // start with window before or equal to new window (if exists)
+ @Nullable Entry<IntervalWindow, MutablePair<Instant, AccT>> lower = buff.floorEntry(window);
+ if (lower != null && window.intersects(lower.getKey())) {
+ // if intersecting, init accumulator and extend window to span both
+ acc = lower.getValue();
+ window = window.span(lower.getKey());
+ first = last = lower.getKey();
+ }
+ // merge following windows in order if they intersect, then stop
+ for (Entry<IntervalWindow, MutablePair<Instant, AccT>> entry :
+ buff.tailMap(window, false).entrySet()) {
+ MutablePair<Instant, AccT> entryAcc = entry.getValue();
+ IntervalWindow entryWindow = entry.getKey();
+ if (window.intersects(entryWindow)) {
+ // extend window and merge accumulators
+ window = window.span(entryWindow);
+ acc = acc == null ? entryAcc : mergeAccs(window, acc, entryAcc);
+ if (first == null) {
+ // there was no previous (lower) window intersecting the input window
+ first = last = entryWindow;
+ } else {
+ last = entryWindow;
+ }
+ } else {
+ break; // stop, later windows won't intersect either
+ }
+ }
+ if (first != null && last != null) {
+ // remove entire subset from from first to last after it got merged into acc
+ buff.navigableKeySet().subSet(first, true, last, true).clear();
+ }
+ // add input and get accumulator for new (potentially merged) window
+ buff.put(window, addToAcc(window, acc, value(input), input.getTimestamp()));
+ }
+ return buff;
+ }
+
+ @Override
+ public TreeMap<IntervalWindow, MutablePair<Instant, AccT>> merge(
+ TreeMap<IntervalWindow, MutablePair<Instant, AccT>> b1,
+ TreeMap<IntervalWindow, MutablePair<Instant, AccT>> b2) {
+ if (b1.isEmpty()) {
+ return b2;
+ } else if (b2.isEmpty()) {
+ return b1;
+ }
+ // Init new tree map to merge both buffers
+ TreeMap<IntervalWindow, MutablePair<Instant, AccT>> res = zero();
+ PeekingIterator<Entry<IntervalWindow, MutablePair<Instant, AccT>>> it1 =
+ peekingIterator(b1.entrySet().iterator());
+ PeekingIterator<Entry<IntervalWindow, MutablePair<Instant, AccT>>> it2 =
+ peekingIterator(b2.entrySet().iterator());
+
+ @Nullable MutablePair<Instant, AccT> acc = null;
+ @Nullable IntervalWindow window = null;
+ while (it1.hasNext() || it2.hasNext()) {
+ // pick iterator with the smallest window ahead and forward it
+ Entry<IntervalWindow, MutablePair<Instant, AccT>> nextMin =
+ (it1.hasNext() && it2.hasNext())
+ ? it1.peek().getKey().compareTo(it2.peek().getKey()) <= 0 ? it1.next() : it2.next()
+ : it1.hasNext() ? it1.next() : it2.next();
+ if (window != null && window.intersects(nextMin.getKey())) {
+ // extend window and merge accumulators if intersecting
+ window = window.span(nextMin.getKey());
+ acc = mergeAccs(window, acc, nextMin.getValue());
+ } else {
+ // store window / accumulator if necessary and continue with next minimum
+ if (window != null && acc != null) {
+ res.put(window, acc);
+ }
+ acc = nextMin.getValue();
+ window = nextMin.getKey();
+ }
+ }
+ if (window != null && acc != null) {
+ res.put(window, acc);
+ }
+ return res;
+ }
+ }
+
+ /**
+ * Merging windowed Spark {@link Aggregator} using a Map of {@link BoundedWindow}s as aggregation
+ * buffer. When reducing new input, a windowed accumulator is created for each new window of the
+ * input that doesn't overlap with existing windows. Otherwise, if the window is known or
+ * overlaps, the window is extended accordingly and accumulators are merged.
+ *
+ * @param <ValT> {@link CombineFn} input type
+ * @param <AccT> {@link CombineFn} accumulator type
+ * @param <ResT> {@link CombineFn} / {@link Aggregator} result type
+ * @param <InT> {@link Aggregator} input type
+ */
+ private static class MergingWindowedAggregator<ValT, AccT, ResT, InT>
+ extends NonMergingWindowedAggregator<ValT, AccT, ResT, InT> {
+
+ private final WindowFn<ValT, BoundedWindow> windowFn;
+
+ public MergingWindowedAggregator(
+ CombineFn<ValT, AccT, ResT> combineFn,
+ Fun1<WindowedValue<InT>, ValT> valueFn,
+ WindowingStrategy<?, ?> windowing,
+ Encoder<BoundedWindow> windowEnc,
+ Encoder<AccT> accEnc,
+ Encoder<WindowedValue<ResT>> outEnc) {
+ super(combineFn, valueFn, windowing, windowEnc, accEnc, outEnc);
+ windowFn = (WindowFn<ValT, BoundedWindow>) windowing.getWindowFn();
+ }
+
+ @Override
+ protected Map<BoundedWindow, MutablePair<Instant, AccT>> reduce(
+ Map<BoundedWindow, MutablePair<Instant, AccT>> buff,
+ Collection<BoundedWindow> windows,
+ ValT value,
+ Instant timestamp) {
+ if (buff.isEmpty()) {
+ // no windows yet to be merged, use the non-merging behavior of super
+ return super.reduce(buff, windows, value, timestamp);
+ }
+ // Merge multiple windows into one target window using the reducer function if the window
+ // already exists. Otherwise, the input value is added to the accumulator. Merged windows are
+ // removed from the accumulator map.
+ Function<BoundedWindow, ReduceFn<AccT>> accFn =
+ target ->
+ (acc, w) -> {
+ MutablePair<Instant, AccT> accW = buff.remove(w);
+ return (accW != null)
+ ? mergeAccs(w, acc, accW)
+ : addToAcc(w, acc, value, timestamp);
+ };
+ Set<BoundedWindow> unmerged = mergeWindows(buff, ImmutableSet.copyOf(windows), accFn);
+ if (!unmerged.isEmpty()) {
+ // remaining windows don't have to be merged
+ return super.reduce(buff, unmerged, value, timestamp);
+ }
+ return buff;
+ }
+
+ @Override
+ public Map<BoundedWindow, MutablePair<Instant, AccT>> merge(
+ Map<BoundedWindow, MutablePair<Instant, AccT>> b1,
+ Map<BoundedWindow, MutablePair<Instant, AccT>> b2) {
+ // Merge multiple windows into one target window using the reducer function. Merged windows
+ // are removed from both accumulator maps
+ Function<BoundedWindow, ReduceFn<AccT>> reduceFn =
+ target -> (acc, w) -> mergeAccs(w, mergeAccs(w, acc, b1.remove(w)), b2.remove(w));
+
+ Set<BoundedWindow> unmerged = b2.keySet();
+ unmerged = mergeWindows(b1, unmerged, reduceFn);
+ if (!unmerged.isEmpty()) {
+ // keep only unmerged windows in 2nd accumulator map, continue using "non-merging" merge
+ b2.keySet().retainAll(unmerged);
+ return super.merge(b1, b2);
+ }
+ return b1;
+ }
+
+ /** Reduce function to merge multiple windowed accumulator values into one target window. */
+ private interface ReduceFn<AccT>
+ extends BiFunction<MutablePair<Instant, AccT>, BoundedWindow, MutablePair<Instant, AccT>> {}
+
+ /**
+ * Attempt to merge windows of accumulator map with additional windows using the reducer
+ * function. The reducer function must support {@code null} as zero value.
+ *
+ * @return The subset of additional windows that don't require a merge.
+ */
+ private Set<BoundedWindow> mergeWindows(
+ Map<BoundedWindow, MutablePair<Instant, AccT>> buff,
+ Set<BoundedWindow> newWindows,
+ Function<BoundedWindow, ReduceFn<AccT>> reduceFn) {
+ try {
+ Set<BoundedWindow> newUnmerged = new HashSet<>(newWindows);
+ windowFn.mergeWindows(
+ windowFn.new MergeContext() {
+ @Override
+ public Collection<BoundedWindow> windows() {
+ return Sets.union(buff.keySet(), newWindows);
+ }
+
+ @Override
+ public void merge(Collection<BoundedWindow> merges, BoundedWindow target) {
+ buff.put(
+ target, merges.stream().reduce(null, reduceFn.apply(target), combiner(target)));
+ newUnmerged.removeAll(merges);
+ }
+ });
+ return newUnmerged;
+ } catch (Exception e) {
+ throw new RuntimeException("Unable to merge accumulators windows", e);
+ }
+ }
+ }
+
+ /**
+ * Non-merging windowed Spark {@link Aggregator} using a Map of {@link BoundedWindow}s as
+ * aggregation buffer. When reducing new input, a windowed accumulator is created for each new
+ * window of the input. Otherwise, if the window is known, the accumulators are merged.
+ *
+ * @param <ValT> {@link CombineFn} input type
+ * @param <AccT> {@link CombineFn} accumulator type
+ * @param <ResT> {@link CombineFn} / {@link Aggregator} result type
+ * @param <InT> {@link Aggregator} input type
+ */
+ private static class NonMergingWindowedAggregator<ValT, AccT, ResT, InT>
+ extends WindowedAggregator<
+ ValT, AccT, ResT, InT, BoundedWindow, Map<BoundedWindow, MutablePair<Instant, AccT>>> {
+
+ public NonMergingWindowedAggregator(
+ CombineFn<ValT, AccT, ResT> combineFn,
+ Fun1<WindowedValue<InT>, ValT> valueFn,
+ WindowingStrategy<?, ?> windowing,
+ Encoder<BoundedWindow> windowEnc,
+ Encoder<AccT> accEnc,
+ Encoder<WindowedValue<ResT>> outEnc) {
+ super(combineFn, valueFn, windowing, windowEnc, accEnc, outEnc, (Class) Map.class);
+ }
+
+ @Override
+ public Map<BoundedWindow, MutablePair<Instant, AccT>> zero() {
+ return new HashMap<>();
+ }
+
+ @Override
+ public final Map<BoundedWindow, MutablePair<Instant, AccT>> reduce(
+ Map<BoundedWindow, MutablePair<Instant, AccT>> buff, WindowedValue<InT> input) {
+ Collection<BoundedWindow> windows = (Collection<BoundedWindow>) input.getWindows();
+ return reduce(buff, windows, value(input), input.getTimestamp());
+ }
+
+ protected Map<BoundedWindow, MutablePair<Instant, AccT>> reduce(
+ Map<BoundedWindow, MutablePair<Instant, AccT>> buff,
+ Collection<BoundedWindow> windows,
+ ValT value,
+ Instant timestamp) {
+ // for each window add the value to the accumulator
+ for (BoundedWindow window : windows) {
+ buff.compute(window, (w, acc) -> addToAcc(w, acc, value, timestamp));
+ }
+ return buff;
+ }
+
+ @Override
+ public Map<BoundedWindow, MutablePair<Instant, AccT>> merge(
+ Map<BoundedWindow, MutablePair<Instant, AccT>> b1,
+ Map<BoundedWindow, MutablePair<Instant, AccT>> b2) {
+ if (b1.isEmpty()) {
+ return b2;
+ } else if (b2.isEmpty()) {
+ return b1;
+ }
+ if (b2.size() > b1.size()) {
+ return merge(b2, b1);
+ }
+ // merge entries of (smaller) 2nd agg buffer map into first by merging the accumulators
+ b2.forEach((w, acc) -> b1.merge(w, acc, combiner(w)));
+ return b1;
+ }
+ }
+
+ /**
+ * Abstract base of a Spark {@link Aggregator} on {@link WindowedValue}s using a Map of {@link W}
+ * as aggregation buffer.
+ *
+ * @param <ValT> {@link CombineFn} input type
+ * @param <AccT> {@link CombineFn} accumulator type
+ * @param <ResT> {@link CombineFn} / {@link Aggregator} result type
+ * @param <InT> {@link Aggregator} input type
+ * @param <W> bounded window type
+ * @param <MapT> aggregation buffer {@link W}
+ */
+ private abstract static class WindowedAggregator<
+ ValT,
+ AccT,
+ ResT,
+ InT,
+ W extends @NonNull BoundedWindow,
+ MapT extends Map<W, @NonNull MutablePair<Instant, AccT>>>
+ extends CombineFnAggregator<
+ ValT, AccT, ResT, WindowedValue<InT>, MapT, Collection<WindowedValue<ResT>>> {
+ private final TimestampCombiner tsCombiner;
+
+ public WindowedAggregator(
+ CombineFn<ValT, AccT, ResT> combineFn,
+ Fun1<WindowedValue<InT>, ValT> valueFn,
+ WindowingStrategy<?, ?> windowing,
+ Encoder<W> windowEnc,
+ Encoder<AccT> accEnc,
+ Encoder<WindowedValue<ResT>> outEnc,
+ Class<MapT> clazz) {
+ super(
+ combineFn,
+ valueFn,
+ mapEncoder(windowEnc, mutablePairEncoder(encoderOf(Instant.class), accEnc), clazz),
+ collectionEncoder(outEnc));
+ tsCombiner = windowing.getTimestampCombiner();
+ }
+
+ protected final Instant resolveTimestamp(BoundedWindow w, Instant t1, Instant t2) {
+ return tsCombiner.merge(w, t1, t2);
+ }
+
+ /** Init accumulator with initial input value and timestamp. */
+ protected final MutablePair<Instant, AccT> initAcc(ValT value, Instant timestamp) {
+ return new MutablePair<>(timestamp, addToAcc(emptyAcc(), value));
+ }
+
+ /** Merge timestamped accumulators. */
+ protected final <T extends MutablePair<Instant, AccT>> @PolyNull T mergeAccs(
+ W window, @PolyNull T a1, @PolyNull T a2) {
+ if (a1 == null || a2 == null) {
+ return a1 == null ? a2 : a1;
+ }
+ return (T) a1.update(resolveTimestamp(window, a1._1, a2._1), mergeAccs(a1._2, a2._2));
+ }
+
+ @SuppressWarnings("nullness") // may return null
+ protected BinaryOperator<MutablePair<Instant, AccT>> combiner(W target) {
+ return (a1, a2) -> mergeAccs(target, a1, a2);
+ }
+
+ /** Add an input value to a nullable accumulator. */
+ protected final MutablePair<Instant, AccT> addToAcc(
+ W window, @Nullable MutablePair<Instant, AccT> acc, ValT val, Instant ts) {
+ if (acc == null) {
+ return initAcc(val, ts);
+ }
+ return acc.update(resolveTimestamp(window, acc._1, ts), addToAcc(acc._2, val));
+ }
+
+ @Override
+ @SuppressWarnings("nullness") // entries are non null
+ public final Collection<WindowedValue<ResT>> finish(MapT buffer) {
+ return Collections2.transform(buffer.entrySet(), this::windowedValue);
+ }
+
+ private WindowedValue<ResT> windowedValue(Entry<W, MutablePair<Instant, AccT>> e) {
+ return WindowedValue.of(extract(e.getValue()._2), e.getValue()._1, e.getKey(), NO_FIRING);
+ }
+ }
+
+ /**
+ * Abstract base of Spark {@link Aggregator}s using a Beam {@link CombineFn}.
+ *
+ * @param <ValT> {@link CombineFn} input type
+ * @param <AccT> {@link CombineFn} accumulator type
+ * @param <ResT> {@link CombineFn} result type
+ * @param <InT> {@link Aggregator} input type
+ * @param <BuffT> {@link Aggregator} buffer type
+ * @param <OutT> {@link Aggregator} output type
+ */
+ private abstract static class CombineFnAggregator<ValT, AccT, ResT, InT, BuffT, OutT>
+ extends Aggregator<InT, BuffT, OutT> {
+ private final CombineFn<ValT, AccT, ResT> fn;
+ private final Fun1<InT, ValT> valueFn;
+ private final Encoder<BuffT> bufferEnc;
+ private final Encoder<OutT> outputEnc;
+
+ public CombineFnAggregator(
+ CombineFn<ValT, AccT, ResT> fn,
+ Fun1<InT, ValT> valueFn,
+ Encoder<BuffT> bufferEnc,
+ Encoder<OutT> outputEnc) {
+ this.fn = fn;
+ this.valueFn = valueFn;
+ this.bufferEnc = bufferEnc;
+ this.outputEnc = outputEnc;
+ }
+
+ protected final ValT value(InT in) {
+ return valueFn.apply(in);
+ }
+
+ protected final AccT emptyAcc() {
+ return fn.createAccumulator();
+ }
+
+ protected final AccT mergeAccs(AccT a1, AccT a2) {
+ return fn.mergeAccumulators(ImmutableList.of(a1, a2));
+ }
+
+ protected final AccT addToAcc(AccT acc, ValT val) {
+ return fn.addInput(acc, val);
+ }
+
+ protected final ResT extract(AccT acc) {
+ return fn.extractOutput(acc);
+ }
+
+ @Override
+ public Encoder<BuffT> bufferEncoder() {
+ return bufferEnc;
+ }
+
+ @Override
+ public Encoder<OutT> outputEncoder() {
+ return outputEnc;
+ }
+ }
+}
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineGloballyTranslatorBatch.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineGloballyTranslatorBatch.java
new file mode 100644
index 00000000000..5bc017134e9
--- /dev/null
+++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineGloballyTranslatorBatch.java
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
+
+import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.value;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.fun1;
+import static scala.collection.Iterator.single;
+
+import java.util.Collection;
+import java.util.Map;
+import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator;
+import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop;
+import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.Fun1;
+import org.apache.beam.sdk.coders.CannotProvideCoderException;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.transforms.Combine;
+import org.apache.beam.sdk.transforms.Combine.CombineFn;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.WindowingStrategy;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.expressions.Aggregator;
+import scala.collection.Iterator;
+
+/**
+ * Translator for {@link Combine.Globally} using a Spark {@link Aggregator}.
+ *
+ * <p>To minimize the amount of data shuffled, this first reduces the data per partition using
+ * {@link Aggregator#reduce}, gathers the partial results (using {@code coalesce(1)}) and finally
+ * merges these using {@link Aggregator#merge}.
+ *
+ * <p>TODOs:
+ * <li>any missing features?
+ */
+class CombineGloballyTranslatorBatch<InT, AccT, OutT>
+ extends TransformTranslator<PCollection<InT>, PCollection<OutT>, Combine.Globally<InT, OutT>> {
+
+ @Override
+ public void translate(Combine.Globally<InT, OutT> transform, Context cxt) {
+ WindowingStrategy<?, ?> windowing = cxt.getInput().getWindowingStrategy();
+ CombineFn<InT, AccT, OutT> combineFn = (CombineFn<InT, AccT, OutT>) transform.getFn();
+
+ Coder<InT> inputCoder = cxt.getInput().getCoder();
+ Coder<OutT> outputCoder = cxt.getOutput().getCoder();
+ Coder<AccT> accumCoder = accumulatorCoder(combineFn, inputCoder, cxt);
+
+ Encoder<OutT> outEnc = cxt.encoderOf(outputCoder);
+ Encoder<AccT> accEnc = cxt.encoderOf(accumCoder);
+ Encoder<WindowedValue<OutT>> wvOutEnc = cxt.windowedEncoder(outEnc);
+
+ Dataset<WindowedValue<InT>> dataset = cxt.getDataset(cxt.getInput());
+
+ final Dataset<WindowedValue<OutT>> result;
+ if (GroupByKeyHelpers.eligibleForGlobalGroupBy(windowing, true)) {
+ Aggregator<InT, ?, OutT> agg = Aggregators.value(combineFn, v -> v, accEnc, outEnc);
+
+ // Drop window and restore afterwards, produces single global aggregation result
+ result = aggregate(dataset, agg, value(), windowedValue(), wvOutEnc);
+ } else {
+ Aggregator<WindowedValue<InT>, ?, Collection<WindowedValue<OutT>>> agg =
+ Aggregators.windowedValue(
+ combineFn, value(), windowing, cxt.windowEncoder(), accEnc, wvOutEnc);
+
+ // Produces aggregation result per window
+ result =
+ aggregate(dataset, agg, v -> v, fun1(out -> ScalaInterop.scalaIterator(out)), wvOutEnc);
+ }
+ cxt.putDataset(cxt.getOutput(), result);
+ }
+
+ /**
+ * Aggregate dataset globally without using key.
+ *
+ * <p>There is no global, typed version of {@link Dataset#agg(Map)} on datasets. This reduces all
+ * partitions first, and then merges them to receive the final result.
+ */
+ private static <InT, OutT, AggInT, BuffT, AggOutT> Dataset<WindowedValue<OutT>> aggregate(
+ Dataset<WindowedValue<InT>> ds,
+ Aggregator<AggInT, BuffT, AggOutT> agg,
+ Fun1<WindowedValue<InT>, AggInT> valueFn,
+ Fun1<AggOutT, Iterator<WindowedValue<OutT>>> finishFn,
+ Encoder<WindowedValue<OutT>> enc) {
+ // reduce partition using aggregator
+ Fun1<Iterator<WindowedValue<InT>>, Iterator<BuffT>> reduce =
+ fun1(it -> single(it.map(valueFn).foldLeft(agg.zero(), agg::reduce)));
+ // merge reduced partitions using aggregator
+ Fun1<Iterator<BuffT>, Iterator<WindowedValue<OutT>>> merge =
+ fun1(it -> finishFn.apply(agg.finish(it.hasNext() ? it.reduce(agg::merge) : agg.zero())));
+
+ return ds.mapPartitions(reduce, agg.bufferEncoder()).coalesce(1).mapPartitions(merge, enc);
+ }
+
+ private Coder<AccT> accumulatorCoder(
+ CombineFn<InT, AccT, OutT> fn, Coder<InT> valueCoder, Context cxt) {
+ try {
+ return fn.getAccumulatorCoder(cxt.getInput().getPipeline().getCoderRegistry(), valueCoder);
+ } catch (CannotProvideCoderException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ private static <T> Fun1<T, Iterator<WindowedValue<T>>> windowedValue() {
+ return v -> single(WindowedValue.valueInGlobalWindow(v));
+ }
+}
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombinePerKeyTranslatorBatch.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombinePerKeyTranslatorBatch.java
index 2b0cf8be995..f990cd114f9 100644
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombinePerKeyTranslatorBatch.java
+++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombinePerKeyTranslatorBatch.java
@@ -17,98 +17,141 @@
*/
package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
-import java.util.ArrayList;
-import java.util.List;
-import org.apache.beam.runners.spark.structuredstreaming.translation.AbstractTranslationContext;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.eligibleForGlobalGroupBy;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.eligibleForGroupByWindow;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.explodeWindowedKey;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.value;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.valueKey;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.valueValue;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.windowedKV;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.fun1;
+
+import java.util.Collection;
import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator;
-import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers;
-import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.KVHelpers;
+import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop;
+import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.Fun1;
import org.apache.beam.sdk.coders.CannotProvideCoderException;
import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.CoderRegistry;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.transforms.Combine;
-import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.Combine.CombineFn;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.WindowingStrategy;
-import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.sql.Dataset;
-import org.apache.spark.sql.KeyValueGroupedDataset;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.expressions.Aggregator;
import scala.Tuple2;
+import scala.collection.TraversableOnce;
-@SuppressWarnings({
- "rawtypes" // TODO(https://github.com/apache/beam/issues/20447)
-})
-class CombinePerKeyTranslatorBatch<K, InputT, AccumT, OutputT>
- implements TransformTranslator<
- PTransform<PCollection<KV<K, InputT>>, PCollection<KV<K, OutputT>>>> {
+/**
+ * Translator for {@link Combine.PerKey} using {@link Dataset#groupByKey} with a Spark {@link
+ * Aggregator}.
+ *
+ * <ul>
+ * <li>When using the default global window, window information is dropped and restored after the
+ * aggregation.
+ * <li>For non-merging windows, windows are exploded and moved into a composite key for better
+ * distribution. After the aggregation, windowed values are restored from the composite key.
+ * <li>All other cases use an aggregator on windowed values that is optimized for the current
+ * windowing strategy.
+ * </ul>
+ *
+ * TODOs:
+ * <li>combine with context (CombineFnWithContext)?
+ * <li>combine with sideInputs?
+ * <li>other there other missing features?
+ */
+class CombinePerKeyTranslatorBatch<K, InT, AccT, OutT>
+ extends TransformTranslator<
+ PCollection<KV<K, InT>>, PCollection<KV<K, OutT>>, Combine.PerKey<K, InT, OutT>> {
@Override
- public void translateTransform(
- PTransform<PCollection<KV<K, InputT>>, PCollection<KV<K, OutputT>>> transform,
- AbstractTranslationContext context) {
+ public void translate(Combine.PerKey<K, InT, OutT> transform, Context cxt) {
+ WindowingStrategy<?, ?> windowing = cxt.getInput().getWindowingStrategy();
+ CombineFn<InT, AccT, OutT> combineFn = (CombineFn<InT, AccT, OutT>) transform.getFn();
+
+ KvCoder<K, InT> inputCoder = (KvCoder<K, InT>) cxt.getInput().getCoder();
+ KvCoder<K, OutT> outputCoder = (KvCoder<K, OutT>) cxt.getOutput().getCoder();
+
+ Encoder<K> keyEnc = cxt.keyEncoderOf(inputCoder);
+ Encoder<KV<K, InT>> inputEnc = cxt.encoderOf(inputCoder);
+ Encoder<WindowedValue<KV<K, OutT>>> wvOutputEnc = cxt.windowedEncoder(outputCoder);
+ Encoder<AccT> accumEnc = accumEncoder(combineFn, inputCoder.getValueCoder(), cxt);
+
+ final Dataset<WindowedValue<KV<K, OutT>>> result;
+
+ boolean globalGroupBy = eligibleForGlobalGroupBy(windowing, true);
+ boolean groupByWindow = eligibleForGroupByWindow(windowing, true);
- Combine.PerKey combineTransform = (Combine.PerKey) transform;
- @SuppressWarnings("unchecked")
- final PCollection<KV<K, InputT>> input = (PCollection<KV<K, InputT>>) context.getInput();
- @SuppressWarnings("unchecked")
- final PCollection<KV<K, OutputT>> output = (PCollection<KV<K, OutputT>>) context.getOutput();
- @SuppressWarnings("unchecked")
- final Combine.CombineFn<InputT, AccumT, OutputT> combineFn =
- (Combine.CombineFn<InputT, AccumT, OutputT>) combineTransform.getFn();
- WindowingStrategy<?, ?> windowingStrategy = input.getWindowingStrategy();
+ if (globalGroupBy || groupByWindow) {
+ Aggregator<KV<K, InT>, ?, OutT> valueAgg =
+ Aggregators.value(combineFn, KV::getValue, accumEnc, cxt.valueEncoderOf(outputCoder));
- Dataset<WindowedValue<KV<K, InputT>>> inputDataset = context.getDataset(input);
+ if (globalGroupBy) {
+ // Drop window and group by key globally to run the aggregation (combineFn), afterwards the
+ // global window is restored
+ result =
+ cxt.getDataset(cxt.getInput())
+ .groupByKey(valueKey(), keyEnc)
+ .mapValues(value(), inputEnc)
+ .agg(valueAgg.toColumn())
+ .map(globalKV(), wvOutputEnc);
+ } else {
+ Encoder<Tuple2<BoundedWindow, K>> windowedKeyEnc =
+ cxt.tupleEncoder(cxt.windowEncoder(), keyEnc);
- KvCoder<K, InputT> inputCoder = (KvCoder<K, InputT>) input.getCoder();
- Coder<K> keyCoder = inputCoder.getKeyCoder();
- KvCoder<K, OutputT> outputKVCoder = (KvCoder<K, OutputT>) output.getCoder();
- Coder<OutputT> outputCoder = outputKVCoder.getValueCoder();
+ // Group by window and key to run the aggregation (combineFn)
+ result =
+ cxt.getDataset(cxt.getInput())
+ .flatMap(explodeWindowedKey(value()), cxt.tupleEncoder(windowedKeyEnc, inputEnc))
+ .groupByKey(fun1(Tuple2::_1), windowedKeyEnc)
+ .mapValues(fun1(Tuple2::_2), inputEnc)
+ .agg(valueAgg.toColumn())
+ .map(windowedKV(), wvOutputEnc);
+ }
+ } else {
+ // Optimized aggregator for non-merging and session window functions, all others depend on
+ // windowFn.mergeWindows
+ Aggregator<WindowedValue<KV<K, InT>>, ?, Collection<WindowedValue<OutT>>> aggregator =
+ Aggregators.windowedValue(
+ combineFn,
+ valueValue(),
+ windowing,
+ cxt.windowEncoder(),
+ accumEnc,
+ cxt.windowedEncoder(outputCoder.getValueCoder()));
+ result =
+ cxt.getDataset(cxt.getInput())
+ .groupByKey(valueKey(), keyEnc)
+ .agg(aggregator.toColumn())
+ .flatMap(explodeWindows(), wvOutputEnc);
+ }
+
+ cxt.putDataset(cxt.getOutput(), result);
+ }
+
+ private static <K, V>
+ Fun1<Tuple2<K, Collection<WindowedValue<V>>>, TraversableOnce<WindowedValue<KV<K, V>>>>
+ explodeWindows() {
+ return t ->
+ ScalaInterop.scalaIterator(t._2).map(wv -> wv.withValue(KV.of(t._1, wv.getValue())));
+ }
- KeyValueGroupedDataset<K, WindowedValue<KV<K, InputT>>> groupedDataset =
- inputDataset.groupByKey(KVHelpers.extractKey(), EncoderHelpers.fromBeamCoder(keyCoder));
+ private static <K, V> Fun1<Tuple2<K, V>, WindowedValue<KV<K, V>>> globalKV() {
+ return t -> WindowedValue.valueInGlobalWindow(KV.of(t._1, t._2));
+ }
- Coder<AccumT> accumulatorCoder = null;
+ private Encoder<AccT> accumEncoder(
+ CombineFn<InT, AccT, OutT> fn, Coder<InT> valueCoder, Context cxt) {
try {
- accumulatorCoder =
- combineFn.getAccumulatorCoder(
- input.getPipeline().getCoderRegistry(), inputCoder.getValueCoder());
+ CoderRegistry registry = cxt.getInput().getPipeline().getCoderRegistry();
+ return cxt.encoderOf(fn.getAccumulatorCoder(registry, valueCoder));
} catch (CannotProvideCoderException e) {
throw new RuntimeException(e);
}
-
- Dataset<Tuple2<K, Iterable<WindowedValue<OutputT>>>> combinedDataset =
- groupedDataset.agg(
- new AggregatorCombiner<K, InputT, AccumT, OutputT, BoundedWindow>(
- combineFn, windowingStrategy, accumulatorCoder, outputCoder)
- .toColumn());
-
- // expand the list into separate elements and put the key back into the elements
- WindowedValue.WindowedValueCoder<KV<K, OutputT>> wvCoder =
- WindowedValue.FullWindowedValueCoder.of(
- outputKVCoder, input.getWindowingStrategy().getWindowFn().windowCoder());
- Dataset<WindowedValue<KV<K, OutputT>>> outputDataset =
- combinedDataset.flatMap(
- (FlatMapFunction<
- Tuple2<K, Iterable<WindowedValue<OutputT>>>, WindowedValue<KV<K, OutputT>>>)
- tuple2 -> {
- K key = tuple2._1();
- Iterable<WindowedValue<OutputT>> windowedValues = tuple2._2();
- List<WindowedValue<KV<K, OutputT>>> result = new ArrayList<>();
- for (WindowedValue<OutputT> windowedValue : windowedValues) {
- KV<K, OutputT> kv = KV.of(key, windowedValue.getValue());
- result.add(
- WindowedValue.of(
- kv,
- windowedValue.getTimestamp(),
- windowedValue.getWindows(),
- windowedValue.getPane()));
- }
- return result.iterator();
- },
- EncoderHelpers.fromBeamCoder(wvCoder));
- context.putDataset(output, outputDataset);
}
}
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CreatePCollectionViewTranslatorBatch.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CreatePCollectionViewTranslatorBatch.java
index ae1eeced328..27151292023 100644
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CreatePCollectionViewTranslatorBatch.java
+++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CreatePCollectionViewTranslatorBatch.java
@@ -19,39 +19,25 @@ package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
import java.io.IOException;
import org.apache.beam.runners.core.construction.CreatePCollectionViewTranslation;
-import org.apache.beam.runners.spark.structuredstreaming.translation.AbstractTranslationContext;
import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator;
-import org.apache.beam.sdk.runners.AppliedPTransform;
-import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.View;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.spark.sql.Dataset;
class CreatePCollectionViewTranslatorBatch<ElemT, ViewT>
- implements TransformTranslator<PTransform<PCollection<ElemT>, PCollection<ElemT>>> {
+ extends TransformTranslator<
+ PCollection<ElemT>, PCollection<ElemT>, View.CreatePCollectionView<ElemT, ViewT>> {
@Override
- public void translateTransform(
- PTransform<PCollection<ElemT>, PCollection<ElemT>> transform,
- AbstractTranslationContext context) {
+ public void translate(View.CreatePCollectionView<ElemT, ViewT> transform, Context context) {
Dataset<WindowedValue<ElemT>> inputDataSet = context.getDataset(context.getInput());
- @SuppressWarnings("unchecked")
- AppliedPTransform<
- PCollection<ElemT>,
- PCollection<ElemT>,
- PTransform<PCollection<ElemT>, PCollection<ElemT>>>
- application =
- (AppliedPTransform<
- PCollection<ElemT>,
- PCollection<ElemT>,
- PTransform<PCollection<ElemT>, PCollection<ElemT>>>)
- context.getCurrentTransform();
PCollectionView<ViewT> input;
try {
- input = CreatePCollectionViewTranslation.getView(application);
+ input = CreatePCollectionViewTranslation.getView(context.getCurrentTransform());
} catch (IOException e) {
throw new RuntimeException(e);
}
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DatasetSourceBatch.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DatasetSourceBatch.java
deleted file mode 100644
index 46bde96c30c..00000000000
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DatasetSourceBatch.java
+++ /dev/null
@@ -1,240 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
-
-import static org.apache.beam.runners.spark.structuredstreaming.Constants.BEAM_SOURCE_OPTION;
-import static org.apache.beam.runners.spark.structuredstreaming.Constants.DEFAULT_PARALLELISM;
-import static org.apache.beam.runners.spark.structuredstreaming.Constants.PIPELINE_OPTIONS;
-import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;
-
-import java.io.IOException;
-import java.io.Serializable;
-import java.util.List;
-import java.util.Map;
-import java.util.Set;
-import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
-import org.apache.beam.runners.core.serialization.Base64Serializer;
-import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.RowHelpers;
-import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.SchemaHelpers;
-import org.apache.beam.sdk.io.BoundedSource;
-import org.apache.beam.sdk.options.PipelineOptions;
-import org.apache.beam.sdk.util.WindowedValue;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Strings;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSet;
-import org.apache.spark.sql.catalyst.InternalRow;
-import org.apache.spark.sql.connector.catalog.SupportsRead;
-import org.apache.spark.sql.connector.catalog.Table;
-import org.apache.spark.sql.connector.catalog.TableCapability;
-import org.apache.spark.sql.connector.catalog.TableProvider;
-import org.apache.spark.sql.connector.expressions.Transform;
-import org.apache.spark.sql.connector.read.Batch;
-import org.apache.spark.sql.connector.read.InputPartition;
-import org.apache.spark.sql.connector.read.PartitionReader;
-import org.apache.spark.sql.connector.read.PartitionReaderFactory;
-import org.apache.spark.sql.connector.read.Scan;
-import org.apache.spark.sql.connector.read.ScanBuilder;
-import org.apache.spark.sql.types.StructType;
-import org.apache.spark.sql.util.CaseInsensitiveStringMap;
-
-/**
- * Spark DataSourceV2 API was removed in Spark3. This is a Beam source wrapper using the new spark 3
- * source API.
- */
-public class DatasetSourceBatch implements TableProvider {
-
- private static final StructType BINARY_SCHEMA = SchemaHelpers.binarySchema();
-
- public DatasetSourceBatch() {}
-
- @Override
- public StructType inferSchema(CaseInsensitiveStringMap options) {
- return BINARY_SCHEMA;
- }
-
- @Override
- public boolean supportsExternalMetadata() {
- return true;
- }
-
- @Override
- public Table getTable(
- StructType schema, Transform[] partitioning, Map<String, String> properties) {
- return new DatasetSourceBatchTable();
- }
-
- private static class DatasetSourceBatchTable implements SupportsRead {
-
- @Override
- public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) {
- return new ScanBuilder() {
-
- @Override
- public Scan build() {
- return new Scan() { // scan for Batch reading
-
- @Override
- public StructType readSchema() {
- return BINARY_SCHEMA;
- }
-
- @Override
- public Batch toBatch() {
- return new BeamBatch<>(options);
- }
- };
- }
- };
- }
-
- @Override
- public String name() {
- return "BeamSource";
- }
-
- @Override
- public StructType schema() {
- return BINARY_SCHEMA;
- }
-
- @Override
- public Set<TableCapability> capabilities() {
- final ImmutableSet<TableCapability> capabilities =
- ImmutableSet.of(TableCapability.BATCH_READ);
- return capabilities;
- }
-
- private static class BeamBatch<T> implements Batch, Serializable {
-
- private final int numPartitions;
- private final BoundedSource<T> source;
- private final SerializablePipelineOptions serializablePipelineOptions;
-
- private BeamBatch(CaseInsensitiveStringMap options) {
- if (Strings.isNullOrEmpty(options.get(BEAM_SOURCE_OPTION))) {
- throw new RuntimeException("Beam source was not set in DataSource options");
- }
- this.source =
- Base64Serializer.deserializeUnchecked(
- options.get(BEAM_SOURCE_OPTION), BoundedSource.class);
-
- if (Strings.isNullOrEmpty(DEFAULT_PARALLELISM)) {
- throw new RuntimeException("Spark default parallelism was not set in DataSource options");
- }
- this.numPartitions = Integer.parseInt(options.get(DEFAULT_PARALLELISM));
- checkArgument(numPartitions > 0, "Number of partitions must be greater than zero.");
-
- if (Strings.isNullOrEmpty(options.get(PIPELINE_OPTIONS))) {
- throw new RuntimeException("Beam pipelineOptions were not set in DataSource options");
- }
- this.serializablePipelineOptions =
- new SerializablePipelineOptions(options.get(PIPELINE_OPTIONS));
- }
-
- @Override
- public InputPartition[] planInputPartitions() {
- PipelineOptions options = serializablePipelineOptions.get();
- long desiredSizeBytes;
-
- try {
- desiredSizeBytes = source.getEstimatedSizeBytes(options) / numPartitions;
- List<? extends BoundedSource<T>> splits = source.split(desiredSizeBytes, options);
- InputPartition[] result = new InputPartition[splits.size()];
- int i = 0;
- for (BoundedSource<T> split : splits) {
- result[i++] = new BeamInputPartition<>(split);
- }
- return result;
- } catch (Exception e) {
- throw new RuntimeException(
- "Error in splitting BoundedSource " + source.getClass().getCanonicalName(), e);
- }
- }
-
- @Override
- public PartitionReaderFactory createReaderFactory() {
- return new PartitionReaderFactory() {
-
- @Override
- public PartitionReader<InternalRow> createReader(InputPartition partition) {
- return new BeamPartitionReader<T>(
- ((BeamInputPartition<T>) partition).getSource(), serializablePipelineOptions);
- }
- };
- }
-
- private static class BeamInputPartition<T> implements InputPartition {
-
- private final BoundedSource<T> source;
-
- private BeamInputPartition(BoundedSource<T> source) {
- this.source = source;
- }
-
- public BoundedSource<T> getSource() {
- return source;
- }
- }
-
- private static class BeamPartitionReader<T> implements PartitionReader<InternalRow> {
-
- private final BoundedSource<T> source;
- private final BoundedSource.BoundedReader<T> reader;
- private boolean started;
- private boolean closed;
-
- BeamPartitionReader(
- BoundedSource<T> source, SerializablePipelineOptions serializablePipelineOptions) {
- this.started = false;
- this.closed = false;
- this.source = source;
- // reader is not serializable so lazy initialize it
- try {
- reader =
- source.createReader(serializablePipelineOptions.get().as(PipelineOptions.class));
- } catch (IOException e) {
- throw new RuntimeException("Error creating BoundedReader ", e);
- }
- }
-
- @Override
- public boolean next() throws IOException {
- if (!started) {
- started = true;
- return reader.start();
- } else {
- return !closed && reader.advance();
- }
- }
-
- @Override
- public InternalRow get() {
- WindowedValue<T> windowedValue =
- WindowedValue.timestampedValueInGlobalWindow(
- reader.getCurrent(), reader.getCurrentTimestamp());
- return RowHelpers.storeWindowedValueInRow(windowedValue, source.getOutputCoder());
- }
-
- @Override
- public void close() throws IOException {
- closed = true;
- reader.close();
- }
- }
- }
- }
-}
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnFunction.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnFunction.java
deleted file mode 100644
index 42a809fdd97..00000000000
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnFunction.java
+++ /dev/null
@@ -1,164 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
-
-import java.util.Collections;
-import java.util.Iterator;
-import java.util.List;
-import java.util.Map;
-import org.apache.beam.runners.core.DoFnRunner;
-import org.apache.beam.runners.core.DoFnRunners;
-import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
-import org.apache.beam.runners.spark.structuredstreaming.metrics.MetricsContainerStepMapAccumulator;
-import org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.NoOpStepContext;
-import org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.SparkSideInputReader;
-import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.SideInputBroadcast;
-import org.apache.beam.runners.spark.structuredstreaming.translation.utils.CachedSideInputReader;
-import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.transforms.DoFn;
-import org.apache.beam.sdk.transforms.DoFnSchemaInformation;
-import org.apache.beam.sdk.transforms.reflect.DoFnInvokers;
-import org.apache.beam.sdk.util.WindowedValue;
-import org.apache.beam.sdk.values.PCollectionView;
-import org.apache.beam.sdk.values.TupleTag;
-import org.apache.beam.sdk.values.WindowingStrategy;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Function;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterators;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.LinkedListMultimap;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Multimap;
-import org.apache.spark.api.java.function.MapPartitionsFunction;
-import scala.Tuple2;
-
-/**
- * Encapsulates a {@link DoFn} inside a Spark {@link
- * org.apache.spark.api.java.function.MapPartitionsFunction}.
- *
- * <p>We get a mapping from {@link org.apache.beam.sdk.values.TupleTag} to output index and must tag
- * all outputs with the output number. Afterwards a filter will filter out those elements that are
- * not to be in a specific output.
- */
-@SuppressWarnings({
- "nullness" // TODO(https://github.com/apache/beam/issues/20497)
-})
-public class DoFnFunction<InputT, OutputT>
- implements MapPartitionsFunction<WindowedValue<InputT>, Tuple2<TupleTag<?>, WindowedValue<?>>> {
-
- private final MetricsContainerStepMapAccumulator metricsAccum;
- private final String stepName;
- private final DoFn<InputT, OutputT> doFn;
- private transient boolean wasSetupCalled;
- private final WindowingStrategy<?, ?> windowingStrategy;
- private final Map<PCollectionView<?>, WindowingStrategy<?, ?>> sideInputs;
- private final SerializablePipelineOptions serializableOptions;
- private final List<TupleTag<?>> additionalOutputTags;
- private final TupleTag<OutputT> mainOutputTag;
- private final Coder<InputT> inputCoder;
- private final Map<TupleTag<?>, Coder<?>> outputCoderMap;
- private final SideInputBroadcast broadcastStateData;
- private DoFnSchemaInformation doFnSchemaInformation;
- private Map<String, PCollectionView<?>> sideInputMapping;
-
- public DoFnFunction(
- MetricsContainerStepMapAccumulator metricsAccum,
- String stepName,
- DoFn<InputT, OutputT> doFn,
- WindowingStrategy<?, ?> windowingStrategy,
- Map<PCollectionView<?>, WindowingStrategy<?, ?>> sideInputs,
- SerializablePipelineOptions serializableOptions,
- List<TupleTag<?>> additionalOutputTags,
- TupleTag<OutputT> mainOutputTag,
- Coder<InputT> inputCoder,
- Map<TupleTag<?>, Coder<?>> outputCoderMap,
- SideInputBroadcast broadcastStateData,
- DoFnSchemaInformation doFnSchemaInformation,
- Map<String, PCollectionView<?>> sideInputMapping) {
- this.metricsAccum = metricsAccum;
- this.stepName = stepName;
- this.doFn = doFn;
- this.windowingStrategy = windowingStrategy;
- this.sideInputs = sideInputs;
- this.serializableOptions = serializableOptions;
- this.additionalOutputTags = additionalOutputTags;
- this.mainOutputTag = mainOutputTag;
- this.inputCoder = inputCoder;
- this.outputCoderMap = outputCoderMap;
- this.broadcastStateData = broadcastStateData;
- this.doFnSchemaInformation = doFnSchemaInformation;
- this.sideInputMapping = sideInputMapping;
- }
-
- @Override
- public Iterator<Tuple2<TupleTag<?>, WindowedValue<?>>> call(Iterator<WindowedValue<InputT>> iter)
- throws Exception {
- if (!wasSetupCalled && iter.hasNext()) {
- DoFnInvokers.tryInvokeSetupFor(doFn, serializableOptions.get());
- wasSetupCalled = true;
- }
-
- DoFnOutputManager outputManager = new DoFnOutputManager();
-
- DoFnRunner<InputT, OutputT> doFnRunner =
- DoFnRunners.simpleRunner(
- serializableOptions.get(),
- doFn,
- CachedSideInputReader.of(new SparkSideInputReader(sideInputs, broadcastStateData)),
- outputManager,
- mainOutputTag,
- additionalOutputTags,
- new NoOpStepContext(),
- inputCoder,
- outputCoderMap,
- windowingStrategy,
- doFnSchemaInformation,
- sideInputMapping);
-
- DoFnRunnerWithMetrics<InputT, OutputT> doFnRunnerWithMetrics =
- new DoFnRunnerWithMetrics<>(stepName, doFnRunner, metricsAccum);
-
- return new ProcessContext<>(
- doFn, doFnRunnerWithMetrics, outputManager, Collections.emptyIterator())
- .processPartition(iter)
- .iterator();
- }
-
- private class DoFnOutputManager
- implements ProcessContext.ProcessOutputManager<Tuple2<TupleTag<?>, WindowedValue<?>>> {
-
- private final Multimap<TupleTag<?>, WindowedValue<?>> outputs = LinkedListMultimap.create();
-
- @Override
- public void clear() {
- outputs.clear();
- }
-
- @Override
- public Iterator<Tuple2<TupleTag<?>, WindowedValue<?>>> iterator() {
- Iterator<Map.Entry<TupleTag<?>, WindowedValue<?>>> entryIter = outputs.entries().iterator();
- return Iterators.transform(entryIter, this.entryToTupleFn());
- }
-
- private <K, V> Function<Map.Entry<K, V>, Tuple2<K, V>> entryToTupleFn() {
- return en -> new Tuple2<>(en.getKey(), en.getValue());
- }
-
- @Override
- public synchronized <T> void output(TupleTag<T> tag, WindowedValue<T> output) {
- outputs.put(tag, output);
- }
- }
-}
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnMapPartitionsFactory.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnMapPartitionsFactory.java
new file mode 100644
index 00000000000..c02e07319af
--- /dev/null
+++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnMapPartitionsFactory.java
@@ -0,0 +1,224 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
+
+import static java.util.function.Function.identity;
+import static java.util.stream.Collectors.toCollection;
+import static java.util.stream.Collectors.toMap;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.scalaIterator;
+import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists.newArrayListWithCapacity;
+
+import java.io.Serializable;
+import java.util.ArrayDeque;
+import java.util.Collection;
+import java.util.Deque;
+import java.util.List;
+import java.util.Map;
+import org.apache.beam.runners.core.DoFnRunner;
+import org.apache.beam.runners.core.DoFnRunners;
+import org.apache.beam.runners.core.DoFnRunners.OutputManager;
+import org.apache.beam.runners.core.SideInputReader;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
+import org.apache.beam.runners.spark.structuredstreaming.metrics.MetricsAccumulator;
+import org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.NoOpStepContext;
+import org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.SparkSideInputReader;
+import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.SideInputBroadcast;
+import org.apache.beam.runners.spark.structuredstreaming.translation.utils.CachedSideInputReader;
+import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.Fun1;
+import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.Fun2;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.DoFnSchemaInformation;
+import org.apache.beam.sdk.transforms.reflect.DoFnInvokers;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionView;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.beam.sdk.values.WindowingStrategy;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.AbstractIterator;
+import org.apache.spark.api.java.function.MapPartitionsFunction;
+import org.checkerframework.checker.nullness.qual.NonNull;
+import scala.collection.Iterator;
+
+/**
+ * Encapsulates a {@link DoFn} inside a Spark {@link
+ * org.apache.spark.api.java.function.MapPartitionsFunction}.
+ */
+class DoFnMapPartitionsFactory<InT, OutT> implements Serializable {
+ private final String stepName;
+
+ private final DoFn<InT, OutT> doFn;
+ private final DoFnSchemaInformation doFnSchema;
+ private final SerializablePipelineOptions options;
+
+ private final Coder<InT> coder;
+ private final WindowingStrategy<?, ?> windowingStrategy;
+ private final TupleTag<OutT> mainOutput;
+ private final List<TupleTag<?>> additionalOutputs;
+ private final Map<TupleTag<?>, Coder<?>> outputCoders;
+
+ private final Map<String, PCollectionView<?>> sideInputs;
+ private final Map<PCollectionView<?>, WindowingStrategy<?, ?>> sideInputWindows;
+ private final SideInputBroadcast broadcastStateData;
+
+ DoFnMapPartitionsFactory(
+ String stepName,
+ DoFn<InT, OutT> doFn,
+ DoFnSchemaInformation doFnSchema,
+ SerializablePipelineOptions options,
+ PCollection<InT> input,
+ TupleTag<OutT> mainOutput,
+ Map<TupleTag<?>, PCollection<?>> outputs,
+ Map<String, PCollectionView<?>> sideInputs,
+ SideInputBroadcast broadcastStateData) {
+ this.stepName = stepName;
+ this.doFn = doFn;
+ this.doFnSchema = doFnSchema;
+ this.options = options;
+ this.coder = input.getCoder();
+ this.windowingStrategy = input.getWindowingStrategy();
+ this.mainOutput = mainOutput;
+ this.additionalOutputs = additionalOutputs(outputs, mainOutput);
+ this.outputCoders = outputCoders(outputs);
+ this.sideInputs = sideInputs;
+ this.sideInputWindows = sideInputWindows(sideInputs.values());
+ this.broadcastStateData = broadcastStateData;
+ }
+
+ /** Create the {@link MapPartitionsFunction} using the provided output function. */
+ <OutputT extends @NonNull Object> Fun1<Iterator<WindowedValue<InT>>, Iterator<OutputT>> create(
+ Fun2<TupleTag<?>, WindowedValue<?>, OutputT> outputFn) {
+ return it ->
+ it.hasNext()
+ ? scalaIterator(new DoFnPartitionIt<>(outputFn, it))
+ : (Iterator<OutputT>) Iterator.empty();
+ }
+
+ // FIXME Add support for TimerInternals.TimerData
+ /**
+ * Partition iterator that lazily processes each element from the (input) iterator on demand
+ * producing zero, one or more output elements as output (via an internal buffer).
+ *
+ * <p>When initializing the iterator for a partition {@code setup} followed by {@code startBundle}
+ * is called.
+ */
+ private class DoFnPartitionIt<FnInT extends InT, OutputT> extends AbstractIterator<OutputT> {
+ private final Deque<OutputT> buffer;
+ private final DoFnRunner<InT, OutT> doFnRunner;
+ private final Iterator<WindowedValue<FnInT>> partitionIt;
+
+ private boolean isBundleFinished;
+
+ DoFnPartitionIt(
+ Fun2<TupleTag<?>, WindowedValue<?>, OutputT> outputFn,
+ Iterator<WindowedValue<FnInT>> partitionIt) {
+ this.buffer = new ArrayDeque<>();
+ this.doFnRunner = metricsRunner(simpleRunner(outputFn, buffer));
+ this.partitionIt = partitionIt;
+ // Before starting to iterate over the partition, invoke setup and then startBundle
+ DoFnInvokers.tryInvokeSetupFor(doFn, options.get());
+ try {
+ doFnRunner.startBundle();
+ } catch (RuntimeException re) {
+ DoFnInvokers.invokerFor(doFn).invokeTeardown();
+ throw re;
+ }
+ }
+
+ @Override
+ protected OutputT computeNext() {
+ try {
+ while (true) {
+ if (!buffer.isEmpty()) {
+ return buffer.remove();
+ }
+ if (partitionIt.hasNext()) {
+ // grab the next element and process it.
+ doFnRunner.processElement((WindowedValue<InT>) partitionIt.next());
+ } else {
+ if (!isBundleFinished) {
+ isBundleFinished = true;
+ doFnRunner.finishBundle();
+ continue; // finishBundle can produce more output
+ }
+ DoFnInvokers.invokerFor(doFn).invokeTeardown();
+ return endOfData();
+ }
+ }
+ } catch (RuntimeException re) {
+ DoFnInvokers.invokerFor(doFn).invokeTeardown();
+ throw re;
+ }
+ }
+ }
+
+ private <OutputT> DoFnRunner<InT, OutT> simpleRunner(
+ Fun2<TupleTag<?>, WindowedValue<?>, OutputT> outputFn, Deque<OutputT> buffer) {
+ OutputManager outputManager =
+ new OutputManager() {
+ @Override
+ public <T> void output(TupleTag<T> tag, WindowedValue<T> output) {
+ buffer.add(outputFn.apply(tag, output));
+ }
+ };
+ SideInputReader sideInputReader =
+ CachedSideInputReader.of(new SparkSideInputReader(sideInputWindows, broadcastStateData));
+ return DoFnRunners.simpleRunner(
+ options.get(),
+ doFn,
+ sideInputReader,
+ outputManager,
+ mainOutput,
+ additionalOutputs,
+ new NoOpStepContext(),
+ coder,
+ outputCoders,
+ windowingStrategy,
+ doFnSchema,
+ sideInputs);
+ }
+
+ private DoFnRunner<InT, OutT> metricsRunner(DoFnRunner<InT, OutT> runner) {
+ return new DoFnRunnerWithMetrics<>(stepName, runner, MetricsAccumulator.getInstance());
+ }
+
+ private static Map<PCollectionView<?>, WindowingStrategy<?, ?>> sideInputWindows(
+ Collection<PCollectionView<?>> views) {
+ return views.stream().collect(toMap(identity(), DoFnMapPartitionsFactory::windowingStrategy));
+ }
+
+ private static WindowingStrategy<?, ?> windowingStrategy(PCollectionView<?> view) {
+ PCollection<?> pc = view.getPCollection();
+ if (pc == null) {
+ throw new IllegalStateException("PCollection not available for " + view);
+ }
+ return pc.getWindowingStrategy();
+ }
+
+ private static List<TupleTag<?>> additionalOutputs(
+ Map<TupleTag<?>, PCollection<?>> outputs, TupleTag<?> mainOutput) {
+ return outputs.keySet().stream()
+ .filter(t -> !t.equals(mainOutput))
+ .collect(toCollection(() -> newArrayListWithCapacity(outputs.size() - 1)));
+ }
+
+ private static Map<TupleTag<?>, Coder<?>> outputCoders(Map<TupleTag<?>, PCollection<?>> outputs) {
+ return outputs.entrySet().stream()
+ .collect(toMap(Map.Entry::getKey, e -> e.getValue().getCoder()));
+ }
+}
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/FlattenTranslatorBatch.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/FlattenTranslatorBatch.java
index db361f7753e..eb071304998 100644
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/FlattenTranslatorBatch.java
+++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/FlattenTranslatorBatch.java
@@ -17,49 +17,47 @@
*/
package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
-import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.fun1;
import java.util.Collection;
-import org.apache.beam.runners.spark.structuredstreaming.translation.AbstractTranslationContext;
+import java.util.Iterator;
import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator;
-import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.transforms.Flatten;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionList;
-import org.apache.beam.sdk.values.PValue;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoder;
-@SuppressWarnings({
- "nullness" // TODO(https://github.com/apache/beam/issues/20497)
-})
class FlattenTranslatorBatch<T>
- implements TransformTranslator<PTransform<PCollectionList<T>, PCollection<T>>> {
+ extends TransformTranslator<PCollectionList<T>, PCollection<T>, Flatten.PCollections<T>> {
@Override
- public void translateTransform(
- PTransform<PCollectionList<T>, PCollection<T>> transform,
- AbstractTranslationContext context) {
- Collection<PCollection<?>> pcollectionList = context.getInputs().values();
- Dataset<WindowedValue<T>> result = null;
- if (pcollectionList.isEmpty()) {
- result = context.emptyDataset();
- } else {
- for (PValue pValue : pcollectionList) {
- checkArgument(
- pValue instanceof PCollection,
- "Got non-PCollection input to flatten: %s of type %s",
- pValue,
- pValue.getClass().getSimpleName());
- @SuppressWarnings("unchecked")
- PCollection<T> pCollection = (PCollection<T>) pValue;
- Dataset<WindowedValue<T>> current = context.getDataset(pCollection);
- if (result == null) {
- result = current;
- } else {
- result = result.union(current);
- }
+ public void translate(Flatten.PCollections<T> transform, Context cxt) {
+ Collection<PCollection<?>> pCollections = cxt.getInputs().values();
+ Coder<T> outputCoder = cxt.getOutput().getCoder();
+ Encoder<WindowedValue<T>> outputEnc =
+ cxt.windowedEncoder(outputCoder, windowCoder(cxt.getOutput()));
+
+ Dataset<WindowedValue<T>> result;
+ Iterator<PCollection<T>> pcIt = (Iterator) pCollections.iterator();
+ if (pcIt.hasNext()) {
+ result = getDataset(pcIt.next(), outputCoder, outputEnc, cxt);
+ while (pcIt.hasNext()) {
+ result = result.union(getDataset(pcIt.next(), outputCoder, outputEnc, cxt));
}
+ } else {
+ result = cxt.createDataset(ImmutableList.of(), outputEnc);
}
- context.putDataset(context.getOutput(), result);
+ cxt.putDataset(cxt.getOutput(), result);
+ }
+
+ private Dataset<WindowedValue<T>> getDataset(
+ PCollection<T> pc, Coder<T> coder, Encoder<WindowedValue<T>> enc, Context cxt) {
+ Dataset<WindowedValue<T>> current = cxt.getDataset(pc);
+ // if coders don't match, map using identity function to replace encoder
+ return pc.getCoder().equals(coder) ? current : current.map(fun1(v -> v), enc);
}
}
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyHelpers.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyHelpers.java
new file mode 100644
index 00000000000..28ab07114c6
--- /dev/null
+++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyHelpers.java
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
+
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.tuple;
+import static org.apache.beam.sdk.transforms.windowing.PaneInfo.NO_FIRING;
+import static org.apache.beam.sdk.transforms.windowing.TimestampCombiner.END_OF_WINDOW;
+
+import java.util.Collection;
+import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop;
+import org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.Fun1;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.GlobalWindows;
+import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.WindowingStrategy;
+import scala.Tuple2;
+import scala.collection.TraversableOnce;
+
+/**
+ * Package private helpers to support translating grouping transforms using `groupByKey` such as
+ * {@link GroupByKeyTranslatorBatch} or {@link CombinePerKeyTranslatorBatch}.
+ */
+class GroupByKeyHelpers {
+
+ private GroupByKeyHelpers() {}
+
+ /**
+ * Checks if it's possible to use an optimized `groupByKey` that also moves the window into the
+ * key.
+ *
+ * @param windowing The windowing strategy
+ * @param endOfWindowOnly Flag if to limit this optimization to {@link
+ * TimestampCombiner#END_OF_WINDOW}.
+ */
+ static boolean eligibleForGroupByWindow(
+ WindowingStrategy<?, ?> windowing, boolean endOfWindowOnly) {
+ return !windowing.needsMerge()
+ && (!endOfWindowOnly || windowing.getTimestampCombiner() == END_OF_WINDOW)
+ && windowing.getWindowFn().windowCoder().consistentWithEquals();
+ }
+
+ /**
+ * Checks if it's possible to use an optimized `groupByKey` for the global window.
+ *
+ * @param windowing The windowing strategy
+ * @param endOfWindowOnly Flag if to limit this optimization to {@link
+ * TimestampCombiner#END_OF_WINDOW}.
+ */
+ static boolean eligibleForGlobalGroupBy(
+ WindowingStrategy<?, ?> windowing, boolean endOfWindowOnly) {
+ return windowing.getWindowFn() instanceof GlobalWindows
+ && (!endOfWindowOnly || windowing.getTimestampCombiner() == END_OF_WINDOW);
+ }
+
+ /**
+ * Explodes a windowed {@link KV} assigned to potentially multiple {@link BoundedWindow}s to a
+ * traversable of composite keys {@code (BoundedWindow, Key)} and value.
+ */
+ static <K, V, T>
+ Fun1<WindowedValue<KV<K, V>>, TraversableOnce<Tuple2<Tuple2<BoundedWindow, K>, T>>>
+ explodeWindowedKey(Fun1<WindowedValue<KV<K, V>>, T> valueFn) {
+ return v -> {
+ T value = valueFn.apply(v);
+ K key = v.getValue().getKey();
+ Collection<BoundedWindow> windows = (Collection<BoundedWindow>) v.getWindows();
+ return ScalaInterop.scalaIterator(windows).map(w -> tuple(tuple(w, key), value));
+ };
+ }
+
+ static <K, V> Fun1<Tuple2<Tuple2<BoundedWindow, K>, V>, WindowedValue<KV<K, V>>> windowedKV() {
+ return t -> windowedKV(t._1, t._2);
+ }
+
+ static <K, V> WindowedValue<KV<K, V>> windowedKV(Tuple2<BoundedWindow, K> key, V value) {
+ return WindowedValue.of(KV.of(key._2, value), key._1.maxTimestamp(), key._1, NO_FIRING);
+ }
+
+ static <V> Fun1<WindowedValue<V>, V> value() {
+ return v -> v.getValue();
+ }
+
+ static <K, V> Fun1<WindowedValue<KV<K, V>>, V> valueValue() {
+ return v -> v.getValue().getValue();
+ }
+
+ static <K, V> Fun1<WindowedValue<KV<K, V>>, K> valueKey() {
+ return v -> v.getValue().getKey();
+ }
+}
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTranslatorBatch.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTranslatorBatch.java
index 6391ba4600c..61306cb993c 100644
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTranslatorBatch.java
+++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTranslatorBatch.java
@@ -17,74 +17,274 @@
*/
package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
+import static org.apache.beam.repackaged.core.org.apache.commons.lang3.ArrayUtils.EMPTY_BYTE_ARRAY;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.eligibleForGlobalGroupBy;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.eligibleForGroupByWindow;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.explodeWindowedKey;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.valueKey;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.valueValue;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.batch.GroupByKeyHelpers.windowedKV;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.CoderHelpers.toByteArray;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.collectionEncoder;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.encoderOf;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.kvEncoder;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.windowedValueEncoder;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.concat;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.fun1;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.fun2;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.javaIterator;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.listOf;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.seqOf;
+import static org.apache.beam.sdk.transforms.windowing.PaneInfo.NO_FIRING;
+import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState;
+import static org.apache.spark.sql.functions.col;
+import static org.apache.spark.sql.functions.collect_list;
+import static org.apache.spark.sql.functions.explode;
+import static org.apache.spark.sql.functions.max;
+import static org.apache.spark.sql.functions.min;
+import static org.apache.spark.sql.functions.struct;
+
import java.io.Serializable;
import org.apache.beam.runners.core.InMemoryStateInternals;
-import org.apache.beam.runners.core.StateInternals;
+import org.apache.beam.runners.core.ReduceFnRunner;
import org.apache.beam.runners.core.StateInternalsFactory;
import org.apache.beam.runners.core.SystemReduceFn;
-import org.apache.beam.runners.spark.structuredstreaming.translation.AbstractTranslationContext;
import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator;
import org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.GroupAlsoByWindowViaOutputBufferFn;
-import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers;
-import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.KVHelpers;
-import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.coders.IterableCoder;
import org.apache.beam.sdk.coders.KvCoder;
-import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.GroupByKey;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
+import org.apache.beam.sdk.transforms.windowing.PaneInfo.PaneInfoCoder;
+import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.WindowingStrategy;
+import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
-import org.apache.spark.sql.KeyValueGroupedDataset;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.TypedColumn;
+import org.apache.spark.sql.catalyst.expressions.CreateArray;
+import org.apache.spark.sql.catalyst.expressions.CreateNamedStruct;
+import org.apache.spark.sql.catalyst.expressions.Expression;
+import org.apache.spark.sql.catalyst.expressions.Literal;
+import org.apache.spark.sql.catalyst.expressions.Literal$;
+import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.DataTypes;
+import org.checkerframework.checker.nullness.qual.NonNull;
+import scala.Tuple2;
+import scala.collection.Iterator;
+import scala.collection.Seq;
+import scala.collection.immutable.List;
+/**
+ * Translator for {@link GroupByKey} using {@link Dataset#groupByKey} with the build-in aggregation
+ * function {@code collect_list} when applicable.
+ *
+ * <p>Note: Using {@code collect_list} isn't any worse than using {@link ReduceFnRunner}. In the
+ * latter case the entire group (iterator) has to be loaded into memory as well. Either way there's
+ * a risk of OOM errors. When disabling {@link #useCollectList}, a more memory sensitive iterable is
+ * used that can be traversed just once. Attempting to traverse the iterable again will throw.
+ *
+ * <ul>
+ * <li>When using the default global window, window information is dropped and restored after the
+ * aggregation.
+ * <li>For non-merging windows, windows are exploded and moved into a composite key for better
+ * distribution. Though, to keep the amount of shuffled data low, this is only done if values
+ * are assigned to a single window or if there are only few keys and distributing data is
+ * important. After the aggregation, windowed values are restored from the composite key.
+ * <li>All other cases are implemented using the SDK {@link ReduceFnRunner}.
+ * </ul>
+ */
class GroupByKeyTranslatorBatch<K, V>
- implements TransformTranslator<
- PTransform<PCollection<KV<K, V>>, PCollection<KV<K, Iterable<V>>>>> {
+ extends TransformTranslator<
+ PCollection<KV<K, V>>, PCollection<KV<K, Iterable<V>>>, GroupByKey<K, V>> {
+
+ /** Literal of binary encoded Pane info. */
+ private static final Expression PANE_NO_FIRING = lit(toByteArray(NO_FIRING, PaneInfoCoder.of()));
+
+ /** Defaults for value in single global window. */
+ private static final List<Expression> GLOBAL_WINDOW_DETAILS =
+ windowDetails(lit(new byte[][] {EMPTY_BYTE_ARRAY}));
+
+ private boolean useCollectList = true;
+
+ public GroupByKeyTranslatorBatch() {}
+
+ public GroupByKeyTranslatorBatch(boolean useCollectList) {
+ this.useCollectList = useCollectList;
+ }
@Override
- public void translateTransform(
- PTransform<PCollection<KV<K, V>>, PCollection<KV<K, Iterable<V>>>> transform,
- AbstractTranslationContext context) {
-
- @SuppressWarnings("unchecked")
- final PCollection<KV<K, V>> inputPCollection = (PCollection<KV<K, V>>) context.getInput();
- Dataset<WindowedValue<KV<K, V>>> input = context.getDataset(inputPCollection);
- WindowingStrategy<?, ?> windowingStrategy = inputPCollection.getWindowingStrategy();
- KvCoder<K, V> kvCoder = (KvCoder<K, V>) inputPCollection.getCoder();
- Coder<V> valueCoder = kvCoder.getValueCoder();
-
- // group by key only
- Coder<K> keyCoder = kvCoder.getKeyCoder();
- KeyValueGroupedDataset<K, WindowedValue<KV<K, V>>> groupByKeyOnly =
- input.groupByKey(KVHelpers.extractKey(), EncoderHelpers.fromBeamCoder(keyCoder));
-
- // group also by windows
- WindowedValue.FullWindowedValueCoder<KV<K, Iterable<V>>> outputCoder =
- WindowedValue.FullWindowedValueCoder.of(
- KvCoder.of(keyCoder, IterableCoder.of(valueCoder)),
- windowingStrategy.getWindowFn().windowCoder());
- Dataset<WindowedValue<KV<K, Iterable<V>>>> output =
- groupByKeyOnly.flatMapGroups(
- new GroupAlsoByWindowViaOutputBufferFn<>(
- windowingStrategy,
- new InMemoryStateInternalsFactory<>(),
- SystemReduceFn.buffering(valueCoder),
- context.getSerializableOptions()),
- EncoderHelpers.fromBeamCoder(outputCoder));
-
- context.putDataset(context.getOutput(), output);
+ public void translate(GroupByKey<K, V> transform, Context cxt) {
+ WindowingStrategy<?, ?> windowing = cxt.getInput().getWindowingStrategy();
+ TimestampCombiner tsCombiner = windowing.getTimestampCombiner();
+
+ Dataset<WindowedValue<KV<K, V>>> input = cxt.getDataset(cxt.getInput());
+
+ KvCoder<K, V> inputCoder = (KvCoder<K, V>) cxt.getInput().getCoder();
+ KvCoder<K, Iterable<V>> outputCoder = (KvCoder<K, Iterable<V>>) cxt.getOutput().getCoder();
+
+ Encoder<V> valueEnc = cxt.valueEncoderOf(inputCoder);
+ Encoder<K> keyEnc = cxt.keyEncoderOf(inputCoder);
+
+ // In batch we can ignore triggering and allowed lateness parameters
+ final Dataset<WindowedValue<KV<K, Iterable<V>>>> result;
+
+ if (useCollectList && eligibleForGlobalGroupBy(windowing, false)) {
+ // Collects all values per key in memory. This might be problematic if there's few keys only
+ // or some highly skewed distribution.
+ result =
+ input
+ .groupBy(col("value.key").as("key"))
+ .agg(collect_list(col("value.value")).as("values"), timestampAggregator(tsCombiner))
+ .select(
+ inGlobalWindow(
+ keyValue(col("key").as(keyEnc), col("values").as(iterableEnc(valueEnc))),
+ windowTimestamp(tsCombiner)));
+
+ } else if (eligibleForGlobalGroupBy(windowing, true)) {
+ // Produces an iterable that can be traversed exactly once. However, on the plus side, data is
+ // not collected in memory until serialized or done by the user.
+ result =
+ cxt.getDataset(cxt.getInput())
+ .groupByKey(valueKey(), keyEnc)
+ .mapValues(valueValue(), cxt.valueEncoderOf(inputCoder))
+ .mapGroups(fun2((k, it) -> KV.of(k, iterableOnce(it))), cxt.kvEncoderOf(outputCoder))
+ .map(fun1(WindowedValue::valueInGlobalWindow), cxt.windowedEncoder(outputCoder));
+
+ } else if (useCollectList
+ && eligibleForGroupByWindow(windowing, false)
+ && (windowing.getWindowFn().assignsToOneWindow() || transform.fewKeys())) {
+ // Using the window as part of the key should help to better distribute the data. However, if
+ // values are assigned to multiple windows, more data would be shuffled around. If there's few
+ // keys only, this is still valuable.
+ // Collects all values per key & window in memory.
+ result =
+ input
+ .select(explode(col("windows")).as("window"), col("value"), col("timestamp"))
+ .groupBy(col("value.key"), col("window"))
+ .agg(collect_list(col("value.value")).as("values"), timestampAggregator(tsCombiner))
+ .select(
+ inSingleWindow(
+ keyValue(col("key").as(keyEnc), col("values").as(iterableEnc(valueEnc))),
+ col("window").as(cxt.windowEncoder()),
+ windowTimestamp(tsCombiner)));
+
+ } else if (eligibleForGroupByWindow(windowing, true)
+ && (windowing.getWindowFn().assignsToOneWindow() || transform.fewKeys())) {
+ // Using the window as part of the key should help to better distribute the data. However, if
+ // values are assigned to multiple windows, more data would be shuffled around. If there's few
+ // keys only, this is still valuable.
+ // Produces an iterable that can be traversed exactly once. However, on the plus side, data is
+ // not collected in memory until serialized or done by the user.
+ Encoder<Tuple2<BoundedWindow, K>> windowedKeyEnc =
+ cxt.tupleEncoder(cxt.windowEncoder(), keyEnc);
+ result =
+ cxt.getDataset(cxt.getInput())
+ .flatMap(explodeWindowedKey(valueValue()), cxt.tupleEncoder(windowedKeyEnc, valueEnc))
+ .groupByKey(fun1(Tuple2::_1), windowedKeyEnc)
+ .mapValues(fun1(Tuple2::_2), valueEnc)
+ .mapGroups(
+ fun2((wKey, it) -> windowedKV(wKey, iterableOnce(it))),
+ cxt.windowedEncoder(outputCoder));
+
+ } else {
+ // Collects all values per key in memory. This might be problematic if there's few keys only
+ // or some highly skewed distribution.
+
+ // FIXME Revisit this case, implementation is far from ideal:
+ // - iterator traversed at least twice, forcing materialization in memory
+
+ // group by key, then by windows
+ result =
+ input
+ .groupByKey(valueKey(), keyEnc)
+ .flatMapGroups(
+ new GroupAlsoByWindowViaOutputBufferFn<>(
+ windowing,
+ (SerStateInternalsFactory) key -> InMemoryStateInternals.forKey(key),
+ SystemReduceFn.buffering(inputCoder.getValueCoder()),
+ cxt.getSerializableOptions()),
+ cxt.windowedEncoder(outputCoder));
+ }
+
+ cxt.putDataset(cxt.getOutput(), result);
+ }
+
+ /** Serializable In-memory state internals factory. */
+ private interface SerStateInternalsFactory<K> extends StateInternalsFactory<K>, Serializable {}
+
+ private Encoder<Iterable<V>> iterableEnc(Encoder<V> enc) {
+ // safe to use list encoder with collect list
+ return (Encoder) collectionEncoder(enc);
+ }
+
+ private static Column[] timestampAggregator(TimestampCombiner tsCombiner) {
+ if (tsCombiner.equals(TimestampCombiner.END_OF_WINDOW)) {
+ return new Column[0]; // no aggregation needed
+ }
+ Column agg =
+ tsCombiner.equals(TimestampCombiner.EARLIEST)
+ ? min(col("timestamp"))
+ : max(col("timestamp"));
+ return new Column[] {agg.as("timestamp")};
+ }
+
+ private static Expression windowTimestamp(TimestampCombiner tsCombiner) {
+ if (tsCombiner.equals(TimestampCombiner.END_OF_WINDOW)) {
+ // null will be set to END_OF_WINDOW by the respective deserializer
+ return litNull(DataTypes.LongType);
+ }
+ return col("timestamp").expr();
}
/**
- * In-memory state internals factory.
- *
- * @param <K> State key type.
+ * Java {@link Iterable} from Scala {@link Iterator} that can be iterated just once so that we
+ * don't have to load all data into memory.
*/
- static class InMemoryStateInternalsFactory<K> implements StateInternalsFactory<K>, Serializable {
- @Override
- public StateInternals stateInternalsForKey(K key) {
- return InMemoryStateInternals.forKey(key);
- }
+ private static <T extends @NonNull Object> Iterable<T> iterableOnce(Iterator<T> it) {
+ return () -> {
+ checkState(!it.isEmpty(), "Iterator on values can only be consumed once!");
+ return javaIterator(it);
+ };
+ }
+
+ private <T> TypedColumn<?, KV<K, T>> keyValue(TypedColumn<?, K> key, TypedColumn<?, T> value) {
+ return struct(key.as("key"), value.as("value")).as(kvEncoder(key.encoder(), value.encoder()));
+ }
+
+ private static <InT, T> TypedColumn<InT, WindowedValue<T>> inGlobalWindow(
+ TypedColumn<?, T> value, Expression ts) {
+ List<Expression> fields = concat(timestampedValue(value, ts), GLOBAL_WINDOW_DETAILS);
+ Encoder<WindowedValue<T>> enc =
+ windowedValueEncoder(value.encoder(), encoderOf(GlobalWindow.class));
+ return (TypedColumn<InT, WindowedValue<T>>) new Column(new CreateNamedStruct(fields)).as(enc);
+ }
+
+ public static <InT, T> TypedColumn<InT, WindowedValue<T>> inSingleWindow(
+ TypedColumn<?, T> value, TypedColumn<?, ? extends BoundedWindow> window, Expression ts) {
+ Expression windows = new CreateArray(listOf(window.expr()));
+ Seq<Expression> fields = concat(timestampedValue(value, ts), windowDetails(windows));
+ Encoder<WindowedValue<T>> enc = windowedValueEncoder(value.encoder(), window.encoder());
+ return (TypedColumn<InT, WindowedValue<T>>) new Column(new CreateNamedStruct(fields)).as(enc);
+ }
+
+ private static List<Expression> timestampedValue(Column value, Expression ts) {
+ return seqOf(lit("value"), value.expr(), lit("timestamp"), ts).toList();
+ }
+
+ private static List<Expression> windowDetails(Expression windows) {
+ return seqOf(lit("windows"), windows, lit("pane"), PANE_NO_FIRING).toList();
+ }
+
+ private static <T extends @NonNull Object> Expression lit(T t) {
+ return Literal$.MODULE$.apply(t);
+ }
+
+ @SuppressWarnings("nullness") // NULL literal
+ private static Expression litNull(DataType dataType) {
+ return new Literal(null, dataType);
}
}
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ImpulseTranslatorBatch.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ImpulseTranslatorBatch.java
index 65f496c772b..cf0d2e7ab09 100644
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ImpulseTranslatorBatch.java
+++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ImpulseTranslatorBatch.java
@@ -17,33 +17,27 @@
*/
package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
-import java.util.Collections;
-import org.apache.beam.runners.spark.structuredstreaming.translation.AbstractTranslationContext;
+import static org.apache.beam.repackaged.core.org.apache.commons.lang3.ArrayUtils.EMPTY_BYTE_ARRAY;
+
import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator;
-import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers;
import org.apache.beam.sdk.coders.ByteArrayCoder;
-import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.Impulse;
import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.PBegin;
import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
import org.apache.spark.sql.Dataset;
public class ImpulseTranslatorBatch
- implements TransformTranslator<PTransform<PBegin, PCollection<byte[]>>> {
+ extends TransformTranslator<PBegin, PCollection<byte[]>, Impulse> {
@Override
- public void translateTransform(
- PTransform<PBegin, PCollection<byte[]>> transform, AbstractTranslationContext context) {
- Coder<WindowedValue<byte[]>> windowedValueCoder =
- WindowedValue.FullWindowedValueCoder.of(ByteArrayCoder.of(), GlobalWindow.Coder.INSTANCE);
+ public void translate(Impulse transform, Context cxt) {
Dataset<WindowedValue<byte[]>> dataset =
- context
- .getSparkSession()
- .createDataset(
- Collections.singletonList(WindowedValue.valueInGlobalWindow(new byte[0])),
- EncoderHelpers.fromBeamCoder(windowedValueCoder));
- context.putDataset(context.getOutput(), dataset);
+ cxt.createDataset(
+ ImmutableList.of(WindowedValue.valueInGlobalWindow(EMPTY_BYTE_ARRAY)),
+ cxt.windowedEncoder(ByteArrayCoder.of(), GlobalWindow.Coder.INSTANCE));
+ cxt.putDataset(cxt.getOutput(), dataset);
}
}
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java
index 52c2d5ae642..131b285be13 100644
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java
+++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java
@@ -17,64 +17,81 @@
*/
package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
+import static java.util.stream.Collectors.toList;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.oneOfEncoder;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.fun1;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.tuple;
import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState;
+import static org.apache.spark.sql.functions.col;
+import static org.apache.spark.storage.StorageLevel.MEMORY_ONLY;
import java.io.IOException;
+import java.util.AbstractMap.SimpleImmutableEntry;
import java.util.ArrayList;
-import java.util.HashMap;
+import java.util.Collection;
import java.util.List;
import java.util.Map;
+import java.util.Map.Entry;
+import javax.annotation.Nullable;
+import org.apache.beam.runners.core.DoFnRunners;
import org.apache.beam.runners.core.construction.ParDoTranslation;
-import org.apache.beam.runners.spark.structuredstreaming.metrics.MetricsAccumulator;
-import org.apache.beam.runners.spark.structuredstreaming.metrics.MetricsContainerStepMapAccumulator;
-import org.apache.beam.runners.spark.structuredstreaming.translation.AbstractTranslationContext;
+import org.apache.beam.runners.spark.SparkCommonPipelineOptions;
import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator;
import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.CoderHelpers;
-import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers;
-import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.MultiOutputCoder;
import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.SideInputBroadcast;
import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.coders.SerializableCoder;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.DoFnSchemaInformation;
-import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.PCollectionView;
-import org.apache.beam.sdk.values.PValue;
import org.apache.beam.sdk.values.TupleTag;
-import org.apache.beam.sdk.values.WindowingStrategy;
-import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.api.java.function.FilterFunction;
-import org.apache.spark.api.java.function.MapFunction;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Streams;
+import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.TypedColumn;
+import org.apache.spark.storage.StorageLevel;
+import scala.Function1;
import scala.Tuple2;
+import scala.collection.Iterator;
+import scala.reflect.ClassTag;
/**
- * TODO: Add support for state and timers.
+ * Translator for {@link ParDo.MultiOutput} based on {@link DoFnRunners#simpleRunner}.
*
- * @param <InputT>
- * @param <OutputT>
+ * <p>Each tag is encoded as individual column with a respective schema & encoder each.
+ *
+ * <p>TODO:
+ * <li>Add support for state and timers.
+ * <li>Add support for SplittableDoFn
*/
-@SuppressWarnings({
- "rawtypes", // TODO(https://github.com/apache/beam/issues/20447)
- "nullness" // TODO(https://github.com/apache/beam/issues/20497)
-})
class ParDoTranslatorBatch<InputT, OutputT>
- implements TransformTranslator<PTransform<PCollection<InputT>, PCollectionTuple>> {
+ extends TransformTranslator<
+ PCollection<? extends InputT>, PCollectionTuple, ParDo.MultiOutput<InputT, OutputT>> {
+
+ private static final ClassTag<WindowedValue<Object>> WINDOWED_VALUE_CTAG =
+ ClassTag.apply(WindowedValue.class);
+
+ private static final ClassTag<Tuple2<Integer, WindowedValue<Object>>> TUPLE2_CTAG =
+ ClassTag.apply(Tuple2.class);
@Override
- public void translateTransform(
- PTransform<PCollection<InputT>, PCollectionTuple> transform,
- AbstractTranslationContext context) {
- String stepName = context.getCurrentTransform().getFullName();
+ public void translate(ParDo.MultiOutput<InputT, OutputT> transform, Context cxt)
+ throws IOException {
+ String stepName = cxt.getCurrentTransform().getFullName();
+
+ SparkCommonPipelineOptions opts = cxt.getOptions().as(SparkCommonPipelineOptions.class);
+ StorageLevel storageLevel = StorageLevel.fromString(opts.getStorageLevel());
// Check for not supported advanced features
// TODO: add support of Splittable DoFn
- DoFn<InputT, OutputT> doFn = getDoFn(context);
+ DoFn<InputT, OutputT> doFn = transform.getFn();
checkState(
!DoFnSignatures.isSplittable(doFn),
"Not expected to directly translate splittable DoFn, should have been overridden: %s",
@@ -86,98 +103,124 @@ class ParDoTranslatorBatch<InputT, OutputT>
checkState(
!DoFnSignatures.requiresTimeSortedInput(doFn),
- "@RequiresTimeSortedInput is not " + "supported for the moment");
-
- DoFnSchemaInformation doFnSchemaInformation =
- ParDoTranslation.getSchemaInformation(context.getCurrentTransform());
-
- // Init main variables
- PValue input = context.getInput();
- Dataset<WindowedValue<InputT>> inputDataSet = context.getDataset(input);
- Map<TupleTag<?>, PCollection<?>> outputs = context.getOutputs();
- TupleTag<?> mainOutputTag = getTupleTag(context);
- List<TupleTag<?>> outputTags = new ArrayList<>(outputs.keySet());
- WindowingStrategy<?, ?> windowingStrategy =
- ((PCollection<InputT>) input).getWindowingStrategy();
- Coder<InputT> inputCoder = ((PCollection<InputT>) input).getCoder();
- Coder<? extends BoundedWindow> windowCoder = windowingStrategy.getWindowFn().windowCoder();
-
- // construct a map from side input to WindowingStrategy so that
- // the DoFn runner can map main-input windows to side input windows
- List<PCollectionView<?>> sideInputs = getSideInputs(context);
- Map<PCollectionView<?>, WindowingStrategy<?, ?>> sideInputStrategies = new HashMap<>();
- for (PCollectionView<?> sideInput : sideInputs) {
- sideInputStrategies.put(sideInput, sideInput.getPCollection().getWindowingStrategy());
- }
-
- SideInputBroadcast broadcastStateData = createBroadcastSideInputs(sideInputs, context);
+ "@RequiresTimeSortedInput is not supported for the moment");
- Map<TupleTag<?>, Coder<?>> outputCoderMap = context.getOutputCoders();
- MetricsContainerStepMapAccumulator metricsAccum = MetricsAccumulator.getInstance();
+ TupleTag<OutputT> mainOutputTag = transform.getMainOutputTag();
- List<TupleTag<?>> additionalOutputTags = new ArrayList<>();
- for (TupleTag<?> tag : outputTags) {
- if (!tag.equals(mainOutputTag)) {
- additionalOutputTags.add(tag);
- }
- }
+ DoFnSchemaInformation doFnSchema =
+ ParDoTranslation.getSchemaInformation(cxt.getCurrentTransform());
- Map<String, PCollectionView<?>> sideInputMapping =
- ParDoTranslation.getSideInputMapping(context.getCurrentTransform());
- @SuppressWarnings("unchecked")
- DoFnFunction<InputT, OutputT> doFnWrapper =
- new DoFnFunction(
- metricsAccum,
+ PCollection<InputT> input = (PCollection<InputT>) cxt.getInput();
+ DoFnMapPartitionsFactory<InputT, OutputT> factory =
+ new DoFnMapPartitionsFactory<>(
stepName,
doFn,
- windowingStrategy,
- sideInputStrategies,
- context.getSerializableOptions(),
- additionalOutputTags,
+ doFnSchema,
+ cxt.getSerializableOptions(),
+ input,
mainOutputTag,
- inputCoder,
- outputCoderMap,
- broadcastStateData,
- doFnSchemaInformation,
- sideInputMapping);
-
- MultiOutputCoder multipleOutputCoder =
- MultiOutputCoder.of(SerializableCoder.of(TupleTag.class), outputCoderMap, windowCoder);
- Dataset<Tuple2<TupleTag<?>, WindowedValue<?>>> allOutputs =
- inputDataSet.mapPartitions(doFnWrapper, EncoderHelpers.fromBeamCoder(multipleOutputCoder));
- if (outputs.entrySet().size() > 1) {
- allOutputs.persist();
- for (Map.Entry<TupleTag<?>, PCollection<?>> output : outputs.entrySet()) {
- pruneOutputFilteredByTag(context, allOutputs, output, windowCoder);
+ cxt.getOutputs(),
+ transform.getSideInputs(),
+ createBroadcastSideInputs(transform.getSideInputs().values(), cxt));
+
+ Dataset<WindowedValue<InputT>> inputDs = cxt.getDataset(input);
+ if (cxt.getOutputs().size() > 1) {
+ // In case of multiple outputs / tags, map each tag to a column by index.
+ // At the end split the result into multiple datasets selecting one column each.
+ Map<TupleTag<?>, Integer> tags = ImmutableMap.copyOf(zipwithIndex(cxt.getOutputs().keySet()));
+
+ List<Encoder<WindowedValue<Object>>> encoders =
+ createEncoders(cxt.getOutputs(), (Iterable<TupleTag<?>>) tags.keySet(), cxt);
+
+ Function1<Iterator<WindowedValue<InputT>>, Iterator<Tuple2<Integer, WindowedValue<Object>>>>
+ doFnMapper = factory.create((tag, v) -> tuple(tags.get(tag), (WindowedValue<Object>) v));
+
+ // FIXME What's the strategy to unpersist Datasets / RDDs?
+
+ // If using storage level MEMORY_ONLY, it's best to persist the dataset as RDD to avoid any
+ // serialization / use of encoders. Persisting a Dataset, even if using a "deserialized"
+ // storage level, involves converting the data to the internal representation (InternalRow)
+ // by use of an encoder.
+ // For any other storage level, persist as Dataset, so we can select columns by TupleTag
+ // individually without restoring the entire row.
+ if (MEMORY_ONLY().equals(storageLevel)) {
+
+ RDD<Tuple2<Integer, WindowedValue<Object>>> allTagsRDD =
+ inputDs.rdd().mapPartitions(doFnMapper, false, TUPLE2_CTAG);
+ allTagsRDD.persist();
+
+ // divide into separate output datasets per tag
+ for (Entry<TupleTag<?>, Integer> e : tags.entrySet()) {
+ TupleTag<Object> key = (TupleTag<Object>) e.getKey();
+ Integer id = e.getValue();
+
+ RDD<WindowedValue<Object>> rddByTag =
+ allTagsRDD
+ .filter(fun1(t -> t._1.equals(id)))
+ .map(fun1(Tuple2::_2), WINDOWED_VALUE_CTAG);
+
+ cxt.putDataset(
+ cxt.getOutput(key), cxt.getSparkSession().createDataset(rddByTag, encoders.get(id)));
+ }
+ } else {
+ // Persist as wide rows with one column per TupleTag to support different schemas
+ Dataset<Tuple2<Integer, WindowedValue<Object>>> allTagsDS =
+ inputDs.mapPartitions(doFnMapper, oneOfEncoder(encoders));
+ allTagsDS.persist(storageLevel);
+
+ // divide into separate output datasets per tag
+ for (Entry<TupleTag<?>, Integer> e : tags.entrySet()) {
+ TupleTag<Object> key = (TupleTag<Object>) e.getKey();
+ Integer id = e.getValue();
+
+ // Resolve specific column matching the tuple tag (by id)
+ TypedColumn<Tuple2<Integer, WindowedValue<Object>>, WindowedValue<Object>> col =
+ (TypedColumn) col(id.toString()).as(encoders.get(id));
+
+ cxt.putDataset(cxt.getOutput(key), allTagsDS.filter(col.isNotNull()).select(col));
+ }
}
} else {
- Coder<OutputT> outputCoder = ((PCollection<OutputT>) outputs.get(mainOutputTag)).getCoder();
- Coder<WindowedValue<?>> windowedValueCoder =
- (Coder<WindowedValue<?>>) (Coder<?>) WindowedValue.getFullCoder(outputCoder, windowCoder);
- Dataset<WindowedValue<?>> outputDataset =
- allOutputs.map(
- (MapFunction<Tuple2<TupleTag<?>, WindowedValue<?>>, WindowedValue<?>>)
- value -> value._2,
- EncoderHelpers.fromBeamCoder(windowedValueCoder));
- context.putDatasetWildcard(outputs.entrySet().iterator().next().getValue(), outputDataset);
+ PCollection<OutputT> output = cxt.getOutput(mainOutputTag);
+ Dataset<WindowedValue<OutputT>> mainDS =
+ inputDs.mapPartitions(
+ factory.create((tag, value) -> (WindowedValue<OutputT>) value),
+ cxt.windowedEncoder(output.getCoder()));
+
+ cxt.putDataset(output, mainDS);
+ }
+ }
+
+ private List<Encoder<WindowedValue<Object>>> createEncoders(
+ Map<TupleTag<?>, PCollection<?>> outputs, Iterable<TupleTag<?>> columns, Context ctx) {
+ return Streams.stream(columns)
+ .map(tag -> ctx.windowedEncoder(getCoder(outputs.get(tag), tag)))
+ .collect(toList());
+ }
+
+ private Coder<Object> getCoder(@Nullable PCollection<?> pc, TupleTag<?> tag) {
+ if (pc == null) {
+ throw new NullPointerException("No PCollection for tag " + tag);
}
+ return (Coder<Object>) pc.getCoder();
}
- private static SideInputBroadcast createBroadcastSideInputs(
- List<PCollectionView<?>> sideInputs, AbstractTranslationContext context) {
- JavaSparkContext jsc =
- JavaSparkContext.fromSparkContext(context.getSparkSession().sparkContext());
+ // FIXME Better ways?
+ private SideInputBroadcast createBroadcastSideInputs(
+ Collection<PCollectionView<?>> sideInputs, Context context) {
SideInputBroadcast sideInputBroadcast = new SideInputBroadcast();
for (PCollectionView<?> sideInput : sideInputs) {
+ PCollection<?> pc = sideInput.getPCollection();
+ if (pc == null) {
+ throw new NullPointerException("PCollection for SideInput is null");
+ }
Coder<? extends BoundedWindow> windowCoder =
- sideInput.getPCollection().getWindowingStrategy().getWindowFn().windowCoder();
-
+ pc.getWindowingStrategy().getWindowFn().windowCoder();
Coder<WindowedValue<?>> windowedValueCoder =
(Coder<WindowedValue<?>>)
- (Coder<?>)
- WindowedValue.getFullCoder(sideInput.getPCollection().getCoder(), windowCoder);
- Dataset<WindowedValue<?>> broadcastSet = context.getSideInputDataSet(sideInput);
+ (Coder<?>) WindowedValue.getFullCoder(pc.getCoder(), windowCoder);
+ Dataset<WindowedValue<?>> broadcastSet = context.getSideInputDataset(sideInput);
List<WindowedValue<?>> valuesList = broadcastSet.collectAsList();
List<byte[]> codedValues = new ArrayList<>();
for (WindowedValue<?> v : valuesList) {
@@ -185,73 +228,17 @@ class ParDoTranslatorBatch<InputT, OutputT>
}
sideInputBroadcast.add(
- sideInput.getTagInternal().getId(), jsc.broadcast(codedValues), windowedValueCoder);
+ sideInput.getTagInternal().getId(), context.broadcast(codedValues), windowedValueCoder);
}
return sideInputBroadcast;
}
- private List<PCollectionView<?>> getSideInputs(AbstractTranslationContext context) {
- List<PCollectionView<?>> sideInputs;
- try {
- sideInputs = ParDoTranslation.getSideInputs(context.getCurrentTransform());
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
- return sideInputs;
- }
-
- private TupleTag<?> getTupleTag(AbstractTranslationContext context) {
- TupleTag<?> mainOutputTag;
- try {
- mainOutputTag = ParDoTranslation.getMainOutputTag(context.getCurrentTransform());
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
- return mainOutputTag;
- }
-
- @SuppressWarnings("unchecked")
- private DoFn<InputT, OutputT> getDoFn(AbstractTranslationContext context) {
- DoFn<InputT, OutputT> doFn;
- try {
- doFn = (DoFn<InputT, OutputT>) ParDoTranslation.getDoFn(context.getCurrentTransform());
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
- return doFn;
- }
-
- private void pruneOutputFilteredByTag(
- AbstractTranslationContext context,
- Dataset<Tuple2<TupleTag<?>, WindowedValue<?>>> allOutputs,
- Map.Entry<TupleTag<?>, PCollection<?>> output,
- Coder<? extends BoundedWindow> windowCoder) {
- Dataset<Tuple2<TupleTag<?>, WindowedValue<?>>> filteredDataset =
- allOutputs.filter(new DoFnFilterFunction(output.getKey()));
- Coder<WindowedValue<?>> windowedValueCoder =
- (Coder<WindowedValue<?>>)
- (Coder<?>)
- WindowedValue.getFullCoder(
- ((PCollection<OutputT>) output.getValue()).getCoder(), windowCoder);
- Dataset<WindowedValue<?>> outputDataset =
- filteredDataset.map(
- (MapFunction<Tuple2<TupleTag<?>, WindowedValue<?>>, WindowedValue<?>>)
- value -> value._2,
- EncoderHelpers.fromBeamCoder(windowedValueCoder));
- context.putDatasetWildcard(output.getValue(), outputDataset);
- }
-
- static class DoFnFilterFunction implements FilterFunction<Tuple2<TupleTag<?>, WindowedValue<?>>> {
-
- private final TupleTag<?> key;
-
- DoFnFilterFunction(TupleTag<?> key) {
- this.key = key;
- }
-
- @Override
- public boolean call(Tuple2<TupleTag<?>, WindowedValue<?>> value) {
- return value._1.equals(key);
+ private static <T> Collection<Entry<T, Integer>> zipwithIndex(Collection<T> col) {
+ ArrayList<Entry<T, Integer>> zipped = new ArrayList<>(col.size());
+ int i = 0;
+ for (T t : col) {
+ zipped.add(new SimpleImmutableEntry<>(t, i++));
}
+ return zipped;
}
}
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/PipelineTranslatorBatch.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/PipelineTranslatorBatch.java
index 5789db6cd30..62d79632cfa 100644
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/PipelineTranslatorBatch.java
+++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/PipelineTranslatorBatch.java
@@ -25,7 +25,6 @@ import org.apache.beam.runners.spark.structuredstreaming.translation.PipelineTra
import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator;
import org.apache.beam.runners.spark.structuredstreaming.translation.TranslationContext;
import org.apache.beam.sdk.Pipeline;
-import org.apache.beam.sdk.runners.TransformHierarchy;
import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.transforms.Flatten;
import org.apache.beam.sdk.transforms.GroupByKey;
@@ -34,6 +33,8 @@ import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.View;
import org.apache.beam.sdk.transforms.windowing.Window;
+import org.apache.beam.sdk.values.PInput;
+import org.apache.beam.sdk.values.POutput;
import org.checkerframework.checker.nullness.qual.Nullable;
/**
@@ -41,10 +42,6 @@ import org.checkerframework.checker.nullness.qual.Nullable;
* only the components specific to batch: registry of batch {@link TransformTranslator} and registry
* lookup code.
*/
-@SuppressWarnings({
- "rawtypes", // TODO(https://github.com/apache/beam/issues/20447)
- "nullness" // TODO(https://github.com/apache/beam/issues/20497)
-})
public class PipelineTranslatorBatch extends PipelineTranslator {
// --------------------------------------------------------------------------------------------
@@ -65,23 +62,24 @@ public class PipelineTranslatorBatch extends PipelineTranslator {
static {
TRANSFORM_TRANSLATORS.put(Impulse.class, new ImpulseTranslatorBatch());
- TRANSFORM_TRANSLATORS.put(Combine.PerKey.class, new CombinePerKeyTranslatorBatch());
- TRANSFORM_TRANSLATORS.put(GroupByKey.class, new GroupByKeyTranslatorBatch());
+ TRANSFORM_TRANSLATORS.put(Combine.PerKey.class, new CombinePerKeyTranslatorBatch<>());
+ TRANSFORM_TRANSLATORS.put(Combine.Globally.class, new CombineGloballyTranslatorBatch<>());
+ TRANSFORM_TRANSLATORS.put(GroupByKey.class, new GroupByKeyTranslatorBatch<>());
// TODO: Do we need to have a dedicated translator for {@code Reshuffle} if it's deprecated?
// TRANSFORM_TRANSLATORS.put(Reshuffle.class, new ReshuffleTranslatorBatch());
- TRANSFORM_TRANSLATORS.put(Flatten.PCollections.class, new FlattenTranslatorBatch());
+ TRANSFORM_TRANSLATORS.put(Flatten.PCollections.class, new FlattenTranslatorBatch<>());
- TRANSFORM_TRANSLATORS.put(Window.Assign.class, new WindowAssignTranslatorBatch());
+ TRANSFORM_TRANSLATORS.put(Window.Assign.class, new WindowAssignTranslatorBatch<>());
- TRANSFORM_TRANSLATORS.put(ParDo.MultiOutput.class, new ParDoTranslatorBatch());
+ TRANSFORM_TRANSLATORS.put(ParDo.MultiOutput.class, new ParDoTranslatorBatch<>());
TRANSFORM_TRANSLATORS.put(
- SplittableParDo.PrimitiveBoundedRead.class, new ReadSourceTranslatorBatch());
+ SplittableParDo.PrimitiveBoundedRead.class, new ReadSourceTranslatorBatch<>());
TRANSFORM_TRANSLATORS.put(
- View.CreatePCollectionView.class, new CreatePCollectionViewTranslatorBatch());
+ View.CreatePCollectionView.class, new CreatePCollectionViewTranslatorBatch<>());
}
public PipelineTranslatorBatch(SparkStructuredStreamingPipelineOptions options) {
@@ -90,8 +88,10 @@ public class PipelineTranslatorBatch extends PipelineTranslator {
/** Returns a translator for the given node, if it is possible, otherwise null. */
@Override
- protected TransformTranslator<?> getTransformTranslator(TransformHierarchy.Node node) {
- @Nullable PTransform<?, ?> transform = node.getTransform();
+ @Nullable
+ protected <InT extends PInput, OutT extends POutput, TransformT extends PTransform<InT, OutT>>
+ TransformTranslator<InT, OutT, TransformT> getTransformTranslator(
+ @Nullable TransformT transform) {
// Root of the graph is null
if (transform == null) {
return null;
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ProcessContext.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ProcessContext.java
deleted file mode 100644
index db64bfd19f3..00000000000
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ProcessContext.java
+++ /dev/null
@@ -1,138 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
-
-import java.util.ArrayList;
-import java.util.Iterator;
-import org.apache.beam.runners.core.DoFnRunner;
-import org.apache.beam.runners.core.DoFnRunners.OutputManager;
-import org.apache.beam.runners.core.TimerInternals;
-import org.apache.beam.sdk.transforms.DoFn;
-import org.apache.beam.sdk.transforms.reflect.DoFnInvokers;
-import org.apache.beam.sdk.util.WindowedValue;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.AbstractIterator;
-
-/** Spark runner process context processes Spark partitions using Beam's {@link DoFnRunner}. */
-class ProcessContext<FnInputT, FnOutputT, OutputT> {
-
- private final DoFn<FnInputT, FnOutputT> doFn;
- private final DoFnRunner<FnInputT, FnOutputT> doFnRunner;
- private final ProcessOutputManager<OutputT> outputManager;
- private final Iterator<TimerInternals.TimerData> timerDataIterator;
-
- ProcessContext(
- DoFn<FnInputT, FnOutputT> doFn,
- DoFnRunner<FnInputT, FnOutputT> doFnRunner,
- ProcessOutputManager<OutputT> outputManager,
- Iterator<TimerInternals.TimerData> timerDataIterator) {
-
- this.doFn = doFn;
- this.doFnRunner = doFnRunner;
- this.outputManager = outputManager;
- this.timerDataIterator = timerDataIterator;
- }
-
- Iterable<OutputT> processPartition(Iterator<WindowedValue<FnInputT>> partition) {
-
- // skip if partition is empty.
- if (!partition.hasNext()) {
- return new ArrayList<>();
- }
-
- // process the partition; finishBundle() is called from within the output iterator.
- return this.getOutputIterable(partition, doFnRunner);
- }
-
- private void clearOutput() {
- outputManager.clear();
- }
-
- private Iterator<OutputT> getOutputIterator() {
- return outputManager.iterator();
- }
-
- private Iterable<OutputT> getOutputIterable(
- final Iterator<WindowedValue<FnInputT>> iter,
- final DoFnRunner<FnInputT, FnOutputT> doFnRunner) {
- return () -> new ProcCtxtIterator(iter, doFnRunner);
- }
-
- interface ProcessOutputManager<T> extends OutputManager, Iterable<T> {
- void clear();
- }
-
- private class ProcCtxtIterator extends AbstractIterator<OutputT> {
-
- private final Iterator<WindowedValue<FnInputT>> inputIterator;
- private final DoFnRunner<FnInputT, FnOutputT> doFnRunner;
- private Iterator<OutputT> outputIterator;
- private boolean isBundleStarted;
- private boolean isBundleFinished;
-
- ProcCtxtIterator(
- Iterator<WindowedValue<FnInputT>> iterator, DoFnRunner<FnInputT, FnOutputT> doFnRunner) {
- this.inputIterator = iterator;
- this.doFnRunner = doFnRunner;
- this.outputIterator = getOutputIterator();
- }
-
- @Override
- protected OutputT computeNext() {
- try {
- // Process each element from the (input) iterator, which produces, zero, one or more
- // output elements (of type V) in the output iterator. Note that the output
- // collection (and iterator) is reset between each call to processElement, so the
- // collection only holds the output values for each call to processElement, rather
- // than for the whole partition (which would use too much memory).
- if (!isBundleStarted) {
- isBundleStarted = true;
- // call startBundle() before beginning to process the partition.
- doFnRunner.startBundle();
- }
-
- while (true) {
- if (outputIterator.hasNext()) {
- return outputIterator.next();
- }
-
- clearOutput();
- if (inputIterator.hasNext()) {
- // grab the next element and process it.
- doFnRunner.processElement(inputIterator.next());
- outputIterator = getOutputIterator();
- } else if (timerDataIterator.hasNext()) {
- outputIterator = getOutputIterator();
- } else {
- // no more input to consume, but finishBundle can produce more output
- if (!isBundleFinished) {
- isBundleFinished = true;
- doFnRunner.finishBundle();
- outputIterator = getOutputIterator();
- continue; // try to consume outputIterator from start of loop
- }
- DoFnInvokers.invokerFor(doFn).invokeTeardown();
- return endOfData();
- }
- }
- } catch (final RuntimeException re) {
- DoFnInvokers.invokerFor(doFn).invokeTeardown();
- throw re;
- }
- }
- }
-}
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ReadSourceTranslatorBatch.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ReadSourceTranslatorBatch.java
index ebeb8a96eda..30b599c7e5e 100644
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ReadSourceTranslatorBatch.java
+++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ReadSourceTranslatorBatch.java
@@ -17,72 +17,38 @@
*/
package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
-import static org.apache.beam.runners.spark.structuredstreaming.Constants.BEAM_SOURCE_OPTION;
-import static org.apache.beam.runners.spark.structuredstreaming.Constants.DEFAULT_PARALLELISM;
-import static org.apache.beam.runners.spark.structuredstreaming.Constants.PIPELINE_OPTIONS;
-
import java.io.IOException;
-import org.apache.beam.runners.core.construction.ReadTranslation;
-import org.apache.beam.runners.core.serialization.Base64Serializer;
-import org.apache.beam.runners.spark.structuredstreaming.translation.AbstractTranslationContext;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
+import org.apache.beam.runners.core.construction.SplittableParDo;
+import org.apache.beam.runners.spark.structuredstreaming.io.BoundedDatasetFactory;
import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator;
-import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers;
-import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.RowHelpers;
import org.apache.beam.sdk.io.BoundedSource;
-import org.apache.beam.sdk.runners.AppliedPTransform;
-import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.PBegin;
import org.apache.beam.sdk.values.PCollection;
-import org.apache.spark.sql.Dataset;
-import org.apache.spark.sql.Row;
+import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.SparkSession;
+/**
+ * Translator for a {@link SplittableParDo.PrimitiveBoundedRead} that creates a Dataset via an RDD
+ * to avoid an additional serialization roundtrip.
+ */
class ReadSourceTranslatorBatch<T>
- implements TransformTranslator<PTransform<PBegin, PCollection<T>>> {
+ extends TransformTranslator<PBegin, PCollection<T>, SplittableParDo.PrimitiveBoundedRead<T>> {
- private static final String sourceProviderClass = DatasetSourceBatch.class.getCanonicalName();
-
- @SuppressWarnings("unchecked")
@Override
- public void translateTransform(
- PTransform<PBegin, PCollection<T>> transform, AbstractTranslationContext context) {
- AppliedPTransform<PBegin, PCollection<T>, PTransform<PBegin, PCollection<T>>> rootTransform =
- (AppliedPTransform<PBegin, PCollection<T>, PTransform<PBegin, PCollection<T>>>)
- context.getCurrentTransform();
-
- BoundedSource<T> source;
- try {
- source = ReadTranslation.boundedSourceFromTransform(rootTransform);
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
- SparkSession sparkSession = context.getSparkSession();
-
- String serializedSource = Base64Serializer.serializeUnchecked(source);
- Dataset<Row> rowDataset =
- sparkSession
- .read()
- .format(sourceProviderClass)
- .option(BEAM_SOURCE_OPTION, serializedSource)
- .option(
- DEFAULT_PARALLELISM,
- String.valueOf(context.getSparkSession().sparkContext().defaultParallelism()))
- .option(PIPELINE_OPTIONS, context.getSerializableOptions().toString())
- .load();
-
- // extract windowedValue from Row
- WindowedValue.FullWindowedValueCoder<T> windowedValueCoder =
- WindowedValue.FullWindowedValueCoder.of(
- source.getOutputCoder(), GlobalWindow.Coder.INSTANCE);
-
- Dataset<WindowedValue<T>> dataset =
- rowDataset.map(
- RowHelpers.extractWindowedValueFromRowMapFunction(windowedValueCoder),
- EncoderHelpers.fromBeamCoder(windowedValueCoder));
-
- PCollection<T> output = (PCollection<T>) context.getOutput();
- context.putDataset(output, dataset);
+ public void translate(SplittableParDo.PrimitiveBoundedRead<T> transform, Context cxt)
+ throws IOException {
+ SparkSession session = cxt.getSparkSession();
+ BoundedSource<T> source = transform.getSource();
+ SerializablePipelineOptions options = cxt.getSerializableOptions();
+
+ Encoder<WindowedValue<T>> encoder =
+ cxt.windowedEncoder(source.getOutputCoder(), GlobalWindow.Coder.INSTANCE);
+
+ cxt.putDataset(
+ cxt.getOutput(),
+ BoundedDatasetFactory.createDatasetFromRDD(session, source, options, encoder));
}
}
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ReshuffleTranslatorBatch.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ReshuffleTranslatorBatch.java
deleted file mode 100644
index a88d5454667..00000000000
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ReshuffleTranslatorBatch.java
+++ /dev/null
@@ -1,30 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
-
-import org.apache.beam.runners.spark.structuredstreaming.translation.AbstractTranslationContext;
-import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator;
-import org.apache.beam.sdk.transforms.Reshuffle;
-
-/** TODO: Should be removed if {@link Reshuffle} won't be translated. */
-class ReshuffleTranslatorBatch<K, InputT> implements TransformTranslator<Reshuffle<K, InputT>> {
-
- @Override
- public void translateTransform(
- Reshuffle<K, InputT> transform, AbstractTranslationContext context) {}
-}
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/WindowAssignTranslatorBatch.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/WindowAssignTranslatorBatch.java
index 875a983b401..3b993a3ce19 100644
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/WindowAssignTranslatorBatch.java
+++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/WindowAssignTranslatorBatch.java
@@ -17,45 +17,83 @@
*/
package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
-import org.apache.beam.runners.spark.structuredstreaming.translation.AbstractTranslationContext;
+import java.util.Collection;
import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator;
-import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers;
-import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.WindowingHelpers;
-import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.GlobalWindows;
import org.apache.beam.sdk.transforms.windowing.Window;
import org.apache.beam.sdk.transforms.windowing.WindowFn;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
+import org.apache.spark.api.java.function.MapFunction;
import org.apache.spark.sql.Dataset;
+import org.checkerframework.checker.nullness.qual.NonNull;
+import org.joda.time.Instant;
-@SuppressWarnings({
- "nullness" // TODO(https://github.com/apache/beam/issues/20497)
-})
class WindowAssignTranslatorBatch<T>
- implements TransformTranslator<PTransform<PCollection<T>, PCollection<T>>> {
+ extends TransformTranslator<PCollection<T>, PCollection<T>, Window.Assign<T>> {
@Override
- public void translateTransform(
- PTransform<PCollection<T>, PCollection<T>> transform, AbstractTranslationContext context) {
-
- Window.Assign<T> assignTransform = (Window.Assign<T>) transform;
- @SuppressWarnings("unchecked")
- final PCollection<T> input = (PCollection<T>) context.getInput();
- @SuppressWarnings("unchecked")
- final PCollection<T> output = (PCollection<T>) context.getOutput();
-
- Dataset<WindowedValue<T>> inputDataset = context.getDataset(input);
- if (WindowingHelpers.skipAssignWindows(assignTransform, context)) {
- context.putDataset(output, inputDataset);
+ public void translate(Window.Assign<T> transform, Context cxt) {
+ WindowFn<T, ?> windowFn = transform.getWindowFn();
+ PCollection<T> input = cxt.getInput();
+ Dataset<WindowedValue<T>> inputDataset = cxt.getDataset(input);
+
+ if (windowFn == null || skipAssignWindows(windowFn, input)) {
+ cxt.putDataset(cxt.getOutput(), inputDataset);
} else {
- WindowFn<T, ?> windowFn = assignTransform.getWindowFn();
- WindowedValue.FullWindowedValueCoder<T> windowedValueCoder =
- WindowedValue.FullWindowedValueCoder.of(input.getCoder(), windowFn.windowCoder());
Dataset<WindowedValue<T>> outputDataset =
inputDataset.map(
- WindowingHelpers.assignWindowsMapFunction(windowFn),
- EncoderHelpers.fromBeamCoder(windowedValueCoder));
- context.putDataset(output, outputDataset);
+ assignWindows(windowFn),
+ cxt.windowedEncoder(input.getCoder(), windowFn.windowCoder()));
+
+ cxt.putDataset(cxt.getOutput(), outputDataset);
}
}
+
+ /**
+ * Checks if the window transformation should be applied or skipped.
+ *
+ * <p>Avoid running assign windows if both source and destination are global window or if the user
+ * has not specified the WindowFn (meaning they are just messing with triggering or allowed
+ * lateness).
+ */
+ private boolean skipAssignWindows(WindowFn<T, ?> newFn, PCollection<T> input) {
+ WindowFn<?, ?> currentFn = input.getWindowingStrategy().getWindowFn();
+ return currentFn instanceof GlobalWindows && newFn instanceof GlobalWindows;
+ }
+
+ private static <T, W extends @NonNull BoundedWindow>
+ MapFunction<WindowedValue<T>, WindowedValue<T>> assignWindows(WindowFn<T, W> windowFn) {
+ return value -> {
+ final BoundedWindow window = getOnlyWindow(value);
+ final T element = value.getValue();
+ final Instant timestamp = value.getTimestamp();
+ Collection<W> windows =
+ windowFn.assignWindows(
+ windowFn.new AssignContext() {
+
+ @Override
+ public T element() {
+ return element;
+ }
+
+ @Override
+ public @NonNull Instant timestamp() {
+ return timestamp;
+ }
+
+ @Override
+ public @NonNull BoundedWindow window() {
+ return window;
+ }
+ });
+ return WindowedValue.of(element, timestamp, windows, value.getPane());
+ };
+ }
+
+ private static <T> BoundedWindow getOnlyWindow(WindowedValue<T> wv) {
+ return Iterables.getOnlyElement((Iterable<BoundedWindow>) wv.getWindows());
+ }
}
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/CoderHelpers.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/CoderHelpers.java
index fe3f39ef51e..f8c63bc34f1 100644
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/CoderHelpers.java
+++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/CoderHelpers.java
@@ -17,10 +17,9 @@
*/
package org.apache.beam.runners.spark.structuredstreaming.translation.helpers;
-import java.io.ByteArrayInputStream;
-import java.io.ByteArrayOutputStream;
import java.io.IOException;
import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.util.CoderUtils;
/** Serialization utility class. */
public final class CoderHelpers {
@@ -35,13 +34,11 @@ public final class CoderHelpers {
* @return Byte array representing serialized object.
*/
public static <T> byte[] toByteArray(T value, Coder<T> coder) {
- ByteArrayOutputStream baos = new ByteArrayOutputStream();
try {
- coder.encode(value, baos);
+ return CoderUtils.encodeToByteArray(coder, value);
} catch (IOException e) {
throw new IllegalStateException("Error encoding value: " + value, e);
}
- return baos.toByteArray();
}
/**
@@ -53,9 +50,8 @@ public final class CoderHelpers {
* @return Deserialized object.
*/
public static <T> T fromByteArray(byte[] serialized, Coder<T> coder) {
- ByteArrayInputStream bais = new ByteArrayInputStream(serialized);
try {
- return coder.decode(bais);
+ return CoderUtils.decodeFromByteArray(coder, serialized);
} catch (IOException e) {
throw new IllegalStateException("Error decoding bytes for coder: " + coder, e);
}
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderFactory.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderFactory.java
index c7d69c0b8ad..e70cc7253f8 100644
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderFactory.java
+++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderFactory.java
@@ -17,13 +17,17 @@
*/
package org.apache.beam.runners.spark.structuredstreaming.translation.helpers;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.emptyList;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.seqOf;
+
import java.lang.reflect.Constructor;
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;
import org.apache.spark.sql.catalyst.expressions.Expression;
+import org.apache.spark.sql.catalyst.expressions.objects.Invoke;
+import org.apache.spark.sql.catalyst.expressions.objects.NewInstance;
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke;
import org.apache.spark.sql.types.DataType;
-import scala.collection.immutable.Nil$;
-import scala.collection.mutable.WrappedArray;
+import scala.Option;
import scala.reflect.ClassTag;
public class EncoderFactory {
@@ -31,6 +35,12 @@ public class EncoderFactory {
private static final Constructor<StaticInvoke> STATIC_INVOKE_CONSTRUCTOR =
(Constructor<StaticInvoke>) StaticInvoke.class.getConstructors()[0];
+ private static final Constructor<Invoke> INVOKE_CONSTRUCTOR =
+ (Constructor<Invoke>) Invoke.class.getConstructors()[0];
+
+ private static final Constructor<NewInstance> NEW_INSTANCE_CONSTRUCTOR =
+ (Constructor<NewInstance>) NewInstance.class.getConstructors()[0];
+
static <T> ExpressionEncoder<T> create(
Expression serializer, Expression deserializer, Class<? super T> clazz) {
return new ExpressionEncoder<>(serializer, deserializer, ClassTag.apply(clazz));
@@ -39,21 +49,68 @@ public class EncoderFactory {
/**
* Invoke method {@code fun} on Class {@code cls}, immediately propagating {@code null} if any
* input arg is {@code null}.
- *
- * <p>To address breaking interfaces between various version of Spark 3 these are created
- * reflectively. This is fine as it's just needed once to create the query plan.
*/
static Expression invokeIfNotNull(Class<?> cls, String fun, DataType type, Expression... args) {
+ return invoke(cls, fun, type, true, args);
+ }
+
+ /** Invoke method {@code fun} on Class {@code cls}. */
+ static Expression invoke(Class<?> cls, String fun, DataType type, Expression... args) {
+ return invoke(cls, fun, type, false, args);
+ }
+
+ private static Expression invoke(
+ Class<?> cls, String fun, DataType type, boolean propagateNull, Expression... args) {
try {
+ // To address breaking interfaces between various version of Spark 3, expressions are
+ // created reflectively. This is fine as it's just needed once to create the query plan.
switch (STATIC_INVOKE_CONSTRUCTOR.getParameterCount()) {
case 6:
// Spark 3.1.x
return STATIC_INVOKE_CONSTRUCTOR.newInstance(
- cls, type, fun, new WrappedArray.ofRef<>(args), true, true);
+ cls, type, fun, seqOf(args), propagateNull, true);
case 8:
// Spark 3.2.x, 3.3.x
return STATIC_INVOKE_CONSTRUCTOR.newInstance(
- cls, type, fun, new WrappedArray.ofRef<>(args), Nil$.MODULE$, true, true, true);
+ cls, type, fun, seqOf(args), emptyList(), propagateNull, true, true);
+ default:
+ throw new RuntimeException("Unsupported version of Spark");
+ }
+ } catch (IllegalArgumentException | ReflectiveOperationException ex) {
+ throw new RuntimeException(ex);
+ }
+ }
+
+ /** Invoke method {@code fun} on {@code obj} with provided {@code args}. */
+ static Expression invoke(
+ Expression obj, String fun, DataType type, boolean nullable, Expression... args) {
+ try {
+ // To address breaking interfaces between various version of Spark 3, expressions are
+ // created reflectively. This is fine as it's just needed once to create the query plan.
+ switch (STATIC_INVOKE_CONSTRUCTOR.getParameterCount()) {
+ case 6:
+ return INVOKE_CONSTRUCTOR.newInstance(obj, fun, type, seqOf(args), false, nullable);
+ case 8:
+ return INVOKE_CONSTRUCTOR.newInstance(
+ obj, fun, type, seqOf(args), emptyList(), false, nullable, true);
+ default:
+ throw new RuntimeException("Unsupported version of Spark");
+ }
+ } catch (IllegalArgumentException | ReflectiveOperationException ex) {
+ throw new RuntimeException(ex);
+ }
+ }
+
+ static Expression newInstance(Class<?> cls, DataType type, Expression... args) {
+ try {
+ // To address breaking interfaces between various version of Spark 3, expressions are
+ // created reflectively. This is fine as it's just needed once to create the query plan.
+ switch (NEW_INSTANCE_CONSTRUCTOR.getParameterCount()) {
+ case 5:
+ return NEW_INSTANCE_CONSTRUCTOR.newInstance(cls, seqOf(args), true, type, Option.empty());
+ case 6:
+ return NEW_INSTANCE_CONSTRUCTOR.newInstance(
+ cls, seqOf(args), emptyList(), true, type, Option.empty());
default:
throw new RuntimeException("Unsupported version of Spark");
}
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java
index 68738cf0308..f89f6bdb9d3 100644
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java
+++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java
@@ -17,44 +17,488 @@
*/
package org.apache.beam.runners.spark.structuredstreaming.translation.helpers;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderFactory.invoke;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderFactory.invokeIfNotNull;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderFactory.newInstance;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.match;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.replace;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.seqOf;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.tuple;
import static org.apache.spark.sql.types.DataTypes.BinaryType;
+import static org.apache.spark.sql.types.DataTypes.IntegerType;
+import static org.apache.spark.sql.types.DataTypes.LongType;
+import java.math.BigDecimal;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.TreeMap;
+import java.util.function.Function;
import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
+import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
+import org.apache.beam.sdk.transforms.windowing.IntervalWindow.IntervalWindowCoder;
+import org.apache.beam.sdk.transforms.windowing.PaneInfo;
+import org.apache.beam.sdk.transforms.windowing.PaneInfo.PaneInfoCoder;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSet;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps;
import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.Encoders;
+import org.apache.spark.sql.catalyst.SerializerBuildHelper;
+import org.apache.spark.sql.catalyst.SerializerBuildHelper.MapElementInformation;
import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal;
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;
import org.apache.spark.sql.catalyst.expressions.BoundReference;
+import org.apache.spark.sql.catalyst.expressions.Coalesce;
+import org.apache.spark.sql.catalyst.expressions.CreateNamedStruct;
+import org.apache.spark.sql.catalyst.expressions.EqualTo;
import org.apache.spark.sql.catalyst.expressions.Expression;
+import org.apache.spark.sql.catalyst.expressions.GetStructField;
+import org.apache.spark.sql.catalyst.expressions.If;
+import org.apache.spark.sql.catalyst.expressions.IsNotNull;
+import org.apache.spark.sql.catalyst.expressions.IsNull;
import org.apache.spark.sql.catalyst.expressions.Literal;
+import org.apache.spark.sql.catalyst.expressions.Literal$;
+import org.apache.spark.sql.catalyst.expressions.MapKeys;
+import org.apache.spark.sql.catalyst.expressions.MapValues;
+import org.apache.spark.sql.catalyst.expressions.objects.MapObjects$;
+import org.apache.spark.sql.catalyst.util.ArrayData;
+import org.apache.spark.sql.types.ArrayType;
import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.MapType;
import org.apache.spark.sql.types.ObjectType;
+import org.apache.spark.sql.types.StructType;
+import org.apache.spark.util.MutablePair;
import org.checkerframework.checker.nullness.qual.NonNull;
+import org.checkerframework.checker.nullness.qual.Nullable;
+import org.joda.time.Instant;
+import scala.Option;
+import scala.Some;
+import scala.Tuple2;
+import scala.collection.IndexedSeq;
+import scala.collection.JavaConverters;
+import scala.collection.Seq;
+/** {@link Encoders} utility class. */
public class EncoderHelpers {
private static final DataType OBJECT_TYPE = new ObjectType(Object.class);
+ private static final DataType TUPLE2_TYPE = new ObjectType(Tuple2.class);
+ private static final DataType WINDOWED_VALUE = new ObjectType(WindowedValue.class);
+ private static final DataType KV_TYPE = new ObjectType(KV.class);
+ private static final DataType MUTABLE_PAIR_TYPE = new ObjectType(MutablePair.class);
+
+ // Collections / maps of these types can be (de)serialized without (de)serializing each member
+ private static final Set<Class<?>> PRIMITIV_TYPES =
+ ImmutableSet.of(
+ Boolean.class,
+ Byte.class,
+ Short.class,
+ Integer.class,
+ Long.class,
+ Float.class,
+ Double.class);
+
+ // Default encoders by class
+ private static final Map<Class<?>, Encoder<?>> DEFAULT_ENCODERS = new HashMap<>();
+
+ // Factory for default encoders by class
+ private static final Function<Class<?>, @Nullable Encoder<?>> ENCODER_FACTORY =
+ cls -> {
+ if (cls.equals(PaneInfo.class)) {
+ return paneInfoEncoder();
+ } else if (cls.equals(GlobalWindow.class)) {
+ return binaryEncoder(GlobalWindow.Coder.INSTANCE, false);
+ } else if (cls.equals(IntervalWindow.class)) {
+ return binaryEncoder(IntervalWindowCoder.of(), false);
+ } else if (cls.equals(Instant.class)) {
+ return instantEncoder();
+ } else if (cls.equals(String.class)) {
+ return Encoders.STRING();
+ } else if (cls.equals(Boolean.class)) {
+ return Encoders.BOOLEAN();
+ } else if (cls.equals(Integer.class)) {
+ return Encoders.INT();
+ } else if (cls.equals(Long.class)) {
+ return Encoders.LONG();
+ } else if (cls.equals(Float.class)) {
+ return Encoders.FLOAT();
+ } else if (cls.equals(Double.class)) {
+ return Encoders.DOUBLE();
+ } else if (cls.equals(BigDecimal.class)) {
+ return Encoders.DECIMAL();
+ } else if (cls.equals(byte[].class)) {
+ return Encoders.BINARY();
+ } else if (cls.equals(Byte.class)) {
+ return Encoders.BYTE();
+ } else if (cls.equals(Short.class)) {
+ return Encoders.SHORT();
+ }
+ return null;
+ };
+
+ private static <T> @Nullable Encoder<T> getOrCreateDefaultEncoder(Class<? super T> cls) {
+ return (Encoder<T>) DEFAULT_ENCODERS.computeIfAbsent(cls, ENCODER_FACTORY);
+ }
+
+ /** Gets or creates a default {@link Encoder} for {@link T}. */
+ public static <T> Encoder<T> encoderOf(Class<? super T> cls) {
+ Encoder<T> enc = getOrCreateDefaultEncoder(cls);
+ if (enc == null) {
+ throw new IllegalArgumentException("No default coder available for class " + cls);
+ }
+ return enc;
+ }
+
+ /**
+ * Creates a Spark {@link Encoder} for {@link T} of {@link DataTypes#BinaryType BinaryType}
+ * delegating to a Beam {@link Coder} underneath.
+ *
+ * <p>Note: For common types, if available, default Spark {@link Encoder}s are used instead.
+ *
+ * @param coder Beam {@link Coder}
+ */
+ public static <T> Encoder<T> encoderFor(Coder<T> coder) {
+ Encoder<T> enc = getOrCreateDefaultEncoder(coder.getEncodedTypeDescriptor().getRawType());
+ return enc != null ? enc : binaryEncoder(coder, true);
+ }
+
+ /**
+ * Creates a Spark {@link Encoder} for {@link T} of {@link StructType} with fields {@code value},
+ * {@code timestamp}, {@code windows} and {@code pane}.
+ *
+ * @param value {@link Encoder} to encode field `{@code value}`.
+ * @param window {@link Encoder} to encode individual windows in field `{@code windows}`
+ */
+ public static <T, W extends BoundedWindow> Encoder<WindowedValue<T>> windowedValueEncoder(
+ Encoder<T> value, Encoder<W> window) {
+ Encoder<Instant> timestamp = encoderOf(Instant.class);
+ Encoder<PaneInfo> pane = encoderOf(PaneInfo.class);
+ Encoder<Collection<W>> windows = collectionEncoder(window);
+ Expression serializer =
+ serializeWindowedValue(rootRef(WINDOWED_VALUE, true), value, timestamp, windows, pane);
+ Expression deserializer =
+ deserializeWindowedValue(rootCol(serializer.dataType()), value, timestamp, windows, pane);
+ return EncoderFactory.create(serializer, deserializer, WindowedValue.class);
+ }
+
+ /**
+ * Creates a one-of Spark {@link Encoder} of {@link StructType} where each alternative is
+ * represented as colum / field named by its index with a separate {@link Encoder} each.
+ *
+ * <p>Externally this is represented as tuple {@code (index, data)} where an index corresponds to
+ * an {@link Encoder} in the provided list.
+ *
+ * @param encoders {@link Encoder}s for each alternative.
+ */
+ public static <T> Encoder<Tuple2<Integer, T>> oneOfEncoder(List<Encoder<T>> encoders) {
+ Expression serializer = serializeOneOf(rootRef(TUPLE2_TYPE, true), encoders);
+ Expression deserializer = deserializeOneOf(rootCol(serializer.dataType()), encoders);
+ return EncoderFactory.create(serializer, deserializer, Tuple2.class);
+ }
+
+ /**
+ * Creates a Spark {@link Encoder} for {@link KV} of {@link StructType} with fields {@code key}
+ * and {@code value}.
+ *
+ * @param key {@link Encoder} to encode field `{@code key}`.
+ * @param value {@link Encoder} to encode field `{@code value}`
+ */
+ public static <K, V> Encoder<KV<K, V>> kvEncoder(Encoder<K> key, Encoder<V> value) {
+ Expression serializer = serializeKV(rootRef(KV_TYPE, true), key, value);
+ Expression deserializer = deserializeKV(rootCol(serializer.dataType()), key, value);
+ return EncoderFactory.create(serializer, deserializer, KV.class);
+ }
+
+ /**
+ * Creates a Spark {@link Encoder} of {@link ArrayType} for Java {@link Collection}s with nullable
+ * elements.
+ *
+ * @param enc {@link Encoder} to encode collection elements
+ */
+ public static <T> Encoder<Collection<T>> collectionEncoder(Encoder<T> enc) {
+ return collectionEncoder(enc, true);
+ }
+
+ /**
+ * Creates a Spark {@link Encoder} of {@link ArrayType} for Java {@link Collection}s.
+ *
+ * @param enc {@link Encoder} to encode collection elements
+ * @param nullable Allow nullable collection elements
+ */
+ public static <T> Encoder<Collection<T>> collectionEncoder(Encoder<T> enc, boolean nullable) {
+ DataType type = new ObjectType(Collection.class);
+ Expression serializer = serializeSeq(rootRef(type, true), enc, nullable);
+ Expression deserializer = deserializeSeq(rootCol(serializer.dataType()), enc, nullable, true);
+ return EncoderFactory.create(serializer, deserializer, Collection.class);
+ }
+
+ /**
+ * Creates a Spark {@link Encoder} of {@link MapType} that deserializes to {@link MapT}.
+ *
+ * @param key {@link Encoder} to encode keys
+ * @param value {@link Encoder} to encode values
+ * @param cls Specific class to use, supported are {@link HashMap} and {@link TreeMap}
+ */
+ public static <MapT extends Map<K, V>, K, V> Encoder<MapT> mapEncoder(
+ Encoder<K> key, Encoder<V> value, Class<MapT> cls) {
+ Expression serializer = mapSerializer(rootRef(new ObjectType(cls), true), key, value);
+ Expression deserializer = mapDeserializer(rootCol(serializer.dataType()), key, value, cls);
+ return EncoderFactory.create(serializer, deserializer, cls);
+ }
+
+ /**
+ * Creates a Spark {@link Encoder} for Spark's {@link MutablePair} of {@link StructType} with
+ * fields `{@code _1}` and `{@code _2}`.
+ *
+ * <p>This is intended to be used in places such as aggregators.
+ *
+ * @param enc1 {@link Encoder} to encode `{@code _1}`
+ * @param enc2 {@link Encoder} to encode `{@code _2}`
+ */
+ public static <T1, T2> Encoder<MutablePair<T1, T2>> mutablePairEncoder(
+ Encoder<T1> enc1, Encoder<T2> enc2) {
+ Expression serializer = serializeMutablePair(rootRef(MUTABLE_PAIR_TYPE, true), enc1, enc2);
+ Expression deserializer = deserializeMutablePair(rootCol(serializer.dataType()), enc1, enc2);
+ return EncoderFactory.create(serializer, deserializer, MutablePair.class);
+ }
+
+ /**
+ * Creates a Spark {@link Encoder} for {@link PaneInfo} of {@link DataTypes#BinaryType
+ * BinaryType}.
+ */
+ private static Encoder<PaneInfo> paneInfoEncoder() {
+ DataType type = new ObjectType(PaneInfo.class);
+ return EncoderFactory.create(
+ invokeIfNotNull(Utils.class, "paneInfoToBytes", BinaryType, rootRef(type, false)),
+ invokeIfNotNull(Utils.class, "paneInfoFromBytes", type, rootCol(BinaryType)),
+ PaneInfo.class);
+ }
/**
- * Wrap a Beam coder into a Spark Encoder using Catalyst Expression Encoders (which uses java code
- * generation).
+ * Creates a Spark {@link Encoder} for Joda {@link Instant} of {@link DataTypes#LongType
+ * LongType}.
*/
- public static <T> Encoder<T> fromBeamCoder(Coder<T> coder) {
- Class<? super T> clazz = coder.getEncodedTypeDescriptor().getRawType();
- // Class T could be private, therefore use OBJECT_TYPE to not risk an IllegalAccessError
+ private static Encoder<Instant> instantEncoder() {
+ DataType type = new ObjectType(Instant.class);
+ Expression instant = rootRef(type, true);
+ Expression millis = rootCol(LongType);
return EncoderFactory.create(
- beamSerializer(rootRef(OBJECT_TYPE, true), coder),
- beamDeserializer(rootCol(BinaryType), coder),
- clazz);
+ nullSafe(instant, invoke(instant, "getMillis", LongType, false)),
+ nullSafe(millis, invoke(Instant.class, "ofEpochMilli", type, millis)),
+ Instant.class);
+ }
+
+ /**
+ * Creates a Spark {@link Encoder} for {@link T} of {@link DataTypes#BinaryType BinaryType}
+ * delegating to a Beam {@link Coder} underneath.
+ *
+ * @param coder Beam {@link Coder}
+ * @param nullable If to allow nullable items
+ */
+ private static <T> Encoder<T> binaryEncoder(Coder<T> coder, boolean nullable) {
+ Literal litCoder = lit(coder, Coder.class);
+ // T could be private, use OBJECT_TYPE for code generation to not risk an IllegalAccessError
+ return EncoderFactory.create(
+ invokeIfNotNull(
+ CoderHelpers.class,
+ "toByteArray",
+ BinaryType,
+ rootRef(OBJECT_TYPE, nullable),
+ litCoder),
+ invokeIfNotNull(
+ CoderHelpers.class, "fromByteArray", OBJECT_TYPE, rootCol(BinaryType), litCoder),
+ coder.getEncodedTypeDescriptor().getRawType());
+ }
+
+ private static <T, W extends BoundedWindow> Expression serializeWindowedValue(
+ Expression in,
+ Encoder<T> valueEnc,
+ Encoder<Instant> timestampEnc,
+ Encoder<Collection<W>> windowsEnc,
+ Encoder<PaneInfo> paneEnc) {
+ return serializerObject(
+ in,
+ tuple("value", serializeField(in, valueEnc, "getValue")),
+ tuple("timestamp", serializeField(in, timestampEnc, "getTimestamp")),
+ tuple("windows", serializeField(in, windowsEnc, "getWindows")),
+ tuple("pane", serializeField(in, paneEnc, "getPane")));
+ }
+
+ private static Expression serializerObject(Expression in, Tuple2<String, Expression>... fields) {
+ return SerializerBuildHelper.createSerializerForObject(in, seqOf(fields));
+ }
+
+ private static <T, W extends BoundedWindow> Expression deserializeWindowedValue(
+ Expression in,
+ Encoder<T> valueEnc,
+ Encoder<Instant> timestampEnc,
+ Encoder<Collection<W>> windowsEnc,
+ Encoder<PaneInfo> paneEnc) {
+ Expression value = deserializeField(in, valueEnc, 0, "value");
+ Expression windows = deserializeField(in, windowsEnc, 2, "windows");
+ Expression timestamp = deserializeField(in, timestampEnc, 1, "timestamp");
+ Expression pane = deserializeField(in, paneEnc, 3, "pane");
+ // set timestamp to end of window (maxTimestamp) if null
+ timestamp =
+ ifNotNull(timestamp, invoke(Utils.class, "maxTimestamp", timestamp.dataType(), windows));
+ Expression[] fields = new Expression[] {value, timestamp, windows, pane};
+
+ return nullSafe(pane, invoke(WindowedValue.class, "of", WINDOWED_VALUE, fields));
+ }
+
+ private static <K, V> Expression serializeMutablePair(
+ Expression in, Encoder<K> enc1, Encoder<V> enc2) {
+ return serializerObject(
+ in,
+ tuple("_1", serializeField(in, enc1, "_1")),
+ tuple("_2", serializeField(in, enc2, "_2")));
}
- /** Catalyst Expression that serializes elements using Beam {@link Coder}. */
- private static <T> Expression beamSerializer(Expression obj, Coder<T> coder) {
- Expression[] args = {obj, lit(coder, Coder.class)};
- return EncoderFactory.invokeIfNotNull(CoderHelpers.class, "toByteArray", BinaryType, args);
+ private static <K, V> Expression deserializeMutablePair(
+ Expression in, Encoder<K> enc1, Encoder<V> enc2) {
+ Expression field1 = deserializeField(in, enc1, 0, "_1");
+ Expression field2 = deserializeField(in, enc2, 1, "_2");
+ return invoke(MutablePair.class, "apply", MUTABLE_PAIR_TYPE, field1, field2);
}
- /** Catalyst Expression that deserializes elements using Beam {@link Coder}. */
- private static <T> Expression beamDeserializer(Expression bytes, Coder<T> coder) {
- Expression[] args = {bytes, lit(coder, Coder.class)};
- return EncoderFactory.invokeIfNotNull(CoderHelpers.class, "fromByteArray", OBJECT_TYPE, args);
+ private static <K, V> Expression serializeKV(
+ Expression in, Encoder<K> keyEnc, Encoder<V> valueEnc) {
+ return serializerObject(
+ in,
+ tuple("key", serializeField(in, keyEnc, "getKey")),
+ tuple("value", serializeField(in, valueEnc, "getValue")));
+ }
+
+ private static <K, V> Expression deserializeKV(
+ Expression in, Encoder<K> keyEnc, Encoder<V> valueEnc) {
+ Expression key = deserializeField(in, keyEnc, 0, "key");
+ Expression value = deserializeField(in, valueEnc, 1, "value");
+ return invoke(KV.class, "of", KV_TYPE, key, value);
+ }
+
+ public static <T> Expression serializeOneOf(Expression in, List<Encoder<T>> encoders) {
+ Expression type = invoke(in, "_1", IntegerType, false);
+ Expression[] args = new Expression[encoders.size() * 2];
+ for (int i = 0; i < encoders.size(); i++) {
+ args[i * 2] = lit(String.valueOf(i));
+ args[i * 2 + 1] = serializeOneOfField(in, type, encoders.get(i), i);
+ }
+ return new CreateNamedStruct(seqOf(args));
+ }
+
+ public static <T> Expression deserializeOneOf(Expression in, List<Encoder<T>> encoders) {
+ Expression[] args = new Expression[encoders.size()];
+ for (int i = 0; i < encoders.size(); i++) {
+ args[i] = deserializeOneOfField(in, encoders.get(i), i);
+ }
+ return new Coalesce(seqOf(args));
+ }
+
+ private static <T> Expression serializeOneOfField(
+ Expression in, Expression type, Encoder<T> enc, int typeIdx) {
+ Expression litNull = lit(null, serializedType(enc));
+ Expression value = invoke(in, "_2", deserializedType(enc), false);
+ return new If(new EqualTo(type, lit(typeIdx)), serialize(value, enc), litNull);
+ }
+
+ private static <T> Expression deserializeOneOfField(Expression in, Encoder<T> enc, int idx) {
+ GetStructField field = new GetStructField(in, idx, Option.empty());
+ Expression litNull = lit(null, TUPLE2_TYPE);
+ Expression newTuple = newInstance(Tuple2.class, TUPLE2_TYPE, lit(idx), deserialize(field, enc));
+ return new If(new IsNull(field), litNull, newTuple);
+ }
+
+ private static <T> Expression serializeField(Expression in, Encoder<T> enc, String getterName) {
+ Expression ref = serializer(enc).collect(match(BoundReference.class)).head();
+ return serialize(invoke(in, getterName, ref.dataType(), ref.nullable()), enc);
+ }
+
+ private static <T> Expression deserializeField(
+ Expression in, Encoder<T> enc, int idx, String name) {
+ return deserialize(new GetStructField(in, idx, new Some<>(name)), enc);
+ }
+
+ // Note: Currently this doesn't support nullable primitive values
+ private static <K, V> Expression mapSerializer(Expression map, Encoder<K> key, Encoder<V> value) {
+ DataType keyType = deserializedType(key);
+ DataType valueType = deserializedType(value);
+ return SerializerBuildHelper.createSerializerForMap(
+ map,
+ new MapElementInformation(keyType, false, e -> serialize(e, key)),
+ new MapElementInformation(valueType, false, e -> serialize(e, value)));
+ }
+
+ private static <MapT extends Map<K, V>, K, V> Expression mapDeserializer(
+ Expression in, Encoder<K> key, Encoder<V> value, Class<MapT> cls) {
+ Preconditions.checkArgument(cls.isAssignableFrom(HashMap.class) || cls.equals(TreeMap.class));
+ Expression keys = deserializeSeq(new MapKeys(in), key, false, false);
+ Expression values = deserializeSeq(new MapValues(in), value, false, false);
+ String fn = cls.equals(TreeMap.class) ? "toTreeMap" : "toMap";
+ return invoke(
+ Utils.class, fn, new ObjectType(cls), keys, values, mapItemType(key), mapItemType(value));
+ }
+
+ // serialized type for primitive types (avoid boxing!), otherwise the deserialized type
+ private static Literal mapItemType(Encoder<?> enc) {
+ return lit(isPrimitiveEnc(enc) ? serializedType(enc) : deserializedType(enc), DataType.class);
+ }
+
+ private static <T> Expression serializeSeq(Expression in, Encoder<T> enc, boolean nullable) {
+ if (isPrimitiveEnc(enc)) {
+ Expression array = invoke(in, "toArray", new ObjectType(Object[].class), false);
+ return SerializerBuildHelper.createSerializerForGenericArray(
+ array, serializedType(enc), nullable);
+ }
+ Expression seq = invoke(Utils.class, "toSeq", new ObjectType(Seq.class), in);
+ return MapObjects$.MODULE$.apply(
+ exp -> serialize(exp, enc), seq, deserializedType(enc), nullable, Option.empty());
+ }
+
+ private static <T> Expression deserializeSeq(
+ Expression in, Encoder<T> enc, boolean nullable, boolean asJava) {
+ DataType type = serializedType(enc); // input type is the serializer result type
+ if (isPrimitiveEnc(enc)) {
+ ObjectType listType = new ObjectType(List.class);
+ return asJava ? invoke(Utils.class, "toList", listType, in, lit(type, DataType.class)) : in;
+ }
+ Option<Class<?>> optCls = asJava ? Option.apply(List.class) : Option.empty();
+ return MapObjects$.MODULE$.apply(exp -> deserialize(exp, enc), in, type, nullable, optCls);
+ }
+
+ private static <T> boolean isPrimitiveEnc(Encoder<T> enc) {
+ return PRIMITIV_TYPES.contains(enc.clsTag().runtimeClass());
+ }
+
+ private static <T> Expression serialize(Expression input, Encoder<T> enc) {
+ return serializer(enc).transformUp(replace(BoundReference.class, input));
+ }
+
+ private static <T> Expression deserialize(Expression input, Encoder<T> enc) {
+ return deserializer(enc).transformUp(replace(GetColumnByOrdinal.class, input));
+ }
+
+ private static <T> Expression serializer(Encoder<T> enc) {
+ return ((ExpressionEncoder<T>) enc).objSerializer();
+ }
+
+ private static <T> Expression deserializer(Encoder<T> enc) {
+ return ((ExpressionEncoder<T>) enc).objDeserializer();
+ }
+
+ private static <T> DataType serializedType(Encoder<T> enc) {
+ return ((ExpressionEncoder<T>) enc).objSerializer().dataType();
+ }
+
+ private static <T> DataType deserializedType(Encoder<T> enc) {
+ return ((ExpressionEncoder<T>) enc).objDeserializer().dataType();
}
private static Expression rootRef(DataType dt, boolean nullable) {
@@ -65,7 +509,77 @@ public class EncoderHelpers {
return new GetColumnByOrdinal(0, dt);
}
+ private static Expression nullSafe(Expression in, Expression out) {
+ return new If(new IsNull(in), lit(null, out.dataType()), out);
+ }
+
+ private static Expression ifNotNull(Expression expr, Expression otherwise) {
+ return new If(new IsNotNull(expr), expr, otherwise);
+ }
+
+ private static <T extends @NonNull Object> Expression lit(T t) {
+ return Literal$.MODULE$.apply(t);
+ }
+
+ @SuppressWarnings("nullness") // literal NULL is allowed
+ private static <T> Expression lit(@Nullable T t, DataType dataType) {
+ return new Literal(t, dataType);
+ }
+
private static <T extends @NonNull Object> Literal lit(T obj, Class<? extends T> cls) {
return Literal.fromObject(obj, new ObjectType(cls));
}
+
+ /** Encoder / expression utils that are called from generated code. */
+ public static class Utils {
+
+ public static PaneInfo paneInfoFromBytes(byte[] bytes) {
+ return CoderHelpers.fromByteArray(bytes, PaneInfoCoder.of());
+ }
+
+ public static byte[] paneInfoToBytes(PaneInfo pane) {
+ return CoderHelpers.toByteArray(pane, PaneInfoCoder.of());
+ }
+
+ /** The end of the only window (max timestamp). */
+ public static Instant maxTimestamp(Iterable<BoundedWindow> windows) {
+ return Iterables.getOnlyElement(windows).maxTimestamp();
+ }
+
+ public static List<Object> toList(ArrayData arrayData, DataType type) {
+ return JavaConverters.seqAsJavaList(arrayData.toSeq(type));
+ }
+
+ public static Seq<Object> toSeq(ArrayData arrayData) {
+ return arrayData.toSeq(OBJECT_TYPE);
+ }
+
+ public static Seq<Object> toSeq(Collection<Object> col) {
+ if (col instanceof List) {
+ return JavaConverters.asScalaBuffer((List<Object>) col);
+ }
+ return JavaConverters.collectionAsScalaIterable(col).toSeq();
+ }
+
+ public static TreeMap<Object, Object> toTreeMap(
+ ArrayData keys, ArrayData values, DataType keyType, DataType valueType) {
+ return toMap(new TreeMap<>(), keys, values, keyType, valueType);
+ }
+
+ public static HashMap<Object, Object> toMap(
+ ArrayData keys, ArrayData values, DataType keyType, DataType valueType) {
+ HashMap<Object, Object> map = Maps.newHashMapWithExpectedSize(keys.numElements());
+ return toMap(map, keys, values, keyType, valueType);
+ }
+
+ private static <MapT extends Map<Object, Object>> MapT toMap(
+ MapT map, ArrayData keys, ArrayData values, DataType keyType, DataType valueType) {
+ IndexedSeq<Object> keysSeq = keys.toSeq(keyType);
+ IndexedSeq<Object> valuesSeq = values.toSeq(valueType);
+ for (int i = 0; i < keysSeq.size(); i++) {
+ map.put(keysSeq.apply(i), valuesSeq.apply(i));
+ }
+ return map;
+ }
+ }
}
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/MultiOutputCoder.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/MultiOutputCoder.java
deleted file mode 100644
index f77fcea6796..00000000000
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/MultiOutputCoder.java
+++ /dev/null
@@ -1,84 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.runners.spark.structuredstreaming.translation.helpers;
-
-import java.io.IOException;
-import java.io.InputStream;
-import java.io.OutputStream;
-import java.util.Map;
-import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.coders.CoderException;
-import org.apache.beam.sdk.coders.CustomCoder;
-import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
-import org.apache.beam.sdk.util.WindowedValue;
-import org.apache.beam.sdk.values.TupleTag;
-import scala.Tuple2;
-
-/**
- * Coder to serialize and deserialize {@code}Tuple2<TupleTag<T>, WindowedValue<T>{/@code} to be used
- * in spark encoders while applying {@link org.apache.beam.sdk.transforms.DoFn}.
- *
- * @param <T> type of the elements in the collection
- */
-@SuppressWarnings({
- "rawtypes", // TODO(https://github.com/apache/beam/issues/20447)
- "nullness" // TODO(https://github.com/apache/beam/issues/20497)
-})
-public class MultiOutputCoder<T> extends CustomCoder<Tuple2<TupleTag<T>, WindowedValue<T>>> {
- Coder<TupleTag> tupleTagCoder;
- Map<TupleTag<?>, Coder<?>> coderMap;
- Coder<? extends BoundedWindow> windowCoder;
-
- public static MultiOutputCoder of(
- Coder<TupleTag> tupleTagCoder,
- Map<TupleTag<?>, Coder<?>> coderMap,
- Coder<? extends BoundedWindow> windowCoder) {
- return new MultiOutputCoder(tupleTagCoder, coderMap, windowCoder);
- }
-
- private MultiOutputCoder(
- Coder<TupleTag> tupleTagCoder,
- Map<TupleTag<?>, Coder<?>> coderMap,
- Coder<? extends BoundedWindow> windowCoder) {
- this.tupleTagCoder = tupleTagCoder;
- this.coderMap = coderMap;
- this.windowCoder = windowCoder;
- }
-
- @Override
- public void encode(Tuple2<TupleTag<T>, WindowedValue<T>> tuple2, OutputStream outStream)
- throws IOException {
- TupleTag<T> tupleTag = tuple2._1();
- tupleTagCoder.encode(tupleTag, outStream);
- Coder<T> valueCoder = (Coder<T>) coderMap.get(tupleTag);
- WindowedValue.FullWindowedValueCoder<T> wvCoder =
- WindowedValue.FullWindowedValueCoder.of(valueCoder, windowCoder);
- wvCoder.encode(tuple2._2(), outStream);
- }
-
- @Override
- public Tuple2<TupleTag<T>, WindowedValue<T>> decode(InputStream inStream)
- throws CoderException, IOException {
- TupleTag<T> tupleTag = (TupleTag<T>) tupleTagCoder.decode(inStream);
- Coder<T> valueCoder = (Coder<T>) coderMap.get(tupleTag);
- WindowedValue.FullWindowedValueCoder<T> wvCoder =
- WindowedValue.FullWindowedValueCoder.of(valueCoder, windowCoder);
- WindowedValue<T> wv = wvCoder.decode(inStream);
- return Tuple2.apply(tupleTag, wv);
- }
-}
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/RowHelpers.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/RowHelpers.java
deleted file mode 100644
index 9b5d5da2b2c..00000000000
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/RowHelpers.java
+++ /dev/null
@@ -1,75 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.runners.spark.structuredstreaming.translation.helpers;
-
-import static scala.collection.JavaConversions.asScalaBuffer;
-
-import java.io.ByteArrayInputStream;
-import java.io.ByteArrayOutputStream;
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.List;
-import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
-import org.apache.beam.sdk.util.WindowedValue;
-import org.apache.spark.api.java.function.MapFunction;
-import org.apache.spark.sql.Row;
-import org.apache.spark.sql.catalyst.InternalRow;
-
-/** Helper functions for working with {@link Row}. */
-public final class RowHelpers {
-
- /**
- * A Spark {@link MapFunction} for extracting a {@link WindowedValue} from a Row in which the
- * {@link WindowedValue} was serialized to bytes using its {@link
- * WindowedValue.WindowedValueCoder}.
- *
- * @param <T> The type of the object.
- * @return A {@link MapFunction} that accepts a {@link Row} and returns its {@link WindowedValue}.
- */
- public static <T> MapFunction<Row, WindowedValue<T>> extractWindowedValueFromRowMapFunction(
- WindowedValue.WindowedValueCoder<T> windowedValueCoder) {
- return (MapFunction<Row, WindowedValue<T>>)
- value -> {
- // there is only one value put in each Row by the InputPartitionReader
- byte[] bytes = (byte[]) value.get(0);
- return windowedValueCoder.decode(new ByteArrayInputStream(bytes));
- };
- }
-
- /**
- * Serialize a windowedValue to bytes using windowedValueCoder {@link
- * WindowedValue.FullWindowedValueCoder} and stores it an InternalRow.
- */
- public static <T> InternalRow storeWindowedValueInRow(
- WindowedValue<T> windowedValue, Coder<T> coder) {
- List<Object> list = new ArrayList<>();
- // serialize the windowedValue to bytes array to comply with dataset binary schema
- WindowedValue.FullWindowedValueCoder<T> windowedValueCoder =
- WindowedValue.FullWindowedValueCoder.of(coder, GlobalWindow.Coder.INSTANCE);
- ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
- try {
- windowedValueCoder.encode(windowedValue, byteArrayOutputStream);
- byte[] bytes = byteArrayOutputStream.toByteArray();
- list.add(bytes);
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
- return InternalRow.apply(asScalaBuffer(list).toList());
- }
-}
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/SchemaHelpers.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/SchemaHelpers.java
deleted file mode 100644
index 71dca5264dd..00000000000
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/SchemaHelpers.java
+++ /dev/null
@@ -1,39 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.runners.spark.structuredstreaming.translation.helpers;
-
-import org.apache.spark.sql.types.DataTypes;
-import org.apache.spark.sql.types.Metadata;
-import org.apache.spark.sql.types.StructField;
-import org.apache.spark.sql.types.StructType;
-
-/** A {@link SchemaHelpers} for the Spark Batch Runner. */
-public class SchemaHelpers {
- private static final StructType BINARY_SCHEMA =
- new StructType(
- new StructField[] {
- StructField.apply("binaryStructField", DataTypes.BinaryType, true, Metadata.empty())
- });
-
- public static StructType binarySchema() {
- // we use a binary schema for now because:
- // using a empty schema raises a indexOutOfBoundsException
- // using a NullType schema stores null in the elements
- return BINARY_SCHEMA;
- }
-}
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/WindowingHelpers.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/WindowingHelpers.java
deleted file mode 100644
index 5085eb9f796..00000000000
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/WindowingHelpers.java
+++ /dev/null
@@ -1,82 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.runners.spark.structuredstreaming.translation.helpers;
-
-import java.util.Collection;
-import org.apache.beam.runners.spark.structuredstreaming.translation.AbstractTranslationContext;
-import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
-import org.apache.beam.sdk.transforms.windowing.GlobalWindows;
-import org.apache.beam.sdk.transforms.windowing.Window;
-import org.apache.beam.sdk.transforms.windowing.WindowFn;
-import org.apache.beam.sdk.util.WindowedValue;
-import org.apache.beam.sdk.values.PCollection;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
-import org.apache.spark.api.java.function.MapFunction;
-import org.joda.time.Instant;
-
-/** Helper functions for working with windows. */
-public final class WindowingHelpers {
-
- /**
- * Checks if the window transformation should be applied or skipped.
- *
- * <p>Avoid running assign windows if both source and destination are global window or if the user
- * has not specified the WindowFn (meaning they are just messing with triggering or allowed
- * lateness).
- */
- @SuppressWarnings("unchecked")
- public static <T, W extends BoundedWindow> boolean skipAssignWindows(
- Window.Assign<T> transform, AbstractTranslationContext context) {
- WindowFn<? super T, W> windowFnToApply = (WindowFn<? super T, W>) transform.getWindowFn();
- PCollection<T> input = (PCollection<T>) context.getInput();
- WindowFn<?, ?> windowFnOfInput = input.getWindowingStrategy().getWindowFn();
- return windowFnToApply == null
- || (windowFnOfInput instanceof GlobalWindows && windowFnToApply instanceof GlobalWindows);
- }
-
- public static <T, W extends BoundedWindow>
- MapFunction<WindowedValue<T>, WindowedValue<T>> assignWindowsMapFunction(
- WindowFn<T, W> windowFn) {
- return (MapFunction<WindowedValue<T>, WindowedValue<T>>)
- windowedValue -> {
- final BoundedWindow boundedWindow = Iterables.getOnlyElement(windowedValue.getWindows());
- final T element = windowedValue.getValue();
- final Instant timestamp = windowedValue.getTimestamp();
- Collection<W> windows =
- windowFn.assignWindows(
- windowFn.new AssignContext() {
-
- @Override
- public T element() {
- return element;
- }
-
- @Override
- public Instant timestamp() {
- return timestamp;
- }
-
- @Override
- public BoundedWindow window() {
- return boundedWindow;
- }
- });
- return WindowedValue.of(element, timestamp, windows, windowedValue.getPane());
- };
- }
-}
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/streaming/DatasetSourceStreaming.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/streaming/DatasetSourceStreaming.java
deleted file mode 100644
index 5eb60f68cb3..00000000000
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/streaming/DatasetSourceStreaming.java
+++ /dev/null
@@ -1,25 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.runners.spark.structuredstreaming.translation.streaming;
-
-/**
- * Spark structured streaming framework does not support more than one aggregation in streaming mode
- * because of watermark implementation. As a consequence, this runner, does not support streaming
- * mode yet see https://github.com/apache/beam/issues/20241
- */
-class DatasetSourceStreaming {}
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/streaming/PipelineTranslatorStreaming.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/streaming/PipelineTranslatorStreaming.java
deleted file mode 100644
index 73d99efa463..00000000000
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/streaming/PipelineTranslatorStreaming.java
+++ /dev/null
@@ -1,93 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.runners.spark.structuredstreaming.translation.streaming;
-
-import java.util.HashMap;
-import java.util.Map;
-import org.apache.beam.runners.core.construction.SplittableParDo;
-import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingPipelineOptions;
-import org.apache.beam.runners.spark.structuredstreaming.translation.PipelineTranslator;
-import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator;
-import org.apache.beam.runners.spark.structuredstreaming.translation.TranslationContext;
-import org.apache.beam.sdk.Pipeline;
-import org.apache.beam.sdk.runners.TransformHierarchy;
-import org.apache.beam.sdk.transforms.PTransform;
-import org.checkerframework.checker.nullness.qual.Nullable;
-
-/**
- * {@link PipelineTranslator} for executing a {@link Pipeline} in Spark in streaming mode. This
- * contains only the components specific to streaming: registry of streaming {@link
- * TransformTranslator} and registry lookup code.
- */
-@SuppressWarnings({
- "rawtypes", // TODO(https://github.com/apache/beam/issues/20447)
- "nullness" // TODO(https://github.com/apache/beam/issues/20497)
-})
-public class PipelineTranslatorStreaming extends PipelineTranslator {
- // --------------------------------------------------------------------------------------------
- // Transform Translator Registry
- // --------------------------------------------------------------------------------------------
-
- @SuppressWarnings("rawtypes")
- private static final Map<Class<? extends PTransform>, TransformTranslator> TRANSFORM_TRANSLATORS =
- new HashMap<>();
-
- // TODO the ability to have more than one TransformTranslator per URN
- // that could be dynamically chosen by a predicated that evaluates based on PCollection
- // obtainable though node.getInputs.getValue()
- // See
- // https://github.com/seznam/euphoria/blob/master/euphoria-spark/src/main/java/cz/seznam/euphoria/spark/SparkFlowTranslator.java#L83
- // And
- // https://github.com/seznam/euphoria/blob/master/euphoria-spark/src/main/java/cz/seznam/euphoria/spark/SparkFlowTranslator.java#L106
-
- static {
- // TRANSFORM_TRANSLATORS.put(Combine.PerKey.class, new CombinePerKeyTranslatorBatch());
- // TRANSFORM_TRANSLATORS.put(Combine.Globally.class, new CombineGloballyTranslatorBatch());
- // TRANSFORM_TRANSLATORS.put(GroupByKey.class, new GroupByKeyTranslatorBatch());
-
- // TODO: Do we need to have a dedicated translator for {@code Reshuffle} if it's deprecated?
- // TRANSFORM_TRANSLATORS.put(Reshuffle.class, new ReshuffleTranslatorBatch());
-
- // TRANSFORM_TRANSLATORS.put(Flatten.PCollections.class, new FlattenTranslatorBatch());
- //
- // TRANSFORM_TRANSLATORS.put(Window.Assign.class, new WindowAssignTranslatorBatch());
- //
- // TRANSFORM_TRANSLATORS.put(ParDo.MultiOutput.class, new ParDoTranslatorBatch());
-
- TRANSFORM_TRANSLATORS.put(
- SplittableParDo.PrimitiveUnboundedRead.class, new ReadSourceTranslatorStreaming());
-
- // TRANSFORM_TRANSLATORS
- // .put(View.CreatePCollectionView.class, new CreatePCollectionViewTranslatorBatch());
- }
-
- public PipelineTranslatorStreaming(SparkStructuredStreamingPipelineOptions options) {
- translationContext = new TranslationContext(options);
- }
-
- /** Returns a translator for the given node, if it is possible, otherwise null. */
- @Override
- protected TransformTranslator<?> getTransformTranslator(TransformHierarchy.Node node) {
- @Nullable PTransform<?, ?> transform = node.getTransform();
- // Root of the graph is null
- if (transform == null) {
- return null;
- }
- return TRANSFORM_TRANSLATORS.get(transform.getClass());
- }
-}
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/streaming/ReadSourceTranslatorStreaming.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/streaming/ReadSourceTranslatorStreaming.java
deleted file mode 100644
index 8abc8771a4e..00000000000
--- a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/streaming/ReadSourceTranslatorStreaming.java
+++ /dev/null
@@ -1,87 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.runners.spark.structuredstreaming.translation.streaming;
-
-import static org.apache.beam.runners.spark.structuredstreaming.Constants.BEAM_SOURCE_OPTION;
-import static org.apache.beam.runners.spark.structuredstreaming.Constants.DEFAULT_PARALLELISM;
-import static org.apache.beam.runners.spark.structuredstreaming.Constants.PIPELINE_OPTIONS;
-
-import java.io.IOException;
-import org.apache.beam.runners.core.construction.ReadTranslation;
-import org.apache.beam.runners.core.serialization.Base64Serializer;
-import org.apache.beam.runners.spark.structuredstreaming.translation.AbstractTranslationContext;
-import org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator;
-import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers;
-import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.RowHelpers;
-import org.apache.beam.sdk.io.UnboundedSource;
-import org.apache.beam.sdk.runners.AppliedPTransform;
-import org.apache.beam.sdk.transforms.PTransform;
-import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
-import org.apache.beam.sdk.util.WindowedValue;
-import org.apache.beam.sdk.values.PBegin;
-import org.apache.beam.sdk.values.PCollection;
-import org.apache.spark.sql.Dataset;
-import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SparkSession;
-
-class ReadSourceTranslatorStreaming<T>
- implements TransformTranslator<PTransform<PBegin, PCollection<T>>> {
-
- private static final String sourceProviderClass = DatasetSourceStreaming.class.getCanonicalName();
-
- @SuppressWarnings("unchecked")
- @Override
- public void translateTransform(
- PTransform<PBegin, PCollection<T>> transform, AbstractTranslationContext context) {
- AppliedPTransform<PBegin, PCollection<T>, PTransform<PBegin, PCollection<T>>> rootTransform =
- (AppliedPTransform<PBegin, PCollection<T>, PTransform<PBegin, PCollection<T>>>)
- context.getCurrentTransform();
-
- UnboundedSource<T, UnboundedSource.CheckpointMark> source;
- try {
- source = ReadTranslation.unboundedSourceFromTransform(rootTransform);
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
- SparkSession sparkSession = context.getSparkSession();
-
- String serializedSource = Base64Serializer.serializeUnchecked(source);
- Dataset<Row> rowDataset =
- sparkSession
- .readStream()
- .format(sourceProviderClass)
- .option(BEAM_SOURCE_OPTION, serializedSource)
- .option(
- DEFAULT_PARALLELISM,
- String.valueOf(context.getSparkSession().sparkContext().defaultParallelism()))
- .option(PIPELINE_OPTIONS, context.getSerializableOptions().toString())
- .load();
-
- // extract windowedValue from Row
- WindowedValue.FullWindowedValueCoder<T> windowedValueCoder =
- WindowedValue.FullWindowedValueCoder.of(
- source.getOutputCoder(), GlobalWindow.Coder.INSTANCE);
- Dataset<WindowedValue<T>> dataset =
- rowDataset.map(
- RowHelpers.extractWindowedValueFromRowMapFunction(windowedValueCoder),
- EncoderHelpers.fromBeamCoder(windowedValueCoder));
-
- PCollection<T> output = (PCollection<T>) context.getOutput();
- context.putDataset(output, dataset);
- }
-}
diff --git a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/utils/ScalaInterop.java b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/utils/ScalaInterop.java
new file mode 100644
index 00000000000..1908cbf2bba
--- /dev/null
+++ b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/utils/ScalaInterop.java
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.spark.structuredstreaming.translation.utils;
+
+import java.io.Serializable;
+import org.checkerframework.checker.nullness.qual.NonNull;
+import scala.Function1;
+import scala.Function2;
+import scala.PartialFunction;
+import scala.Tuple2;
+import scala.collection.Iterator;
+import scala.collection.JavaConverters;
+import scala.collection.Seq;
+import scala.collection.immutable.List;
+import scala.collection.immutable.Nil$;
+import scala.collection.mutable.WrappedArray;
+
+/** Utilities for easier interoperability with the Spark Scala API. */
+public class ScalaInterop {
+ private ScalaInterop() {}
+
+ public static <T> Seq<T> seqOf(T... t) {
+ return new WrappedArray.ofRef<>(t);
+ }
+
+ public static <T> List<T> concat(List<T> a, List<T> b) {
+ return b.$colon$colon$colon(a);
+ }
+
+ public static <T> Seq<T> listOf(T t) {
+ return emptyList().$colon$colon(t);
+ }
+
+ public static <T> List<T> emptyList() {
+ return (List<T>) Nil$.MODULE$;
+ }
+
+ /** Scala {@link Iterator} of Java {@link Iterable}. */
+ public static <T extends @NonNull Object> Iterator<T> scalaIterator(Iterable<T> iterable) {
+ return scalaIterator(iterable.iterator());
+ }
+
+ /** Scala {@link Iterator} of Java {@link java.util.Iterator}. */
+ public static <T extends @NonNull Object> Iterator<T> scalaIterator(java.util.Iterator<T> it) {
+ return JavaConverters.asScalaIterator(it);
+ }
+
+ /** Java {@link java.util.Iterator} of Scala {@link Iterator}. */
+ public static <T extends @NonNull Object> java.util.Iterator<T> javaIterator(Iterator<T> it) {
+ return JavaConverters.asJavaIterator(it);
+ }
+
+ public static <T1, T2> Tuple2<T1, T2> tuple(T1 t1, T2 t2) {
+ return new Tuple2<>(t1, t2);
+ }
+
+ public static <T extends @NonNull Object, V> PartialFunction<T, T> replace(
+ Class<V> clazz, T replace) {
+ return new PartialFunction<T, T>() {
+
+ @Override
+ public boolean isDefinedAt(T x) {
+ return clazz.isAssignableFrom(x.getClass());
+ }
+
+ @Override
+ public T apply(T x) {
+ return replace;
+ }
+ };
+ }
+
+ public static <T extends @NonNull Object, V> PartialFunction<T, V> match(Class<V> clazz) {
+ return new PartialFunction<T, V>() {
+
+ @Override
+ public boolean isDefinedAt(T x) {
+ return clazz.isAssignableFrom(x.getClass());
+ }
+
+ @Override
+ public V apply(T x) {
+ return (V) x;
+ }
+ };
+ }
+
+ public static <T, V> Fun1<T, V> fun1(Fun1<T, V> fun) {
+ return fun;
+ }
+
+ public static <T1, T2, V> Fun2<T1, T2, V> fun2(Fun2<T1, T2, V> fun) {
+ return fun;
+ }
+
+ public interface Fun1<T, V> extends Function1<T, V>, Serializable {}
+
+ public interface Fun2<T1, T2, V> extends Function2<T1, T2, V>, Serializable {}
+}
diff --git a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/aggregators/metrics/sink/InMemoryMetrics.java b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/aggregators/metrics/sink/InMemoryMetrics.java
index f994f7712b3..7f2eaa10e80 100644
--- a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/aggregators/metrics/sink/InMemoryMetrics.java
+++ b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/aggregators/metrics/sink/InMemoryMetrics.java
@@ -49,7 +49,7 @@ public class InMemoryMetrics implements Sink {
internalMetricRegistry = metricRegistry;
}
- @SuppressWarnings({"TypeParameterUnusedInFormals", "rawtypes"})
+ @SuppressWarnings({"TypeParameterUnusedInFormals", "rawtypes"}) // because of getGauges
public static <T> T valueOf(final String name) {
// this might fail in case we have multiple aggregators with the same suffix after
// the last dot, but it should be good enough for tests.
diff --git a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/AggregatorsTest.java b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/AggregatorsTest.java
new file mode 100644
index 00000000000..b5b07db0e38
--- /dev/null
+++ b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/AggregatorsTest.java
@@ -0,0 +1,370 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
+
+import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.encoderFor;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.windowedValueEncoder;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.equalTo;
+import static org.joda.time.Duration.standardMinutes;
+
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Map;
+import java.util.TreeMap;
+import java.util.stream.Collectors;
+import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.transforms.Combine;
+import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
+import org.apache.beam.sdk.transforms.windowing.PaneInfo;
+import org.apache.beam.sdk.transforms.windowing.Sessions;
+import org.apache.beam.sdk.transforms.windowing.SlidingWindows;
+import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
+import org.apache.beam.sdk.transforms.windowing.WindowFn;
+import org.apache.beam.sdk.transforms.windowing.WindowMappingFn;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.WindowingStrategy;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Streams;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.expressions.Aggregator;
+import org.apache.spark.util.MutablePair;
+import org.hamcrest.Matcher;
+import org.joda.time.Duration;
+import org.joda.time.Instant;
+import org.junit.Test;
+import org.junit.experimental.runners.Enclosed;
+import org.junit.runner.RunWith;
+
+@RunWith(Enclosed.class)
+public class AggregatorsTest {
+
+ // just something easy readable
+ private static final Instant NOW = Instant.parse("2000-01-01T00:00Z");
+
+ /** Tests for NonMergingWindowedAggregator in {@link Aggregators}. */
+ public static class NonMergingWindowedAggregatorTest {
+
+ private SlidingWindows sliding =
+ SlidingWindows.of(standardMinutes(15)).every(standardMinutes(5));
+
+ private Aggregator<
+ WindowedValue<Integer>,
+ Map<IntervalWindow, MutablePair<Instant, Integer>>,
+ Collection<WindowedValue<Integer>>>
+ agg = windowedAgg(sliding);
+
+ @Test
+ public void testReduce() {
+ Map<IntervalWindow, MutablePair<Instant, Integer>> acc;
+
+ acc = agg.reduce(agg.zero(), windowedValue(1, at(10)));
+ assertThat(
+ acc,
+ equalsToMap(
+ KV.of(intervalWindow(0, 15), pair(at(10), 1)),
+ KV.of(intervalWindow(5, 20), pair(at(10), 1)),
+ KV.of(intervalWindow(10, 25), pair(at(10), 1))));
+
+ acc = agg.reduce(acc, windowedValue(2, at(16)));
+ assertThat(
+ acc,
+ equalsToMap(
+ KV.of(intervalWindow(0, 15), pair(at(10), 1)),
+ KV.of(intervalWindow(5, 20), pair(at(16), 3)),
+ KV.of(intervalWindow(10, 25), pair(at(16), 3)),
+ KV.of(intervalWindow(15, 30), pair(at(16), 2))));
+ }
+
+ @Test
+ public void testMerge() {
+ Map<IntervalWindow, MutablePair<Instant, Integer>> acc;
+
+ assertThat(agg.merge(agg.zero(), agg.zero()), equalTo(agg.zero()));
+
+ acc = mapOf(KV.of(intervalWindow(0, 15), pair(at(0), 1)));
+
+ assertThat(agg.merge(acc, agg.zero()), equalTo(acc));
+ assertThat(agg.merge(agg.zero(), acc), equalTo(acc));
+
+ acc = agg.merge(acc, acc);
+ assertThat(acc, equalsToMap(KV.of(intervalWindow(0, 15), pair(at(0), 1 + 1))));
+
+ acc = agg.merge(acc, mapOf(KV.of(intervalWindow(5, 20), pair(at(5), 3))));
+ assertThat(
+ acc,
+ equalsToMap(
+ KV.of(intervalWindow(0, 15), pair(at(0), 1 + 1)),
+ KV.of(intervalWindow(5, 20), pair(at(5), 3))));
+
+ acc = agg.merge(mapOf(KV.of(intervalWindow(10, 25), pair(at(10), 4))), acc);
+ assertThat(
+ acc,
+ equalsToMap(
+ KV.of(intervalWindow(0, 15), pair(at(0), 1 + 1)),
+ KV.of(intervalWindow(5, 20), pair(at(5), 3)),
+ KV.of(intervalWindow(10, 25), pair(at(10), 4))));
+ }
+
+ private WindowedValue<Integer> windowedValue(Integer value, Instant ts) {
+ return WindowedValue.of(value, ts, sliding.assignWindows(ts), PaneInfo.NO_FIRING);
+ }
+ }
+
+ /**
+ * Shared implementation of tests for SessionsAggregator and MergingWindowedAggregator in {@link
+ * Aggregators}.
+ */
+ public abstract static class AbstractSessionsTest<
+ AccT extends Map<IntervalWindow, MutablePair<Instant, Integer>>> {
+
+ static final Duration SESSIONS_GAP = standardMinutes(15);
+
+ final Aggregator<WindowedValue<Integer>, AccT, Collection<WindowedValue<Integer>>> agg;
+
+ AbstractSessionsTest(WindowFn<?, ?> windowFn) {
+ agg = windowedAgg(windowFn);
+ }
+
+ abstract AccT accOf(KV<IntervalWindow, MutablePair<Instant, Integer>>... entries);
+
+ @Test
+ public void testReduce() {
+ AccT acc;
+
+ acc = agg.reduce(agg.zero(), sessionValue(10, at(0)));
+ assertThat(acc, equalsToMap(KV.of(sessionWindow(0), pair(at(0), 10))));
+
+ // 2nd session after 1st
+ acc = agg.reduce(acc, sessionValue(7, at(20)));
+ assertThat(
+ acc,
+ equalsToMap(
+ KV.of(sessionWindow(0), pair(at(0), 10)), KV.of(sessionWindow(20), pair(at(20), 7))));
+
+ // merge into 2nd session
+ acc = agg.reduce(acc, sessionValue(6, at(18)));
+ assertThat(
+ acc,
+ equalsToMap(
+ KV.of(sessionWindow(0), pair(at(0), 10)),
+ KV.of(sessionWindow(18, 35), pair(at(20), 7 + 6))));
+
+ // merge into 2nd session
+ acc = agg.reduce(acc, sessionValue(5, at(21)));
+ assertThat(
+ acc,
+ equalsToMap(
+ KV.of(sessionWindow(0), pair(at(0), 10)),
+ KV.of(sessionWindow(18, 36), pair(at(21), 7 + 6 + 5))));
+
+ // 3rd session after 2nd
+ acc = agg.reduce(acc, sessionValue(2, NOW.plus(standardMinutes(45))));
+ assertThat(
+ acc,
+ equalsToMap(
+ KV.of(sessionWindow(0), pair(at(0), 10)),
+ KV.of(sessionWindow(18, 36), pair(at(21), 7 + 6 + 5)),
+ KV.of(sessionWindow(45), pair(at(45), 2))));
+
+ // merge with 1st and 2nd
+ acc = agg.reduce(acc, sessionValue(1, at(10)));
+ assertThat(
+ acc,
+ equalsToMap(
+ KV.of(sessionWindow(0, 36), pair(at(21), 10 + 7 + 6 + 5 + 1)),
+ KV.of(sessionWindow(45), pair(at(45), 2))));
+ }
+
+ @Test
+ public void testMerge() {
+ AccT acc;
+
+ assertThat(agg.merge(agg.zero(), agg.zero()), equalTo(agg.zero()));
+
+ acc = accOf(KV.of(sessionWindow(0), pair(at(0), 1)));
+
+ assertThat(agg.merge(acc, agg.zero()), equalTo(acc));
+ assertThat(agg.merge(agg.zero(), acc), equalTo(acc));
+
+ acc = agg.merge(acc, acc);
+ assertThat(acc, equalsToMap(KV.of(sessionWindow(0), pair(at(0), 1 + 1))));
+
+ acc = agg.merge(acc, accOf(KV.of(sessionWindow(20), pair(at(20), 2))));
+ assertThat(
+ acc,
+ equalsToMap(
+ KV.of(sessionWindow(0), pair(at(0), 1 + 1)),
+ KV.of(sessionWindow(20), pair(at(20), 2))));
+
+ acc = agg.merge(accOf(KV.of(sessionWindow(40), pair(at(40), 3))), acc);
+ assertThat(
+ acc,
+ equalsToMap(
+ KV.of(sessionWindow(0), pair(at(0), 1 + 1)),
+ KV.of(sessionWindow(20), pair(at(20), 2)),
+ KV.of(sessionWindow(40), pair(at(40), 3))));
+
+ acc = agg.merge(acc, accOf(KV.of(sessionWindow(10), pair(at(10), 4))));
+ assertThat(
+ acc,
+ equalsToMap(
+ KV.of(sessionWindow(0, 35), pair(at(20), 1 + 1 + 2 + 4)),
+ KV.of(sessionWindow(40), pair(at(40), 3))));
+
+ acc = agg.merge(accOf(KV.of(sessionWindow(5, 45), pair(at(30), 5))), acc);
+ assertThat(
+ acc, equalsToMap(KV.of(sessionWindow(0, 55), pair(at(40), 1 + 1 + 2 + 4 + 3 + 5))));
+ }
+
+ private WindowedValue<Integer> sessionValue(Integer value, Instant ts) {
+ return WindowedValue.of(value, ts, new IntervalWindow(ts, SESSIONS_GAP), PaneInfo.NO_FIRING);
+ }
+
+ private IntervalWindow sessionWindow(int fromMinutes) {
+ return new IntervalWindow(at(fromMinutes), SESSIONS_GAP);
+ }
+
+ private static IntervalWindow sessionWindow(int fromMinutes, int toMinutes) {
+ return intervalWindow(fromMinutes, toMinutes);
+ }
+ }
+
+ /** Tests for specialized SessionsAggregator in {@link Aggregators}. */
+ public static class SessionsAggregatorTest
+ extends AbstractSessionsTest<TreeMap<IntervalWindow, MutablePair<Instant, Integer>>> {
+
+ public SessionsAggregatorTest() {
+ super(Sessions.withGapDuration(SESSIONS_GAP));
+ }
+
+ @Override
+ TreeMap<IntervalWindow, MutablePair<Instant, Integer>> accOf(
+ KV<IntervalWindow, MutablePair<Instant, Integer>>... entries) {
+ return new TreeMap<>(mapOf(entries));
+ }
+ }
+
+ /** Tests for MergingWindowedAggregator in {@link Aggregators}. */
+ public static class MergingWindowedAggregatorTest
+ extends AbstractSessionsTest<Map<IntervalWindow, MutablePair<Instant, Integer>>> {
+
+ public MergingWindowedAggregatorTest() {
+ super(new CustomSessions<>());
+ }
+
+ @Override
+ Map<IntervalWindow, MutablePair<Instant, Integer>> accOf(
+ KV<IntervalWindow, MutablePair<Instant, Integer>>... entries) {
+ return mapOf(entries);
+ }
+
+ /** Wrapper around {@link Sessions} to test the MergingWindowedAggregator. */
+ private static class CustomSessions<T> extends WindowFn<T, IntervalWindow> {
+ private final Sessions sessions = Sessions.withGapDuration(SESSIONS_GAP);
+
+ @Override
+ public Collection<IntervalWindow> assignWindows(WindowFn<T, IntervalWindow>.AssignContext c) {
+ return sessions.assignWindows((WindowFn.AssignContext) c);
+ }
+
+ @Override
+ public void mergeWindows(WindowFn<T, IntervalWindow>.MergeContext c) throws Exception {
+ sessions.mergeWindows((WindowFn<Object, IntervalWindow>.MergeContext) c);
+ }
+
+ @Override
+ public boolean isCompatible(WindowFn<?, ?> other) {
+ return sessions.isCompatible(other);
+ }
+
+ @Override
+ public Coder<IntervalWindow> windowCoder() {
+ return sessions.windowCoder();
+ }
+
+ @Override
+ public WindowMappingFn<IntervalWindow> getDefaultWindowMappingFn() {
+ return sessions.getDefaultWindowMappingFn();
+ }
+ }
+ }
+
+ private static IntervalWindow intervalWindow(int fromMinutes, int toMinutes) {
+ return new IntervalWindow(at(fromMinutes), at(toMinutes));
+ }
+
+ private static Instant at(int minutes) {
+ return NOW.plus(standardMinutes(minutes));
+ }
+
+ private static Matcher<Map<IntervalWindow, MutablePair<Instant, Integer>>> equalsToMap(
+ KV<IntervalWindow, MutablePair<Instant, Integer>>... entries) {
+ return equalTo(mapOf(entries));
+ }
+
+ private static Map<IntervalWindow, MutablePair<Instant, Integer>> mapOf(
+ KV<IntervalWindow, MutablePair<Instant, Integer>>... entries) {
+ return Arrays.asList(entries).stream().collect(Collectors.toMap(KV::getKey, KV::getValue));
+ }
+
+ private static MutablePair<Instant, Integer> pair(Instant ts, int value) {
+ return new MutablePair<>(ts, value);
+ }
+
+ private static <AccT>
+ Aggregator<WindowedValue<Integer>, AccT, Collection<WindowedValue<Integer>>> windowedAgg(
+ WindowFn<?, ?> windowFn) {
+ Encoder<Integer> intEnc = EncoderHelpers.encoderOf(Integer.class);
+ Encoder<BoundedWindow> windowEnc = encoderFor((Coder) IntervalWindow.getCoder());
+ Encoder<WindowedValue<Integer>> outputEnc = windowedValueEncoder(intEnc, windowEnc);
+
+ WindowingStrategy<?, ?> windowing =
+ WindowingStrategy.of(windowFn).withTimestampCombiner(TimestampCombiner.LATEST);
+
+ Aggregator<WindowedValue<Integer>, ?, Collection<WindowedValue<Integer>>> agg =
+ Aggregators.windowedValue(
+ new SimpleSum(), WindowedValue::getValue, windowing, windowEnc, intEnc, outputEnc);
+ return (Aggregator) agg;
+ }
+
+ private static class SimpleSum extends Combine.CombineFn<Integer, Integer, Integer> {
+
+ @Override
+ public Integer createAccumulator() {
+ return 0;
+ }
+
+ @Override
+ public Integer addInput(Integer acc, Integer input) {
+ return acc + input;
+ }
+
+ @Override
+ public Integer mergeAccumulators(Iterable<Integer> accs) {
+ return Streams.stream(accs.iterator()).reduce((a, b) -> a + b).orElseGet(() -> 0);
+ }
+
+ @Override
+ public Integer extractOutput(Integer acc) {
+ return acc;
+ }
+ }
+}
diff --git a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineTest.java b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineGloballyTest.java
similarity index 56%
copy from runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineTest.java
copy to runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineGloballyTest.java
index 52e60a3db54..dca8b664bd3 100644
--- a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineTest.java
+++ b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineGloballyTest.java
@@ -18,43 +18,44 @@
package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
import java.io.Serializable;
-import java.util.ArrayList;
-import java.util.List;
import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingPipelineOptions;
import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingRunner;
-import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Combine;
+import org.apache.beam.sdk.transforms.Combine.BinaryCombineFn;
+import org.apache.beam.sdk.transforms.CombineFnBase;
import org.apache.beam.sdk.transforms.Count;
import org.apache.beam.sdk.transforms.Create;
-import org.apache.beam.sdk.transforms.SerializableBiFunction;
import org.apache.beam.sdk.transforms.Sum;
import org.apache.beam.sdk.transforms.windowing.FixedWindows;
+import org.apache.beam.sdk.transforms.windowing.Sessions;
import org.apache.beam.sdk.transforms.windowing.SlidingWindows;
import org.apache.beam.sdk.transforms.windowing.Window;
-import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.TimestampedValue;
import org.joda.time.Duration;
import org.joda.time.Instant;
-import org.junit.BeforeClass;
+import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
-/** Test class for beam to spark {@link org.apache.beam.sdk.transforms.Combine} translation. */
+/**
+ * Test class for beam to spark {@link Combine#globally(CombineFnBase.GlobalCombineFn)} translation.
+ */
@RunWith(JUnit4.class)
-public class CombineTest implements Serializable {
- private static Pipeline pipeline;
+public class CombineGloballyTest implements Serializable {
+ @Rule public transient TestPipeline pipeline = TestPipeline.fromOptions(testOptions());
- @BeforeClass
- public static void beforeClass() {
+ private static PipelineOptions testOptions() {
SparkStructuredStreamingPipelineOptions options =
PipelineOptionsFactory.create().as(SparkStructuredStreamingPipelineOptions.class);
options.setRunner(SparkStructuredStreamingRunner.class);
options.setTestMode(true);
- pipeline = Pipeline.create(options);
+ return options;
}
@Test
@@ -79,65 +80,59 @@ public class CombineTest implements Serializable {
TimestampedValue.of(5, new Instant(11)),
TimestampedValue.of(6, new Instant(12))))
.apply(Window.into(FixedWindows.of(Duration.millis(10))))
- .apply(Combine.globally(Sum.ofIntegers()).withoutDefaults());
+ .apply(Sum.integersGlobally().withoutDefaults());
PAssert.that(input).containsInAnyOrder(7, 14);
+ pipeline.run();
}
@Test
- public void testCombinePerKey() {
- List<KV<Integer, Integer>> elems = new ArrayList<>();
- elems.add(KV.of(1, 1));
- elems.add(KV.of(1, 3));
- elems.add(KV.of(1, 5));
- elems.add(KV.of(2, 2));
- elems.add(KV.of(2, 4));
- elems.add(KV.of(2, 6));
-
- PCollection<KV<Integer, Integer>> input =
- pipeline.apply(Create.of(elems)).apply(Sum.integersPerKey());
- PAssert.that(input).containsInAnyOrder(KV.of(1, 9), KV.of(2, 12));
+ public void testCombineGloballyWithSlidingWindows() {
+ PCollection<Integer> input =
+ pipeline
+ .apply(
+ Create.timestamped(
+ TimestampedValue.of(1, new Instant(1)),
+ TimestampedValue.of(3, new Instant(2)),
+ TimestampedValue.of(5, new Instant(3)),
+ TimestampedValue.of(2, new Instant(1)),
+ TimestampedValue.of(4, new Instant(2)),
+ TimestampedValue.of(6, new Instant(3))))
+ .apply(Window.into(SlidingWindows.of(Duration.millis(3)).every(Duration.millis(1))))
+ .apply(Sum.integersGlobally().withoutDefaults());
+ PAssert.that(input)
+ .containsInAnyOrder(1 + 2, 1 + 2 + 3 + 4, 1 + 3 + 5 + 2 + 4 + 6, 3 + 4 + 5 + 6, 5 + 6);
pipeline.run();
}
@Test
- public void testCombinePerKeyPreservesWindowing() {
- PCollection<KV<Integer, Integer>> input =
+ public void testCombineGloballyWithMergingWindows() {
+ PCollection<Integer> input =
pipeline
.apply(
Create.timestamped(
- TimestampedValue.of(KV.of(1, 1), new Instant(1)),
- TimestampedValue.of(KV.of(1, 3), new Instant(2)),
- TimestampedValue.of(KV.of(1, 5), new Instant(11)),
- TimestampedValue.of(KV.of(2, 2), new Instant(3)),
- TimestampedValue.of(KV.of(2, 4), new Instant(11)),
- TimestampedValue.of(KV.of(2, 6), new Instant(12))))
- .apply(Window.into(FixedWindows.of(Duration.millis(10))))
- .apply(Sum.integersPerKey());
- PAssert.that(input).containsInAnyOrder(KV.of(1, 4), KV.of(1, 5), KV.of(2, 2), KV.of(2, 10));
+ TimestampedValue.of(2, new Instant(5)),
+ TimestampedValue.of(4, new Instant(11)),
+ TimestampedValue.of(6, new Instant(12))))
+ .apply(Window.into(Sessions.withGapDuration(Duration.millis(5))))
+ .apply(Sum.integersGlobally().withoutDefaults());
+
+ PAssert.that(input).containsInAnyOrder(2 /*window [5-10)*/, 10 /*window [11-17)*/);
pipeline.run();
}
@Test
- public void testCombinePerKeyWithSlidingWindows() {
- PCollection<KV<Integer, Integer>> input =
+ public void testCountGloballyWithSlidingWindows() {
+ PCollection<String> input =
pipeline
.apply(
Create.timestamped(
- TimestampedValue.of(KV.of(1, 1), new Instant(1)),
- TimestampedValue.of(KV.of(1, 3), new Instant(2)),
- TimestampedValue.of(KV.of(1, 5), new Instant(3)),
- TimestampedValue.of(KV.of(1, 2), new Instant(1)),
- TimestampedValue.of(KV.of(1, 4), new Instant(2)),
- TimestampedValue.of(KV.of(1, 6), new Instant(3))))
- .apply(Window.into(SlidingWindows.of(Duration.millis(3)).every(Duration.millis(1))))
- .apply(Sum.integersPerKey());
- PAssert.that(input)
- .containsInAnyOrder(
- KV.of(1, 1 + 2),
- KV.of(1, 1 + 2 + 3 + 4),
- KV.of(1, 1 + 3 + 5 + 2 + 4 + 6),
- KV.of(1, 3 + 4 + 5 + 6),
- KV.of(1, 5 + 6));
+ TimestampedValue.of("a", new Instant(1)),
+ TimestampedValue.of("a", new Instant(2)),
+ TimestampedValue.of("a", new Instant(2))))
+ .apply(Window.into(SlidingWindows.of(Duration.millis(2)).every(Duration.millis(1))));
+ PCollection<Long> output =
+ input.apply(Combine.globally(Count.<String>combineFn()).withoutDefaults());
+ PAssert.that(output).containsInAnyOrder(1L, 3L, 2L);
pipeline.run();
}
@@ -152,35 +147,9 @@ public class CombineTest implements Serializable {
TimestampedValue.of(5, new Instant(3))))
.apply(Window.into(SlidingWindows.of(Duration.millis(3)).every(Duration.millis(1))))
.apply(
- Combine.globally(
- Combine.BinaryCombineFn.of(
- (SerializableBiFunction<Integer, Integer, Integer>)
- (integer1, integer2) -> integer1 > integer2 ? integer1 : integer2))
+ Combine.globally(BinaryCombineFn.<Integer>of((i1, i2) -> i1 > i2 ? i1 : i2))
.withoutDefaults());
PAssert.that(input).containsInAnyOrder(1, 3, 5, 5, 5);
pipeline.run();
}
-
- @Test
- public void testCountPerElementWithSlidingWindows() {
- PCollection<String> input =
- pipeline
- .apply(
- Create.timestamped(
- TimestampedValue.of("a", new Instant(1)),
- TimestampedValue.of("a", new Instant(2)),
- TimestampedValue.of("b", new Instant(3)),
- TimestampedValue.of("b", new Instant(4))))
- .apply(Window.into(SlidingWindows.of(Duration.millis(2)).every(Duration.millis(1))));
- PCollection<KV<String, Long>> output = input.apply(Count.perElement());
- PAssert.that(output)
- .containsInAnyOrder(
- KV.of("a", 1L),
- KV.of("a", 2L),
- KV.of("a", 1L),
- KV.of("b", 1L),
- KV.of("b", 2L),
- KV.of("b", 1L));
- pipeline.run();
- }
}
diff --git a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineTest.java b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombinePerKeyTest.java
similarity index 71%
rename from runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineTest.java
rename to runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombinePerKeyTest.java
index 52e60a3db54..c8b25b3355d 100644
--- a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombineTest.java
+++ b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/CombinePerKeyTest.java
@@ -22,65 +22,44 @@ import java.util.ArrayList;
import java.util.List;
import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingPipelineOptions;
import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingRunner;
-import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.testing.PAssert;
-import org.apache.beam.sdk.transforms.Combine;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.CombineFnBase;
import org.apache.beam.sdk.transforms.Count;
import org.apache.beam.sdk.transforms.Create;
-import org.apache.beam.sdk.transforms.SerializableBiFunction;
+import org.apache.beam.sdk.transforms.Distinct;
import org.apache.beam.sdk.transforms.Sum;
import org.apache.beam.sdk.transforms.windowing.FixedWindows;
+import org.apache.beam.sdk.transforms.windowing.Sessions;
import org.apache.beam.sdk.transforms.windowing.SlidingWindows;
import org.apache.beam.sdk.transforms.windowing.Window;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.TimestampedValue;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
import org.joda.time.Duration;
import org.joda.time.Instant;
-import org.junit.BeforeClass;
+import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
-/** Test class for beam to spark {@link org.apache.beam.sdk.transforms.Combine} translation. */
+/**
+ * Test class for beam to spark {@link
+ * org.apache.beam.sdk.transforms.Combine#perKey(CombineFnBase.GlobalCombineFn)} translation.
+ */
@RunWith(JUnit4.class)
-public class CombineTest implements Serializable {
- private static Pipeline pipeline;
+public class CombinePerKeyTest implements Serializable {
+ @Rule public transient TestPipeline pipeline = TestPipeline.fromOptions(testOptions());
- @BeforeClass
- public static void beforeClass() {
+ private static PipelineOptions testOptions() {
SparkStructuredStreamingPipelineOptions options =
PipelineOptionsFactory.create().as(SparkStructuredStreamingPipelineOptions.class);
options.setRunner(SparkStructuredStreamingRunner.class);
options.setTestMode(true);
- pipeline = Pipeline.create(options);
- }
-
- @Test
- public void testCombineGlobally() {
- PCollection<Integer> input =
- pipeline.apply(Create.of(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)).apply(Sum.integersGlobally());
- PAssert.that(input).containsInAnyOrder(55);
- // uses combine per key
- pipeline.run();
- }
-
- @Test
- public void testCombineGloballyPreservesWindowing() {
- PCollection<Integer> input =
- pipeline
- .apply(
- Create.timestamped(
- TimestampedValue.of(1, new Instant(1)),
- TimestampedValue.of(2, new Instant(2)),
- TimestampedValue.of(3, new Instant(11)),
- TimestampedValue.of(4, new Instant(3)),
- TimestampedValue.of(5, new Instant(11)),
- TimestampedValue.of(6, new Instant(12))))
- .apply(Window.into(FixedWindows.of(Duration.millis(10))))
- .apply(Combine.globally(Sum.ofIntegers()).withoutDefaults());
- PAssert.that(input).containsInAnyOrder(7, 14);
+ return options;
}
@Test
@@ -99,6 +78,17 @@ public class CombineTest implements Serializable {
pipeline.run();
}
+ @Test
+ public void testDistinctViaCombinePerKey() {
+ List<Integer> elems = Lists.newArrayList(1, 2, 3, 3, 4, 4, 4, 4, 5, 5);
+
+ // Distinct is implemented in terms of CombinePerKey
+ PCollection<Integer> result = pipeline.apply(Create.of(elems)).apply(Distinct.create());
+
+ PAssert.that(result).containsInAnyOrder(1, 2, 3, 4, 5);
+ pipeline.run();
+ }
+
@Test
public void testCombinePerKeyPreservesWindowing() {
PCollection<KV<Integer, Integer>> input =
@@ -142,22 +132,26 @@ public class CombineTest implements Serializable {
}
@Test
- public void testBinaryCombineWithSlidingWindows() {
- PCollection<Integer> input =
+ public void testCombineByKeyWithMergingWindows() {
+ PCollection<KV<Integer, Integer>> input =
pipeline
.apply(
Create.timestamped(
- TimestampedValue.of(1, new Instant(1)),
- TimestampedValue.of(3, new Instant(2)),
- TimestampedValue.of(5, new Instant(3))))
- .apply(Window.into(SlidingWindows.of(Duration.millis(3)).every(Duration.millis(1))))
- .apply(
- Combine.globally(
- Combine.BinaryCombineFn.of(
- (SerializableBiFunction<Integer, Integer, Integer>)
- (integer1, integer2) -> integer1 > integer2 ? integer1 : integer2))
- .withoutDefaults());
- PAssert.that(input).containsInAnyOrder(1, 3, 5, 5, 5);
+ TimestampedValue.of(KV.of(1, 1), new Instant(5)),
+ TimestampedValue.of(KV.of(1, 3), new Instant(7)),
+ TimestampedValue.of(KV.of(1, 5), new Instant(11)),
+ TimestampedValue.of(KV.of(2, 2), new Instant(5)),
+ TimestampedValue.of(KV.of(2, 4), new Instant(11)),
+ TimestampedValue.of(KV.of(2, 6), new Instant(12))))
+ .apply(Window.into(Sessions.withGapDuration(Duration.millis(5))))
+ .apply(Sum.integersPerKey());
+
+ PAssert.that(input)
+ .containsInAnyOrder(
+ KV.of(1, 9), // window [5-16)
+ KV.of(2, 2), // window [5-10)
+ KV.of(2, 10) // window [11-17)
+ );
pipeline.run();
}
diff --git a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ComplexSourceTest.java b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ComplexSourceTest.java
index 0175d03f875..582a31a05a6 100644
--- a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ComplexSourceTest.java
+++ b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ComplexSourceTest.java
@@ -27,13 +27,15 @@ import java.util.ArrayList;
import java.util.List;
import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingPipelineOptions;
import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingRunner;
-import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.io.TextIO;
+import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.values.PCollection;
import org.junit.BeforeClass;
import org.junit.ClassRule;
+import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.runner.RunWith;
@@ -46,15 +48,18 @@ public class ComplexSourceTest implements Serializable {
private static File file;
private static List<String> lines = createLines(30);
- private static Pipeline pipeline;
+ @Rule public transient TestPipeline pipeline = TestPipeline.fromOptions(testOptions());
- @BeforeClass
- public static void beforeClass() throws IOException {
+ private static PipelineOptions testOptions() {
SparkStructuredStreamingPipelineOptions options =
PipelineOptionsFactory.create().as(SparkStructuredStreamingPipelineOptions.class);
options.setRunner(SparkStructuredStreamingRunner.class);
options.setTestMode(true);
- pipeline = Pipeline.create(options);
+ return options;
+ }
+
+ @BeforeClass
+ public static void beforeClass() throws IOException {
file = createFile(lines);
}
diff --git a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/FlattenTest.java b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/FlattenTest.java
index e126d06e685..50b443da9ae 100644
--- a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/FlattenTest.java
+++ b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/FlattenTest.java
@@ -20,14 +20,15 @@ package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
import java.io.Serializable;
import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingPipelineOptions;
import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingRunner;
-import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.Flatten;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionList;
-import org.junit.BeforeClass;
+import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -35,15 +36,14 @@ import org.junit.runners.JUnit4;
/** Test class for beam to spark flatten translation. */
@RunWith(JUnit4.class)
public class FlattenTest implements Serializable {
- private static Pipeline pipeline;
+ @Rule public transient TestPipeline pipeline = TestPipeline.fromOptions(testOptions());
- @BeforeClass
- public static void beforeClass() {
+ private static PipelineOptions testOptions() {
SparkStructuredStreamingPipelineOptions options =
PipelineOptionsFactory.create().as(SparkStructuredStreamingPipelineOptions.class);
options.setRunner(SparkStructuredStreamingRunner.class);
options.setTestMode(true);
- pipeline = Pipeline.create(options);
+ return options;
}
@Test
diff --git a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTest.java b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTest.java
index 07850232853..1a84466b319 100644
--- a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTest.java
+++ b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/GroupByKeyTest.java
@@ -17,30 +17,41 @@
*/
package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
+import static java.util.Arrays.stream;
+import static java.util.stream.Collectors.groupingBy;
+import static java.util.stream.Collectors.mapping;
+import static java.util.stream.Collectors.toList;
import static org.apache.beam.sdk.testing.SerializableMatchers.containsInAnyOrder;
import static org.hamcrest.MatcherAssert.assertThat;
import java.io.Serializable;
import java.util.ArrayList;
+import java.util.Collections;
import java.util.List;
+import java.util.Map;
import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingPipelineOptions;
import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingRunner;
-import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.SerializableMatcher;
+import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.GroupByKey;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.windowing.FixedWindows;
+import org.apache.beam.sdk.transforms.windowing.Sessions;
+import org.apache.beam.sdk.transforms.windowing.SlidingWindows;
import org.apache.beam.sdk.transforms.windowing.Window;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.TimestampedValue;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
import org.joda.time.Duration;
import org.joda.time.Instant;
-import org.junit.BeforeClass;
+import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -48,15 +59,14 @@ import org.junit.runners.JUnit4;
/** Test class for beam to spark {@link ParDo} translation. */
@RunWith(JUnit4.class)
public class GroupByKeyTest implements Serializable {
- private static Pipeline pipeline;
+ @Rule public transient TestPipeline pipeline = TestPipeline.fromOptions(testOptions());
- @BeforeClass
- public static void beforeClass() {
+ private static PipelineOptions testOptions() {
SparkStructuredStreamingPipelineOptions options =
PipelineOptionsFactory.create().as(SparkStructuredStreamingPipelineOptions.class);
options.setRunner(SparkStructuredStreamingRunner.class);
options.setTestMode(true);
- pipeline = Pipeline.create(options);
+ return options;
}
@Test
@@ -64,54 +74,89 @@ public class GroupByKeyTest implements Serializable {
pipeline
.apply(
Create.timestamped(
- TimestampedValue.of(KV.of(1, 1), new Instant(1)),
- TimestampedValue.of(KV.of(1, 3), new Instant(2)),
- TimestampedValue.of(KV.of(1, 5), new Instant(11)),
- TimestampedValue.of(KV.of(2, 2), new Instant(3)),
- TimestampedValue.of(KV.of(2, 4), new Instant(11)),
- TimestampedValue.of(KV.of(2, 6), new Instant(12))))
+ shuffleRandomly(
+ TimestampedValue.of(KV.of(1, 1), new Instant(1)),
+ TimestampedValue.of(KV.of(1, 3), new Instant(2)),
+ TimestampedValue.of(KV.of(1, 5), new Instant(11)),
+ TimestampedValue.of(KV.of(2, 2), new Instant(3)),
+ TimestampedValue.of(KV.of(2, 4), new Instant(11)),
+ TimestampedValue.of(KV.of(2, 6), new Instant(12)))))
.apply(Window.into(FixedWindows.of(Duration.millis(10))))
.apply(GroupByKey.create())
- // do manual assertion for windows because Passert do not support multiple kv with same key
- // (because multiple windows)
+ // Passert do not support multiple kv with same key (because multiple windows)
.apply(
ParDo.of(
- new DoFn<KV<Integer, Iterable<Integer>>, KV<Integer, Iterable<Integer>>>() {
+ new AssertContains<>(
+ KV.of(1, containsInAnyOrder(1, 3)), // window [0-10)
+ KV.of(1, containsInAnyOrder(5)), // window [10-20)
+ KV.of(2, containsInAnyOrder(4, 6)), // window [10-20)
+ KV.of(2, containsInAnyOrder(2)) // window [0-10)
+ )));
+ pipeline.run();
+ }
- @ProcessElement
- public void processElement(ProcessContext context) {
- KV<Integer, Iterable<Integer>> element = context.element();
- if (element.getKey() == 1) {
- if (Iterables.size(element.getValue()) == 2) {
- assertThat(element.getValue(), containsInAnyOrder(1, 3)); // window [0-10)
- } else {
- assertThat(element.getValue(), containsInAnyOrder(5)); // window [10-20)
- }
- } else { // key == 2
- if (Iterables.size(element.getValue()) == 2) {
- assertThat(element.getValue(), containsInAnyOrder(4, 6)); // window [10-20)
- } else {
- assertThat(element.getValue(), containsInAnyOrder(2)); // window [0-10)
- }
- }
- context.output(element);
- }
- }));
+ @Test
+ public void testGroupByKeyExplodesMultipleWindows() {
+ pipeline
+ .apply(
+ Create.timestamped(
+ shuffleRandomly(
+ TimestampedValue.of(KV.of(1, 1), new Instant(5)),
+ TimestampedValue.of(KV.of(1, 3), new Instant(7)),
+ TimestampedValue.of(KV.of(1, 5), new Instant(11)),
+ TimestampedValue.of(KV.of(2, 2), new Instant(5)),
+ TimestampedValue.of(KV.of(2, 4), new Instant(11)),
+ TimestampedValue.of(KV.of(2, 6), new Instant(12)))))
+ .apply(Window.into(SlidingWindows.of(Duration.millis(10)).every(Duration.millis(5))))
+ .apply(GroupByKey.create())
+ // Passert do not support multiple kv with same key (because multiple windows)
+ .apply(
+ ParDo.of(
+ new AssertContains<>(
+ KV.of(1, containsInAnyOrder(1, 3)), // window [0-10)
+ KV.of(1, containsInAnyOrder(1, 3, 5)), // window [5-15)
+ KV.of(1, containsInAnyOrder(5)), // window [10-20)
+ KV.of(2, containsInAnyOrder(2)), // window [0-10)
+ KV.of(2, containsInAnyOrder(2, 4, 6)), // window [5-15)
+ KV.of(2, containsInAnyOrder(4, 6)) // window [10-20)
+ )));
+ pipeline.run();
+ }
+
+ @Test
+ public void testGroupByKeyWithMergingWindows() {
+ pipeline
+ .apply(
+ Create.timestamped(
+ shuffleRandomly(
+ TimestampedValue.of(KV.of(1, 1), new Instant(5)),
+ TimestampedValue.of(KV.of(1, 3), new Instant(7)),
+ TimestampedValue.of(KV.of(1, 5), new Instant(11)),
+ TimestampedValue.of(KV.of(2, 2), new Instant(5)),
+ TimestampedValue.of(KV.of(2, 4), new Instant(11)),
+ TimestampedValue.of(KV.of(2, 6), new Instant(12)))))
+ .apply(Window.into(Sessions.withGapDuration(Duration.millis(5))))
+ .apply(GroupByKey.create())
+ // Passert do not support multiple kv with same key (because multiple windows)
+ .apply(
+ ParDo.of(
+ new AssertContains<>(
+ KV.of(1, containsInAnyOrder(1, 3, 5)), // window [5-16)
+ KV.of(2, containsInAnyOrder(2)), // window [5-10)
+ KV.of(2, containsInAnyOrder(4, 6)) // window [11-17)
+ )));
pipeline.run();
}
@Test
public void testGroupByKey() {
- List<KV<Integer, Integer>> elems = new ArrayList<>();
- elems.add(KV.of(1, 1));
- elems.add(KV.of(1, 3));
- elems.add(KV.of(1, 5));
- elems.add(KV.of(2, 2));
- elems.add(KV.of(2, 4));
- elems.add(KV.of(2, 6));
+ List<KV<Integer, Integer>> elems =
+ shuffleRandomly(
+ KV.of(1, 1), KV.of(1, 3), KV.of(1, 5), KV.of(2, 2), KV.of(2, 4), KV.of(2, 6));
PCollection<KV<Integer, Iterable<Integer>>> input =
pipeline.apply(Create.of(elems)).apply(GroupByKey.create());
+
PAssert.thatMap(input)
.satisfies(
results -> {
@@ -121,4 +166,27 @@ public class GroupByKeyTest implements Serializable {
});
pipeline.run();
}
+
+ static class AssertContains<K, V> extends DoFn<KV<K, Iterable<V>>, Void> {
+ private final Map<K, List<SerializableMatcher<Iterable<? extends V>>>> byKey;
+
+ public AssertContains(KV<K, SerializableMatcher<Iterable<? extends V>>>... matchers) {
+ byKey = stream(matchers).collect(groupingBy(KV::getKey, mapping(KV::getValue, toList())));
+ }
+
+ @ProcessElement
+ public void processElement(@Element KV<K, Iterable<V>> elem) {
+ assertThat("Unexpected key: " + elem.getKey(), byKey.containsKey(elem.getKey()));
+ List<V> values = ImmutableList.copyOf(elem.getValue());
+ assertThat(
+ "Unexpected values " + values + " for key " + elem.getKey(),
+ byKey.get(elem.getKey()).stream().anyMatch(m -> m.matches(values)));
+ }
+ }
+
+ private <T> List<T> shuffleRandomly(T... elems) {
+ ArrayList<T> list = Lists.newArrayList(elems);
+ Collections.shuffle(list);
+ return list;
+ }
}
diff --git a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTest.java b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTest.java
index 16d9a8b7fa8..f319173ed2b 100644
--- a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTest.java
+++ b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTest.java
@@ -20,19 +20,24 @@ package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
import java.io.Serializable;
import java.util.List;
import java.util.Map;
+import org.apache.beam.runners.spark.SparkCommonPipelineOptions;
import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingPipelineOptions;
import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingRunner;
-import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.View;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.PCollectionView;
-import org.junit.BeforeClass;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.beam.sdk.values.TupleTagList;
+import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -40,15 +45,14 @@ import org.junit.runners.JUnit4;
/** Test class for beam to spark {@link ParDo} translation. */
@RunWith(JUnit4.class)
public class ParDoTest implements Serializable {
- private static Pipeline pipeline;
+ @Rule public transient TestPipeline pipeline = TestPipeline.fromOptions(testOptions());
- @BeforeClass
- public static void beforeClass() {
+ private static PipelineOptions testOptions() {
SparkStructuredStreamingPipelineOptions options =
PipelineOptionsFactory.create().as(SparkStructuredStreamingPipelineOptions.class);
options.setRunner(SparkStructuredStreamingRunner.class);
options.setTestMode(true);
- pipeline = Pipeline.create(options);
+ return options;
}
@Test
@@ -59,6 +63,44 @@ public class ParDoTest implements Serializable {
pipeline.run();
}
+ @Test
+ public void testPardoWithOutputTagsCachedRDD() {
+ pardoWithOutputTags("MEMORY_ONLY");
+ }
+
+ @Test
+ public void testPardoWithOutputTagsCachedDataset() {
+ pardoWithOutputTags("MEMORY_AND_DISK");
+ }
+
+ private void pardoWithOutputTags(String storageLevel) {
+ pipeline.getOptions().as(SparkCommonPipelineOptions.class).setStorageLevel(storageLevel);
+
+ TupleTag<Integer> even = new TupleTag<Integer>() {};
+ TupleTag<String> unevenAsString = new TupleTag<String>() {};
+
+ DoFn<Integer, Integer> doFn =
+ new DoFn<Integer, Integer>() {
+ @ProcessElement
+ public void processElement(@Element Integer i, MultiOutputReceiver out) {
+ if (i % 2 == 0) {
+ out.get(even).output(i);
+ } else {
+ out.get(unevenAsString).output(i.toString());
+ }
+ }
+ };
+
+ PCollectionTuple outputs =
+ pipeline
+ .apply(Create.of(1, 2, 3, 4, 5, 6, 7, 8, 9, 10))
+ .apply(ParDo.of(doFn).withOutputTags(even, TupleTagList.of(unevenAsString)));
+
+ PAssert.that(outputs.get(even)).containsInAnyOrder(2, 4, 6, 8, 10);
+ PAssert.that(outputs.get(unevenAsString)).containsInAnyOrder("1", "3", "5", "7", "9");
+ pipeline.run();
+ }
+
@Test
public void testTwoPardoInRow() {
PCollection<Integer> input =
diff --git a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/SimpleSourceTest.java b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/SimpleSourceTest.java
index 70cdca630b9..0f16b644222 100644
--- a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/SimpleSourceTest.java
+++ b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/SimpleSourceTest.java
@@ -20,12 +20,13 @@ package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
import java.io.Serializable;
import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingPipelineOptions;
import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingRunner;
-import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.values.PCollection;
-import org.junit.BeforeClass;
+import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -33,15 +34,14 @@ import org.junit.runners.JUnit4;
/** Test class for beam to spark source translation. */
@RunWith(JUnit4.class)
public class SimpleSourceTest implements Serializable {
- private static Pipeline pipeline;
+ @Rule public transient TestPipeline pipeline = TestPipeline.fromOptions(testOptions());
- @BeforeClass
- public static void beforeClass() {
+ private static PipelineOptions testOptions() {
SparkStructuredStreamingPipelineOptions options =
PipelineOptionsFactory.create().as(SparkStructuredStreamingPipelineOptions.class);
options.setRunner(SparkStructuredStreamingRunner.class);
options.setTestMode(true);
- pipeline = Pipeline.create(options);
+ return options;
}
@Test
diff --git a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/WindowAssignTest.java b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/WindowAssignTest.java
index b8b41010a24..28efe754ddf 100644
--- a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/WindowAssignTest.java
+++ b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/WindowAssignTest.java
@@ -20,9 +20,10 @@ package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
import java.io.Serializable;
import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingPipelineOptions;
import org.apache.beam.runners.spark.structuredstreaming.SparkStructuredStreamingRunner;
-import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.Sum;
import org.apache.beam.sdk.transforms.windowing.FixedWindows;
@@ -31,7 +32,7 @@ import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.TimestampedValue;
import org.joda.time.Duration;
import org.joda.time.Instant;
-import org.junit.BeforeClass;
+import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -39,15 +40,14 @@ import org.junit.runners.JUnit4;
/** Test class for beam to spark window assign translation. */
@RunWith(JUnit4.class)
public class WindowAssignTest implements Serializable {
- private static Pipeline pipeline;
+ @Rule public transient TestPipeline pipeline = TestPipeline.fromOptions(testOptions());
- @BeforeClass
- public static void beforeClass() {
+ private static PipelineOptions testOptions() {
SparkStructuredStreamingPipelineOptions options =
PipelineOptionsFactory.create().as(SparkStructuredStreamingPipelineOptions.class);
options.setRunner(SparkStructuredStreamingRunner.class);
options.setTestMode(true);
- pipeline = Pipeline.create(options);
+ return options;
}
@Test
diff --git a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpersTest.java b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpersTest.java
index c8a8fba8d28..ab6e3083c54 100644
--- a/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpersTest.java
+++ b/runners/spark/3/src/test/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpersTest.java
@@ -18,32 +18,95 @@
package org.apache.beam.runners.spark.structuredstreaming.translation.helpers;
import static java.util.Arrays.asList;
-import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.fromBeamCoder;
+import static java.util.function.Function.identity;
+import static java.util.stream.Collectors.toList;
+import static java.util.stream.Collectors.toMap;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.collectionEncoder;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.encoderFor;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.kvEncoder;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.mapEncoder;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.oneOfEncoder;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.windowedValueEncoder;
+import static org.apache.beam.runners.spark.structuredstreaming.translation.utils.ScalaInterop.tuple;
+import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Predicates.notNull;
+import static org.apache.spark.sql.types.DataTypes.IntegerType;
+import static org.apache.spark.sql.types.DataTypes.StringType;
+import static org.apache.spark.sql.types.DataTypes.createStructField;
+import static org.apache.spark.sql.types.DataTypes.createStructType;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.equalTo;
-import static org.junit.Assert.assertEquals;
+import static org.hamcrest.Matchers.instanceOf;
-import java.util.Arrays;
+import java.math.BigDecimal;
+import java.math.MathContext;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
import java.util.List;
+import java.util.Map;
import java.util.Objects;
+import java.util.TreeMap;
+import java.util.function.Function;
import org.apache.beam.runners.spark.structuredstreaming.SparkSessionRule;
+import org.apache.beam.sdk.coders.BigDecimalCoder;
+import org.apache.beam.sdk.coders.BigEndianIntegerCoder;
+import org.apache.beam.sdk.coders.BigEndianLongCoder;
+import org.apache.beam.sdk.coders.BigEndianShortCoder;
+import org.apache.beam.sdk.coders.BooleanCoder;
+import org.apache.beam.sdk.coders.ByteCoder;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.DelegateCoder;
+import org.apache.beam.sdk.coders.DoubleCoder;
+import org.apache.beam.sdk.coders.FloatCoder;
+import org.apache.beam.sdk.coders.InstantCoder;
+import org.apache.beam.sdk.coders.ListCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.coders.VarIntCoder;
+import org.apache.beam.sdk.coders.VarLongCoder;
+import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.TypeDescriptor;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;
+import org.apache.spark.sql.types.DecimalType;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+import org.joda.time.Instant;
import org.junit.ClassRule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
+import scala.Tuple2;
/** Test of the wrapping of Beam Coders as Spark ExpressionEncoders. */
@RunWith(JUnit4.class)
public class EncoderHelpersTest {
+ @ClassRule public static SparkSessionRule sessionRule = new SparkSessionRule("local[1]");
- @ClassRule public static SparkSessionRule sessionRule = new SparkSessionRule();
+ private static final Encoder<GlobalWindow> windowEnc =
+ EncoderHelpers.encoderOf(GlobalWindow.class);
+
+ private static final Map<Coder<?>, List<?>> BASIC_CASES =
+ ImmutableMap.<Coder<?>, List<?>>builder()
+ .put(BooleanCoder.of(), asList(true, false, null))
+ .put(ByteCoder.of(), asList((byte) 1, null))
+ .put(BigEndianShortCoder.of(), asList((short) 1, null))
+ .put(BigEndianIntegerCoder.of(), asList(1, 2, 3, null))
+ .put(VarIntCoder.of(), asList(1, 2, 3, null))
+ .put(BigEndianLongCoder.of(), asList(1L, 2L, 3L, null))
+ .put(VarLongCoder.of(), asList(1L, 2L, 3L, null))
+ .put(FloatCoder.of(), asList((float) 1.0, (float) 2.0, null))
+ .put(DoubleCoder.of(), asList(1.0, 2.0, null))
+ .put(StringUtf8Coder.of(), asList("1", "2", null))
+ .put(BigDecimalCoder.of(), asList(bigDecimalOf(1L), bigDecimalOf(2L), null))
+ .put(InstantCoder.of(), asList(Instant.ofEpochMilli(1), null))
+ .build();
private <T> Dataset<T> createDataset(List<T> data, Encoder<T> encoder) {
Dataset<T> ds = sessionRule.getSession().createDataset(data, encoder);
@@ -52,10 +115,14 @@ public class EncoderHelpersTest {
}
@Test
- public void beamCoderToSparkEncoderTest() {
- List<Integer> data = Arrays.asList(1, 2, 3);
- Dataset<Integer> dataset = createDataset(data, EncoderHelpers.fromBeamCoder(VarIntCoder.of()));
- assertEquals(data, dataset.collectAsList());
+ public void testBeamEncoderMappings() {
+ BASIC_CASES.forEach(
+ (coder, data) -> {
+ Encoder<?> encoder = encoderFor(coder);
+ serializeAndDeserialize(data.get(0), (Encoder) encoder);
+ Dataset<?> dataset = createDataset(data, (Encoder) encoder);
+ assertThat(dataset.collect(), equalTo(data.toArray()));
+ });
}
@Test
@@ -63,10 +130,135 @@ public class EncoderHelpersTest {
// Verify concrete types are not used in coder generation.
// In case of private types this would cause an IllegalAccessError.
List<PrivateString> data = asList(new PrivateString("1"), new PrivateString("2"));
- Dataset<PrivateString> dataset = createDataset(data, fromBeamCoder(PrivateString.CODER));
+ Dataset<PrivateString> dataset = createDataset(data, encoderFor(PrivateString.CODER));
+ assertThat(dataset.collect(), equalTo(data.toArray()));
+ }
+
+ @Test
+ public void testBeamWindowedValueEncoderMappings() {
+ BASIC_CASES.forEach(
+ (coder, data) -> {
+ List<WindowedValue<?>> windowed =
+ Lists.transform(data, WindowedValue::valueInGlobalWindow);
+
+ Encoder<?> encoder = windowedValueEncoder(encoderFor(coder), windowEnc);
+ serializeAndDeserialize(windowed.get(0), (Encoder) encoder);
+
+ Dataset<?> dataset = createDataset(windowed, (Encoder) encoder);
+ assertThat(dataset.collect(), equalTo(windowed.toArray()));
+ });
+ }
+
+ @Test
+ public void testCollectionEncoder() {
+ BASIC_CASES.forEach(
+ (coder, data) -> {
+ Encoder<? extends Collection<?>> encoder = collectionEncoder(encoderFor(coder), true);
+ Collection<?> collection = Collections.unmodifiableCollection(data);
+
+ Dataset<Collection<?>> dataset = createDataset(asList(collection), (Encoder) encoder);
+ assertThat(dataset.head(), equalTo(data));
+ });
+ }
+
+ private void testMapEncoder(Class<?> cls, Function<Map<?, ?>, Map<?, ?>> decorator) {
+ BASIC_CASES.forEach(
+ (coder, data) -> {
+ Encoder<?> enc = encoderFor(coder);
+ Encoder<Map<?, ?>> mapEncoder = mapEncoder(enc, enc, (Class) cls);
+ Map<?, ?> map =
+ decorator.apply(
+ data.stream().filter(notNull()).collect(toMap(identity(), identity())));
+
+ Dataset<Map<?, ?>> dataset = createDataset(asList(map), mapEncoder);
+ Map<?, ?> head = dataset.head();
+ assertThat(head, equalTo(map));
+ assertThat(head, instanceOf(cls));
+ });
+ }
+
+ @Test
+ public void testMapEncoder() {
+ testMapEncoder(Map.class, identity());
+ }
+
+ @Test
+ public void testHashMapEncoder() {
+ testMapEncoder(HashMap.class, identity());
+ }
+
+ @Test
+ public void testTreeMapEncoder() {
+ testMapEncoder(TreeMap.class, TreeMap::new);
+ }
+
+ @Test
+ public void testBeamBinaryEncoder() {
+ List<List<String>> data = asList(asList("a1", "a2", "a3"), asList("b1", "b2"), asList("c1"));
+
+ Encoder<List<String>> encoder = encoderFor(ListCoder.of(StringUtf8Coder.of()));
+ serializeAndDeserialize(data.get(0), encoder);
+
+ Dataset<List<String>> dataset = createDataset(data, encoder);
assertThat(dataset.collect(), equalTo(data.toArray()));
}
+ @Test
+ public void testEncoderForKVCoder() {
+ List<KV<Integer, String>> data =
+ asList(KV.of(1, "value1"), KV.of(null, "value2"), KV.of(3, null));
+
+ Encoder<KV<Integer, String>> encoder =
+ kvEncoder(encoderFor(VarIntCoder.of()), encoderFor(StringUtf8Coder.of()));
+ serializeAndDeserialize(data.get(0), encoder);
+
+ Dataset<KV<Integer, String>> dataset = createDataset(data, encoder);
+
+ StructType kvSchema =
+ createStructType(
+ new StructField[] {
+ createStructField("key", IntegerType, true),
+ createStructField("value", StringType, true)
+ });
+
+ assertThat(dataset.schema(), equalTo(kvSchema));
+ assertThat(dataset.collectAsList(), equalTo(data));
+ }
+
+ @Test
+ public void testOneOffEncoder() {
+ List<Coder<?>> coders = ImmutableList.copyOf(BASIC_CASES.keySet());
+ List<Encoder<?>> encoders = coders.stream().map(EncoderHelpers::encoderFor).collect(toList());
+
+ // build oneOf tuples of type index and corresponding value
+ List<Tuple2<Integer, ?>> data =
+ BASIC_CASES.entrySet().stream()
+ .map(e -> tuple(coders.indexOf(e.getKey()), (Object) e.getValue().get(0)))
+ .collect(toList());
+
+ // dataset is a sparse dataset with only one column set per row
+ Dataset<Tuple2<Integer, ?>> dataset = createDataset(data, oneOfEncoder((List) encoders));
+ assertThat(dataset.collectAsList(), equalTo(data));
+ }
+
+ // fix scale/precision to system default to compare using equals
+ private static BigDecimal bigDecimalOf(long l) {
+ DecimalType type = DecimalType.SYSTEM_DEFAULT();
+ return new BigDecimal(l, new MathContext(type.precision())).setScale(type.scale());
+ }
+
+ // test and explicit serialization roundtrip
+ private static <T> void serializeAndDeserialize(T data, Encoder<T> enc) {
+ ExpressionEncoder<T> bound = (ExpressionEncoder<T>) enc;
+ bound =
+ bound.resolveAndBind(bound.resolveAndBind$default$1(), bound.resolveAndBind$default$2());
+
+ InternalRow row = bound.createSerializer().apply(data);
+ T deserialized = bound.createDeserializer().apply(row);
+
+ assertThat(deserialized, equalTo(data));
+ }
+
private static class PrivateString {
private static final Coder<PrivateString> CODER =
DelegateCoder.of(
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkCommonPipelineOptions.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkCommonPipelineOptions.java
index 95e7865e6f9..461cfc80917 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkCommonPipelineOptions.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkCommonPipelineOptions.java
@@ -50,6 +50,12 @@ public interface SparkCommonPipelineOptions
void setCheckpointDir(String checkpointDir);
+ @Description("Batch default storage level")
+ @Default.String("MEMORY_ONLY")
+ String getStorageLevel();
+
+ void setStorageLevel(String storageLevel);
+
@Description("Enable/disable sending aggregator values to Spark's metric sinks")
@Default.Boolean(true)
Boolean getEnableSparkMetricSinks();
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptions.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptions.java
index ecf933336c9..9a5229f21ae 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptions.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/SparkPipelineOptions.java
@@ -34,12 +34,6 @@ public interface SparkPipelineOptions extends SparkCommonPipelineOptions {
void setBatchIntervalMillis(Long batchInterval);
- @Description("Batch default storage level")
- @Default.String("MEMORY_ONLY")
- String getStorageLevel();
-
- void setStorageLevel(String storageLevel);
-
@Description("Minimum time to spend on read, for each micro-batch.")
@Default.Long(200)
Long getMinReadTimeMillis();