You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@beam.apache.org by GitBox <gi...@apache.org> on 2022/11/07 14:14:51 UTC

[GitHub] [beam] mosche opened a new pull request, #24009: [Spark dataset runner] Cache datasets if used multiple times

mosche opened a new pull request, #24009:
URL: https://github.com/apache/beam/pull/24009

   Cache datasets if used multiple times to prevent repeated lazy evaluation (closes #24008).
   
   ------------------------
   
   Thank you for your contribution! Follow this checklist to help us incorporate your contribution quickly and easily:
   
    - [ ] [**Choose reviewer(s)**](https://beam.apache.org/contribute/#make-your-change) and mention them in a comment (`R: @username`).
    - [ ] Mention the appropriate issue in your description (for example: `addresses #123`), if applicable. This will automatically add a link to the pull request in the issue. If you would like the issue to automatically close on merging the pull request, comment `fixes #<ISSUE NUMBER>` instead.
    - [ ] Update `CHANGES.md` with noteworthy changes.
    - [ ] If this contribution is large, please file an Apache [Individual Contributor License Agreement](https://www.apache.org/licenses/icla.pdf).
   
   See the [Contributor Guide](https://beam.apache.org/contribute) for more tips on [how to make review process smoother](https://beam.apache.org/contribute/get-started-contributing/#make-the-reviewers-job-easier).
   
   To check the build health, please visit [https://github.com/apache/beam/blob/master/.test-infra/BUILD_STATUS.md](https://github.com/apache/beam/blob/master/.test-infra/BUILD_STATUS.md)
   
   GitHub Actions Tests Status (on master branch)
   ------------------------------------------------------------------------------------------------
   [![Build python source distribution and wheels](https://github.com/apache/beam/workflows/Build%20python%20source%20distribution%20and%20wheels/badge.svg?branch=master&event=schedule)](https://github.com/apache/beam/actions?query=workflow%3A%22Build+python+source+distribution+and+wheels%22+branch%3Amaster+event%3Aschedule)
   [![Python tests](https://github.com/apache/beam/workflows/Python%20tests/badge.svg?branch=master&event=schedule)](https://github.com/apache/beam/actions?query=workflow%3A%22Python+Tests%22+branch%3Amaster+event%3Aschedule)
   [![Java tests](https://github.com/apache/beam/workflows/Java%20Tests/badge.svg?branch=master&event=schedule)](https://github.com/apache/beam/actions?query=workflow%3A%22Java+Tests%22+branch%3Amaster+event%3Aschedule)
   [![Go tests](https://github.com/apache/beam/workflows/Go%20tests/badge.svg?branch=master&event=schedule)](https://github.com/apache/beam/actions?query=workflow%3A%22Go+tests%22+branch%3Amaster+event%3Aschedule)
   
   See [CI.md](https://github.com/apache/beam/blob/master/CI.md) for more information about GitHub Actions CI.
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] mosche commented on pull request #24009: [Spark dataset runner] Cache datasets if used multiple times

Posted by GitBox <gi...@apache.org>.
mosche commented on PR #24009:
URL: https://github.com/apache/beam/pull/24009#issuecomment-1305822663

   Run Spark Runner Tpcds Tests


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] mosche commented on a diff in pull request #24009: [Spark dataset runner] Cache datasets if used multiple times

Posted by GitBox <gi...@apache.org>.
mosche commented on code in PR #24009:
URL: https://github.com/apache/beam/pull/24009#discussion_r1025003902


##########
runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/TransformTranslator.java:
##########
@@ -118,69 +126,53 @@ public <T> PCollection<T> getOutput(TupleTag<T> tag) {
       return pc;
     }
 
-    public Map<TupleTag<?>, PCollection<?>> getOutputs() {
-      return transform.getOutputs();
-    }
-
     public AppliedPTransform<InT, OutT, PTransform<InT, OutT>> getCurrentTransform() {
       return transform;
     }
 
+    @Override
     public <T> Dataset<WindowedValue<T>> getDataset(PCollection<T> pCollection) {
-      return cxt.getDataset(pCollection);
+      return state.getDataset(pCollection);
     }
 
-    public <T> void putDataset(PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset) {
-      cxt.putDataset(pCollection, dataset);
+    @Override
+    public <T> void putDataset(
+        PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset, boolean noCache) {

Review Comment:
   Makes sense, changed that 👍 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] mosche commented on a diff in pull request #24009: [Spark dataset runner] Cache datasets if used multiple times

Posted by GitBox <gi...@apache.org>.
mosche commented on code in PR #24009:
URL: https://github.com/apache/beam/pull/24009#discussion_r1025029738


##########
runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/PipelineTranslator.java:
##########
@@ -17,170 +17,336 @@
  */
 package org.apache.beam.runners.spark.structuredstreaming.translation;
 
+import static org.apache.beam.sdk.Pipeline.PipelineVisitor.CompositeBehavior.DO_NOT_ENTER_TRANSFORM;
+import static org.apache.beam.sdk.Pipeline.PipelineVisitor.CompositeBehavior.ENTER_TRANSFORM;
+import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull;
+import static org.apache.beam.sdk.values.PCollection.IsBounded.UNBOUNDED;
+
 import java.io.IOException;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+import javax.annotation.Nullable;
 import org.apache.beam.runners.core.construction.PTransformTranslation;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
+import org.apache.beam.runners.spark.SparkCommonPipelineOptions;
+import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderProvider;
 import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.Pipeline.PipelineVisitor;
+import org.apache.beam.sdk.annotations.Internal;
+import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.options.StreamingOptions;
 import org.apache.beam.sdk.runners.AppliedPTransform;
-import org.apache.beam.sdk.runners.TransformHierarchy;
+import org.apache.beam.sdk.runners.TransformHierarchy.Node;
 import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.View;
+import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PInput;
 import org.apache.beam.sdk.values.POutput;
 import org.apache.beam.sdk.values.PValue;
-import org.checkerframework.checker.nullness.qual.Nullable;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;
+import org.apache.spark.storage.StorageLevel;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 /**
- * {@link Pipeline.PipelineVisitor} that translates the Beam operators to their Spark counterparts.
- * It also does the pipeline preparation: mode detection, transforms replacement, classpath
- * preparation.
+ * The pipeline translator translates a Beam {@link Pipeline} into a Spark correspondence, that can
+ * then be evaluated.
+ *
+ * <p>The translation involves traversing the hierarchy of a pipeline multiple times:
+ *
+ * <ol>
+ *   <li>Detect if {@link StreamingOptions#setStreaming streaming} mode is required.
+ *   <li>Identify datasets that are repeatedly used as input and should be cached.
+ *   <li>And finally, translate each primitive or composite {@link PTransform} that is {@link
+ *       #getTransformTranslator known} and {@link TransformTranslator#canTranslate supported} into
+ *       its Spark correspondence. If a composite is not supported, it will be expanded further into
+ *       its parts and translated then.
+ * </ol>
  */
-@SuppressWarnings({
-  "nullness" // TODO(https://github.com/apache/beam/issues/20497)
-})
-public abstract class PipelineTranslator extends Pipeline.PipelineVisitor.Defaults {
+@Internal
+public abstract class PipelineTranslator {
   private static final Logger LOG = LoggerFactory.getLogger(PipelineTranslator.class);
-  protected TranslationContext translationContext;
 
-  // --------------------------------------------------------------------------------------------
-  //  Pipeline preparation methods
-  // --------------------------------------------------------------------------------------------
   public static void replaceTransforms(Pipeline pipeline, StreamingOptions options) {
     pipeline.replaceAll(SparkTransformOverrides.getDefaultOverrides(options.isStreaming()));
   }
 
   /**
-   * Visit the pipeline to determine the translation mode (batch/streaming) and update options
-   * accordingly.
+   * Analyse the pipeline to determine if we have to switch to streaming mode for the pipeline
+   * translation and update {@link StreamingOptions} accordingly.
    */
-  public static void detectTranslationMode(Pipeline pipeline, StreamingOptions options) {
-    TranslationModeDetector detector = new TranslationModeDetector();
+  public static void detectStreamingMode(Pipeline pipeline, StreamingOptions options) {
+    StreamingModeDetector detector = new StreamingModeDetector(options.isStreaming());
     pipeline.traverseTopologically(detector);
-    if (detector.getTranslationMode().equals(TranslationMode.STREAMING)) {
-      options.setStreaming(true);
+    options.setStreaming(detector.streaming);
+  }
+
+  /** Returns a {@link TransformTranslator} for the given {@link PTransform} if known. */
+  protected abstract @Nullable <
+          InT extends PInput, OutT extends POutput, TransformT extends PTransform<InT, OutT>>
+      TransformTranslator<InT, OutT, TransformT> getTransformTranslator(TransformT transform);
+
+  /**
+   * Translates a Beam pipeline into its Spark correspondence using the Spark SQL / Dataset API.
+   *
+   * <p>Note, in some cases this involves the early evaluation of some parts of the pipeline. For
+   * example, in order to use a side-input {@link org.apache.beam.sdk.values.PCollectionView
+   * PCollectionView} in a translation the corresponding Spark {@link
+   * org.apache.beam.runners.spark.translation.Dataset Dataset} might have to be collected and
+   * broadcasted to be able to continue with the translation.
+   *
+   * @return The result of the translation is an {@link EvaluationContext} that can trigger the
+   *     evaluation of the Spark pipeline.
+   */
+  public EvaluationContext translate(
+      Pipeline pipeline, SparkSession session, SparkCommonPipelineOptions options) {
+    LOG.debug("starting translation of the pipeline using {}", getClass().getName());
+    DependencyVisitor dependencies = new DependencyVisitor();
+    pipeline.traverseTopologically(dependencies);
+
+    TranslatingVisitor translator = new TranslatingVisitor(session, options, dependencies.results);
+    pipeline.traverseTopologically(translator);
+
+    return new EvaluationContext(translator.leaves, session);
+  }
+
+  /**
+   * The correspondence of a {@link PCollection} as result of translating a {@link PTransform}
+   * including additional metadata (such as name and dependents).
+   */
+  private static final class TranslationResult<T> implements EvaluationContext.NamedDataset<T> {
+    private final String name;
+    private @Nullable Dataset<WindowedValue<T>> dataset = null;
+    private final Set<PTransform<?, ?>> dependentTransforms = new HashSet<>();
+
+    private TranslationResult(PCollection<?> pCol) {
+      this.name = pCol.getName();
+    }
+
+    @Override
+    public String name() {
+      return name;
+    }
+
+    @Override
+    public @Nullable Dataset<WindowedValue<T>> dataset() {
+      return dataset;
     }
   }
 
-  /** The translation mode of the Beam Pipeline. */
-  private enum TranslationMode {
+  /** Shared, mutable state during the translation of a pipeline and omitted afterwards. */
+  interface TranslationState extends EncoderProvider {
+    <T> Dataset<WindowedValue<T>> getDataset(PCollection<T> pCollection);
+
+    <T> void putDataset(
+        PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset, boolean noCache);
 
-    /** Uses the batch mode. */
-    BATCH,
+    default <T> void putDataset(PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset) {
+      putDataset(pCollection, dataset, false);
+    }
 
-    /** Uses the streaming mode. */
-    STREAMING
+    SerializablePipelineOptions getSerializableOptions();
+
+    SparkSession getSparkSession();
   }
 
-  /** Traverses the Pipeline to determine the {@link TranslationMode} for this pipeline. */
-  private static class TranslationModeDetector extends Pipeline.PipelineVisitor.Defaults {
-    private static final Logger LOG = LoggerFactory.getLogger(TranslationModeDetector.class);
+  /**
+   * {@link PTransformVisitor} that translates supported {@link PTransform PTransforms} into their
+   * Spark correspondence.
+   *
+   * <p>Note, in some cases this involves the early evaluation of some parts of the pipeline. For
+   * example, in order to use a side-input {@link org.apache.beam.sdk.values.PCollectionView
+   * PCollectionView} in a translation the corresponding Spark {@link
+   * org.apache.beam.runners.spark.translation.Dataset Dataset} might have to be collected and
+   * broadcasted.
+   */
+  private class TranslatingVisitor extends PTransformVisitor implements TranslationState {
+    private final Map<PCollection<?>, TranslationResult<?>> translationResults;
+    private final Map<Coder<?>, ExpressionEncoder<?>> encoders;
+    private final SparkSession sparkSession;
+    private final SerializablePipelineOptions serializableOptions;
+    private final StorageLevel storageLevel;
+
+    private final Set<TranslationResult<?>> leaves;
+
+    public TranslatingVisitor(
+        SparkSession sparkSession,
+        SparkCommonPipelineOptions options,
+        Map<PCollection<?>, TranslationResult<?>> translationResults) {
+      this.sparkSession = sparkSession;
+      this.translationResults = translationResults;
+      this.serializableOptions = new SerializablePipelineOptions(options);
+      this.storageLevel = StorageLevel.fromString(options.getStorageLevel());
+      this.encoders = new HashMap<>();
+      this.leaves = new HashSet<>();
+    }
 
-    private TranslationMode translationMode;
+    @Override
+    <InT extends PInput, OutT extends POutput> void visit(
+        Node node,
+        PTransform<InT, OutT> transform,
+        TransformTranslator<InT, OutT, PTransform<InT, OutT>> translator) {
+
+      AppliedPTransform<InT, OutT, PTransform<InT, OutT>> appliedTransform =
+          (AppliedPTransform) node.toAppliedPTransform(getPipeline());
+      try {
+        LOG.info(
+            "Translating {}: {}",
+            node.isCompositeNode() ? "composite" : "primitive",
+            node.getFullName());
+        translator.translate(transform, appliedTransform, this);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
 
-    TranslationModeDetector(TranslationMode defaultMode) {
-      this.translationMode = defaultMode;
+    @Override
+    public <T> Encoder<T> encoderOf(Coder<T> coder, Factory<T> factory) {
+      return (Encoder<T>) encoders.computeIfAbsent(coder, (Factory) factory);
     }
 
-    TranslationModeDetector() {
-      this(TranslationMode.BATCH);
+    private <T> TranslationResult<T> getResult(PCollection<T> pCollection) {
+      return (TranslationResult<T>) checkStateNotNull(translationResults.get(pCollection));
     }
 
-    TranslationMode getTranslationMode() {
-      return translationMode;
+    @Override
+    public <T> Dataset<WindowedValue<T>> getDataset(PCollection<T> pCollection) {
+      return checkStateNotNull(getResult(pCollection).dataset);
     }
 
     @Override
-    public void visitValue(PValue value, TransformHierarchy.Node producer) {
-      if (translationMode.equals(TranslationMode.BATCH)) {
-        if (value instanceof PCollection
-            && ((PCollection) value).isBounded() == PCollection.IsBounded.UNBOUNDED) {
-          LOG.info(
-              "Found unbounded PCollection {}. Switching to streaming execution.", value.getName());
-          translationMode = TranslationMode.STREAMING;
+    public <T> void putDataset(
+        PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset, boolean noCache) {
+      TranslationResult<T> result = getResult(pCollection);
+      if (!noCache && result.dependentTransforms.size() > 1) {
+        LOG.info("Dataset {} will be cached.", result.name);
+        result.dataset = dataset.persist(storageLevel); // use NONE to disable
+      } else {
+        result.dataset = dataset;
+        if (result.dependentTransforms.isEmpty()) {
+          leaves.add(result);
         }
       }
     }
-  }
 
-  // --------------------------------------------------------------------------------------------
-  //  Pipeline utility methods
-  // --------------------------------------------------------------------------------------------
+    @Override
+    public SerializablePipelineOptions getSerializableOptions() {
+      return serializableOptions;
+    }
 
-  /** Get a {@link TransformTranslator} for the given {@link TransformHierarchy.Node}. */
-  protected abstract @Nullable <
-          InT extends PInput, OutT extends POutput, TransformT extends PTransform<InT, OutT>>
-      TransformTranslator<InT, OutT, TransformT> getTransformTranslator(
-          @Nullable TransformT transform);
-
-  /** Apply the given TransformTranslator to the given node. */
-  private <InT extends PInput, OutT extends POutput, TransformT extends PTransform<InT, OutT>>
-      void applyTransformTranslator(
-          TransformHierarchy.Node node,
-          TransformT transform,
-          TransformTranslator<InT, OutT, TransformT> transformTranslator) {
-    // create the applied PTransform on the translationContext
-    AppliedPTransform<InT, OutT, PTransform<InT, OutT>> appliedTransform =
-        (AppliedPTransform) node.toAppliedPTransform(getPipeline());
-    try {
-      transformTranslator.translate(transform, appliedTransform, translationContext);
-    } catch (IOException e) {
-      throw new RuntimeException(e);
+    @Override
+    public SparkSession getSparkSession() {
+      return sparkSession;
     }
   }
 
-  // --------------------------------------------------------------------------------------------
-  //  Pipeline visitor entry point
-  // --------------------------------------------------------------------------------------------
-
   /**
-   * Translates the pipeline by passing this class as a visitor.
+   * {@link PTransformVisitor} that analyses dependencies of supported {@link PTransform
+   * PTransforms} to help identify cache candidates.
    *
-   * @param pipeline The pipeline to be translated
+   * <p>The visitor may throw if a {@link PTransform} is observed that uses unsupported features.
    */
-  public void translate(Pipeline pipeline) {
-    LOG.debug("starting translation of the pipeline using {}", getClass().getName());
-    pipeline.traverseTopologically(this);
+  private class DependencyVisitor extends PTransformVisitor {
+    private final Map<PCollection<?>, TranslationResult<?>> results = new HashMap<>();
+
+    @Override
+    <InT extends PInput, OutT extends POutput> void visit(
+        Node node,
+        PTransform<InT, OutT> transform,
+        TransformTranslator<InT, OutT, PTransform<InT, OutT>> translator) {
+      for (PCollection<?> pOut : node.getOutputs().values()) {
+        results.put(pOut, new TranslationResult<>(pOut));
+        for (Map.Entry<TupleTag<?>, PCollection<?>> entry : node.getInputs().entrySet()) {
+          TranslationResult<?> input = checkStateNotNull(results.get(entry.getValue()));
+          input.dependentTransforms.add(transform);
+        }
+      }
+    }

Review Comment:
   this should be 
   ```java
         // add new translation result for every output of `transform`
         for (PCollection<?> pOut : node.getOutputs().values()) {
           results.put(pOut, new TranslationResult<>(pOut));
         }
         // track `transform` as downstream dependency for every input
         for (Map.Entry<TupleTag<?>, PCollection<?>> entry : node.getInputs().entrySet()) {
           TranslationResult<?> input = checkStateNotNull(results.get(entry.getValue()));
           input.dependentTransforms.add(transform);
         }
       }
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] mosche commented on pull request #24009: [Spark dataset runner] Cache datasets if used multiple times

Posted by GitBox <gi...@apache.org>.
mosche commented on PR #24009:
URL: https://github.com/apache/beam/pull/24009#issuecomment-1305778250

   Run Spark StructuredStreaming ValidatesRunner


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] mosche commented on pull request #24009: [Spark dataset runner] Cache datasets if used multiple times

Posted by GitBox <gi...@apache.org>.
mosche commented on PR #24009:
URL: https://github.com/apache/beam/pull/24009#issuecomment-1305743304

   Spark Runner Tpcds Tests


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] echauchot commented on a diff in pull request #24009: [Spark dataset runner] Cache datasets if used multiple times

Posted by GitBox <gi...@apache.org>.
echauchot commented on code in PR #24009:
URL: https://github.com/apache/beam/pull/24009#discussion_r1026514453


##########
runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/EvaluationContext.java:
##########
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.spark.structuredstreaming.translation;
+
+import java.util.Collection;
+import java.util.concurrent.Callable;
+import javax.annotation.Nullable;
+import org.apache.beam.sdk.annotations.Internal;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Throwables;
+import org.apache.spark.api.java.function.ForeachFunction;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.execution.ExplainMode;
+import org.apache.spark.util.Utils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * The {@link EvaluationContext} is the result of a pipeline {@link PipelineTranslator#translate
+ * translation} and can be used to evaluate / run the pipeline.
+ *
+ * <p>However, in some cases pipeline translation involves the early evaluation of some parts of the
+ * pipeline. For example, this is necessary to materialize side-inputs. The {@link
+ * EvaluationContext} won't re-evaluate such datasets.
+ */
+@Internal
+public final class EvaluationContext {
+  private static final Logger LOG = LoggerFactory.getLogger(EvaluationContext.class);
+
+  interface NamedDataset<T> {
+    String name();
+
+    @Nullable
+    Dataset<WindowedValue<T>> dataset();
+  }
+
+  private final Collection<? extends NamedDataset<?>> leaveDatasets;
+  private final SparkSession session;
+
+  EvaluationContext(Collection<? extends NamedDataset<?>> leaveDatasets, SparkSession session) {
+    this.leaveDatasets = leaveDatasets;
+    this.session = session;
+  }
+
+  /** Trigger evaluation of all leave datasets. */
+  public void evaluate() {
+    for (NamedDataset<?> ds : leaveDatasets) {
+      final Dataset<?> dataset = ds.dataset();
+      if (dataset == null) {
+        continue;
+      }
+      if (LOG.isDebugEnabled()) {
+        ExplainMode explainMode = ExplainMode.fromString("simple");
+        String execPlan = dataset.queryExecution().explainString(explainMode);
+        LOG.debug("Evaluating dataset {}:\n{}", ds.name(), execPlan);
+      }
+      // force evaluation using a dummy foreach action
+      evaluate(ds.name(), () -> dataset.foreach(NOOP));
+    }
+  }
+
+  /**
+   * The purpose of this utility is to mark the evaluation of Spark actions, both during Pipeline
+   * translation, when evaluation is required, and when finally evaluating the pipeline.
+   */
+  public static void evaluate(String name, Runnable action) {
+    long startMs = System.currentTimeMillis();
+    try {
+      action.run();
+      LOG.info("Evaluated dataset {} in {}", name, durationSince(startMs));
+    } catch (RuntimeException e) {
+      LOG.error("Failed to evaluate dataset {}: {}", name, Throwables.getRootCause(e).getMessage());
+      throw new RuntimeException(e);
+    }
+  }
+
+  /**
+   * The purpose of this utility is to mark the evaluation of Spark actions, both during Pipeline
+   * translation, when evaluation is required, and when finally evaluating the pipeline.
+   */
+  public static <T> T evaluate(String name, Callable<T> action) {

Review Comment:
   ok



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] echauchot commented on a diff in pull request #24009: [Spark dataset runner] Cache datasets if used multiple times

Posted by GitBox <gi...@apache.org>.
echauchot commented on code in PR #24009:
URL: https://github.com/apache/beam/pull/24009#discussion_r1021700719


##########
runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/EvaluationContext.java:
##########
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.spark.structuredstreaming.translation;
+
+import java.util.Collection;
+import java.util.concurrent.Callable;
+import javax.annotation.Nullable;
+import org.apache.beam.sdk.annotations.Internal;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Throwables;
+import org.apache.spark.api.java.function.ForeachFunction;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.execution.ExplainMode;
+import org.apache.spark.util.Utils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * The {@link EvaluationContext} is the result of a pipeline {@link PipelineTranslator#translate
+ * translation} and can be used to evaluate / run the pipeline.
+ *
+ * <p>However, in some cases pipeline translation involves the early evaluation of some parts of the
+ * pipeline. For example, this is necessary to materialize side-inputs. The {@link
+ * EvaluationContext} won't re-evaluate such datasets.
+ */
+@Internal
+public final class EvaluationContext {
+  private static final Logger LOG = LoggerFactory.getLogger(EvaluationContext.class);
+
+  interface NamedDataset<T> {
+    String name();
+
+    @Nullable
+    Dataset<WindowedValue<T>> dataset();
+  }
+
+  private final Collection<? extends NamedDataset<?>> leaveDatasets;
+  private final SparkSession session;
+
+  EvaluationContext(Collection<? extends NamedDataset<?>> leaveDatasets, SparkSession session) {
+    this.leaveDatasets = leaveDatasets;
+    this.session = session;
+  }
+
+  /** Trigger evaluation of all leave datasets. */
+  public void evaluate() {
+    for (NamedDataset<?> ds : leaveDatasets) {
+      final Dataset<?> dataset = ds.dataset();
+      if (dataset == null) {
+        continue;
+      }
+      if (LOG.isDebugEnabled()) {
+        ExplainMode explainMode = ExplainMode.fromString("simple");
+        String execPlan = dataset.queryExecution().explainString(explainMode);
+        LOG.debug("Evaluating dataset {}:\n{}", ds.name(), execPlan);
+      }
+      // force evaluation using a dummy foreach action
+      evaluate(ds.name(), () -> dataset.foreach(NOOP));
+    }
+  }
+
+  /**
+   * The purpose of this utility is to mark the evaluation of Spark actions, both during Pipeline
+   * translation, when evaluation is required, and when finally evaluating the pipeline.
+   */
+  public static void evaluate(String name, Runnable action) {
+    long startMs = System.currentTimeMillis();
+    try {
+      action.run();
+      LOG.info("Evaluated dataset {} in {}", name, durationSince(startMs));
+    } catch (RuntimeException e) {
+      LOG.error("Failed to evaluate dataset {}: {}", name, Throwables.getRootCause(e).getMessage());
+      throw new RuntimeException(e);
+    }
+  }
+
+  /**
+   * The purpose of this utility is to mark the evaluation of Spark actions, both during Pipeline
+   * translation, when evaluation is required, and when finally evaluating the pipeline.
+   */
+  public static <T> T evaluate(String name, Callable<T> action) {

Review Comment:
   unused



##########
runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/PipelineTranslatorBatch.java:
##########
@@ -81,27 +80,13 @@ public class PipelineTranslatorBatch extends PipelineTranslator {
 
     TRANSFORM_TRANSLATORS.put(
         SplittableParDo.PrimitiveBoundedRead.class, new ReadSourceTranslatorBatch<>());
-

Review Comment:
   It is true that the CreatePCollectionView translation was doing nothing except setting the view inside the translation context. You don't need this anymore ? This is the preparation for [#24035](https://github.com/apache/beam/issues/24035) you mentioned ?



##########
runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/TransformTranslator.java:
##########
@@ -118,69 +126,53 @@ public <T> PCollection<T> getOutput(TupleTag<T> tag) {
       return pc;
     }
 
-    public Map<TupleTag<?>, PCollection<?>> getOutputs() {
-      return transform.getOutputs();
-    }
-
     public AppliedPTransform<InT, OutT, PTransform<InT, OutT>> getCurrentTransform() {
       return transform;
     }
 
+    @Override
     public <T> Dataset<WindowedValue<T>> getDataset(PCollection<T> pCollection) {
-      return cxt.getDataset(pCollection);
+      return state.getDataset(pCollection);
     }
 
-    public <T> void putDataset(PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset) {
-      cxt.putDataset(pCollection, dataset);
+    @Override
+    public <T> void putDataset(
+        PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset, boolean noCache) {

Review Comment:
   nit: I'm not a big fan of double negations: cache = false seems better than noCache = true



##########
runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/PipelineTranslator.java:
##########
@@ -17,170 +17,336 @@
  */
 package org.apache.beam.runners.spark.structuredstreaming.translation;
 
+import static org.apache.beam.sdk.Pipeline.PipelineVisitor.CompositeBehavior.DO_NOT_ENTER_TRANSFORM;
+import static org.apache.beam.sdk.Pipeline.PipelineVisitor.CompositeBehavior.ENTER_TRANSFORM;
+import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull;
+import static org.apache.beam.sdk.values.PCollection.IsBounded.UNBOUNDED;
+
 import java.io.IOException;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+import javax.annotation.Nullable;
 import org.apache.beam.runners.core.construction.PTransformTranslation;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
+import org.apache.beam.runners.spark.SparkCommonPipelineOptions;
+import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderProvider;
 import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.Pipeline.PipelineVisitor;
+import org.apache.beam.sdk.annotations.Internal;
+import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.options.StreamingOptions;
 import org.apache.beam.sdk.runners.AppliedPTransform;
-import org.apache.beam.sdk.runners.TransformHierarchy;
+import org.apache.beam.sdk.runners.TransformHierarchy.Node;
 import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.View;
+import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PInput;
 import org.apache.beam.sdk.values.POutput;
 import org.apache.beam.sdk.values.PValue;
-import org.checkerframework.checker.nullness.qual.Nullable;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;
+import org.apache.spark.storage.StorageLevel;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 /**
- * {@link Pipeline.PipelineVisitor} that translates the Beam operators to their Spark counterparts.
- * It also does the pipeline preparation: mode detection, transforms replacement, classpath
- * preparation.
+ * The pipeline translator translates a Beam {@link Pipeline} into a Spark correspondence, that can
+ * then be evaluated.
+ *
+ * <p>The translation involves traversing the hierarchy of a pipeline multiple times:
+ *
+ * <ol>
+ *   <li>Detect if {@link StreamingOptions#setStreaming streaming} mode is required.
+ *   <li>Identify datasets that are repeatedly used as input and should be cached.
+ *   <li>And finally, translate each primitive or composite {@link PTransform} that is {@link
+ *       #getTransformTranslator known} and {@link TransformTranslator#canTranslate supported} into
+ *       its Spark correspondence. If a composite is not supported, it will be expanded further into
+ *       its parts and translated then.
+ * </ol>
  */
-@SuppressWarnings({
-  "nullness" // TODO(https://github.com/apache/beam/issues/20497)
-})
-public abstract class PipelineTranslator extends Pipeline.PipelineVisitor.Defaults {
+@Internal
+public abstract class PipelineTranslator {
   private static final Logger LOG = LoggerFactory.getLogger(PipelineTranslator.class);
-  protected TranslationContext translationContext;
 
-  // --------------------------------------------------------------------------------------------
-  //  Pipeline preparation methods
-  // --------------------------------------------------------------------------------------------
   public static void replaceTransforms(Pipeline pipeline, StreamingOptions options) {
     pipeline.replaceAll(SparkTransformOverrides.getDefaultOverrides(options.isStreaming()));
   }
 
   /**
-   * Visit the pipeline to determine the translation mode (batch/streaming) and update options
-   * accordingly.
+   * Analyse the pipeline to determine if we have to switch to streaming mode for the pipeline
+   * translation and update {@link StreamingOptions} accordingly.
    */
-  public static void detectTranslationMode(Pipeline pipeline, StreamingOptions options) {
-    TranslationModeDetector detector = new TranslationModeDetector();
+  public static void detectStreamingMode(Pipeline pipeline, StreamingOptions options) {
+    StreamingModeDetector detector = new StreamingModeDetector(options.isStreaming());
     pipeline.traverseTopologically(detector);
-    if (detector.getTranslationMode().equals(TranslationMode.STREAMING)) {
-      options.setStreaming(true);
+    options.setStreaming(detector.streaming);
+  }
+
+  /** Returns a {@link TransformTranslator} for the given {@link PTransform} if known. */
+  protected abstract @Nullable <
+          InT extends PInput, OutT extends POutput, TransformT extends PTransform<InT, OutT>>
+      TransformTranslator<InT, OutT, TransformT> getTransformTranslator(TransformT transform);
+
+  /**
+   * Translates a Beam pipeline into its Spark correspondence using the Spark SQL / Dataset API.
+   *
+   * <p>Note, in some cases this involves the early evaluation of some parts of the pipeline. For
+   * example, in order to use a side-input {@link org.apache.beam.sdk.values.PCollectionView
+   * PCollectionView} in a translation the corresponding Spark {@link
+   * org.apache.beam.runners.spark.translation.Dataset Dataset} might have to be collected and
+   * broadcasted to be able to continue with the translation.
+   *
+   * @return The result of the translation is an {@link EvaluationContext} that can trigger the
+   *     evaluation of the Spark pipeline.
+   */
+  public EvaluationContext translate(

Review Comment:
   I like this API :+1: 



##########
runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/PipelineTranslator.java:
##########
@@ -17,170 +17,336 @@
  */
 package org.apache.beam.runners.spark.structuredstreaming.translation;
 
+import static org.apache.beam.sdk.Pipeline.PipelineVisitor.CompositeBehavior.DO_NOT_ENTER_TRANSFORM;
+import static org.apache.beam.sdk.Pipeline.PipelineVisitor.CompositeBehavior.ENTER_TRANSFORM;
+import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull;
+import static org.apache.beam.sdk.values.PCollection.IsBounded.UNBOUNDED;
+
 import java.io.IOException;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+import javax.annotation.Nullable;
 import org.apache.beam.runners.core.construction.PTransformTranslation;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
+import org.apache.beam.runners.spark.SparkCommonPipelineOptions;
+import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderProvider;
 import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.Pipeline.PipelineVisitor;
+import org.apache.beam.sdk.annotations.Internal;
+import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.options.StreamingOptions;
 import org.apache.beam.sdk.runners.AppliedPTransform;
-import org.apache.beam.sdk.runners.TransformHierarchy;
+import org.apache.beam.sdk.runners.TransformHierarchy.Node;
 import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.View;
+import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PInput;
 import org.apache.beam.sdk.values.POutput;
 import org.apache.beam.sdk.values.PValue;
-import org.checkerframework.checker.nullness.qual.Nullable;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;
+import org.apache.spark.storage.StorageLevel;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 /**
- * {@link Pipeline.PipelineVisitor} that translates the Beam operators to their Spark counterparts.
- * It also does the pipeline preparation: mode detection, transforms replacement, classpath
- * preparation.
+ * The pipeline translator translates a Beam {@link Pipeline} into a Spark correspondence, that can
+ * then be evaluated.
+ *
+ * <p>The translation involves traversing the hierarchy of a pipeline multiple times:
+ *
+ * <ol>
+ *   <li>Detect if {@link StreamingOptions#setStreaming streaming} mode is required.
+ *   <li>Identify datasets that are repeatedly used as input and should be cached.
+ *   <li>And finally, translate each primitive or composite {@link PTransform} that is {@link
+ *       #getTransformTranslator known} and {@link TransformTranslator#canTranslate supported} into
+ *       its Spark correspondence. If a composite is not supported, it will be expanded further into
+ *       its parts and translated then.
+ * </ol>
  */
-@SuppressWarnings({
-  "nullness" // TODO(https://github.com/apache/beam/issues/20497)
-})
-public abstract class PipelineTranslator extends Pipeline.PipelineVisitor.Defaults {
+@Internal
+public abstract class PipelineTranslator {
   private static final Logger LOG = LoggerFactory.getLogger(PipelineTranslator.class);
-  protected TranslationContext translationContext;
 
-  // --------------------------------------------------------------------------------------------
-  //  Pipeline preparation methods
-  // --------------------------------------------------------------------------------------------
   public static void replaceTransforms(Pipeline pipeline, StreamingOptions options) {
     pipeline.replaceAll(SparkTransformOverrides.getDefaultOverrides(options.isStreaming()));
   }
 
   /**
-   * Visit the pipeline to determine the translation mode (batch/streaming) and update options
-   * accordingly.
+   * Analyse the pipeline to determine if we have to switch to streaming mode for the pipeline
+   * translation and update {@link StreamingOptions} accordingly.
    */
-  public static void detectTranslationMode(Pipeline pipeline, StreamingOptions options) {
-    TranslationModeDetector detector = new TranslationModeDetector();
+  public static void detectStreamingMode(Pipeline pipeline, StreamingOptions options) {
+    StreamingModeDetector detector = new StreamingModeDetector(options.isStreaming());
     pipeline.traverseTopologically(detector);
-    if (detector.getTranslationMode().equals(TranslationMode.STREAMING)) {
-      options.setStreaming(true);
+    options.setStreaming(detector.streaming);
+  }
+
+  /** Returns a {@link TransformTranslator} for the given {@link PTransform} if known. */
+  protected abstract @Nullable <
+          InT extends PInput, OutT extends POutput, TransformT extends PTransform<InT, OutT>>
+      TransformTranslator<InT, OutT, TransformT> getTransformTranslator(TransformT transform);
+
+  /**
+   * Translates a Beam pipeline into its Spark correspondence using the Spark SQL / Dataset API.
+   *
+   * <p>Note, in some cases this involves the early evaluation of some parts of the pipeline. For
+   * example, in order to use a side-input {@link org.apache.beam.sdk.values.PCollectionView
+   * PCollectionView} in a translation the corresponding Spark {@link
+   * org.apache.beam.runners.spark.translation.Dataset Dataset} might have to be collected and
+   * broadcasted to be able to continue with the translation.
+   *
+   * @return The result of the translation is an {@link EvaluationContext} that can trigger the
+   *     evaluation of the Spark pipeline.
+   */
+  public EvaluationContext translate(
+      Pipeline pipeline, SparkSession session, SparkCommonPipelineOptions options) {
+    LOG.debug("starting translation of the pipeline using {}", getClass().getName());
+    DependencyVisitor dependencies = new DependencyVisitor();
+    pipeline.traverseTopologically(dependencies);
+
+    TranslatingVisitor translator = new TranslatingVisitor(session, options, dependencies.results);
+    pipeline.traverseTopologically(translator);
+
+    return new EvaluationContext(translator.leaves, session);
+  }
+
+  /**
+   * The correspondence of a {@link PCollection} as result of translating a {@link PTransform}
+   * including additional metadata (such as name and dependents).
+   */
+  private static final class TranslationResult<T> implements EvaluationContext.NamedDataset<T> {
+    private final String name;
+    private @Nullable Dataset<WindowedValue<T>> dataset = null;
+    private final Set<PTransform<?, ?>> dependentTransforms = new HashSet<>();
+
+    private TranslationResult(PCollection<?> pCol) {
+      this.name = pCol.getName();
+    }
+
+    @Override
+    public String name() {
+      return name;
+    }
+
+    @Override
+    public @Nullable Dataset<WindowedValue<T>> dataset() {
+      return dataset;
     }
   }
 
-  /** The translation mode of the Beam Pipeline. */
-  private enum TranslationMode {
+  /** Shared, mutable state during the translation of a pipeline and omitted afterwards. */
+  interface TranslationState extends EncoderProvider {

Review Comment:
   elegant use of state and transparent encoder capabilities :+1: 



##########
runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderProvider.java:
##########
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.spark.structuredstreaming.translation.helpers;
+
+import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.kvEncoder;
+
+import java.util.function.Function;
+import org.apache.beam.sdk.annotations.Internal;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.values.KV;
+import org.apache.spark.sql.Encoder;
+
+@Internal
+public interface EncoderProvider {
+  interface Factory<T> extends Function<Coder<T>, Encoder<T>> {
+    Factory<?> INSTANCE = EncoderHelpers::encoderFor;
+  }
+
+  <T> Encoder<T> encoderOf(Coder<T> coder, Factory<T> factory);
+
+  default <T> Encoder<T> encoderOf(Coder<T> coder) {
+    return coder instanceof KvCoder
+        ? (Encoder<T>) kvEncoderOf((KvCoder) coder)
+        : encoderOf(coder, encoderFactory());
+  }
+
+  default <K, V> Encoder<KV<K, V>> kvEncoderOf(KvCoder<K, V> coder) {
+    return encoderOf(coder, c -> kvEncoder(keyEncoderOf(coder), valueEncoderOf(coder)));
+  }
+
+  default <K, V> Encoder<K> keyEncoderOf(KvCoder<K, V> coder) {
+    return encoderOf(coder.getKeyCoder(), encoderFactory());
+  }
+
+  default <K, V> Encoder<V> valueEncoderOf(KvCoder<K, V> coder) {
+    return encoderOf(coder.getValueCoder(), encoderFactory());
+  }
+
+  default <T> Factory<T> encoderFactory() {
+    return (Factory<T>) Factory.INSTANCE;

Review Comment:
   I would prefer that you inline the INSTANCE here (as it is used only here) and leave Factory as a simple tagging interface over Function. That would remove the need for the cast and the strange Factory containing a Factory.



##########
runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/PipelineTranslator.java:
##########
@@ -17,170 +17,336 @@
  */
 package org.apache.beam.runners.spark.structuredstreaming.translation;
 
+import static org.apache.beam.sdk.Pipeline.PipelineVisitor.CompositeBehavior.DO_NOT_ENTER_TRANSFORM;
+import static org.apache.beam.sdk.Pipeline.PipelineVisitor.CompositeBehavior.ENTER_TRANSFORM;
+import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull;
+import static org.apache.beam.sdk.values.PCollection.IsBounded.UNBOUNDED;
+
 import java.io.IOException;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+import javax.annotation.Nullable;
 import org.apache.beam.runners.core.construction.PTransformTranslation;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
+import org.apache.beam.runners.spark.SparkCommonPipelineOptions;
+import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderProvider;
 import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.Pipeline.PipelineVisitor;
+import org.apache.beam.sdk.annotations.Internal;
+import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.options.StreamingOptions;
 import org.apache.beam.sdk.runners.AppliedPTransform;
-import org.apache.beam.sdk.runners.TransformHierarchy;
+import org.apache.beam.sdk.runners.TransformHierarchy.Node;
 import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.View;
+import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PInput;
 import org.apache.beam.sdk.values.POutput;
 import org.apache.beam.sdk.values.PValue;
-import org.checkerframework.checker.nullness.qual.Nullable;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;
+import org.apache.spark.storage.StorageLevel;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 /**
- * {@link Pipeline.PipelineVisitor} that translates the Beam operators to their Spark counterparts.
- * It also does the pipeline preparation: mode detection, transforms replacement, classpath
- * preparation.
+ * The pipeline translator translates a Beam {@link Pipeline} into a Spark correspondence, that can
+ * then be evaluated.
+ *
+ * <p>The translation involves traversing the hierarchy of a pipeline multiple times:
+ *
+ * <ol>
+ *   <li>Detect if {@link StreamingOptions#setStreaming streaming} mode is required.
+ *   <li>Identify datasets that are repeatedly used as input and should be cached.
+ *   <li>And finally, translate each primitive or composite {@link PTransform} that is {@link
+ *       #getTransformTranslator known} and {@link TransformTranslator#canTranslate supported} into
+ *       its Spark correspondence. If a composite is not supported, it will be expanded further into
+ *       its parts and translated then.
+ * </ol>
  */
-@SuppressWarnings({
-  "nullness" // TODO(https://github.com/apache/beam/issues/20497)
-})
-public abstract class PipelineTranslator extends Pipeline.PipelineVisitor.Defaults {
+@Internal
+public abstract class PipelineTranslator {
   private static final Logger LOG = LoggerFactory.getLogger(PipelineTranslator.class);
-  protected TranslationContext translationContext;
 
-  // --------------------------------------------------------------------------------------------
-  //  Pipeline preparation methods
-  // --------------------------------------------------------------------------------------------
   public static void replaceTransforms(Pipeline pipeline, StreamingOptions options) {
     pipeline.replaceAll(SparkTransformOverrides.getDefaultOverrides(options.isStreaming()));
   }
 
   /**
-   * Visit the pipeline to determine the translation mode (batch/streaming) and update options
-   * accordingly.
+   * Analyse the pipeline to determine if we have to switch to streaming mode for the pipeline
+   * translation and update {@link StreamingOptions} accordingly.
    */
-  public static void detectTranslationMode(Pipeline pipeline, StreamingOptions options) {
-    TranslationModeDetector detector = new TranslationModeDetector();
+  public static void detectStreamingMode(Pipeline pipeline, StreamingOptions options) {
+    StreamingModeDetector detector = new StreamingModeDetector(options.isStreaming());
     pipeline.traverseTopologically(detector);
-    if (detector.getTranslationMode().equals(TranslationMode.STREAMING)) {
-      options.setStreaming(true);
+    options.setStreaming(detector.streaming);
+  }
+
+  /** Returns a {@link TransformTranslator} for the given {@link PTransform} if known. */
+  protected abstract @Nullable <
+          InT extends PInput, OutT extends POutput, TransformT extends PTransform<InT, OutT>>
+      TransformTranslator<InT, OutT, TransformT> getTransformTranslator(TransformT transform);
+
+  /**
+   * Translates a Beam pipeline into its Spark correspondence using the Spark SQL / Dataset API.
+   *
+   * <p>Note, in some cases this involves the early evaluation of some parts of the pipeline. For
+   * example, in order to use a side-input {@link org.apache.beam.sdk.values.PCollectionView
+   * PCollectionView} in a translation the corresponding Spark {@link
+   * org.apache.beam.runners.spark.translation.Dataset Dataset} might have to be collected and
+   * broadcasted to be able to continue with the translation.
+   *
+   * @return The result of the translation is an {@link EvaluationContext} that can trigger the
+   *     evaluation of the Spark pipeline.
+   */
+  public EvaluationContext translate(
+      Pipeline pipeline, SparkSession session, SparkCommonPipelineOptions options) {
+    LOG.debug("starting translation of the pipeline using {}", getClass().getName());
+    DependencyVisitor dependencies = new DependencyVisitor();
+    pipeline.traverseTopologically(dependencies);
+
+    TranslatingVisitor translator = new TranslatingVisitor(session, options, dependencies.results);
+    pipeline.traverseTopologically(translator);
+
+    return new EvaluationContext(translator.leaves, session);
+  }
+
+  /**
+   * The correspondence of a {@link PCollection} as result of translating a {@link PTransform}
+   * including additional metadata (such as name and dependents).
+   */
+  private static final class TranslationResult<T> implements EvaluationContext.NamedDataset<T> {
+    private final String name;
+    private @Nullable Dataset<WindowedValue<T>> dataset = null;
+    private final Set<PTransform<?, ?>> dependentTransforms = new HashSet<>();
+
+    private TranslationResult(PCollection<?> pCol) {
+      this.name = pCol.getName();
+    }
+
+    @Override
+    public String name() {
+      return name;
+    }
+
+    @Override
+    public @Nullable Dataset<WindowedValue<T>> dataset() {
+      return dataset;
     }
   }
 
-  /** The translation mode of the Beam Pipeline. */
-  private enum TranslationMode {
+  /** Shared, mutable state during the translation of a pipeline and omitted afterwards. */
+  interface TranslationState extends EncoderProvider {

Review Comment:
   good also to avoid storing all the datasets and just keep the current one and then forget about it



##########
runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/PipelineTranslator.java:
##########
@@ -17,170 +17,336 @@
  */
 package org.apache.beam.runners.spark.structuredstreaming.translation;
 
+import static org.apache.beam.sdk.Pipeline.PipelineVisitor.CompositeBehavior.DO_NOT_ENTER_TRANSFORM;
+import static org.apache.beam.sdk.Pipeline.PipelineVisitor.CompositeBehavior.ENTER_TRANSFORM;
+import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull;
+import static org.apache.beam.sdk.values.PCollection.IsBounded.UNBOUNDED;
+
 import java.io.IOException;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+import javax.annotation.Nullable;
 import org.apache.beam.runners.core.construction.PTransformTranslation;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
+import org.apache.beam.runners.spark.SparkCommonPipelineOptions;
+import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderProvider;
 import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.Pipeline.PipelineVisitor;
+import org.apache.beam.sdk.annotations.Internal;
+import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.options.StreamingOptions;
 import org.apache.beam.sdk.runners.AppliedPTransform;
-import org.apache.beam.sdk.runners.TransformHierarchy;
+import org.apache.beam.sdk.runners.TransformHierarchy.Node;
 import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.View;
+import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PInput;
 import org.apache.beam.sdk.values.POutput;
 import org.apache.beam.sdk.values.PValue;
-import org.checkerframework.checker.nullness.qual.Nullable;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;
+import org.apache.spark.storage.StorageLevel;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 /**
- * {@link Pipeline.PipelineVisitor} that translates the Beam operators to their Spark counterparts.
- * It also does the pipeline preparation: mode detection, transforms replacement, classpath
- * preparation.
+ * The pipeline translator translates a Beam {@link Pipeline} into a Spark correspondence, that can
+ * then be evaluated.
+ *
+ * <p>The translation involves traversing the hierarchy of a pipeline multiple times:
+ *
+ * <ol>
+ *   <li>Detect if {@link StreamingOptions#setStreaming streaming} mode is required.
+ *   <li>Identify datasets that are repeatedly used as input and should be cached.
+ *   <li>And finally, translate each primitive or composite {@link PTransform} that is {@link
+ *       #getTransformTranslator known} and {@link TransformTranslator#canTranslate supported} into
+ *       its Spark correspondence. If a composite is not supported, it will be expanded further into
+ *       its parts and translated then.
+ * </ol>
  */
-@SuppressWarnings({
-  "nullness" // TODO(https://github.com/apache/beam/issues/20497)
-})
-public abstract class PipelineTranslator extends Pipeline.PipelineVisitor.Defaults {
+@Internal
+public abstract class PipelineTranslator {
   private static final Logger LOG = LoggerFactory.getLogger(PipelineTranslator.class);
-  protected TranslationContext translationContext;
 
-  // --------------------------------------------------------------------------------------------
-  //  Pipeline preparation methods
-  // --------------------------------------------------------------------------------------------
   public static void replaceTransforms(Pipeline pipeline, StreamingOptions options) {
     pipeline.replaceAll(SparkTransformOverrides.getDefaultOverrides(options.isStreaming()));
   }
 
   /**
-   * Visit the pipeline to determine the translation mode (batch/streaming) and update options
-   * accordingly.
+   * Analyse the pipeline to determine if we have to switch to streaming mode for the pipeline
+   * translation and update {@link StreamingOptions} accordingly.
    */
-  public static void detectTranslationMode(Pipeline pipeline, StreamingOptions options) {
-    TranslationModeDetector detector = new TranslationModeDetector();
+  public static void detectStreamingMode(Pipeline pipeline, StreamingOptions options) {
+    StreamingModeDetector detector = new StreamingModeDetector(options.isStreaming());
     pipeline.traverseTopologically(detector);
-    if (detector.getTranslationMode().equals(TranslationMode.STREAMING)) {
-      options.setStreaming(true);
+    options.setStreaming(detector.streaming);
+  }
+
+  /** Returns a {@link TransformTranslator} for the given {@link PTransform} if known. */
+  protected abstract @Nullable <
+          InT extends PInput, OutT extends POutput, TransformT extends PTransform<InT, OutT>>
+      TransformTranslator<InT, OutT, TransformT> getTransformTranslator(TransformT transform);
+
+  /**
+   * Translates a Beam pipeline into its Spark correspondence using the Spark SQL / Dataset API.
+   *
+   * <p>Note, in some cases this involves the early evaluation of some parts of the pipeline. For
+   * example, in order to use a side-input {@link org.apache.beam.sdk.values.PCollectionView
+   * PCollectionView} in a translation the corresponding Spark {@link
+   * org.apache.beam.runners.spark.translation.Dataset Dataset} might have to be collected and
+   * broadcasted to be able to continue with the translation.
+   *
+   * @return The result of the translation is an {@link EvaluationContext} that can trigger the
+   *     evaluation of the Spark pipeline.
+   */
+  public EvaluationContext translate(
+      Pipeline pipeline, SparkSession session, SparkCommonPipelineOptions options) {
+    LOG.debug("starting translation of the pipeline using {}", getClass().getName());
+    DependencyVisitor dependencies = new DependencyVisitor();
+    pipeline.traverseTopologically(dependencies);
+
+    TranslatingVisitor translator = new TranslatingVisitor(session, options, dependencies.results);
+    pipeline.traverseTopologically(translator);
+
+    return new EvaluationContext(translator.leaves, session);
+  }
+
+  /**
+   * The correspondence of a {@link PCollection} as result of translating a {@link PTransform}
+   * including additional metadata (such as name and dependents).
+   */
+  private static final class TranslationResult<T> implements EvaluationContext.NamedDataset<T> {
+    private final String name;
+    private @Nullable Dataset<WindowedValue<T>> dataset = null;
+    private final Set<PTransform<?, ?>> dependentTransforms = new HashSet<>();
+
+    private TranslationResult(PCollection<?> pCol) {
+      this.name = pCol.getName();
+    }
+
+    @Override
+    public String name() {
+      return name;
+    }
+
+    @Override
+    public @Nullable Dataset<WindowedValue<T>> dataset() {
+      return dataset;
     }
   }
 
-  /** The translation mode of the Beam Pipeline. */
-  private enum TranslationMode {
+  /** Shared, mutable state during the translation of a pipeline and omitted afterwards. */
+  interface TranslationState extends EncoderProvider {
+    <T> Dataset<WindowedValue<T>> getDataset(PCollection<T> pCollection);
+
+    <T> void putDataset(
+        PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset, boolean noCache);
 
-    /** Uses the batch mode. */
-    BATCH,
+    default <T> void putDataset(PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset) {
+      putDataset(pCollection, dataset, false);
+    }
 
-    /** Uses the streaming mode. */
-    STREAMING
+    SerializablePipelineOptions getSerializableOptions();
+
+    SparkSession getSparkSession();
   }
 
-  /** Traverses the Pipeline to determine the {@link TranslationMode} for this pipeline. */
-  private static class TranslationModeDetector extends Pipeline.PipelineVisitor.Defaults {
-    private static final Logger LOG = LoggerFactory.getLogger(TranslationModeDetector.class);
+  /**
+   * {@link PTransformVisitor} that translates supported {@link PTransform PTransforms} into their
+   * Spark correspondence.
+   *
+   * <p>Note, in some cases this involves the early evaluation of some parts of the pipeline. For
+   * example, in order to use a side-input {@link org.apache.beam.sdk.values.PCollectionView
+   * PCollectionView} in a translation the corresponding Spark {@link
+   * org.apache.beam.runners.spark.translation.Dataset Dataset} might have to be collected and
+   * broadcasted.
+   */
+  private class TranslatingVisitor extends PTransformVisitor implements TranslationState {
+    private final Map<PCollection<?>, TranslationResult<?>> translationResults;
+    private final Map<Coder<?>, ExpressionEncoder<?>> encoders;
+    private final SparkSession sparkSession;
+    private final SerializablePipelineOptions serializableOptions;
+    private final StorageLevel storageLevel;
+
+    private final Set<TranslationResult<?>> leaves;
+
+    public TranslatingVisitor(
+        SparkSession sparkSession,
+        SparkCommonPipelineOptions options,
+        Map<PCollection<?>, TranslationResult<?>> translationResults) {
+      this.sparkSession = sparkSession;
+      this.translationResults = translationResults;
+      this.serializableOptions = new SerializablePipelineOptions(options);
+      this.storageLevel = StorageLevel.fromString(options.getStorageLevel());
+      this.encoders = new HashMap<>();
+      this.leaves = new HashSet<>();
+    }
 
-    private TranslationMode translationMode;
+    @Override
+    <InT extends PInput, OutT extends POutput> void visit(
+        Node node,
+        PTransform<InT, OutT> transform,
+        TransformTranslator<InT, OutT, PTransform<InT, OutT>> translator) {
+
+      AppliedPTransform<InT, OutT, PTransform<InT, OutT>> appliedTransform =
+          (AppliedPTransform) node.toAppliedPTransform(getPipeline());
+      try {
+        LOG.info(
+            "Translating {}: {}",
+            node.isCompositeNode() ? "composite" : "primitive",
+            node.getFullName());
+        translator.translate(transform, appliedTransform, this);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
 
-    TranslationModeDetector(TranslationMode defaultMode) {
-      this.translationMode = defaultMode;
+    @Override
+    public <T> Encoder<T> encoderOf(Coder<T> coder, Factory<T> factory) {
+      return (Encoder<T>) encoders.computeIfAbsent(coder, (Factory) factory);
     }
 
-    TranslationModeDetector() {
-      this(TranslationMode.BATCH);
+    private <T> TranslationResult<T> getResult(PCollection<T> pCollection) {
+      return (TranslationResult<T>) checkStateNotNull(translationResults.get(pCollection));
     }
 
-    TranslationMode getTranslationMode() {
-      return translationMode;
+    @Override
+    public <T> Dataset<WindowedValue<T>> getDataset(PCollection<T> pCollection) {
+      return checkStateNotNull(getResult(pCollection).dataset);
     }
 
     @Override
-    public void visitValue(PValue value, TransformHierarchy.Node producer) {
-      if (translationMode.equals(TranslationMode.BATCH)) {
-        if (value instanceof PCollection
-            && ((PCollection) value).isBounded() == PCollection.IsBounded.UNBOUNDED) {
-          LOG.info(
-              "Found unbounded PCollection {}. Switching to streaming execution.", value.getName());
-          translationMode = TranslationMode.STREAMING;
+    public <T> void putDataset(
+        PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset, boolean noCache) {
+      TranslationResult<T> result = getResult(pCollection);
+      if (!noCache && result.dependentTransforms.size() > 1) {
+        LOG.info("Dataset {} will be cached.", result.name);
+        result.dataset = dataset.persist(storageLevel); // use NONE to disable
+      } else {
+        result.dataset = dataset;
+        if (result.dependentTransforms.isEmpty()) {
+          leaves.add(result);
         }
       }
     }
-  }
 
-  // --------------------------------------------------------------------------------------------
-  //  Pipeline utility methods
-  // --------------------------------------------------------------------------------------------
+    @Override
+    public SerializablePipelineOptions getSerializableOptions() {
+      return serializableOptions;
+    }
 
-  /** Get a {@link TransformTranslator} for the given {@link TransformHierarchy.Node}. */
-  protected abstract @Nullable <
-          InT extends PInput, OutT extends POutput, TransformT extends PTransform<InT, OutT>>
-      TransformTranslator<InT, OutT, TransformT> getTransformTranslator(
-          @Nullable TransformT transform);
-
-  /** Apply the given TransformTranslator to the given node. */
-  private <InT extends PInput, OutT extends POutput, TransformT extends PTransform<InT, OutT>>
-      void applyTransformTranslator(
-          TransformHierarchy.Node node,
-          TransformT transform,
-          TransformTranslator<InT, OutT, TransformT> transformTranslator) {
-    // create the applied PTransform on the translationContext
-    AppliedPTransform<InT, OutT, PTransform<InT, OutT>> appliedTransform =
-        (AppliedPTransform) node.toAppliedPTransform(getPipeline());
-    try {
-      transformTranslator.translate(transform, appliedTransform, translationContext);
-    } catch (IOException e) {
-      throw new RuntimeException(e);
+    @Override
+    public SparkSession getSparkSession() {
+      return sparkSession;
     }
   }
 
-  // --------------------------------------------------------------------------------------------
-  //  Pipeline visitor entry point
-  // --------------------------------------------------------------------------------------------
-
   /**
-   * Translates the pipeline by passing this class as a visitor.
+   * {@link PTransformVisitor} that analyses dependencies of supported {@link PTransform
+   * PTransforms} to help identify cache candidates.
    *
-   * @param pipeline The pipeline to be translated
+   * <p>The visitor may throw if a {@link PTransform} is observed that uses unsupported features.
    */
-  public void translate(Pipeline pipeline) {
-    LOG.debug("starting translation of the pipeline using {}", getClass().getName());
-    pipeline.traverseTopologically(this);
+  private class DependencyVisitor extends PTransformVisitor {
+    private final Map<PCollection<?>, TranslationResult<?>> results = new HashMap<>();
+
+    @Override
+    <InT extends PInput, OutT extends POutput> void visit(
+        Node node,
+        PTransform<InT, OutT> transform,
+        TransformTranslator<InT, OutT, PTransform<InT, OutT>> translator) {
+      for (PCollection<?> pOut : node.getOutputs().values()) {
+        results.put(pOut, new TranslationResult<>(pOut));
+        for (Map.Entry<TupleTag<?>, PCollection<?>> entry : node.getInputs().entrySet()) {
+          TranslationResult<?> input = checkStateNotNull(results.get(entry.getValue()));
+          input.dependentTransforms.add(transform);
+        }
+      }
+    }

Review Comment:
   I don't understand this algo especially the nested loops because `pOut` is not referenced inside the inner loop.
   Why don't you populate `results` based on outputs first and then do the second loop?
   
   There might be something I don't get, if it's the case, please add some comments to the algo for ease of maintenance.



##########
runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/PipelineTranslator.java:
##########
@@ -17,170 +17,336 @@
  */
 package org.apache.beam.runners.spark.structuredstreaming.translation;
 
+import static org.apache.beam.sdk.Pipeline.PipelineVisitor.CompositeBehavior.DO_NOT_ENTER_TRANSFORM;
+import static org.apache.beam.sdk.Pipeline.PipelineVisitor.CompositeBehavior.ENTER_TRANSFORM;
+import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull;
+import static org.apache.beam.sdk.values.PCollection.IsBounded.UNBOUNDED;
+
 import java.io.IOException;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+import javax.annotation.Nullable;
 import org.apache.beam.runners.core.construction.PTransformTranslation;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
+import org.apache.beam.runners.spark.SparkCommonPipelineOptions;
+import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderProvider;
 import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.Pipeline.PipelineVisitor;
+import org.apache.beam.sdk.annotations.Internal;
+import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.options.StreamingOptions;
 import org.apache.beam.sdk.runners.AppliedPTransform;
-import org.apache.beam.sdk.runners.TransformHierarchy;
+import org.apache.beam.sdk.runners.TransformHierarchy.Node;
 import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.View;
+import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PInput;
 import org.apache.beam.sdk.values.POutput;
 import org.apache.beam.sdk.values.PValue;
-import org.checkerframework.checker.nullness.qual.Nullable;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;
+import org.apache.spark.storage.StorageLevel;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 /**
- * {@link Pipeline.PipelineVisitor} that translates the Beam operators to their Spark counterparts.
- * It also does the pipeline preparation: mode detection, transforms replacement, classpath
- * preparation.
+ * The pipeline translator translates a Beam {@link Pipeline} into a Spark correspondence, that can
+ * then be evaluated.
+ *
+ * <p>The translation involves traversing the hierarchy of a pipeline multiple times:
+ *
+ * <ol>
+ *   <li>Detect if {@link StreamingOptions#setStreaming streaming} mode is required.
+ *   <li>Identify datasets that are repeatedly used as input and should be cached.
+ *   <li>And finally, translate each primitive or composite {@link PTransform} that is {@link
+ *       #getTransformTranslator known} and {@link TransformTranslator#canTranslate supported} into
+ *       its Spark correspondence. If a composite is not supported, it will be expanded further into
+ *       its parts and translated then.
+ * </ol>
  */
-@SuppressWarnings({
-  "nullness" // TODO(https://github.com/apache/beam/issues/20497)
-})
-public abstract class PipelineTranslator extends Pipeline.PipelineVisitor.Defaults {
+@Internal
+public abstract class PipelineTranslator {
   private static final Logger LOG = LoggerFactory.getLogger(PipelineTranslator.class);
-  protected TranslationContext translationContext;
 
-  // --------------------------------------------------------------------------------------------
-  //  Pipeline preparation methods
-  // --------------------------------------------------------------------------------------------
   public static void replaceTransforms(Pipeline pipeline, StreamingOptions options) {
     pipeline.replaceAll(SparkTransformOverrides.getDefaultOverrides(options.isStreaming()));
   }
 
   /**
-   * Visit the pipeline to determine the translation mode (batch/streaming) and update options
-   * accordingly.
+   * Analyse the pipeline to determine if we have to switch to streaming mode for the pipeline
+   * translation and update {@link StreamingOptions} accordingly.
    */
-  public static void detectTranslationMode(Pipeline pipeline, StreamingOptions options) {
-    TranslationModeDetector detector = new TranslationModeDetector();
+  public static void detectStreamingMode(Pipeline pipeline, StreamingOptions options) {
+    StreamingModeDetector detector = new StreamingModeDetector(options.isStreaming());
     pipeline.traverseTopologically(detector);
-    if (detector.getTranslationMode().equals(TranslationMode.STREAMING)) {
-      options.setStreaming(true);
+    options.setStreaming(detector.streaming);
+  }
+
+  /** Returns a {@link TransformTranslator} for the given {@link PTransform} if known. */
+  protected abstract @Nullable <
+          InT extends PInput, OutT extends POutput, TransformT extends PTransform<InT, OutT>>
+      TransformTranslator<InT, OutT, TransformT> getTransformTranslator(TransformT transform);
+
+  /**
+   * Translates a Beam pipeline into its Spark correspondence using the Spark SQL / Dataset API.
+   *
+   * <p>Note, in some cases this involves the early evaluation of some parts of the pipeline. For
+   * example, in order to use a side-input {@link org.apache.beam.sdk.values.PCollectionView
+   * PCollectionView} in a translation the corresponding Spark {@link
+   * org.apache.beam.runners.spark.translation.Dataset Dataset} might have to be collected and
+   * broadcasted to be able to continue with the translation.
+   *
+   * @return The result of the translation is an {@link EvaluationContext} that can trigger the
+   *     evaluation of the Spark pipeline.
+   */
+  public EvaluationContext translate(
+      Pipeline pipeline, SparkSession session, SparkCommonPipelineOptions options) {
+    LOG.debug("starting translation of the pipeline using {}", getClass().getName());
+    DependencyVisitor dependencies = new DependencyVisitor();
+    pipeline.traverseTopologically(dependencies);
+
+    TranslatingVisitor translator = new TranslatingVisitor(session, options, dependencies.results);
+    pipeline.traverseTopologically(translator);
+
+    return new EvaluationContext(translator.leaves, session);
+  }
+
+  /**
+   * The correspondence of a {@link PCollection} as result of translating a {@link PTransform}
+   * including additional metadata (such as name and dependents).
+   */
+  private static final class TranslationResult<T> implements EvaluationContext.NamedDataset<T> {
+    private final String name;
+    private @Nullable Dataset<WindowedValue<T>> dataset = null;
+    private final Set<PTransform<?, ?>> dependentTransforms = new HashSet<>();
+
+    private TranslationResult(PCollection<?> pCol) {
+      this.name = pCol.getName();
+    }
+
+    @Override
+    public String name() {
+      return name;
+    }
+
+    @Override
+    public @Nullable Dataset<WindowedValue<T>> dataset() {
+      return dataset;
     }
   }
 
-  /** The translation mode of the Beam Pipeline. */
-  private enum TranslationMode {
+  /** Shared, mutable state during the translation of a pipeline and omitted afterwards. */
+  interface TranslationState extends EncoderProvider {
+    <T> Dataset<WindowedValue<T>> getDataset(PCollection<T> pCollection);
+
+    <T> void putDataset(
+        PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset, boolean noCache);
 
-    /** Uses the batch mode. */
-    BATCH,
+    default <T> void putDataset(PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset) {
+      putDataset(pCollection, dataset, false);
+    }
 
-    /** Uses the streaming mode. */
-    STREAMING
+    SerializablePipelineOptions getSerializableOptions();
+
+    SparkSession getSparkSession();
   }
 
-  /** Traverses the Pipeline to determine the {@link TranslationMode} for this pipeline. */
-  private static class TranslationModeDetector extends Pipeline.PipelineVisitor.Defaults {
-    private static final Logger LOG = LoggerFactory.getLogger(TranslationModeDetector.class);
+  /**
+   * {@link PTransformVisitor} that translates supported {@link PTransform PTransforms} into their
+   * Spark correspondence.
+   *
+   * <p>Note, in some cases this involves the early evaluation of some parts of the pipeline. For
+   * example, in order to use a side-input {@link org.apache.beam.sdk.values.PCollectionView
+   * PCollectionView} in a translation the corresponding Spark {@link
+   * org.apache.beam.runners.spark.translation.Dataset Dataset} might have to be collected and
+   * broadcasted.
+   */
+  private class TranslatingVisitor extends PTransformVisitor implements TranslationState {
+    private final Map<PCollection<?>, TranslationResult<?>> translationResults;
+    private final Map<Coder<?>, ExpressionEncoder<?>> encoders;
+    private final SparkSession sparkSession;
+    private final SerializablePipelineOptions serializableOptions;
+    private final StorageLevel storageLevel;
+
+    private final Set<TranslationResult<?>> leaves;
+
+    public TranslatingVisitor(
+        SparkSession sparkSession,
+        SparkCommonPipelineOptions options,
+        Map<PCollection<?>, TranslationResult<?>> translationResults) {
+      this.sparkSession = sparkSession;
+      this.translationResults = translationResults;
+      this.serializableOptions = new SerializablePipelineOptions(options);
+      this.storageLevel = StorageLevel.fromString(options.getStorageLevel());
+      this.encoders = new HashMap<>();
+      this.leaves = new HashSet<>();
+    }
 
-    private TranslationMode translationMode;
+    @Override
+    <InT extends PInput, OutT extends POutput> void visit(
+        Node node,
+        PTransform<InT, OutT> transform,
+        TransformTranslator<InT, OutT, PTransform<InT, OutT>> translator) {
+
+      AppliedPTransform<InT, OutT, PTransform<InT, OutT>> appliedTransform =
+          (AppliedPTransform) node.toAppliedPTransform(getPipeline());
+      try {
+        LOG.info(
+            "Translating {}: {}",
+            node.isCompositeNode() ? "composite" : "primitive",
+            node.getFullName());
+        translator.translate(transform, appliedTransform, this);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
 
-    TranslationModeDetector(TranslationMode defaultMode) {
-      this.translationMode = defaultMode;
+    @Override
+    public <T> Encoder<T> encoderOf(Coder<T> coder, Factory<T> factory) {
+      return (Encoder<T>) encoders.computeIfAbsent(coder, (Factory) factory);
     }
 
-    TranslationModeDetector() {
-      this(TranslationMode.BATCH);
+    private <T> TranslationResult<T> getResult(PCollection<T> pCollection) {
+      return (TranslationResult<T>) checkStateNotNull(translationResults.get(pCollection));
     }
 
-    TranslationMode getTranslationMode() {
-      return translationMode;
+    @Override
+    public <T> Dataset<WindowedValue<T>> getDataset(PCollection<T> pCollection) {
+      return checkStateNotNull(getResult(pCollection).dataset);
     }
 
     @Override
-    public void visitValue(PValue value, TransformHierarchy.Node producer) {
-      if (translationMode.equals(TranslationMode.BATCH)) {
-        if (value instanceof PCollection
-            && ((PCollection) value).isBounded() == PCollection.IsBounded.UNBOUNDED) {
-          LOG.info(
-              "Found unbounded PCollection {}. Switching to streaming execution.", value.getName());
-          translationMode = TranslationMode.STREAMING;
+    public <T> void putDataset(
+        PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset, boolean noCache) {
+      TranslationResult<T> result = getResult(pCollection);
+      if (!noCache && result.dependentTransforms.size() > 1) {
+        LOG.info("Dataset {} will be cached.", result.name);
+        result.dataset = dataset.persist(storageLevel); // use NONE to disable
+      } else {
+        result.dataset = dataset;
+        if (result.dependentTransforms.isEmpty()) {
+          leaves.add(result);
         }
       }
     }
-  }
 
-  // --------------------------------------------------------------------------------------------
-  //  Pipeline utility methods
-  // --------------------------------------------------------------------------------------------
+    @Override
+    public SerializablePipelineOptions getSerializableOptions() {
+      return serializableOptions;
+    }
 
-  /** Get a {@link TransformTranslator} for the given {@link TransformHierarchy.Node}. */
-  protected abstract @Nullable <
-          InT extends PInput, OutT extends POutput, TransformT extends PTransform<InT, OutT>>
-      TransformTranslator<InT, OutT, TransformT> getTransformTranslator(
-          @Nullable TransformT transform);
-
-  /** Apply the given TransformTranslator to the given node. */
-  private <InT extends PInput, OutT extends POutput, TransformT extends PTransform<InT, OutT>>
-      void applyTransformTranslator(
-          TransformHierarchy.Node node,
-          TransformT transform,
-          TransformTranslator<InT, OutT, TransformT> transformTranslator) {
-    // create the applied PTransform on the translationContext
-    AppliedPTransform<InT, OutT, PTransform<InT, OutT>> appliedTransform =
-        (AppliedPTransform) node.toAppliedPTransform(getPipeline());
-    try {
-      transformTranslator.translate(transform, appliedTransform, translationContext);
-    } catch (IOException e) {
-      throw new RuntimeException(e);
+    @Override
+    public SparkSession getSparkSession() {
+      return sparkSession;
     }
   }
 
-  // --------------------------------------------------------------------------------------------
-  //  Pipeline visitor entry point
-  // --------------------------------------------------------------------------------------------
-
   /**
-   * Translates the pipeline by passing this class as a visitor.
+   * {@link PTransformVisitor} that analyses dependencies of supported {@link PTransform
+   * PTransforms} to help identify cache candidates.
    *
-   * @param pipeline The pipeline to be translated
+   * <p>The visitor may throw if a {@link PTransform} is observed that uses unsupported features.
    */
-  public void translate(Pipeline pipeline) {
-    LOG.debug("starting translation of the pipeline using {}", getClass().getName());
-    pipeline.traverseTopologically(this);
+  private class DependencyVisitor extends PTransformVisitor {
+    private final Map<PCollection<?>, TranslationResult<?>> results = new HashMap<>();
+
+    @Override
+    <InT extends PInput, OutT extends POutput> void visit(
+        Node node,
+        PTransform<InT, OutT> transform,
+        TransformTranslator<InT, OutT, PTransform<InT, OutT>> translator) {
+      for (PCollection<?> pOut : node.getOutputs().values()) {
+        results.put(pOut, new TranslationResult<>(pOut));
+        for (Map.Entry<TupleTag<?>, PCollection<?>> entry : node.getInputs().entrySet()) {
+          TranslationResult<?> input = checkStateNotNull(results.get(entry.getValue()));
+          input.dependentTransforms.add(transform);
+        }
+      }
+    }
   }
 
-  // --------------------------------------------------------------------------------------------
-  //  Pipeline Visitor Methods
-  // --------------------------------------------------------------------------------------------
+  /**
+   * An abstract {@link PipelineVisitor} that visits all translatable {@link PTransform} pipeline
+   * nodes of a pipeline with the respective {@link TransformTranslator}.
+   *
+   * <p>The visitor may throw if a {@link PTransform} is observed that uses unsupported features.
+   */
+  private abstract class PTransformVisitor extends PipelineVisitor.Defaults {
 
-  @Override
-  public CompositeBehavior enterCompositeTransform(TransformHierarchy.Node node) {
-    PTransform<PInput, POutput> transform = (PTransform<PInput, POutput>) node.getTransform();
-    TransformTranslator<PInput, POutput, PTransform<PInput, POutput>> transformTranslator =
-        getTransformTranslator(transform);
+    /** Visit the {@link PTransform} with its respective {@link TransformTranslator}. */
+    abstract <InT extends PInput, OutT extends POutput> void visit(
+        Node node,
+        PTransform<InT, OutT> transform,
+        TransformTranslator<InT, OutT, PTransform<InT, OutT>> translator);
 
-    if (transformTranslator != null) {
-      LOG.info("Translating composite: {}", node.getFullName());
-      applyTransformTranslator(node, transform, transformTranslator);
-      return CompositeBehavior.DO_NOT_ENTER_TRANSFORM;
-    } else {
-      return CompositeBehavior.ENTER_TRANSFORM;
+    @Override
+    public final CompositeBehavior enterCompositeTransform(Node node) {
+      PTransform<PInput, POutput> transform = (PTransform<PInput, POutput>) node.getTransform();
+      TransformTranslator<PInput, POutput, PTransform<PInput, POutput>> translator =
+          getTranslator(transform);
+      if (transform != null && translator != null) {
+        visit(node, transform, translator);
+        return DO_NOT_ENTER_TRANSFORM;
+      } else {
+        return ENTER_TRANSFORM;
+      }
+    }
+
+    @Override
+    public final void visitPrimitiveTransform(Node node) {
+      PTransform<PInput, POutput> transform = (PTransform<PInput, POutput>) node.getTransform();
+      if (transform == null || transform.getClass().equals(View.CreatePCollectionView.class)) {
+        return; // ignore, nothing to be translated here
+      }
+      TransformTranslator<PInput, POutput, PTransform<PInput, POutput>> translator =
+          getTranslator(transform);
+      if (translator == null) {
+        String urn = PTransformTranslation.urnForTransform(transform);
+        throw new UnsupportedOperationException("Transform " + urn + " is not supported.");
+      }
+      visit(node, transform, translator);
     }
-  }
 
-  @Override
-  public void visitPrimitiveTransform(TransformHierarchy.Node node) {
-    LOG.info("Translating primitive: {}", node.getFullName());
-    // get the transformation corresponding to the node we are
-    // currently visiting and translate it into its Spark alternative.
-    PTransform<PInput, POutput> transform = (PTransform<PInput, POutput>) node.getTransform();
-    TransformTranslator<PInput, POutput, PTransform<PInput, POutput>> transformTranslator =
-        getTransformTranslator(transform);
-
-    if (transformTranslator == null) {
-      String transformUrn = PTransformTranslation.urnForTransform(node.getTransform());
-      throw new UnsupportedOperationException(
-          "The transform " + transformUrn + " is currently not supported.");
-    }
-    applyTransformTranslator(node, transform, transformTranslator);
+    /** {@link TransformTranslator} for {@link PTransform} if translation is known and supported. */
+    private @Nullable TransformTranslator<PInput, POutput, PTransform<PInput, POutput>>

Review Comment:
   please rename to getTransformTranslatorIfTranslatable for clarity



##########
runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java:
##########
@@ -220,8 +225,9 @@ private SideInputBroadcast createBroadcastSideInputs(
       Coder<WindowedValue<?>> windowedValueCoder =
           (Coder<WindowedValue<?>>)
               (Coder<?>) WindowedValue.getFullCoder(pc.getCoder(), windowCoder);
-      Dataset<WindowedValue<?>> broadcastSet = context.getSideInputDataset(sideInput);
-      List<WindowedValue<?>> valuesList = broadcastSet.collectAsList();
+      Dataset<WindowedValue<?>> broadcastSet = context.getDataset((PCollection) pc);
+      List<WindowedValue<?>> valuesList =

Review Comment:
   now I see the preparation for side inputs you mentioned. :+1: 
   You evaluate the associated PCollection when you need to broadcast it for side inputs



##########
runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/PipelineTranslator.java:
##########
@@ -17,170 +17,336 @@
  */
 package org.apache.beam.runners.spark.structuredstreaming.translation;
 
+import static org.apache.beam.sdk.Pipeline.PipelineVisitor.CompositeBehavior.DO_NOT_ENTER_TRANSFORM;
+import static org.apache.beam.sdk.Pipeline.PipelineVisitor.CompositeBehavior.ENTER_TRANSFORM;
+import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull;
+import static org.apache.beam.sdk.values.PCollection.IsBounded.UNBOUNDED;
+
 import java.io.IOException;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+import javax.annotation.Nullable;
 import org.apache.beam.runners.core.construction.PTransformTranslation;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
+import org.apache.beam.runners.spark.SparkCommonPipelineOptions;
+import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderProvider;
 import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.Pipeline.PipelineVisitor;
+import org.apache.beam.sdk.annotations.Internal;
+import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.options.StreamingOptions;
 import org.apache.beam.sdk.runners.AppliedPTransform;
-import org.apache.beam.sdk.runners.TransformHierarchy;
+import org.apache.beam.sdk.runners.TransformHierarchy.Node;
 import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.View;
+import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PInput;
 import org.apache.beam.sdk.values.POutput;
 import org.apache.beam.sdk.values.PValue;
-import org.checkerframework.checker.nullness.qual.Nullable;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;
+import org.apache.spark.storage.StorageLevel;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 /**
- * {@link Pipeline.PipelineVisitor} that translates the Beam operators to their Spark counterparts.
- * It also does the pipeline preparation: mode detection, transforms replacement, classpath
- * preparation.
+ * The pipeline translator translates a Beam {@link Pipeline} into a Spark correspondence, that can
+ * then be evaluated.
+ *
+ * <p>The translation involves traversing the hierarchy of a pipeline multiple times:
+ *
+ * <ol>
+ *   <li>Detect if {@link StreamingOptions#setStreaming streaming} mode is required.
+ *   <li>Identify datasets that are repeatedly used as input and should be cached.
+ *   <li>And finally, translate each primitive or composite {@link PTransform} that is {@link
+ *       #getTransformTranslator known} and {@link TransformTranslator#canTranslate supported} into
+ *       its Spark correspondence. If a composite is not supported, it will be expanded further into
+ *       its parts and translated then.
+ * </ol>
  */
-@SuppressWarnings({
-  "nullness" // TODO(https://github.com/apache/beam/issues/20497)
-})
-public abstract class PipelineTranslator extends Pipeline.PipelineVisitor.Defaults {
+@Internal
+public abstract class PipelineTranslator {
   private static final Logger LOG = LoggerFactory.getLogger(PipelineTranslator.class);
-  protected TranslationContext translationContext;
 
-  // --------------------------------------------------------------------------------------------
-  //  Pipeline preparation methods
-  // --------------------------------------------------------------------------------------------
   public static void replaceTransforms(Pipeline pipeline, StreamingOptions options) {
     pipeline.replaceAll(SparkTransformOverrides.getDefaultOverrides(options.isStreaming()));
   }
 
   /**
-   * Visit the pipeline to determine the translation mode (batch/streaming) and update options
-   * accordingly.
+   * Analyse the pipeline to determine if we have to switch to streaming mode for the pipeline
+   * translation and update {@link StreamingOptions} accordingly.
    */
-  public static void detectTranslationMode(Pipeline pipeline, StreamingOptions options) {
-    TranslationModeDetector detector = new TranslationModeDetector();
+  public static void detectStreamingMode(Pipeline pipeline, StreamingOptions options) {
+    StreamingModeDetector detector = new StreamingModeDetector(options.isStreaming());
     pipeline.traverseTopologically(detector);
-    if (detector.getTranslationMode().equals(TranslationMode.STREAMING)) {
-      options.setStreaming(true);
+    options.setStreaming(detector.streaming);
+  }
+
+  /** Returns a {@link TransformTranslator} for the given {@link PTransform} if known. */
+  protected abstract @Nullable <
+          InT extends PInput, OutT extends POutput, TransformT extends PTransform<InT, OutT>>
+      TransformTranslator<InT, OutT, TransformT> getTransformTranslator(TransformT transform);
+
+  /**
+   * Translates a Beam pipeline into its Spark correspondence using the Spark SQL / Dataset API.
+   *
+   * <p>Note, in some cases this involves the early evaluation of some parts of the pipeline. For
+   * example, in order to use a side-input {@link org.apache.beam.sdk.values.PCollectionView
+   * PCollectionView} in a translation the corresponding Spark {@link
+   * org.apache.beam.runners.spark.translation.Dataset Dataset} might have to be collected and
+   * broadcasted to be able to continue with the translation.
+   *
+   * @return The result of the translation is an {@link EvaluationContext} that can trigger the
+   *     evaluation of the Spark pipeline.
+   */
+  public EvaluationContext translate(
+      Pipeline pipeline, SparkSession session, SparkCommonPipelineOptions options) {
+    LOG.debug("starting translation of the pipeline using {}", getClass().getName());
+    DependencyVisitor dependencies = new DependencyVisitor();
+    pipeline.traverseTopologically(dependencies);
+
+    TranslatingVisitor translator = new TranslatingVisitor(session, options, dependencies.results);
+    pipeline.traverseTopologically(translator);
+
+    return new EvaluationContext(translator.leaves, session);

Review Comment:
   Very clear ! :+1: 



##########
runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/TransformTranslator.java:
##########
@@ -51,22 +48,24 @@
 import scala.reflect.ClassTag;
 
 /**
- * Supports translation between a Beam transform, and Spark's operations on Datasets.
+ * A {@link TransformTranslator} provides the capability to translate a specific primitive or
+ * composite {@link PTransform} into its Spark correspondence.
  *
- * <p>WARNING: Do not make this class serializable! It could easily hide situations where
- * unnecessary references leak into Spark closures.
+ * <p>WARNING: {@link TransformTranslator TransformTranslators} should never be serializable! This
+ * could easily hide situations where unnecessary references leak into Spark closures.
  */
+@Internal
 public abstract class TransformTranslator<
     InT extends PInput, OutT extends POutput, TransformT extends PTransform<? extends InT, OutT>> {
 
   protected abstract void translate(TransformT transform, Context cxt) throws IOException;
 
-  public final void translate(
+  protected final void translate(

Review Comment:
   can be package local now that this method is no more the one that is overridden



##########
runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/PipelineTranslator.java:
##########
@@ -17,170 +17,336 @@
  */
 package org.apache.beam.runners.spark.structuredstreaming.translation;
 
+import static org.apache.beam.sdk.Pipeline.PipelineVisitor.CompositeBehavior.DO_NOT_ENTER_TRANSFORM;
+import static org.apache.beam.sdk.Pipeline.PipelineVisitor.CompositeBehavior.ENTER_TRANSFORM;
+import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull;
+import static org.apache.beam.sdk.values.PCollection.IsBounded.UNBOUNDED;
+
 import java.io.IOException;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+import javax.annotation.Nullable;
 import org.apache.beam.runners.core.construction.PTransformTranslation;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
+import org.apache.beam.runners.spark.SparkCommonPipelineOptions;
+import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderProvider;
 import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.Pipeline.PipelineVisitor;
+import org.apache.beam.sdk.annotations.Internal;
+import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.options.StreamingOptions;
 import org.apache.beam.sdk.runners.AppliedPTransform;
-import org.apache.beam.sdk.runners.TransformHierarchy;
+import org.apache.beam.sdk.runners.TransformHierarchy.Node;
 import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.View;
+import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PInput;
 import org.apache.beam.sdk.values.POutput;
 import org.apache.beam.sdk.values.PValue;
-import org.checkerframework.checker.nullness.qual.Nullable;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;
+import org.apache.spark.storage.StorageLevel;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 /**
- * {@link Pipeline.PipelineVisitor} that translates the Beam operators to their Spark counterparts.
- * It also does the pipeline preparation: mode detection, transforms replacement, classpath
- * preparation.
+ * The pipeline translator translates a Beam {@link Pipeline} into a Spark correspondence, that can
+ * then be evaluated.
+ *
+ * <p>The translation involves traversing the hierarchy of a pipeline multiple times:
+ *
+ * <ol>
+ *   <li>Detect if {@link StreamingOptions#setStreaming streaming} mode is required.
+ *   <li>Identify datasets that are repeatedly used as input and should be cached.
+ *   <li>And finally, translate each primitive or composite {@link PTransform} that is {@link
+ *       #getTransformTranslator known} and {@link TransformTranslator#canTranslate supported} into
+ *       its Spark correspondence. If a composite is not supported, it will be expanded further into
+ *       its parts and translated then.
+ * </ol>
  */
-@SuppressWarnings({
-  "nullness" // TODO(https://github.com/apache/beam/issues/20497)
-})
-public abstract class PipelineTranslator extends Pipeline.PipelineVisitor.Defaults {
+@Internal
+public abstract class PipelineTranslator {
   private static final Logger LOG = LoggerFactory.getLogger(PipelineTranslator.class);
-  protected TranslationContext translationContext;
 
-  // --------------------------------------------------------------------------------------------
-  //  Pipeline preparation methods
-  // --------------------------------------------------------------------------------------------
   public static void replaceTransforms(Pipeline pipeline, StreamingOptions options) {
     pipeline.replaceAll(SparkTransformOverrides.getDefaultOverrides(options.isStreaming()));
   }
 
   /**
-   * Visit the pipeline to determine the translation mode (batch/streaming) and update options
-   * accordingly.
+   * Analyse the pipeline to determine if we have to switch to streaming mode for the pipeline
+   * translation and update {@link StreamingOptions} accordingly.
    */
-  public static void detectTranslationMode(Pipeline pipeline, StreamingOptions options) {
-    TranslationModeDetector detector = new TranslationModeDetector();
+  public static void detectStreamingMode(Pipeline pipeline, StreamingOptions options) {
+    StreamingModeDetector detector = new StreamingModeDetector(options.isStreaming());
     pipeline.traverseTopologically(detector);
-    if (detector.getTranslationMode().equals(TranslationMode.STREAMING)) {
-      options.setStreaming(true);
+    options.setStreaming(detector.streaming);
+  }
+
+  /** Returns a {@link TransformTranslator} for the given {@link PTransform} if known. */
+  protected abstract @Nullable <
+          InT extends PInput, OutT extends POutput, TransformT extends PTransform<InT, OutT>>
+      TransformTranslator<InT, OutT, TransformT> getTransformTranslator(TransformT transform);
+
+  /**
+   * Translates a Beam pipeline into its Spark correspondence using the Spark SQL / Dataset API.
+   *
+   * <p>Note, in some cases this involves the early evaluation of some parts of the pipeline. For
+   * example, in order to use a side-input {@link org.apache.beam.sdk.values.PCollectionView
+   * PCollectionView} in a translation the corresponding Spark {@link
+   * org.apache.beam.runners.spark.translation.Dataset Dataset} might have to be collected and
+   * broadcasted to be able to continue with the translation.
+   *
+   * @return The result of the translation is an {@link EvaluationContext} that can trigger the
+   *     evaluation of the Spark pipeline.
+   */
+  public EvaluationContext translate(
+      Pipeline pipeline, SparkSession session, SparkCommonPipelineOptions options) {
+    LOG.debug("starting translation of the pipeline using {}", getClass().getName());
+    DependencyVisitor dependencies = new DependencyVisitor();
+    pipeline.traverseTopologically(dependencies);
+
+    TranslatingVisitor translator = new TranslatingVisitor(session, options, dependencies.results);
+    pipeline.traverseTopologically(translator);
+
+    return new EvaluationContext(translator.leaves, session);
+  }
+
+  /**
+   * The correspondence of a {@link PCollection} as result of translating a {@link PTransform}
+   * including additional metadata (such as name and dependents).
+   */
+  private static final class TranslationResult<T> implements EvaluationContext.NamedDataset<T> {
+    private final String name;
+    private @Nullable Dataset<WindowedValue<T>> dataset = null;
+    private final Set<PTransform<?, ?>> dependentTransforms = new HashSet<>();
+
+    private TranslationResult(PCollection<?> pCol) {
+      this.name = pCol.getName();
+    }
+
+    @Override
+    public String name() {
+      return name;
+    }
+
+    @Override
+    public @Nullable Dataset<WindowedValue<T>> dataset() {
+      return dataset;
     }
   }
 
-  /** The translation mode of the Beam Pipeline. */
-  private enum TranslationMode {
+  /** Shared, mutable state during the translation of a pipeline and omitted afterwards. */
+  interface TranslationState extends EncoderProvider {
+    <T> Dataset<WindowedValue<T>> getDataset(PCollection<T> pCollection);
+
+    <T> void putDataset(
+        PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset, boolean noCache);
 
-    /** Uses the batch mode. */
-    BATCH,
+    default <T> void putDataset(PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset) {
+      putDataset(pCollection, dataset, false);
+    }
 
-    /** Uses the streaming mode. */
-    STREAMING
+    SerializablePipelineOptions getSerializableOptions();
+
+    SparkSession getSparkSession();
   }
 
-  /** Traverses the Pipeline to determine the {@link TranslationMode} for this pipeline. */
-  private static class TranslationModeDetector extends Pipeline.PipelineVisitor.Defaults {
-    private static final Logger LOG = LoggerFactory.getLogger(TranslationModeDetector.class);
+  /**
+   * {@link PTransformVisitor} that translates supported {@link PTransform PTransforms} into their
+   * Spark correspondence.
+   *
+   * <p>Note, in some cases this involves the early evaluation of some parts of the pipeline. For
+   * example, in order to use a side-input {@link org.apache.beam.sdk.values.PCollectionView
+   * PCollectionView} in a translation the corresponding Spark {@link
+   * org.apache.beam.runners.spark.translation.Dataset Dataset} might have to be collected and
+   * broadcasted.
+   */
+  private class TranslatingVisitor extends PTransformVisitor implements TranslationState {
+    private final Map<PCollection<?>, TranslationResult<?>> translationResults;
+    private final Map<Coder<?>, ExpressionEncoder<?>> encoders;
+    private final SparkSession sparkSession;
+    private final SerializablePipelineOptions serializableOptions;
+    private final StorageLevel storageLevel;
+
+    private final Set<TranslationResult<?>> leaves;
+
+    public TranslatingVisitor(
+        SparkSession sparkSession,
+        SparkCommonPipelineOptions options,
+        Map<PCollection<?>, TranslationResult<?>> translationResults) {
+      this.sparkSession = sparkSession;
+      this.translationResults = translationResults;
+      this.serializableOptions = new SerializablePipelineOptions(options);
+      this.storageLevel = StorageLevel.fromString(options.getStorageLevel());
+      this.encoders = new HashMap<>();
+      this.leaves = new HashSet<>();
+    }
 
-    private TranslationMode translationMode;
+    @Override
+    <InT extends PInput, OutT extends POutput> void visit(
+        Node node,
+        PTransform<InT, OutT> transform,
+        TransformTranslator<InT, OutT, PTransform<InT, OutT>> translator) {
+
+      AppliedPTransform<InT, OutT, PTransform<InT, OutT>> appliedTransform =

Review Comment:
   Better use of types :+1: 



##########
runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/PipelineTranslatorBatch.java:
##########
@@ -81,27 +80,13 @@ public class PipelineTranslatorBatch extends PipelineTranslator {
 
     TRANSFORM_TRANSLATORS.put(
         SplittableParDo.PrimitiveBoundedRead.class, new ReadSourceTranslatorBatch<>());
-
-    TRANSFORM_TRANSLATORS.put(
-        View.CreatePCollectionView.class, new CreatePCollectionViewTranslatorBatch<>());
-  }
-
-  public PipelineTranslatorBatch(SparkStructuredStreamingPipelineOptions options) {
-    translationContext = new TranslationContext(options);
   }
 
-  /** Returns a translator for the given node, if it is possible, otherwise null. */
+  /** Returns a {@link TransformTranslator} for the given {@link PTransform} if known. */
   @Override
   @Nullable
   protected <InT extends PInput, OutT extends POutput, TransformT extends PTransform<InT, OutT>>
-      TransformTranslator<InT, OutT, TransformT> getTransformTranslator(
-          @Nullable TransformT transform) {
-    // Root of the graph is null
-    if (transform == null) {
-      return null;
-    }
-    TransformTranslator<InT, OutT, TransformT> translator =
-        TRANSFORM_TRANSLATORS.get(transform.getClass());
-    return translator != null && translator.canTranslate(transform) ? translator : null;

Review Comment:
   you no more check that the transform can be translated here but you rather do it in the visitor. Right?



##########
runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/PipelineTranslator.java:
##########
@@ -17,170 +17,336 @@
  */
 package org.apache.beam.runners.spark.structuredstreaming.translation;
 
+import static org.apache.beam.sdk.Pipeline.PipelineVisitor.CompositeBehavior.DO_NOT_ENTER_TRANSFORM;
+import static org.apache.beam.sdk.Pipeline.PipelineVisitor.CompositeBehavior.ENTER_TRANSFORM;
+import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull;
+import static org.apache.beam.sdk.values.PCollection.IsBounded.UNBOUNDED;
+
 import java.io.IOException;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+import javax.annotation.Nullable;
 import org.apache.beam.runners.core.construction.PTransformTranslation;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
+import org.apache.beam.runners.spark.SparkCommonPipelineOptions;
+import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderProvider;
 import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.Pipeline.PipelineVisitor;
+import org.apache.beam.sdk.annotations.Internal;
+import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.options.StreamingOptions;
 import org.apache.beam.sdk.runners.AppliedPTransform;
-import org.apache.beam.sdk.runners.TransformHierarchy;
+import org.apache.beam.sdk.runners.TransformHierarchy.Node;
 import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.View;
+import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PInput;
 import org.apache.beam.sdk.values.POutput;
 import org.apache.beam.sdk.values.PValue;
-import org.checkerframework.checker.nullness.qual.Nullable;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;
+import org.apache.spark.storage.StorageLevel;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 /**
- * {@link Pipeline.PipelineVisitor} that translates the Beam operators to their Spark counterparts.
- * It also does the pipeline preparation: mode detection, transforms replacement, classpath
- * preparation.
+ * The pipeline translator translates a Beam {@link Pipeline} into a Spark correspondence, that can
+ * then be evaluated.
+ *
+ * <p>The translation involves traversing the hierarchy of a pipeline multiple times:
+ *
+ * <ol>
+ *   <li>Detect if {@link StreamingOptions#setStreaming streaming} mode is required.
+ *   <li>Identify datasets that are repeatedly used as input and should be cached.
+ *   <li>And finally, translate each primitive or composite {@link PTransform} that is {@link
+ *       #getTransformTranslator known} and {@link TransformTranslator#canTranslate supported} into
+ *       its Spark correspondence. If a composite is not supported, it will be expanded further into
+ *       its parts and translated then.
+ * </ol>
  */
-@SuppressWarnings({
-  "nullness" // TODO(https://github.com/apache/beam/issues/20497)
-})
-public abstract class PipelineTranslator extends Pipeline.PipelineVisitor.Defaults {
+@Internal
+public abstract class PipelineTranslator {
   private static final Logger LOG = LoggerFactory.getLogger(PipelineTranslator.class);
-  protected TranslationContext translationContext;
 
-  // --------------------------------------------------------------------------------------------
-  //  Pipeline preparation methods
-  // --------------------------------------------------------------------------------------------
   public static void replaceTransforms(Pipeline pipeline, StreamingOptions options) {
     pipeline.replaceAll(SparkTransformOverrides.getDefaultOverrides(options.isStreaming()));
   }
 
   /**
-   * Visit the pipeline to determine the translation mode (batch/streaming) and update options
-   * accordingly.
+   * Analyse the pipeline to determine if we have to switch to streaming mode for the pipeline
+   * translation and update {@link StreamingOptions} accordingly.
    */
-  public static void detectTranslationMode(Pipeline pipeline, StreamingOptions options) {
-    TranslationModeDetector detector = new TranslationModeDetector();
+  public static void detectStreamingMode(Pipeline pipeline, StreamingOptions options) {
+    StreamingModeDetector detector = new StreamingModeDetector(options.isStreaming());
     pipeline.traverseTopologically(detector);
-    if (detector.getTranslationMode().equals(TranslationMode.STREAMING)) {
-      options.setStreaming(true);
+    options.setStreaming(detector.streaming);
+  }
+
+  /** Returns a {@link TransformTranslator} for the given {@link PTransform} if known. */
+  protected abstract @Nullable <
+          InT extends PInput, OutT extends POutput, TransformT extends PTransform<InT, OutT>>
+      TransformTranslator<InT, OutT, TransformT> getTransformTranslator(TransformT transform);
+
+  /**
+   * Translates a Beam pipeline into its Spark correspondence using the Spark SQL / Dataset API.
+   *
+   * <p>Note, in some cases this involves the early evaluation of some parts of the pipeline. For
+   * example, in order to use a side-input {@link org.apache.beam.sdk.values.PCollectionView
+   * PCollectionView} in a translation the corresponding Spark {@link
+   * org.apache.beam.runners.spark.translation.Dataset Dataset} might have to be collected and
+   * broadcasted to be able to continue with the translation.
+   *
+   * @return The result of the translation is an {@link EvaluationContext} that can trigger the
+   *     evaluation of the Spark pipeline.
+   */
+  public EvaluationContext translate(
+      Pipeline pipeline, SparkSession session, SparkCommonPipelineOptions options) {
+    LOG.debug("starting translation of the pipeline using {}", getClass().getName());
+    DependencyVisitor dependencies = new DependencyVisitor();
+    pipeline.traverseTopologically(dependencies);
+
+    TranslatingVisitor translator = new TranslatingVisitor(session, options, dependencies.results);
+    pipeline.traverseTopologically(translator);
+
+    return new EvaluationContext(translator.leaves, session);
+  }
+
+  /**
+   * The correspondence of a {@link PCollection} as result of translating a {@link PTransform}
+   * including additional metadata (such as name and dependents).
+   */
+  private static final class TranslationResult<T> implements EvaluationContext.NamedDataset<T> {
+    private final String name;
+    private @Nullable Dataset<WindowedValue<T>> dataset = null;
+    private final Set<PTransform<?, ?>> dependentTransforms = new HashSet<>();
+
+    private TranslationResult(PCollection<?> pCol) {
+      this.name = pCol.getName();
+    }
+
+    @Override
+    public String name() {
+      return name;
+    }
+
+    @Override
+    public @Nullable Dataset<WindowedValue<T>> dataset() {
+      return dataset;
     }
   }
 
-  /** The translation mode of the Beam Pipeline. */
-  private enum TranslationMode {
+  /** Shared, mutable state during the translation of a pipeline and omitted afterwards. */
+  interface TranslationState extends EncoderProvider {
+    <T> Dataset<WindowedValue<T>> getDataset(PCollection<T> pCollection);
+
+    <T> void putDataset(
+        PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset, boolean noCache);
 
-    /** Uses the batch mode. */
-    BATCH,
+    default <T> void putDataset(PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset) {
+      putDataset(pCollection, dataset, false);
+    }
 
-    /** Uses the streaming mode. */
-    STREAMING
+    SerializablePipelineOptions getSerializableOptions();
+
+    SparkSession getSparkSession();
   }
 
-  /** Traverses the Pipeline to determine the {@link TranslationMode} for this pipeline. */
-  private static class TranslationModeDetector extends Pipeline.PipelineVisitor.Defaults {
-    private static final Logger LOG = LoggerFactory.getLogger(TranslationModeDetector.class);
+  /**
+   * {@link PTransformVisitor} that translates supported {@link PTransform PTransforms} into their
+   * Spark correspondence.
+   *
+   * <p>Note, in some cases this involves the early evaluation of some parts of the pipeline. For
+   * example, in order to use a side-input {@link org.apache.beam.sdk.values.PCollectionView
+   * PCollectionView} in a translation the corresponding Spark {@link
+   * org.apache.beam.runners.spark.translation.Dataset Dataset} might have to be collected and
+   * broadcasted.
+   */
+  private class TranslatingVisitor extends PTransformVisitor implements TranslationState {
+    private final Map<PCollection<?>, TranslationResult<?>> translationResults;
+    private final Map<Coder<?>, ExpressionEncoder<?>> encoders;
+    private final SparkSession sparkSession;
+    private final SerializablePipelineOptions serializableOptions;
+    private final StorageLevel storageLevel;
+
+    private final Set<TranslationResult<?>> leaves;
+
+    public TranslatingVisitor(
+        SparkSession sparkSession,
+        SparkCommonPipelineOptions options,
+        Map<PCollection<?>, TranslationResult<?>> translationResults) {
+      this.sparkSession = sparkSession;
+      this.translationResults = translationResults;
+      this.serializableOptions = new SerializablePipelineOptions(options);
+      this.storageLevel = StorageLevel.fromString(options.getStorageLevel());
+      this.encoders = new HashMap<>();
+      this.leaves = new HashSet<>();
+    }
 
-    private TranslationMode translationMode;
+    @Override
+    <InT extends PInput, OutT extends POutput> void visit(
+        Node node,
+        PTransform<InT, OutT> transform,
+        TransformTranslator<InT, OutT, PTransform<InT, OutT>> translator) {
+
+      AppliedPTransform<InT, OutT, PTransform<InT, OutT>> appliedTransform =
+          (AppliedPTransform) node.toAppliedPTransform(getPipeline());
+      try {
+        LOG.info(
+            "Translating {}: {}",
+            node.isCompositeNode() ? "composite" : "primitive",
+            node.getFullName());
+        translator.translate(transform, appliedTransform, this);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
 
-    TranslationModeDetector(TranslationMode defaultMode) {
-      this.translationMode = defaultMode;
+    @Override
+    public <T> Encoder<T> encoderOf(Coder<T> coder, Factory<T> factory) {
+      return (Encoder<T>) encoders.computeIfAbsent(coder, (Factory) factory);
     }
 
-    TranslationModeDetector() {
-      this(TranslationMode.BATCH);
+    private <T> TranslationResult<T> getResult(PCollection<T> pCollection) {
+      return (TranslationResult<T>) checkStateNotNull(translationResults.get(pCollection));
     }
 
-    TranslationMode getTranslationMode() {
-      return translationMode;
+    @Override
+    public <T> Dataset<WindowedValue<T>> getDataset(PCollection<T> pCollection) {
+      return checkStateNotNull(getResult(pCollection).dataset);
     }
 
     @Override
-    public void visitValue(PValue value, TransformHierarchy.Node producer) {
-      if (translationMode.equals(TranslationMode.BATCH)) {
-        if (value instanceof PCollection
-            && ((PCollection) value).isBounded() == PCollection.IsBounded.UNBOUNDED) {
-          LOG.info(
-              "Found unbounded PCollection {}. Switching to streaming execution.", value.getName());
-          translationMode = TranslationMode.STREAMING;
+    public <T> void putDataset(
+        PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset, boolean noCache) {
+      TranslationResult<T> result = getResult(pCollection);
+      if (!noCache && result.dependentTransforms.size() > 1) {
+        LOG.info("Dataset {} will be cached.", result.name);
+        result.dataset = dataset.persist(storageLevel); // use NONE to disable
+      } else {
+        result.dataset = dataset;
+        if (result.dependentTransforms.isEmpty()) {
+          leaves.add(result);
         }
       }
     }
-  }
 
-  // --------------------------------------------------------------------------------------------
-  //  Pipeline utility methods
-  // --------------------------------------------------------------------------------------------
+    @Override
+    public SerializablePipelineOptions getSerializableOptions() {
+      return serializableOptions;
+    }
 
-  /** Get a {@link TransformTranslator} for the given {@link TransformHierarchy.Node}. */
-  protected abstract @Nullable <
-          InT extends PInput, OutT extends POutput, TransformT extends PTransform<InT, OutT>>
-      TransformTranslator<InT, OutT, TransformT> getTransformTranslator(
-          @Nullable TransformT transform);
-
-  /** Apply the given TransformTranslator to the given node. */
-  private <InT extends PInput, OutT extends POutput, TransformT extends PTransform<InT, OutT>>
-      void applyTransformTranslator(
-          TransformHierarchy.Node node,
-          TransformT transform,
-          TransformTranslator<InT, OutT, TransformT> transformTranslator) {
-    // create the applied PTransform on the translationContext
-    AppliedPTransform<InT, OutT, PTransform<InT, OutT>> appliedTransform =
-        (AppliedPTransform) node.toAppliedPTransform(getPipeline());
-    try {
-      transformTranslator.translate(transform, appliedTransform, translationContext);
-    } catch (IOException e) {
-      throw new RuntimeException(e);
+    @Override
+    public SparkSession getSparkSession() {
+      return sparkSession;
     }
   }
 
-  // --------------------------------------------------------------------------------------------
-  //  Pipeline visitor entry point
-  // --------------------------------------------------------------------------------------------
-
   /**
-   * Translates the pipeline by passing this class as a visitor.
+   * {@link PTransformVisitor} that analyses dependencies of supported {@link PTransform
+   * PTransforms} to help identify cache candidates.
    *
-   * @param pipeline The pipeline to be translated
+   * <p>The visitor may throw if a {@link PTransform} is observed that uses unsupported features.
    */
-  public void translate(Pipeline pipeline) {
-    LOG.debug("starting translation of the pipeline using {}", getClass().getName());
-    pipeline.traverseTopologically(this);
+  private class DependencyVisitor extends PTransformVisitor {
+    private final Map<PCollection<?>, TranslationResult<?>> results = new HashMap<>();
+
+    @Override
+    <InT extends PInput, OutT extends POutput> void visit(
+        Node node,
+        PTransform<InT, OutT> transform,
+        TransformTranslator<InT, OutT, PTransform<InT, OutT>> translator) {
+      for (PCollection<?> pOut : node.getOutputs().values()) {
+        results.put(pOut, new TranslationResult<>(pOut));
+        for (Map.Entry<TupleTag<?>, PCollection<?>> entry : node.getInputs().entrySet()) {
+          TranslationResult<?> input = checkStateNotNull(results.get(entry.getValue()));
+          input.dependentTransforms.add(transform);
+        }
+      }
+    }
   }
 
-  // --------------------------------------------------------------------------------------------
-  //  Pipeline Visitor Methods
-  // --------------------------------------------------------------------------------------------
+  /**
+   * An abstract {@link PipelineVisitor} that visits all translatable {@link PTransform} pipeline
+   * nodes of a pipeline with the respective {@link TransformTranslator}.
+   *
+   * <p>The visitor may throw if a {@link PTransform} is observed that uses unsupported features.
+   */
+  private abstract class PTransformVisitor extends PipelineVisitor.Defaults {
 
-  @Override
-  public CompositeBehavior enterCompositeTransform(TransformHierarchy.Node node) {
-    PTransform<PInput, POutput> transform = (PTransform<PInput, POutput>) node.getTransform();
-    TransformTranslator<PInput, POutput, PTransform<PInput, POutput>> transformTranslator =
-        getTransformTranslator(transform);
+    /** Visit the {@link PTransform} with its respective {@link TransformTranslator}. */
+    abstract <InT extends PInput, OutT extends POutput> void visit(
+        Node node,
+        PTransform<InT, OutT> transform,
+        TransformTranslator<InT, OutT, PTransform<InT, OutT>> translator);
 
-    if (transformTranslator != null) {
-      LOG.info("Translating composite: {}", node.getFullName());
-      applyTransformTranslator(node, transform, transformTranslator);
-      return CompositeBehavior.DO_NOT_ENTER_TRANSFORM;
-    } else {
-      return CompositeBehavior.ENTER_TRANSFORM;
+    @Override
+    public final CompositeBehavior enterCompositeTransform(Node node) {
+      PTransform<PInput, POutput> transform = (PTransform<PInput, POutput>) node.getTransform();
+      TransformTranslator<PInput, POutput, PTransform<PInput, POutput>> translator =
+          getTranslator(transform);
+      if (transform != null && translator != null) {
+        visit(node, transform, translator);
+        return DO_NOT_ENTER_TRANSFORM;
+      } else {
+        return ENTER_TRANSFORM;
+      }
+    }
+
+    @Override
+    public final void visitPrimitiveTransform(Node node) {
+      PTransform<PInput, POutput> transform = (PTransform<PInput, POutput>) node.getTransform();
+      if (transform == null || transform.getClass().equals(View.CreatePCollectionView.class)) {
+        return; // ignore, nothing to be translated here

Review Comment:
   cf question on [PipelineTranslatorBatch.java](https://github.com/apache/beam/pull/24009/files#diff-135056b6cae8cfffc97af038ebe5d427e0ce4f58a0c947e510a748e0879cb2b3). Can you put a comment to elaborate the PCollectionView case ?



##########
runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/PipelineTranslator.java:
##########
@@ -17,170 +17,336 @@
  */
 package org.apache.beam.runners.spark.structuredstreaming.translation;
 
+import static org.apache.beam.sdk.Pipeline.PipelineVisitor.CompositeBehavior.DO_NOT_ENTER_TRANSFORM;
+import static org.apache.beam.sdk.Pipeline.PipelineVisitor.CompositeBehavior.ENTER_TRANSFORM;
+import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull;
+import static org.apache.beam.sdk.values.PCollection.IsBounded.UNBOUNDED;
+
 import java.io.IOException;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+import javax.annotation.Nullable;
 import org.apache.beam.runners.core.construction.PTransformTranslation;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
+import org.apache.beam.runners.spark.SparkCommonPipelineOptions;
+import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderProvider;
 import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.Pipeline.PipelineVisitor;
+import org.apache.beam.sdk.annotations.Internal;
+import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.options.StreamingOptions;
 import org.apache.beam.sdk.runners.AppliedPTransform;
-import org.apache.beam.sdk.runners.TransformHierarchy;
+import org.apache.beam.sdk.runners.TransformHierarchy.Node;
 import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.View;
+import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PInput;
 import org.apache.beam.sdk.values.POutput;
 import org.apache.beam.sdk.values.PValue;
-import org.checkerframework.checker.nullness.qual.Nullable;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;
+import org.apache.spark.storage.StorageLevel;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 /**
- * {@link Pipeline.PipelineVisitor} that translates the Beam operators to their Spark counterparts.
- * It also does the pipeline preparation: mode detection, transforms replacement, classpath
- * preparation.
+ * The pipeline translator translates a Beam {@link Pipeline} into a Spark correspondence, that can
+ * then be evaluated.
+ *
+ * <p>The translation involves traversing the hierarchy of a pipeline multiple times:
+ *
+ * <ol>
+ *   <li>Detect if {@link StreamingOptions#setStreaming streaming} mode is required.
+ *   <li>Identify datasets that are repeatedly used as input and should be cached.
+ *   <li>And finally, translate each primitive or composite {@link PTransform} that is {@link
+ *       #getTransformTranslator known} and {@link TransformTranslator#canTranslate supported} into
+ *       its Spark correspondence. If a composite is not supported, it will be expanded further into
+ *       its parts and translated then.
+ * </ol>
  */
-@SuppressWarnings({
-  "nullness" // TODO(https://github.com/apache/beam/issues/20497)
-})
-public abstract class PipelineTranslator extends Pipeline.PipelineVisitor.Defaults {
+@Internal
+public abstract class PipelineTranslator {
   private static final Logger LOG = LoggerFactory.getLogger(PipelineTranslator.class);
-  protected TranslationContext translationContext;
 
-  // --------------------------------------------------------------------------------------------
-  //  Pipeline preparation methods
-  // --------------------------------------------------------------------------------------------
   public static void replaceTransforms(Pipeline pipeline, StreamingOptions options) {
     pipeline.replaceAll(SparkTransformOverrides.getDefaultOverrides(options.isStreaming()));
   }
 
   /**
-   * Visit the pipeline to determine the translation mode (batch/streaming) and update options
-   * accordingly.
+   * Analyse the pipeline to determine if we have to switch to streaming mode for the pipeline
+   * translation and update {@link StreamingOptions} accordingly.
    */
-  public static void detectTranslationMode(Pipeline pipeline, StreamingOptions options) {
-    TranslationModeDetector detector = new TranslationModeDetector();
+  public static void detectStreamingMode(Pipeline pipeline, StreamingOptions options) {
+    StreamingModeDetector detector = new StreamingModeDetector(options.isStreaming());
     pipeline.traverseTopologically(detector);
-    if (detector.getTranslationMode().equals(TranslationMode.STREAMING)) {
-      options.setStreaming(true);
+    options.setStreaming(detector.streaming);
+  }
+
+  /** Returns a {@link TransformTranslator} for the given {@link PTransform} if known. */
+  protected abstract @Nullable <
+          InT extends PInput, OutT extends POutput, TransformT extends PTransform<InT, OutT>>
+      TransformTranslator<InT, OutT, TransformT> getTransformTranslator(TransformT transform);
+
+  /**
+   * Translates a Beam pipeline into its Spark correspondence using the Spark SQL / Dataset API.
+   *
+   * <p>Note, in some cases this involves the early evaluation of some parts of the pipeline. For
+   * example, in order to use a side-input {@link org.apache.beam.sdk.values.PCollectionView
+   * PCollectionView} in a translation the corresponding Spark {@link
+   * org.apache.beam.runners.spark.translation.Dataset Dataset} might have to be collected and
+   * broadcasted to be able to continue with the translation.
+   *
+   * @return The result of the translation is an {@link EvaluationContext} that can trigger the
+   *     evaluation of the Spark pipeline.
+   */
+  public EvaluationContext translate(
+      Pipeline pipeline, SparkSession session, SparkCommonPipelineOptions options) {
+    LOG.debug("starting translation of the pipeline using {}", getClass().getName());
+    DependencyVisitor dependencies = new DependencyVisitor();
+    pipeline.traverseTopologically(dependencies);
+
+    TranslatingVisitor translator = new TranslatingVisitor(session, options, dependencies.results);
+    pipeline.traverseTopologically(translator);
+
+    return new EvaluationContext(translator.leaves, session);
+  }
+
+  /**
+   * The correspondence of a {@link PCollection} as result of translating a {@link PTransform}
+   * including additional metadata (such as name and dependents).
+   */
+  private static final class TranslationResult<T> implements EvaluationContext.NamedDataset<T> {
+    private final String name;
+    private @Nullable Dataset<WindowedValue<T>> dataset = null;
+    private final Set<PTransform<?, ?>> dependentTransforms = new HashSet<>();
+
+    private TranslationResult(PCollection<?> pCol) {
+      this.name = pCol.getName();
+    }
+
+    @Override
+    public String name() {
+      return name;
+    }
+
+    @Override
+    public @Nullable Dataset<WindowedValue<T>> dataset() {
+      return dataset;
     }
   }
 
-  /** The translation mode of the Beam Pipeline. */
-  private enum TranslationMode {
+  /** Shared, mutable state during the translation of a pipeline and omitted afterwards. */
+  interface TranslationState extends EncoderProvider {
+    <T> Dataset<WindowedValue<T>> getDataset(PCollection<T> pCollection);
+
+    <T> void putDataset(
+        PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset, boolean noCache);
 
-    /** Uses the batch mode. */
-    BATCH,
+    default <T> void putDataset(PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset) {
+      putDataset(pCollection, dataset, false);
+    }
 
-    /** Uses the streaming mode. */
-    STREAMING
+    SerializablePipelineOptions getSerializableOptions();
+
+    SparkSession getSparkSession();
   }
 
-  /** Traverses the Pipeline to determine the {@link TranslationMode} for this pipeline. */
-  private static class TranslationModeDetector extends Pipeline.PipelineVisitor.Defaults {
-    private static final Logger LOG = LoggerFactory.getLogger(TranslationModeDetector.class);
+  /**
+   * {@link PTransformVisitor} that translates supported {@link PTransform PTransforms} into their
+   * Spark correspondence.
+   *
+   * <p>Note, in some cases this involves the early evaluation of some parts of the pipeline. For
+   * example, in order to use a side-input {@link org.apache.beam.sdk.values.PCollectionView
+   * PCollectionView} in a translation the corresponding Spark {@link
+   * org.apache.beam.runners.spark.translation.Dataset Dataset} might have to be collected and
+   * broadcasted.
+   */
+  private class TranslatingVisitor extends PTransformVisitor implements TranslationState {
+    private final Map<PCollection<?>, TranslationResult<?>> translationResults;
+    private final Map<Coder<?>, ExpressionEncoder<?>> encoders;
+    private final SparkSession sparkSession;
+    private final SerializablePipelineOptions serializableOptions;
+    private final StorageLevel storageLevel;
+
+    private final Set<TranslationResult<?>> leaves;
+
+    public TranslatingVisitor(
+        SparkSession sparkSession,
+        SparkCommonPipelineOptions options,
+        Map<PCollection<?>, TranslationResult<?>> translationResults) {
+      this.sparkSession = sparkSession;
+      this.translationResults = translationResults;
+      this.serializableOptions = new SerializablePipelineOptions(options);
+      this.storageLevel = StorageLevel.fromString(options.getStorageLevel());
+      this.encoders = new HashMap<>();
+      this.leaves = new HashSet<>();
+    }
 
-    private TranslationMode translationMode;
+    @Override
+    <InT extends PInput, OutT extends POutput> void visit(
+        Node node,
+        PTransform<InT, OutT> transform,
+        TransformTranslator<InT, OutT, PTransform<InT, OutT>> translator) {
+
+      AppliedPTransform<InT, OutT, PTransform<InT, OutT>> appliedTransform =
+          (AppliedPTransform) node.toAppliedPTransform(getPipeline());
+      try {
+        LOG.info(
+            "Translating {}: {}",
+            node.isCompositeNode() ? "composite" : "primitive",
+            node.getFullName());
+        translator.translate(transform, appliedTransform, this);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
 
-    TranslationModeDetector(TranslationMode defaultMode) {
-      this.translationMode = defaultMode;
+    @Override
+    public <T> Encoder<T> encoderOf(Coder<T> coder, Factory<T> factory) {
+      return (Encoder<T>) encoders.computeIfAbsent(coder, (Factory) factory);
     }
 
-    TranslationModeDetector() {
-      this(TranslationMode.BATCH);
+    private <T> TranslationResult<T> getResult(PCollection<T> pCollection) {
+      return (TranslationResult<T>) checkStateNotNull(translationResults.get(pCollection));
     }
 
-    TranslationMode getTranslationMode() {
-      return translationMode;
+    @Override
+    public <T> Dataset<WindowedValue<T>> getDataset(PCollection<T> pCollection) {
+      return checkStateNotNull(getResult(pCollection).dataset);
     }
 
     @Override
-    public void visitValue(PValue value, TransformHierarchy.Node producer) {
-      if (translationMode.equals(TranslationMode.BATCH)) {
-        if (value instanceof PCollection
-            && ((PCollection) value).isBounded() == PCollection.IsBounded.UNBOUNDED) {
-          LOG.info(
-              "Found unbounded PCollection {}. Switching to streaming execution.", value.getName());
-          translationMode = TranslationMode.STREAMING;
+    public <T> void putDataset(
+        PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset, boolean noCache) {
+      TranslationResult<T> result = getResult(pCollection);
+      if (!noCache && result.dependentTransforms.size() > 1) {

Review Comment:
   nit: for symmetry with else branch: use `!result.dependentTransforms.isEmpty()`



##########
runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/PipelineTranslator.java:
##########
@@ -17,170 +17,336 @@
  */
 package org.apache.beam.runners.spark.structuredstreaming.translation;
 
+import static org.apache.beam.sdk.Pipeline.PipelineVisitor.CompositeBehavior.DO_NOT_ENTER_TRANSFORM;
+import static org.apache.beam.sdk.Pipeline.PipelineVisitor.CompositeBehavior.ENTER_TRANSFORM;
+import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull;
+import static org.apache.beam.sdk.values.PCollection.IsBounded.UNBOUNDED;
+
 import java.io.IOException;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+import javax.annotation.Nullable;
 import org.apache.beam.runners.core.construction.PTransformTranslation;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
+import org.apache.beam.runners.spark.SparkCommonPipelineOptions;
+import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderProvider;
 import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.Pipeline.PipelineVisitor;
+import org.apache.beam.sdk.annotations.Internal;
+import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.options.StreamingOptions;
 import org.apache.beam.sdk.runners.AppliedPTransform;
-import org.apache.beam.sdk.runners.TransformHierarchy;
+import org.apache.beam.sdk.runners.TransformHierarchy.Node;
 import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.View;
+import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PInput;
 import org.apache.beam.sdk.values.POutput;
 import org.apache.beam.sdk.values.PValue;
-import org.checkerframework.checker.nullness.qual.Nullable;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;
+import org.apache.spark.storage.StorageLevel;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 /**
- * {@link Pipeline.PipelineVisitor} that translates the Beam operators to their Spark counterparts.
- * It also does the pipeline preparation: mode detection, transforms replacement, classpath
- * preparation.
+ * The pipeline translator translates a Beam {@link Pipeline} into a Spark correspondence, that can
+ * then be evaluated.
+ *
+ * <p>The translation involves traversing the hierarchy of a pipeline multiple times:
+ *
+ * <ol>
+ *   <li>Detect if {@link StreamingOptions#setStreaming streaming} mode is required.
+ *   <li>Identify datasets that are repeatedly used as input and should be cached.
+ *   <li>And finally, translate each primitive or composite {@link PTransform} that is {@link
+ *       #getTransformTranslator known} and {@link TransformTranslator#canTranslate supported} into
+ *       its Spark correspondence. If a composite is not supported, it will be expanded further into
+ *       its parts and translated then.
+ * </ol>
  */
-@SuppressWarnings({
-  "nullness" // TODO(https://github.com/apache/beam/issues/20497)
-})
-public abstract class PipelineTranslator extends Pipeline.PipelineVisitor.Defaults {
+@Internal
+public abstract class PipelineTranslator {
   private static final Logger LOG = LoggerFactory.getLogger(PipelineTranslator.class);
-  protected TranslationContext translationContext;
 
-  // --------------------------------------------------------------------------------------------
-  //  Pipeline preparation methods
-  // --------------------------------------------------------------------------------------------
   public static void replaceTransforms(Pipeline pipeline, StreamingOptions options) {
     pipeline.replaceAll(SparkTransformOverrides.getDefaultOverrides(options.isStreaming()));
   }
 
   /**
-   * Visit the pipeline to determine the translation mode (batch/streaming) and update options
-   * accordingly.
+   * Analyse the pipeline to determine if we have to switch to streaming mode for the pipeline
+   * translation and update {@link StreamingOptions} accordingly.
    */
-  public static void detectTranslationMode(Pipeline pipeline, StreamingOptions options) {
-    TranslationModeDetector detector = new TranslationModeDetector();
+  public static void detectStreamingMode(Pipeline pipeline, StreamingOptions options) {
+    StreamingModeDetector detector = new StreamingModeDetector(options.isStreaming());
     pipeline.traverseTopologically(detector);
-    if (detector.getTranslationMode().equals(TranslationMode.STREAMING)) {
-      options.setStreaming(true);
+    options.setStreaming(detector.streaming);
+  }
+
+  /** Returns a {@link TransformTranslator} for the given {@link PTransform} if known. */
+  protected abstract @Nullable <
+          InT extends PInput, OutT extends POutput, TransformT extends PTransform<InT, OutT>>
+      TransformTranslator<InT, OutT, TransformT> getTransformTranslator(TransformT transform);
+
+  /**
+   * Translates a Beam pipeline into its Spark correspondence using the Spark SQL / Dataset API.
+   *
+   * <p>Note, in some cases this involves the early evaluation of some parts of the pipeline. For
+   * example, in order to use a side-input {@link org.apache.beam.sdk.values.PCollectionView
+   * PCollectionView} in a translation the corresponding Spark {@link
+   * org.apache.beam.runners.spark.translation.Dataset Dataset} might have to be collected and
+   * broadcasted to be able to continue with the translation.
+   *
+   * @return The result of the translation is an {@link EvaluationContext} that can trigger the
+   *     evaluation of the Spark pipeline.
+   */
+  public EvaluationContext translate(
+      Pipeline pipeline, SparkSession session, SparkCommonPipelineOptions options) {
+    LOG.debug("starting translation of the pipeline using {}", getClass().getName());
+    DependencyVisitor dependencies = new DependencyVisitor();
+    pipeline.traverseTopologically(dependencies);
+
+    TranslatingVisitor translator = new TranslatingVisitor(session, options, dependencies.results);
+    pipeline.traverseTopologically(translator);
+
+    return new EvaluationContext(translator.leaves, session);
+  }
+
+  /**
+   * The correspondence of a {@link PCollection} as result of translating a {@link PTransform}
+   * including additional metadata (such as name and dependents).
+   */
+  private static final class TranslationResult<T> implements EvaluationContext.NamedDataset<T> {
+    private final String name;
+    private @Nullable Dataset<WindowedValue<T>> dataset = null;
+    private final Set<PTransform<?, ?>> dependentTransforms = new HashSet<>();
+
+    private TranslationResult(PCollection<?> pCol) {
+      this.name = pCol.getName();
+    }
+
+    @Override
+    public String name() {
+      return name;
+    }
+
+    @Override
+    public @Nullable Dataset<WindowedValue<T>> dataset() {
+      return dataset;
     }
   }
 
-  /** The translation mode of the Beam Pipeline. */
-  private enum TranslationMode {
+  /** Shared, mutable state during the translation of a pipeline and omitted afterwards. */
+  interface TranslationState extends EncoderProvider {
+    <T> Dataset<WindowedValue<T>> getDataset(PCollection<T> pCollection);
+
+    <T> void putDataset(
+        PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset, boolean noCache);
 
-    /** Uses the batch mode. */
-    BATCH,
+    default <T> void putDataset(PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset) {
+      putDataset(pCollection, dataset, false);
+    }
 
-    /** Uses the streaming mode. */
-    STREAMING
+    SerializablePipelineOptions getSerializableOptions();
+
+    SparkSession getSparkSession();
   }
 
-  /** Traverses the Pipeline to determine the {@link TranslationMode} for this pipeline. */
-  private static class TranslationModeDetector extends Pipeline.PipelineVisitor.Defaults {
-    private static final Logger LOG = LoggerFactory.getLogger(TranslationModeDetector.class);
+  /**
+   * {@link PTransformVisitor} that translates supported {@link PTransform PTransforms} into their
+   * Spark correspondence.
+   *
+   * <p>Note, in some cases this involves the early evaluation of some parts of the pipeline. For
+   * example, in order to use a side-input {@link org.apache.beam.sdk.values.PCollectionView
+   * PCollectionView} in a translation the corresponding Spark {@link
+   * org.apache.beam.runners.spark.translation.Dataset Dataset} might have to be collected and
+   * broadcasted.
+   */
+  private class TranslatingVisitor extends PTransformVisitor implements TranslationState {
+    private final Map<PCollection<?>, TranslationResult<?>> translationResults;
+    private final Map<Coder<?>, ExpressionEncoder<?>> encoders;
+    private final SparkSession sparkSession;
+    private final SerializablePipelineOptions serializableOptions;
+    private final StorageLevel storageLevel;
+
+    private final Set<TranslationResult<?>> leaves;
+
+    public TranslatingVisitor(
+        SparkSession sparkSession,
+        SparkCommonPipelineOptions options,
+        Map<PCollection<?>, TranslationResult<?>> translationResults) {
+      this.sparkSession = sparkSession;
+      this.translationResults = translationResults;
+      this.serializableOptions = new SerializablePipelineOptions(options);
+      this.storageLevel = StorageLevel.fromString(options.getStorageLevel());
+      this.encoders = new HashMap<>();
+      this.leaves = new HashSet<>();
+    }
 
-    private TranslationMode translationMode;
+    @Override
+    <InT extends PInput, OutT extends POutput> void visit(
+        Node node,
+        PTransform<InT, OutT> transform,
+        TransformTranslator<InT, OutT, PTransform<InT, OutT>> translator) {
+
+      AppliedPTransform<InT, OutT, PTransform<InT, OutT>> appliedTransform =
+          (AppliedPTransform) node.toAppliedPTransform(getPipeline());
+      try {
+        LOG.info(
+            "Translating {}: {}",
+            node.isCompositeNode() ? "composite" : "primitive",
+            node.getFullName());
+        translator.translate(transform, appliedTransform, this);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
 
-    TranslationModeDetector(TranslationMode defaultMode) {
-      this.translationMode = defaultMode;
+    @Override
+    public <T> Encoder<T> encoderOf(Coder<T> coder, Factory<T> factory) {
+      return (Encoder<T>) encoders.computeIfAbsent(coder, (Factory) factory);
     }
 
-    TranslationModeDetector() {
-      this(TranslationMode.BATCH);
+    private <T> TranslationResult<T> getResult(PCollection<T> pCollection) {
+      return (TranslationResult<T>) checkStateNotNull(translationResults.get(pCollection));
     }
 
-    TranslationMode getTranslationMode() {
-      return translationMode;
+    @Override
+    public <T> Dataset<WindowedValue<T>> getDataset(PCollection<T> pCollection) {
+      return checkStateNotNull(getResult(pCollection).dataset);
     }
 
     @Override
-    public void visitValue(PValue value, TransformHierarchy.Node producer) {
-      if (translationMode.equals(TranslationMode.BATCH)) {
-        if (value instanceof PCollection
-            && ((PCollection) value).isBounded() == PCollection.IsBounded.UNBOUNDED) {
-          LOG.info(
-              "Found unbounded PCollection {}. Switching to streaming execution.", value.getName());
-          translationMode = TranslationMode.STREAMING;
+    public <T> void putDataset(
+        PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset, boolean noCache) {
+      TranslationResult<T> result = getResult(pCollection);
+      if (!noCache && result.dependentTransforms.size() > 1) {
+        LOG.info("Dataset {} will be cached.", result.name);
+        result.dataset = dataset.persist(storageLevel); // use NONE to disable
+      } else {
+        result.dataset = dataset;
+        if (result.dependentTransforms.isEmpty()) {
+          leaves.add(result);
         }
       }
     }
-  }
 
-  // --------------------------------------------------------------------------------------------
-  //  Pipeline utility methods
-  // --------------------------------------------------------------------------------------------
+    @Override
+    public SerializablePipelineOptions getSerializableOptions() {
+      return serializableOptions;
+    }
 
-  /** Get a {@link TransformTranslator} for the given {@link TransformHierarchy.Node}. */
-  protected abstract @Nullable <
-          InT extends PInput, OutT extends POutput, TransformT extends PTransform<InT, OutT>>
-      TransformTranslator<InT, OutT, TransformT> getTransformTranslator(
-          @Nullable TransformT transform);
-
-  /** Apply the given TransformTranslator to the given node. */
-  private <InT extends PInput, OutT extends POutput, TransformT extends PTransform<InT, OutT>>
-      void applyTransformTranslator(
-          TransformHierarchy.Node node,
-          TransformT transform,
-          TransformTranslator<InT, OutT, TransformT> transformTranslator) {
-    // create the applied PTransform on the translationContext
-    AppliedPTransform<InT, OutT, PTransform<InT, OutT>> appliedTransform =
-        (AppliedPTransform) node.toAppliedPTransform(getPipeline());
-    try {
-      transformTranslator.translate(transform, appliedTransform, translationContext);
-    } catch (IOException e) {
-      throw new RuntimeException(e);
+    @Override
+    public SparkSession getSparkSession() {
+      return sparkSession;
     }
   }
 
-  // --------------------------------------------------------------------------------------------
-  //  Pipeline visitor entry point
-  // --------------------------------------------------------------------------------------------
-
   /**
-   * Translates the pipeline by passing this class as a visitor.
+   * {@link PTransformVisitor} that analyses dependencies of supported {@link PTransform
+   * PTransforms} to help identify cache candidates.
    *
-   * @param pipeline The pipeline to be translated
+   * <p>The visitor may throw if a {@link PTransform} is observed that uses unsupported features.
    */
-  public void translate(Pipeline pipeline) {
-    LOG.debug("starting translation of the pipeline using {}", getClass().getName());
-    pipeline.traverseTopologically(this);
+  private class DependencyVisitor extends PTransformVisitor {
+    private final Map<PCollection<?>, TranslationResult<?>> results = new HashMap<>();
+
+    @Override
+    <InT extends PInput, OutT extends POutput> void visit(
+        Node node,
+        PTransform<InT, OutT> transform,
+        TransformTranslator<InT, OutT, PTransform<InT, OutT>> translator) {
+      for (PCollection<?> pOut : node.getOutputs().values()) {
+        results.put(pOut, new TranslationResult<>(pOut));
+        for (Map.Entry<TupleTag<?>, PCollection<?>> entry : node.getInputs().entrySet()) {
+          TranslationResult<?> input = checkStateNotNull(results.get(entry.getValue()));
+          input.dependentTransforms.add(transform);
+        }
+      }
+    }
   }
 
-  // --------------------------------------------------------------------------------------------
-  //  Pipeline Visitor Methods
-  // --------------------------------------------------------------------------------------------
+  /**
+   * An abstract {@link PipelineVisitor} that visits all translatable {@link PTransform} pipeline
+   * nodes of a pipeline with the respective {@link TransformTranslator}.
+   *
+   * <p>The visitor may throw if a {@link PTransform} is observed that uses unsupported features.
+   */
+  private abstract class PTransformVisitor extends PipelineVisitor.Defaults {
 
-  @Override
-  public CompositeBehavior enterCompositeTransform(TransformHierarchy.Node node) {
-    PTransform<PInput, POutput> transform = (PTransform<PInput, POutput>) node.getTransform();
-    TransformTranslator<PInput, POutput, PTransform<PInput, POutput>> transformTranslator =
-        getTransformTranslator(transform);
+    /** Visit the {@link PTransform} with its respective {@link TransformTranslator}. */
+    abstract <InT extends PInput, OutT extends POutput> void visit(
+        Node node,
+        PTransform<InT, OutT> transform,
+        TransformTranslator<InT, OutT, PTransform<InT, OutT>> translator);
 
-    if (transformTranslator != null) {
-      LOG.info("Translating composite: {}", node.getFullName());
-      applyTransformTranslator(node, transform, transformTranslator);
-      return CompositeBehavior.DO_NOT_ENTER_TRANSFORM;
-    } else {
-      return CompositeBehavior.ENTER_TRANSFORM;
+    @Override
+    public final CompositeBehavior enterCompositeTransform(Node node) {
+      PTransform<PInput, POutput> transform = (PTransform<PInput, POutput>) node.getTransform();
+      TransformTranslator<PInput, POutput, PTransform<PInput, POutput>> translator =
+          getTranslator(transform);
+      if (transform != null && translator != null) {
+        visit(node, transform, translator);
+        return DO_NOT_ENTER_TRANSFORM;
+      } else {
+        return ENTER_TRANSFORM;
+      }
+    }
+
+    @Override
+    public final void visitPrimitiveTransform(Node node) {
+      PTransform<PInput, POutput> transform = (PTransform<PInput, POutput>) node.getTransform();
+      if (transform == null || transform.getClass().equals(View.CreatePCollectionView.class)) {
+        return; // ignore, nothing to be translated here
+      }
+      TransformTranslator<PInput, POutput, PTransform<PInput, POutput>> translator =
+          getTranslator(transform);
+      if (translator == null) {
+        String urn = PTransformTranslation.urnForTransform(transform);
+        throw new UnsupportedOperationException("Transform " + urn + " is not supported.");
+      }
+      visit(node, transform, translator);
     }
-  }
 
-  @Override
-  public void visitPrimitiveTransform(TransformHierarchy.Node node) {
-    LOG.info("Translating primitive: {}", node.getFullName());
-    // get the transformation corresponding to the node we are
-    // currently visiting and translate it into its Spark alternative.
-    PTransform<PInput, POutput> transform = (PTransform<PInput, POutput>) node.getTransform();
-    TransformTranslator<PInput, POutput, PTransform<PInput, POutput>> transformTranslator =
-        getTransformTranslator(transform);
-
-    if (transformTranslator == null) {
-      String transformUrn = PTransformTranslation.urnForTransform(node.getTransform());
-      throw new UnsupportedOperationException(
-          "The transform " + transformUrn + " is currently not supported.");
-    }
-    applyTransformTranslator(node, transform, transformTranslator);
+    /** {@link TransformTranslator} for {@link PTransform} if translation is known and supported. */
+    private @Nullable TransformTranslator<PInput, POutput, PTransform<PInput, POutput>>
+        getTranslator(@Nullable PTransform<PInput, POutput> transform) {
+      if (transform == null) {
+        return null;
+      }
+      TransformTranslator<PInput, POutput, PTransform<PInput, POutput>> translator =
+          getTransformTranslator(transform);
+      return translator != null && translator.canTranslate(transform) ? translator : null;
+    }
   }
 
-  public TranslationContext getTranslationContext() {
-    return translationContext;
+  /**
+   * Traverse the pipeline to check for unbounded {@link PCollection PCollections} that would
+   * require streaming mode unless streaming mode is already enabled.
+   */
+  private static class StreamingModeDetector extends PipelineVisitor.Defaults {
+    private boolean streaming;
+
+    StreamingModeDetector(boolean streaming) {
+      this.streaming = streaming;
+    }
+
+    @Override
+    public CompositeBehavior enterCompositeTransform(Node node) {
+      return streaming ? DO_NOT_ENTER_TRANSFORM : ENTER_TRANSFORM; // stop if in streaming mode

Review Comment:
   good improvement to stop traversing the hierarchy if the streaming mode is already forced !



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] mosche commented on a diff in pull request #24009: [Spark dataset runner] Cache datasets if used multiple times

Posted by GitBox <gi...@apache.org>.
mosche commented on code in PR #24009:
URL: https://github.com/apache/beam/pull/24009#discussion_r1025008863


##########
runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/PipelineTranslator.java:
##########
@@ -17,170 +17,336 @@
  */
 package org.apache.beam.runners.spark.structuredstreaming.translation;
 
+import static org.apache.beam.sdk.Pipeline.PipelineVisitor.CompositeBehavior.DO_NOT_ENTER_TRANSFORM;
+import static org.apache.beam.sdk.Pipeline.PipelineVisitor.CompositeBehavior.ENTER_TRANSFORM;
+import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull;
+import static org.apache.beam.sdk.values.PCollection.IsBounded.UNBOUNDED;
+
 import java.io.IOException;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+import javax.annotation.Nullable;
 import org.apache.beam.runners.core.construction.PTransformTranslation;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
+import org.apache.beam.runners.spark.SparkCommonPipelineOptions;
+import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderProvider;
 import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.Pipeline.PipelineVisitor;
+import org.apache.beam.sdk.annotations.Internal;
+import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.options.StreamingOptions;
 import org.apache.beam.sdk.runners.AppliedPTransform;
-import org.apache.beam.sdk.runners.TransformHierarchy;
+import org.apache.beam.sdk.runners.TransformHierarchy.Node;
 import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.View;
+import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PInput;
 import org.apache.beam.sdk.values.POutput;
 import org.apache.beam.sdk.values.PValue;
-import org.checkerframework.checker.nullness.qual.Nullable;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;
+import org.apache.spark.storage.StorageLevel;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 /**
- * {@link Pipeline.PipelineVisitor} that translates the Beam operators to their Spark counterparts.
- * It also does the pipeline preparation: mode detection, transforms replacement, classpath
- * preparation.
+ * The pipeline translator translates a Beam {@link Pipeline} into a Spark correspondence, that can
+ * then be evaluated.
+ *
+ * <p>The translation involves traversing the hierarchy of a pipeline multiple times:
+ *
+ * <ol>
+ *   <li>Detect if {@link StreamingOptions#setStreaming streaming} mode is required.
+ *   <li>Identify datasets that are repeatedly used as input and should be cached.
+ *   <li>And finally, translate each primitive or composite {@link PTransform} that is {@link
+ *       #getTransformTranslator known} and {@link TransformTranslator#canTranslate supported} into
+ *       its Spark correspondence. If a composite is not supported, it will be expanded further into
+ *       its parts and translated then.
+ * </ol>
  */
-@SuppressWarnings({
-  "nullness" // TODO(https://github.com/apache/beam/issues/20497)
-})
-public abstract class PipelineTranslator extends Pipeline.PipelineVisitor.Defaults {
+@Internal
+public abstract class PipelineTranslator {
   private static final Logger LOG = LoggerFactory.getLogger(PipelineTranslator.class);
-  protected TranslationContext translationContext;
 
-  // --------------------------------------------------------------------------------------------
-  //  Pipeline preparation methods
-  // --------------------------------------------------------------------------------------------
   public static void replaceTransforms(Pipeline pipeline, StreamingOptions options) {
     pipeline.replaceAll(SparkTransformOverrides.getDefaultOverrides(options.isStreaming()));
   }
 
   /**
-   * Visit the pipeline to determine the translation mode (batch/streaming) and update options
-   * accordingly.
+   * Analyse the pipeline to determine if we have to switch to streaming mode for the pipeline
+   * translation and update {@link StreamingOptions} accordingly.
    */
-  public static void detectTranslationMode(Pipeline pipeline, StreamingOptions options) {
-    TranslationModeDetector detector = new TranslationModeDetector();
+  public static void detectStreamingMode(Pipeline pipeline, StreamingOptions options) {
+    StreamingModeDetector detector = new StreamingModeDetector(options.isStreaming());
     pipeline.traverseTopologically(detector);
-    if (detector.getTranslationMode().equals(TranslationMode.STREAMING)) {
-      options.setStreaming(true);
+    options.setStreaming(detector.streaming);
+  }
+
+  /** Returns a {@link TransformTranslator} for the given {@link PTransform} if known. */
+  protected abstract @Nullable <
+          InT extends PInput, OutT extends POutput, TransformT extends PTransform<InT, OutT>>
+      TransformTranslator<InT, OutT, TransformT> getTransformTranslator(TransformT transform);
+
+  /**
+   * Translates a Beam pipeline into its Spark correspondence using the Spark SQL / Dataset API.
+   *
+   * <p>Note, in some cases this involves the early evaluation of some parts of the pipeline. For
+   * example, in order to use a side-input {@link org.apache.beam.sdk.values.PCollectionView
+   * PCollectionView} in a translation the corresponding Spark {@link
+   * org.apache.beam.runners.spark.translation.Dataset Dataset} might have to be collected and
+   * broadcasted to be able to continue with the translation.
+   *
+   * @return The result of the translation is an {@link EvaluationContext} that can trigger the
+   *     evaluation of the Spark pipeline.
+   */
+  public EvaluationContext translate(
+      Pipeline pipeline, SparkSession session, SparkCommonPipelineOptions options) {
+    LOG.debug("starting translation of the pipeline using {}", getClass().getName());
+    DependencyVisitor dependencies = new DependencyVisitor();
+    pipeline.traverseTopologically(dependencies);
+
+    TranslatingVisitor translator = new TranslatingVisitor(session, options, dependencies.results);
+    pipeline.traverseTopologically(translator);
+
+    return new EvaluationContext(translator.leaves, session);
+  }
+
+  /**
+   * The correspondence of a {@link PCollection} as result of translating a {@link PTransform}
+   * including additional metadata (such as name and dependents).
+   */
+  private static final class TranslationResult<T> implements EvaluationContext.NamedDataset<T> {
+    private final String name;
+    private @Nullable Dataset<WindowedValue<T>> dataset = null;
+    private final Set<PTransform<?, ?>> dependentTransforms = new HashSet<>();
+
+    private TranslationResult(PCollection<?> pCol) {
+      this.name = pCol.getName();
+    }
+
+    @Override
+    public String name() {
+      return name;
+    }
+
+    @Override
+    public @Nullable Dataset<WindowedValue<T>> dataset() {
+      return dataset;
     }
   }
 
-  /** The translation mode of the Beam Pipeline. */
-  private enum TranslationMode {
+  /** Shared, mutable state during the translation of a pipeline and omitted afterwards. */
+  interface TranslationState extends EncoderProvider {

Review Comment:
   👍 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] echauchot commented on pull request #24009: [Spark dataset runner] Cache datasets if used multiple times

Posted by GitBox <gi...@apache.org>.
echauchot commented on PR #24009:
URL: https://github.com/apache/beam/pull/24009#issuecomment-1318326424

   > > @mosche @aromanenko-dev as promised, I'm reviewing
   > 
   > Almost done, should finish today.
   
   Finishing the review of dependency visitor and I'm done


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] mosche commented on pull request #24009: [Spark dataset runner] Cache datasets if used multiple times

Posted by GitBox <gi...@apache.org>.
mosche commented on PR #24009:
URL: https://github.com/apache/beam/pull/24009#issuecomment-1305742906

   Run Spark StructuredStreaming ValidatesRunner


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] echauchot commented on pull request #24009: [Spark dataset runner] Cache datasets if used multiple times

Posted by GitBox <gi...@apache.org>.
echauchot commented on PR #24009:
URL: https://github.com/apache/beam/pull/24009#issuecomment-1313730401

   @mosche @aromanenko-dev as promised, I'm reviewing


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] mosche commented on a diff in pull request #24009: [Spark dataset runner] Cache datasets if used multiple times

Posted by GitBox <gi...@apache.org>.
mosche commented on code in PR #24009:
URL: https://github.com/apache/beam/pull/24009#discussion_r1026211619


##########
runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/PipelineTranslatorBatch.java:
##########
@@ -81,27 +80,13 @@ public class PipelineTranslatorBatch extends PipelineTranslator {
 
     TRANSFORM_TRANSLATORS.put(
         SplittableParDo.PrimitiveBoundedRead.class, new ReadSourceTranslatorBatch<>());
-

Review Comment:
   This is unrelated to #24035, see comment below
   > PCollectionView view translation just stored the same Spark dataset (reference!) again for a different PTransform. That's obviously problematic for caching as we're not gathering metadata on that dataset in a single place. Also, beam runner guidelines discourage translation of PCollectionView, they are just there for legacy reasons.
   
   In terms of prep for #24035, that's mostly the introduction of `TranslationResult` to capture all kinds of metadata / context on a specific Spark dataset. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] mosche commented on a diff in pull request #24009: [Spark dataset runner] Cache datasets if used multiple times

Posted by GitBox <gi...@apache.org>.
mosche commented on code in PR #24009:
URL: https://github.com/apache/beam/pull/24009#discussion_r1025011738


##########
runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/PipelineTranslator.java:
##########
@@ -17,170 +17,336 @@
  */
 package org.apache.beam.runners.spark.structuredstreaming.translation;
 
+import static org.apache.beam.sdk.Pipeline.PipelineVisitor.CompositeBehavior.DO_NOT_ENTER_TRANSFORM;
+import static org.apache.beam.sdk.Pipeline.PipelineVisitor.CompositeBehavior.ENTER_TRANSFORM;
+import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull;
+import static org.apache.beam.sdk.values.PCollection.IsBounded.UNBOUNDED;
+
 import java.io.IOException;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+import javax.annotation.Nullable;
 import org.apache.beam.runners.core.construction.PTransformTranslation;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
+import org.apache.beam.runners.spark.SparkCommonPipelineOptions;
+import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderProvider;
 import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.Pipeline.PipelineVisitor;
+import org.apache.beam.sdk.annotations.Internal;
+import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.options.StreamingOptions;
 import org.apache.beam.sdk.runners.AppliedPTransform;
-import org.apache.beam.sdk.runners.TransformHierarchy;
+import org.apache.beam.sdk.runners.TransformHierarchy.Node;
 import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.View;
+import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PInput;
 import org.apache.beam.sdk.values.POutput;
 import org.apache.beam.sdk.values.PValue;
-import org.checkerframework.checker.nullness.qual.Nullable;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;
+import org.apache.spark.storage.StorageLevel;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 /**
- * {@link Pipeline.PipelineVisitor} that translates the Beam operators to their Spark counterparts.
- * It also does the pipeline preparation: mode detection, transforms replacement, classpath
- * preparation.
+ * The pipeline translator translates a Beam {@link Pipeline} into a Spark correspondence, that can
+ * then be evaluated.
+ *
+ * <p>The translation involves traversing the hierarchy of a pipeline multiple times:
+ *
+ * <ol>
+ *   <li>Detect if {@link StreamingOptions#setStreaming streaming} mode is required.
+ *   <li>Identify datasets that are repeatedly used as input and should be cached.
+ *   <li>And finally, translate each primitive or composite {@link PTransform} that is {@link
+ *       #getTransformTranslator known} and {@link TransformTranslator#canTranslate supported} into
+ *       its Spark correspondence. If a composite is not supported, it will be expanded further into
+ *       its parts and translated then.
+ * </ol>
  */
-@SuppressWarnings({
-  "nullness" // TODO(https://github.com/apache/beam/issues/20497)
-})
-public abstract class PipelineTranslator extends Pipeline.PipelineVisitor.Defaults {
+@Internal
+public abstract class PipelineTranslator {
   private static final Logger LOG = LoggerFactory.getLogger(PipelineTranslator.class);
-  protected TranslationContext translationContext;
 
-  // --------------------------------------------------------------------------------------------
-  //  Pipeline preparation methods
-  // --------------------------------------------------------------------------------------------
   public static void replaceTransforms(Pipeline pipeline, StreamingOptions options) {
     pipeline.replaceAll(SparkTransformOverrides.getDefaultOverrides(options.isStreaming()));
   }
 
   /**
-   * Visit the pipeline to determine the translation mode (batch/streaming) and update options
-   * accordingly.
+   * Analyse the pipeline to determine if we have to switch to streaming mode for the pipeline
+   * translation and update {@link StreamingOptions} accordingly.
    */
-  public static void detectTranslationMode(Pipeline pipeline, StreamingOptions options) {
-    TranslationModeDetector detector = new TranslationModeDetector();
+  public static void detectStreamingMode(Pipeline pipeline, StreamingOptions options) {
+    StreamingModeDetector detector = new StreamingModeDetector(options.isStreaming());
     pipeline.traverseTopologically(detector);
-    if (detector.getTranslationMode().equals(TranslationMode.STREAMING)) {
-      options.setStreaming(true);
+    options.setStreaming(detector.streaming);
+  }
+
+  /** Returns a {@link TransformTranslator} for the given {@link PTransform} if known. */
+  protected abstract @Nullable <
+          InT extends PInput, OutT extends POutput, TransformT extends PTransform<InT, OutT>>
+      TransformTranslator<InT, OutT, TransformT> getTransformTranslator(TransformT transform);
+
+  /**
+   * Translates a Beam pipeline into its Spark correspondence using the Spark SQL / Dataset API.
+   *
+   * <p>Note, in some cases this involves the early evaluation of some parts of the pipeline. For
+   * example, in order to use a side-input {@link org.apache.beam.sdk.values.PCollectionView
+   * PCollectionView} in a translation the corresponding Spark {@link
+   * org.apache.beam.runners.spark.translation.Dataset Dataset} might have to be collected and
+   * broadcasted to be able to continue with the translation.
+   *
+   * @return The result of the translation is an {@link EvaluationContext} that can trigger the
+   *     evaluation of the Spark pipeline.
+   */
+  public EvaluationContext translate(
+      Pipeline pipeline, SparkSession session, SparkCommonPipelineOptions options) {
+    LOG.debug("starting translation of the pipeline using {}", getClass().getName());
+    DependencyVisitor dependencies = new DependencyVisitor();
+    pipeline.traverseTopologically(dependencies);
+
+    TranslatingVisitor translator = new TranslatingVisitor(session, options, dependencies.results);
+    pipeline.traverseTopologically(translator);
+
+    return new EvaluationContext(translator.leaves, session);
+  }
+
+  /**
+   * The correspondence of a {@link PCollection} as result of translating a {@link PTransform}
+   * including additional metadata (such as name and dependents).
+   */
+  private static final class TranslationResult<T> implements EvaluationContext.NamedDataset<T> {
+    private final String name;
+    private @Nullable Dataset<WindowedValue<T>> dataset = null;
+    private final Set<PTransform<?, ?>> dependentTransforms = new HashSet<>();
+
+    private TranslationResult(PCollection<?> pCol) {
+      this.name = pCol.getName();
+    }
+
+    @Override
+    public String name() {
+      return name;
+    }
+
+    @Override
+    public @Nullable Dataset<WindowedValue<T>> dataset() {
+      return dataset;
     }
   }
 
-  /** The translation mode of the Beam Pipeline. */
-  private enum TranslationMode {
+  /** Shared, mutable state during the translation of a pipeline and omitted afterwards. */
+  interface TranslationState extends EncoderProvider {
+    <T> Dataset<WindowedValue<T>> getDataset(PCollection<T> pCollection);
+
+    <T> void putDataset(
+        PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset, boolean noCache);
 
-    /** Uses the batch mode. */
-    BATCH,
+    default <T> void putDataset(PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset) {
+      putDataset(pCollection, dataset, false);
+    }
 
-    /** Uses the streaming mode. */
-    STREAMING
+    SerializablePipelineOptions getSerializableOptions();
+
+    SparkSession getSparkSession();
   }
 
-  /** Traverses the Pipeline to determine the {@link TranslationMode} for this pipeline. */
-  private static class TranslationModeDetector extends Pipeline.PipelineVisitor.Defaults {
-    private static final Logger LOG = LoggerFactory.getLogger(TranslationModeDetector.class);
+  /**
+   * {@link PTransformVisitor} that translates supported {@link PTransform PTransforms} into their
+   * Spark correspondence.
+   *
+   * <p>Note, in some cases this involves the early evaluation of some parts of the pipeline. For
+   * example, in order to use a side-input {@link org.apache.beam.sdk.values.PCollectionView
+   * PCollectionView} in a translation the corresponding Spark {@link
+   * org.apache.beam.runners.spark.translation.Dataset Dataset} might have to be collected and
+   * broadcasted.
+   */
+  private class TranslatingVisitor extends PTransformVisitor implements TranslationState {
+    private final Map<PCollection<?>, TranslationResult<?>> translationResults;
+    private final Map<Coder<?>, ExpressionEncoder<?>> encoders;
+    private final SparkSession sparkSession;
+    private final SerializablePipelineOptions serializableOptions;
+    private final StorageLevel storageLevel;
+
+    private final Set<TranslationResult<?>> leaves;
+
+    public TranslatingVisitor(
+        SparkSession sparkSession,
+        SparkCommonPipelineOptions options,
+        Map<PCollection<?>, TranslationResult<?>> translationResults) {
+      this.sparkSession = sparkSession;
+      this.translationResults = translationResults;
+      this.serializableOptions = new SerializablePipelineOptions(options);
+      this.storageLevel = StorageLevel.fromString(options.getStorageLevel());
+      this.encoders = new HashMap<>();
+      this.leaves = new HashSet<>();
+    }
 
-    private TranslationMode translationMode;
+    @Override
+    <InT extends PInput, OutT extends POutput> void visit(
+        Node node,
+        PTransform<InT, OutT> transform,
+        TransformTranslator<InT, OutT, PTransform<InT, OutT>> translator) {
+
+      AppliedPTransform<InT, OutT, PTransform<InT, OutT>> appliedTransform =
+          (AppliedPTransform) node.toAppliedPTransform(getPipeline());
+      try {
+        LOG.info(
+            "Translating {}: {}",
+            node.isCompositeNode() ? "composite" : "primitive",
+            node.getFullName());
+        translator.translate(transform, appliedTransform, this);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
 
-    TranslationModeDetector(TranslationMode defaultMode) {
-      this.translationMode = defaultMode;
+    @Override
+    public <T> Encoder<T> encoderOf(Coder<T> coder, Factory<T> factory) {
+      return (Encoder<T>) encoders.computeIfAbsent(coder, (Factory) factory);
     }
 
-    TranslationModeDetector() {
-      this(TranslationMode.BATCH);
+    private <T> TranslationResult<T> getResult(PCollection<T> pCollection) {
+      return (TranslationResult<T>) checkStateNotNull(translationResults.get(pCollection));
     }
 
-    TranslationMode getTranslationMode() {
-      return translationMode;
+    @Override
+    public <T> Dataset<WindowedValue<T>> getDataset(PCollection<T> pCollection) {
+      return checkStateNotNull(getResult(pCollection).dataset);
     }
 
     @Override
-    public void visitValue(PValue value, TransformHierarchy.Node producer) {
-      if (translationMode.equals(TranslationMode.BATCH)) {
-        if (value instanceof PCollection
-            && ((PCollection) value).isBounded() == PCollection.IsBounded.UNBOUNDED) {
-          LOG.info(
-              "Found unbounded PCollection {}. Switching to streaming execution.", value.getName());
-          translationMode = TranslationMode.STREAMING;
+    public <T> void putDataset(
+        PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset, boolean noCache) {
+      TranslationResult<T> result = getResult(pCollection);
+      if (!noCache && result.dependentTransforms.size() > 1) {

Review Comment:
   that would be wrong, there's three cases: empty, 1 and multiple



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] mosche commented on pull request #24009: [Spark dataset runner] Cache datasets if used multiple times

Posted by GitBox <gi...@apache.org>.
mosche commented on PR #24009:
URL: https://github.com/apache/beam/pull/24009#issuecomment-1305685184

   Spark Runner Tpcds Tests


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] mosche commented on pull request #24009: [Spark dataset runner] Cache datasets if used multiple times

Posted by GitBox <gi...@apache.org>.
mosche commented on PR #24009:
URL: https://github.com/apache/beam/pull/24009#issuecomment-1319758326

   @echauchot I've addressed your comments, thanks for the feedback :) Pls have another look.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] mosche commented on pull request #24009: [Spark dataset runner] Cache datasets if used multiple times

Posted by GitBox <gi...@apache.org>.
mosche commented on PR #24009:
URL: https://github.com/apache/beam/pull/24009#issuecomment-1305821227

   Spark Runner Tpcds Tests


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] mosche commented on pull request #24009: [Spark dataset runner] Cache datasets if used multiple times

Posted by GitBox <gi...@apache.org>.
mosche commented on PR #24009:
URL: https://github.com/apache/beam/pull/24009#issuecomment-1320337223

   Thx @echauchot :)


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] echauchot commented on a diff in pull request #24009: [Spark dataset runner] Cache datasets if used multiple times

Posted by GitBox <gi...@apache.org>.
echauchot commented on code in PR #24009:
URL: https://github.com/apache/beam/pull/24009#discussion_r1026481858


##########
runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/TransformTranslator.java:
##########
@@ -118,69 +126,53 @@ public <T> PCollection<T> getOutput(TupleTag<T> tag) {
       return pc;
     }
 
-    public Map<TupleTag<?>, PCollection<?>> getOutputs() {
-      return transform.getOutputs();
-    }
-
     public AppliedPTransform<InT, OutT, PTransform<InT, OutT>> getCurrentTransform() {
       return transform;
     }
 
+    @Override
     public <T> Dataset<WindowedValue<T>> getDataset(PCollection<T> pCollection) {
-      return cxt.getDataset(pCollection);
+      return state.getDataset(pCollection);
     }
 
-    public <T> void putDataset(PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset) {
-      cxt.putDataset(pCollection, dataset);
+    @Override
+    public <T> void putDataset(
+        PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset, boolean noCache) {

Review Comment:
   yes and it was even more difficult to read with noCache = false :smile: 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] echauchot commented on a diff in pull request #24009: [Spark dataset runner] Cache datasets if used multiple times

Posted by GitBox <gi...@apache.org>.
echauchot commented on code in PR #24009:
URL: https://github.com/apache/beam/pull/24009#discussion_r1026482935


##########
runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java:
##########
@@ -220,8 +225,9 @@ private SideInputBroadcast createBroadcastSideInputs(
       Coder<WindowedValue<?>> windowedValueCoder =
           (Coder<WindowedValue<?>>)
               (Coder<?>) WindowedValue.getFullCoder(pc.getCoder(), windowCoder);
-      Dataset<WindowedValue<?>> broadcastSet = context.getSideInputDataset(sideInput);
-      List<WindowedValue<?>> valuesList = broadcastSet.collectAsList();
+      Dataset<WindowedValue<?>> broadcastSet = context.getDataset((PCollection) pc);
+      List<WindowedValue<?>> valuesList =

Review Comment:
   makes sense



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] mosche commented on a diff in pull request #24009: [Spark dataset runner] Cache datasets if used multiple times

Posted by GitBox <gi...@apache.org>.
mosche commented on code in PR #24009:
URL: https://github.com/apache/beam/pull/24009#discussion_r1025017277


##########
runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/PipelineTranslator.java:
##########
@@ -17,170 +17,336 @@
  */
 package org.apache.beam.runners.spark.structuredstreaming.translation;
 
+import static org.apache.beam.sdk.Pipeline.PipelineVisitor.CompositeBehavior.DO_NOT_ENTER_TRANSFORM;
+import static org.apache.beam.sdk.Pipeline.PipelineVisitor.CompositeBehavior.ENTER_TRANSFORM;
+import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull;
+import static org.apache.beam.sdk.values.PCollection.IsBounded.UNBOUNDED;
+
 import java.io.IOException;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+import javax.annotation.Nullable;
 import org.apache.beam.runners.core.construction.PTransformTranslation;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
+import org.apache.beam.runners.spark.SparkCommonPipelineOptions;
+import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderProvider;
 import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.Pipeline.PipelineVisitor;
+import org.apache.beam.sdk.annotations.Internal;
+import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.options.StreamingOptions;
 import org.apache.beam.sdk.runners.AppliedPTransform;
-import org.apache.beam.sdk.runners.TransformHierarchy;
+import org.apache.beam.sdk.runners.TransformHierarchy.Node;
 import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.View;
+import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PInput;
 import org.apache.beam.sdk.values.POutput;
 import org.apache.beam.sdk.values.PValue;
-import org.checkerframework.checker.nullness.qual.Nullable;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;
+import org.apache.spark.storage.StorageLevel;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 /**
- * {@link Pipeline.PipelineVisitor} that translates the Beam operators to their Spark counterparts.
- * It also does the pipeline preparation: mode detection, transforms replacement, classpath
- * preparation.
+ * The pipeline translator translates a Beam {@link Pipeline} into a Spark correspondence, that can
+ * then be evaluated.
+ *
+ * <p>The translation involves traversing the hierarchy of a pipeline multiple times:
+ *
+ * <ol>
+ *   <li>Detect if {@link StreamingOptions#setStreaming streaming} mode is required.
+ *   <li>Identify datasets that are repeatedly used as input and should be cached.
+ *   <li>And finally, translate each primitive or composite {@link PTransform} that is {@link
+ *       #getTransformTranslator known} and {@link TransformTranslator#canTranslate supported} into
+ *       its Spark correspondence. If a composite is not supported, it will be expanded further into
+ *       its parts and translated then.
+ * </ol>
  */
-@SuppressWarnings({
-  "nullness" // TODO(https://github.com/apache/beam/issues/20497)
-})
-public abstract class PipelineTranslator extends Pipeline.PipelineVisitor.Defaults {
+@Internal
+public abstract class PipelineTranslator {
   private static final Logger LOG = LoggerFactory.getLogger(PipelineTranslator.class);
-  protected TranslationContext translationContext;
 
-  // --------------------------------------------------------------------------------------------
-  //  Pipeline preparation methods
-  // --------------------------------------------------------------------------------------------
   public static void replaceTransforms(Pipeline pipeline, StreamingOptions options) {
     pipeline.replaceAll(SparkTransformOverrides.getDefaultOverrides(options.isStreaming()));
   }
 
   /**
-   * Visit the pipeline to determine the translation mode (batch/streaming) and update options
-   * accordingly.
+   * Analyse the pipeline to determine if we have to switch to streaming mode for the pipeline
+   * translation and update {@link StreamingOptions} accordingly.
    */
-  public static void detectTranslationMode(Pipeline pipeline, StreamingOptions options) {
-    TranslationModeDetector detector = new TranslationModeDetector();
+  public static void detectStreamingMode(Pipeline pipeline, StreamingOptions options) {
+    StreamingModeDetector detector = new StreamingModeDetector(options.isStreaming());
     pipeline.traverseTopologically(detector);
-    if (detector.getTranslationMode().equals(TranslationMode.STREAMING)) {
-      options.setStreaming(true);
+    options.setStreaming(detector.streaming);
+  }
+
+  /** Returns a {@link TransformTranslator} for the given {@link PTransform} if known. */
+  protected abstract @Nullable <
+          InT extends PInput, OutT extends POutput, TransformT extends PTransform<InT, OutT>>
+      TransformTranslator<InT, OutT, TransformT> getTransformTranslator(TransformT transform);
+
+  /**
+   * Translates a Beam pipeline into its Spark correspondence using the Spark SQL / Dataset API.
+   *
+   * <p>Note, in some cases this involves the early evaluation of some parts of the pipeline. For
+   * example, in order to use a side-input {@link org.apache.beam.sdk.values.PCollectionView
+   * PCollectionView} in a translation the corresponding Spark {@link
+   * org.apache.beam.runners.spark.translation.Dataset Dataset} might have to be collected and
+   * broadcasted to be able to continue with the translation.
+   *
+   * @return The result of the translation is an {@link EvaluationContext} that can trigger the
+   *     evaluation of the Spark pipeline.
+   */
+  public EvaluationContext translate(
+      Pipeline pipeline, SparkSession session, SparkCommonPipelineOptions options) {
+    LOG.debug("starting translation of the pipeline using {}", getClass().getName());
+    DependencyVisitor dependencies = new DependencyVisitor();
+    pipeline.traverseTopologically(dependencies);
+
+    TranslatingVisitor translator = new TranslatingVisitor(session, options, dependencies.results);
+    pipeline.traverseTopologically(translator);
+
+    return new EvaluationContext(translator.leaves, session);
+  }
+
+  /**
+   * The correspondence of a {@link PCollection} as result of translating a {@link PTransform}
+   * including additional metadata (such as name and dependents).
+   */
+  private static final class TranslationResult<T> implements EvaluationContext.NamedDataset<T> {
+    private final String name;
+    private @Nullable Dataset<WindowedValue<T>> dataset = null;
+    private final Set<PTransform<?, ?>> dependentTransforms = new HashSet<>();
+
+    private TranslationResult(PCollection<?> pCol) {
+      this.name = pCol.getName();
+    }
+
+    @Override
+    public String name() {
+      return name;
+    }
+
+    @Override
+    public @Nullable Dataset<WindowedValue<T>> dataset() {
+      return dataset;
     }
   }
 
-  /** The translation mode of the Beam Pipeline. */
-  private enum TranslationMode {
+  /** Shared, mutable state during the translation of a pipeline and omitted afterwards. */
+  interface TranslationState extends EncoderProvider {
+    <T> Dataset<WindowedValue<T>> getDataset(PCollection<T> pCollection);
+
+    <T> void putDataset(
+        PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset, boolean noCache);
 
-    /** Uses the batch mode. */
-    BATCH,
+    default <T> void putDataset(PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset) {
+      putDataset(pCollection, dataset, false);
+    }
 
-    /** Uses the streaming mode. */
-    STREAMING
+    SerializablePipelineOptions getSerializableOptions();
+
+    SparkSession getSparkSession();
   }
 
-  /** Traverses the Pipeline to determine the {@link TranslationMode} for this pipeline. */
-  private static class TranslationModeDetector extends Pipeline.PipelineVisitor.Defaults {
-    private static final Logger LOG = LoggerFactory.getLogger(TranslationModeDetector.class);
+  /**
+   * {@link PTransformVisitor} that translates supported {@link PTransform PTransforms} into their
+   * Spark correspondence.
+   *
+   * <p>Note, in some cases this involves the early evaluation of some parts of the pipeline. For
+   * example, in order to use a side-input {@link org.apache.beam.sdk.values.PCollectionView
+   * PCollectionView} in a translation the corresponding Spark {@link
+   * org.apache.beam.runners.spark.translation.Dataset Dataset} might have to be collected and
+   * broadcasted.
+   */
+  private class TranslatingVisitor extends PTransformVisitor implements TranslationState {
+    private final Map<PCollection<?>, TranslationResult<?>> translationResults;
+    private final Map<Coder<?>, ExpressionEncoder<?>> encoders;
+    private final SparkSession sparkSession;
+    private final SerializablePipelineOptions serializableOptions;
+    private final StorageLevel storageLevel;
+
+    private final Set<TranslationResult<?>> leaves;
+
+    public TranslatingVisitor(
+        SparkSession sparkSession,
+        SparkCommonPipelineOptions options,
+        Map<PCollection<?>, TranslationResult<?>> translationResults) {
+      this.sparkSession = sparkSession;
+      this.translationResults = translationResults;
+      this.serializableOptions = new SerializablePipelineOptions(options);
+      this.storageLevel = StorageLevel.fromString(options.getStorageLevel());
+      this.encoders = new HashMap<>();
+      this.leaves = new HashSet<>();
+    }
 
-    private TranslationMode translationMode;
+    @Override
+    <InT extends PInput, OutT extends POutput> void visit(
+        Node node,
+        PTransform<InT, OutT> transform,
+        TransformTranslator<InT, OutT, PTransform<InT, OutT>> translator) {
+
+      AppliedPTransform<InT, OutT, PTransform<InT, OutT>> appliedTransform =
+          (AppliedPTransform) node.toAppliedPTransform(getPipeline());
+      try {
+        LOG.info(
+            "Translating {}: {}",
+            node.isCompositeNode() ? "composite" : "primitive",
+            node.getFullName());
+        translator.translate(transform, appliedTransform, this);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
 
-    TranslationModeDetector(TranslationMode defaultMode) {
-      this.translationMode = defaultMode;
+    @Override
+    public <T> Encoder<T> encoderOf(Coder<T> coder, Factory<T> factory) {
+      return (Encoder<T>) encoders.computeIfAbsent(coder, (Factory) factory);
     }
 
-    TranslationModeDetector() {
-      this(TranslationMode.BATCH);
+    private <T> TranslationResult<T> getResult(PCollection<T> pCollection) {
+      return (TranslationResult<T>) checkStateNotNull(translationResults.get(pCollection));
     }
 
-    TranslationMode getTranslationMode() {
-      return translationMode;
+    @Override
+    public <T> Dataset<WindowedValue<T>> getDataset(PCollection<T> pCollection) {
+      return checkStateNotNull(getResult(pCollection).dataset);
     }
 
     @Override
-    public void visitValue(PValue value, TransformHierarchy.Node producer) {
-      if (translationMode.equals(TranslationMode.BATCH)) {
-        if (value instanceof PCollection
-            && ((PCollection) value).isBounded() == PCollection.IsBounded.UNBOUNDED) {
-          LOG.info(
-              "Found unbounded PCollection {}. Switching to streaming execution.", value.getName());
-          translationMode = TranslationMode.STREAMING;
+    public <T> void putDataset(
+        PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset, boolean noCache) {
+      TranslationResult<T> result = getResult(pCollection);
+      if (!noCache && result.dependentTransforms.size() > 1) {
+        LOG.info("Dataset {} will be cached.", result.name);
+        result.dataset = dataset.persist(storageLevel); // use NONE to disable
+      } else {
+        result.dataset = dataset;
+        if (result.dependentTransforms.isEmpty()) {
+          leaves.add(result);
         }
       }
     }
-  }
 
-  // --------------------------------------------------------------------------------------------
-  //  Pipeline utility methods
-  // --------------------------------------------------------------------------------------------
+    @Override
+    public SerializablePipelineOptions getSerializableOptions() {
+      return serializableOptions;
+    }
 
-  /** Get a {@link TransformTranslator} for the given {@link TransformHierarchy.Node}. */
-  protected abstract @Nullable <
-          InT extends PInput, OutT extends POutput, TransformT extends PTransform<InT, OutT>>
-      TransformTranslator<InT, OutT, TransformT> getTransformTranslator(
-          @Nullable TransformT transform);
-
-  /** Apply the given TransformTranslator to the given node. */
-  private <InT extends PInput, OutT extends POutput, TransformT extends PTransform<InT, OutT>>
-      void applyTransformTranslator(
-          TransformHierarchy.Node node,
-          TransformT transform,
-          TransformTranslator<InT, OutT, TransformT> transformTranslator) {
-    // create the applied PTransform on the translationContext
-    AppliedPTransform<InT, OutT, PTransform<InT, OutT>> appliedTransform =
-        (AppliedPTransform) node.toAppliedPTransform(getPipeline());
-    try {
-      transformTranslator.translate(transform, appliedTransform, translationContext);
-    } catch (IOException e) {
-      throw new RuntimeException(e);
+    @Override
+    public SparkSession getSparkSession() {
+      return sparkSession;
     }
   }
 
-  // --------------------------------------------------------------------------------------------
-  //  Pipeline visitor entry point
-  // --------------------------------------------------------------------------------------------
-
   /**
-   * Translates the pipeline by passing this class as a visitor.
+   * {@link PTransformVisitor} that analyses dependencies of supported {@link PTransform
+   * PTransforms} to help identify cache candidates.
    *
-   * @param pipeline The pipeline to be translated
+   * <p>The visitor may throw if a {@link PTransform} is observed that uses unsupported features.
    */
-  public void translate(Pipeline pipeline) {
-    LOG.debug("starting translation of the pipeline using {}", getClass().getName());
-    pipeline.traverseTopologically(this);
+  private class DependencyVisitor extends PTransformVisitor {
+    private final Map<PCollection<?>, TranslationResult<?>> results = new HashMap<>();
+
+    @Override
+    <InT extends PInput, OutT extends POutput> void visit(
+        Node node,
+        PTransform<InT, OutT> transform,
+        TransformTranslator<InT, OutT, PTransform<InT, OutT>> translator) {
+      for (PCollection<?> pOut : node.getOutputs().values()) {
+        results.put(pOut, new TranslationResult<>(pOut));
+        for (Map.Entry<TupleTag<?>, PCollection<?>> entry : node.getInputs().entrySet()) {
+          TranslationResult<?> input = checkStateNotNull(results.get(entry.getValue()));
+          input.dependentTransforms.add(transform);
+        }
+      }
+    }
   }
 
-  // --------------------------------------------------------------------------------------------
-  //  Pipeline Visitor Methods
-  // --------------------------------------------------------------------------------------------
+  /**
+   * An abstract {@link PipelineVisitor} that visits all translatable {@link PTransform} pipeline
+   * nodes of a pipeline with the respective {@link TransformTranslator}.
+   *
+   * <p>The visitor may throw if a {@link PTransform} is observed that uses unsupported features.
+   */
+  private abstract class PTransformVisitor extends PipelineVisitor.Defaults {
 
-  @Override
-  public CompositeBehavior enterCompositeTransform(TransformHierarchy.Node node) {
-    PTransform<PInput, POutput> transform = (PTransform<PInput, POutput>) node.getTransform();
-    TransformTranslator<PInput, POutput, PTransform<PInput, POutput>> transformTranslator =
-        getTransformTranslator(transform);
+    /** Visit the {@link PTransform} with its respective {@link TransformTranslator}. */
+    abstract <InT extends PInput, OutT extends POutput> void visit(
+        Node node,
+        PTransform<InT, OutT> transform,
+        TransformTranslator<InT, OutT, PTransform<InT, OutT>> translator);
 
-    if (transformTranslator != null) {
-      LOG.info("Translating composite: {}", node.getFullName());
-      applyTransformTranslator(node, transform, transformTranslator);
-      return CompositeBehavior.DO_NOT_ENTER_TRANSFORM;
-    } else {
-      return CompositeBehavior.ENTER_TRANSFORM;
+    @Override
+    public final CompositeBehavior enterCompositeTransform(Node node) {
+      PTransform<PInput, POutput> transform = (PTransform<PInput, POutput>) node.getTransform();
+      TransformTranslator<PInput, POutput, PTransform<PInput, POutput>> translator =
+          getTranslator(transform);
+      if (transform != null && translator != null) {
+        visit(node, transform, translator);
+        return DO_NOT_ENTER_TRANSFORM;
+      } else {
+        return ENTER_TRANSFORM;
+      }
+    }
+
+    @Override
+    public final void visitPrimitiveTransform(Node node) {
+      PTransform<PInput, POutput> transform = (PTransform<PInput, POutput>) node.getTransform();
+      if (transform == null || transform.getClass().equals(View.CreatePCollectionView.class)) {
+        return; // ignore, nothing to be translated here

Review Comment:
   yes, PCollectionView view translation just stored the same Spark dataset (reference!) again for a different PTransform. that's obviously problematic for caching as we're not gathering metadata on that dataset in a single place. also, beam runner guidelines discourage translation of PCollectionView, they are just there for legacy reasons.
   
   > ignore, nothing to be translated here, views are handled on the consumer side
   
   is this sufficient as a comment, or what would you suggest?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] echauchot commented on a diff in pull request #24009: [Spark dataset runner] Cache datasets if used multiple times

Posted by GitBox <gi...@apache.org>.
echauchot commented on code in PR #24009:
URL: https://github.com/apache/beam/pull/24009#discussion_r1026487652


##########
runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderProvider.java:
##########
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.spark.structuredstreaming.translation.helpers;
+
+import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.kvEncoder;
+
+import java.util.function.Function;
+import org.apache.beam.sdk.annotations.Internal;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.values.KV;
+import org.apache.spark.sql.Encoder;
+
+@Internal
+public interface EncoderProvider {
+  interface Factory<T> extends Function<Coder<T>, Encoder<T>> {
+    Factory<?> INSTANCE = EncoderHelpers::encoderFor;
+  }
+
+  <T> Encoder<T> encoderOf(Coder<T> coder, Factory<T> factory);
+
+  default <T> Encoder<T> encoderOf(Coder<T> coder) {
+    return coder instanceof KvCoder
+        ? (Encoder<T>) kvEncoderOf((KvCoder) coder)
+        : encoderOf(coder, encoderFactory());
+  }
+
+  default <K, V> Encoder<KV<K, V>> kvEncoderOf(KvCoder<K, V> coder) {
+    return encoderOf(coder, c -> kvEncoder(keyEncoderOf(coder), valueEncoderOf(coder)));
+  }
+
+  default <K, V> Encoder<K> keyEncoderOf(KvCoder<K, V> coder) {
+    return encoderOf(coder.getKeyCoder(), encoderFactory());
+  }
+
+  default <K, V> Encoder<V> valueEncoderOf(KvCoder<K, V> coder) {
+    return encoderOf(coder.getValueCoder(), encoderFactory());
+  }
+
+  default <T> Factory<T> encoderFactory() {
+    return (Factory<T>) Factory.INSTANCE;

Review Comment:
   I was mislead by the interface. I did not see you were creating a singleton here. 
   Fair enough for keeping it of course.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] echauchot commented on pull request #24009: [Spark dataset runner] Cache datasets if used multiple times

Posted by GitBox <gi...@apache.org>.
echauchot commented on PR #24009:
URL: https://github.com/apache/beam/pull/24009#issuecomment-1320085050

   Run Spark StructuredStreaming ValidatesRunner


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] echauchot merged pull request #24009: [Spark dataset runner] Cache datasets if used multiple times

Posted by GitBox <gi...@apache.org>.
echauchot merged PR #24009:
URL: https://github.com/apache/beam/pull/24009


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] mosche commented on a diff in pull request #24009: [Spark dataset runner] Cache datasets if used multiple times

Posted by GitBox <gi...@apache.org>.
mosche commented on code in PR #24009:
URL: https://github.com/apache/beam/pull/24009#discussion_r1025017277


##########
runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/PipelineTranslator.java:
##########
@@ -17,170 +17,336 @@
  */
 package org.apache.beam.runners.spark.structuredstreaming.translation;
 
+import static org.apache.beam.sdk.Pipeline.PipelineVisitor.CompositeBehavior.DO_NOT_ENTER_TRANSFORM;
+import static org.apache.beam.sdk.Pipeline.PipelineVisitor.CompositeBehavior.ENTER_TRANSFORM;
+import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull;
+import static org.apache.beam.sdk.values.PCollection.IsBounded.UNBOUNDED;
+
 import java.io.IOException;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+import javax.annotation.Nullable;
 import org.apache.beam.runners.core.construction.PTransformTranslation;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
+import org.apache.beam.runners.spark.SparkCommonPipelineOptions;
+import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderProvider;
 import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.Pipeline.PipelineVisitor;
+import org.apache.beam.sdk.annotations.Internal;
+import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.options.StreamingOptions;
 import org.apache.beam.sdk.runners.AppliedPTransform;
-import org.apache.beam.sdk.runners.TransformHierarchy;
+import org.apache.beam.sdk.runners.TransformHierarchy.Node;
 import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.View;
+import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PInput;
 import org.apache.beam.sdk.values.POutput;
 import org.apache.beam.sdk.values.PValue;
-import org.checkerframework.checker.nullness.qual.Nullable;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;
+import org.apache.spark.storage.StorageLevel;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 /**
- * {@link Pipeline.PipelineVisitor} that translates the Beam operators to their Spark counterparts.
- * It also does the pipeline preparation: mode detection, transforms replacement, classpath
- * preparation.
+ * The pipeline translator translates a Beam {@link Pipeline} into a Spark correspondence, that can
+ * then be evaluated.
+ *
+ * <p>The translation involves traversing the hierarchy of a pipeline multiple times:
+ *
+ * <ol>
+ *   <li>Detect if {@link StreamingOptions#setStreaming streaming} mode is required.
+ *   <li>Identify datasets that are repeatedly used as input and should be cached.
+ *   <li>And finally, translate each primitive or composite {@link PTransform} that is {@link
+ *       #getTransformTranslator known} and {@link TransformTranslator#canTranslate supported} into
+ *       its Spark correspondence. If a composite is not supported, it will be expanded further into
+ *       its parts and translated then.
+ * </ol>
  */
-@SuppressWarnings({
-  "nullness" // TODO(https://github.com/apache/beam/issues/20497)
-})
-public abstract class PipelineTranslator extends Pipeline.PipelineVisitor.Defaults {
+@Internal
+public abstract class PipelineTranslator {
   private static final Logger LOG = LoggerFactory.getLogger(PipelineTranslator.class);
-  protected TranslationContext translationContext;
 
-  // --------------------------------------------------------------------------------------------
-  //  Pipeline preparation methods
-  // --------------------------------------------------------------------------------------------
   public static void replaceTransforms(Pipeline pipeline, StreamingOptions options) {
     pipeline.replaceAll(SparkTransformOverrides.getDefaultOverrides(options.isStreaming()));
   }
 
   /**
-   * Visit the pipeline to determine the translation mode (batch/streaming) and update options
-   * accordingly.
+   * Analyse the pipeline to determine if we have to switch to streaming mode for the pipeline
+   * translation and update {@link StreamingOptions} accordingly.
    */
-  public static void detectTranslationMode(Pipeline pipeline, StreamingOptions options) {
-    TranslationModeDetector detector = new TranslationModeDetector();
+  public static void detectStreamingMode(Pipeline pipeline, StreamingOptions options) {
+    StreamingModeDetector detector = new StreamingModeDetector(options.isStreaming());
     pipeline.traverseTopologically(detector);
-    if (detector.getTranslationMode().equals(TranslationMode.STREAMING)) {
-      options.setStreaming(true);
+    options.setStreaming(detector.streaming);
+  }
+
+  /** Returns a {@link TransformTranslator} for the given {@link PTransform} if known. */
+  protected abstract @Nullable <
+          InT extends PInput, OutT extends POutput, TransformT extends PTransform<InT, OutT>>
+      TransformTranslator<InT, OutT, TransformT> getTransformTranslator(TransformT transform);
+
+  /**
+   * Translates a Beam pipeline into its Spark correspondence using the Spark SQL / Dataset API.
+   *
+   * <p>Note, in some cases this involves the early evaluation of some parts of the pipeline. For
+   * example, in order to use a side-input {@link org.apache.beam.sdk.values.PCollectionView
+   * PCollectionView} in a translation the corresponding Spark {@link
+   * org.apache.beam.runners.spark.translation.Dataset Dataset} might have to be collected and
+   * broadcasted to be able to continue with the translation.
+   *
+   * @return The result of the translation is an {@link EvaluationContext} that can trigger the
+   *     evaluation of the Spark pipeline.
+   */
+  public EvaluationContext translate(
+      Pipeline pipeline, SparkSession session, SparkCommonPipelineOptions options) {
+    LOG.debug("starting translation of the pipeline using {}", getClass().getName());
+    DependencyVisitor dependencies = new DependencyVisitor();
+    pipeline.traverseTopologically(dependencies);
+
+    TranslatingVisitor translator = new TranslatingVisitor(session, options, dependencies.results);
+    pipeline.traverseTopologically(translator);
+
+    return new EvaluationContext(translator.leaves, session);
+  }
+
+  /**
+   * The correspondence of a {@link PCollection} as result of translating a {@link PTransform}
+   * including additional metadata (such as name and dependents).
+   */
+  private static final class TranslationResult<T> implements EvaluationContext.NamedDataset<T> {
+    private final String name;
+    private @Nullable Dataset<WindowedValue<T>> dataset = null;
+    private final Set<PTransform<?, ?>> dependentTransforms = new HashSet<>();
+
+    private TranslationResult(PCollection<?> pCol) {
+      this.name = pCol.getName();
+    }
+
+    @Override
+    public String name() {
+      return name;
+    }
+
+    @Override
+    public @Nullable Dataset<WindowedValue<T>> dataset() {
+      return dataset;
     }
   }
 
-  /** The translation mode of the Beam Pipeline. */
-  private enum TranslationMode {
+  /** Shared, mutable state during the translation of a pipeline and omitted afterwards. */
+  interface TranslationState extends EncoderProvider {
+    <T> Dataset<WindowedValue<T>> getDataset(PCollection<T> pCollection);
+
+    <T> void putDataset(
+        PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset, boolean noCache);
 
-    /** Uses the batch mode. */
-    BATCH,
+    default <T> void putDataset(PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset) {
+      putDataset(pCollection, dataset, false);
+    }
 
-    /** Uses the streaming mode. */
-    STREAMING
+    SerializablePipelineOptions getSerializableOptions();
+
+    SparkSession getSparkSession();
   }
 
-  /** Traverses the Pipeline to determine the {@link TranslationMode} for this pipeline. */
-  private static class TranslationModeDetector extends Pipeline.PipelineVisitor.Defaults {
-    private static final Logger LOG = LoggerFactory.getLogger(TranslationModeDetector.class);
+  /**
+   * {@link PTransformVisitor} that translates supported {@link PTransform PTransforms} into their
+   * Spark correspondence.
+   *
+   * <p>Note, in some cases this involves the early evaluation of some parts of the pipeline. For
+   * example, in order to use a side-input {@link org.apache.beam.sdk.values.PCollectionView
+   * PCollectionView} in a translation the corresponding Spark {@link
+   * org.apache.beam.runners.spark.translation.Dataset Dataset} might have to be collected and
+   * broadcasted.
+   */
+  private class TranslatingVisitor extends PTransformVisitor implements TranslationState {
+    private final Map<PCollection<?>, TranslationResult<?>> translationResults;
+    private final Map<Coder<?>, ExpressionEncoder<?>> encoders;
+    private final SparkSession sparkSession;
+    private final SerializablePipelineOptions serializableOptions;
+    private final StorageLevel storageLevel;
+
+    private final Set<TranslationResult<?>> leaves;
+
+    public TranslatingVisitor(
+        SparkSession sparkSession,
+        SparkCommonPipelineOptions options,
+        Map<PCollection<?>, TranslationResult<?>> translationResults) {
+      this.sparkSession = sparkSession;
+      this.translationResults = translationResults;
+      this.serializableOptions = new SerializablePipelineOptions(options);
+      this.storageLevel = StorageLevel.fromString(options.getStorageLevel());
+      this.encoders = new HashMap<>();
+      this.leaves = new HashSet<>();
+    }
 
-    private TranslationMode translationMode;
+    @Override
+    <InT extends PInput, OutT extends POutput> void visit(
+        Node node,
+        PTransform<InT, OutT> transform,
+        TransformTranslator<InT, OutT, PTransform<InT, OutT>> translator) {
+
+      AppliedPTransform<InT, OutT, PTransform<InT, OutT>> appliedTransform =
+          (AppliedPTransform) node.toAppliedPTransform(getPipeline());
+      try {
+        LOG.info(
+            "Translating {}: {}",
+            node.isCompositeNode() ? "composite" : "primitive",
+            node.getFullName());
+        translator.translate(transform, appliedTransform, this);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
 
-    TranslationModeDetector(TranslationMode defaultMode) {
-      this.translationMode = defaultMode;
+    @Override
+    public <T> Encoder<T> encoderOf(Coder<T> coder, Factory<T> factory) {
+      return (Encoder<T>) encoders.computeIfAbsent(coder, (Factory) factory);
     }
 
-    TranslationModeDetector() {
-      this(TranslationMode.BATCH);
+    private <T> TranslationResult<T> getResult(PCollection<T> pCollection) {
+      return (TranslationResult<T>) checkStateNotNull(translationResults.get(pCollection));
     }
 
-    TranslationMode getTranslationMode() {
-      return translationMode;
+    @Override
+    public <T> Dataset<WindowedValue<T>> getDataset(PCollection<T> pCollection) {
+      return checkStateNotNull(getResult(pCollection).dataset);
     }
 
     @Override
-    public void visitValue(PValue value, TransformHierarchy.Node producer) {
-      if (translationMode.equals(TranslationMode.BATCH)) {
-        if (value instanceof PCollection
-            && ((PCollection) value).isBounded() == PCollection.IsBounded.UNBOUNDED) {
-          LOG.info(
-              "Found unbounded PCollection {}. Switching to streaming execution.", value.getName());
-          translationMode = TranslationMode.STREAMING;
+    public <T> void putDataset(
+        PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset, boolean noCache) {
+      TranslationResult<T> result = getResult(pCollection);
+      if (!noCache && result.dependentTransforms.size() > 1) {
+        LOG.info("Dataset {} will be cached.", result.name);
+        result.dataset = dataset.persist(storageLevel); // use NONE to disable
+      } else {
+        result.dataset = dataset;
+        if (result.dependentTransforms.isEmpty()) {
+          leaves.add(result);
         }
       }
     }
-  }
 
-  // --------------------------------------------------------------------------------------------
-  //  Pipeline utility methods
-  // --------------------------------------------------------------------------------------------
+    @Override
+    public SerializablePipelineOptions getSerializableOptions() {
+      return serializableOptions;
+    }
 
-  /** Get a {@link TransformTranslator} for the given {@link TransformHierarchy.Node}. */
-  protected abstract @Nullable <
-          InT extends PInput, OutT extends POutput, TransformT extends PTransform<InT, OutT>>
-      TransformTranslator<InT, OutT, TransformT> getTransformTranslator(
-          @Nullable TransformT transform);
-
-  /** Apply the given TransformTranslator to the given node. */
-  private <InT extends PInput, OutT extends POutput, TransformT extends PTransform<InT, OutT>>
-      void applyTransformTranslator(
-          TransformHierarchy.Node node,
-          TransformT transform,
-          TransformTranslator<InT, OutT, TransformT> transformTranslator) {
-    // create the applied PTransform on the translationContext
-    AppliedPTransform<InT, OutT, PTransform<InT, OutT>> appliedTransform =
-        (AppliedPTransform) node.toAppliedPTransform(getPipeline());
-    try {
-      transformTranslator.translate(transform, appliedTransform, translationContext);
-    } catch (IOException e) {
-      throw new RuntimeException(e);
+    @Override
+    public SparkSession getSparkSession() {
+      return sparkSession;
     }
   }
 
-  // --------------------------------------------------------------------------------------------
-  //  Pipeline visitor entry point
-  // --------------------------------------------------------------------------------------------
-
   /**
-   * Translates the pipeline by passing this class as a visitor.
+   * {@link PTransformVisitor} that analyses dependencies of supported {@link PTransform
+   * PTransforms} to help identify cache candidates.
    *
-   * @param pipeline The pipeline to be translated
+   * <p>The visitor may throw if a {@link PTransform} is observed that uses unsupported features.
    */
-  public void translate(Pipeline pipeline) {
-    LOG.debug("starting translation of the pipeline using {}", getClass().getName());
-    pipeline.traverseTopologically(this);
+  private class DependencyVisitor extends PTransformVisitor {
+    private final Map<PCollection<?>, TranslationResult<?>> results = new HashMap<>();
+
+    @Override
+    <InT extends PInput, OutT extends POutput> void visit(
+        Node node,
+        PTransform<InT, OutT> transform,
+        TransformTranslator<InT, OutT, PTransform<InT, OutT>> translator) {
+      for (PCollection<?> pOut : node.getOutputs().values()) {
+        results.put(pOut, new TranslationResult<>(pOut));
+        for (Map.Entry<TupleTag<?>, PCollection<?>> entry : node.getInputs().entrySet()) {
+          TranslationResult<?> input = checkStateNotNull(results.get(entry.getValue()));
+          input.dependentTransforms.add(transform);
+        }
+      }
+    }
   }
 
-  // --------------------------------------------------------------------------------------------
-  //  Pipeline Visitor Methods
-  // --------------------------------------------------------------------------------------------
+  /**
+   * An abstract {@link PipelineVisitor} that visits all translatable {@link PTransform} pipeline
+   * nodes of a pipeline with the respective {@link TransformTranslator}.
+   *
+   * <p>The visitor may throw if a {@link PTransform} is observed that uses unsupported features.
+   */
+  private abstract class PTransformVisitor extends PipelineVisitor.Defaults {
 
-  @Override
-  public CompositeBehavior enterCompositeTransform(TransformHierarchy.Node node) {
-    PTransform<PInput, POutput> transform = (PTransform<PInput, POutput>) node.getTransform();
-    TransformTranslator<PInput, POutput, PTransform<PInput, POutput>> transformTranslator =
-        getTransformTranslator(transform);
+    /** Visit the {@link PTransform} with its respective {@link TransformTranslator}. */
+    abstract <InT extends PInput, OutT extends POutput> void visit(
+        Node node,
+        PTransform<InT, OutT> transform,
+        TransformTranslator<InT, OutT, PTransform<InT, OutT>> translator);
 
-    if (transformTranslator != null) {
-      LOG.info("Translating composite: {}", node.getFullName());
-      applyTransformTranslator(node, transform, transformTranslator);
-      return CompositeBehavior.DO_NOT_ENTER_TRANSFORM;
-    } else {
-      return CompositeBehavior.ENTER_TRANSFORM;
+    @Override
+    public final CompositeBehavior enterCompositeTransform(Node node) {
+      PTransform<PInput, POutput> transform = (PTransform<PInput, POutput>) node.getTransform();
+      TransformTranslator<PInput, POutput, PTransform<PInput, POutput>> translator =
+          getTranslator(transform);
+      if (transform != null && translator != null) {
+        visit(node, transform, translator);
+        return DO_NOT_ENTER_TRANSFORM;
+      } else {
+        return ENTER_TRANSFORM;
+      }
+    }
+
+    @Override
+    public final void visitPrimitiveTransform(Node node) {
+      PTransform<PInput, POutput> transform = (PTransform<PInput, POutput>) node.getTransform();
+      if (transform == null || transform.getClass().equals(View.CreatePCollectionView.class)) {
+        return; // ignore, nothing to be translated here

Review Comment:
   yes, PCollectionView view translation just stored the same Spark dataset (reference!) again for a different PTransform. that's obviously problematic for caching as we're not gathering metadata on that dataset in a single place. also, beam runner guidelines discourage translation of PCollectionView, they are just there for legacy reasons.
   
   not sure what kind of comment you'd expect ...



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] echauchot commented on a diff in pull request #24009: [Spark dataset runner] Cache datasets if used multiple times

Posted by GitBox <gi...@apache.org>.
echauchot commented on code in PR #24009:
URL: https://github.com/apache/beam/pull/24009#discussion_r1026473635


##########
runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/EvaluationContext.java:
##########
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.spark.structuredstreaming.translation;
+
+import java.util.Collection;
+import java.util.concurrent.Callable;
+import javax.annotation.Nullable;
+import org.apache.beam.sdk.annotations.Internal;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Throwables;
+import org.apache.spark.api.java.function.ForeachFunction;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.execution.ExplainMode;
+import org.apache.spark.util.Utils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * The {@link EvaluationContext} is the result of a pipeline {@link PipelineTranslator#translate
+ * translation} and can be used to evaluate / run the pipeline.
+ *
+ * <p>However, in some cases pipeline translation involves the early evaluation of some parts of the
+ * pipeline. For example, this is necessary to materialize side-inputs. The {@link
+ * EvaluationContext} won't re-evaluate such datasets.
+ */
+@Internal
+public final class EvaluationContext {
+  private static final Logger LOG = LoggerFactory.getLogger(EvaluationContext.class);
+
+  interface NamedDataset<T> {
+    String name();
+
+    @Nullable
+    Dataset<WindowedValue<T>> dataset();
+  }
+
+  private final Collection<? extends NamedDataset<?>> leaveDatasets;
+  private final SparkSession session;
+
+  EvaluationContext(Collection<? extends NamedDataset<?>> leaveDatasets, SparkSession session) {
+    this.leaveDatasets = leaveDatasets;
+    this.session = session;
+  }
+
+  /** Trigger evaluation of all leave datasets. */
+  public void evaluate() {
+    for (NamedDataset<?> ds : leaveDatasets) {
+      final Dataset<?> dataset = ds.dataset();
+      if (dataset == null) {
+        continue;
+      }
+      if (LOG.isDebugEnabled()) {
+        ExplainMode explainMode = ExplainMode.fromString("simple");
+        String execPlan = dataset.queryExecution().explainString(explainMode);
+        LOG.debug("Evaluating dataset {}:\n{}", ds.name(), execPlan);
+      }
+      // force evaluation using a dummy foreach action
+      evaluate(ds.name(), () -> dataset.foreach(NOOP));
+    }
+  }
+
+  /**
+   * The purpose of this utility is to mark the evaluation of Spark actions, both during Pipeline
+   * translation, when evaluation is required, and when finally evaluating the pipeline.
+   */
+  public static void evaluate(String name, Runnable action) {
+    long startMs = System.currentTimeMillis();
+    try {
+      action.run();
+      LOG.info("Evaluated dataset {} in {}", name, durationSince(startMs));
+    } catch (RuntimeException e) {
+      LOG.error("Failed to evaluate dataset {}: {}", name, Throwables.getRootCause(e).getMessage());
+      throw new RuntimeException(e);
+    }
+  }
+
+  /**
+   * The purpose of this utility is to mark the evaluation of Spark actions, both during Pipeline
+   * translation, when evaluation is required, and when finally evaluating the pipeline.
+   */
+  public static <T> T evaluate(String name, Callable<T> action) {

Review Comment:
   https://github.com/apache/beam/commit/dd6d0781c7b0321999168586da058d8b66d0b138 is a commit you added after my review. So at the time of my review this code was unused.
   Have I misunderstood something ?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] mosche commented on pull request #24009: [Spark dataset runner] Cache datasets if used multiple times

Posted by GitBox <gi...@apache.org>.
mosche commented on PR #24009:
URL: https://github.com/apache/beam/pull/24009#issuecomment-1305681016

   Run Spark StructuredStreaming ValidatesRunner


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] echauchot commented on pull request #24009: [Spark dataset runner] Cache datasets if used multiple times

Posted by GitBox <gi...@apache.org>.
echauchot commented on PR #24009:
URL: https://github.com/apache/beam/pull/24009#issuecomment-1308601816

   @mosche @aromanenko-dev I plan on reviewing this PR on Monday.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] mosche commented on a diff in pull request #24009: [Spark dataset runner] Cache datasets if used multiple times

Posted by GitBox <gi...@apache.org>.
mosche commented on code in PR #24009:
URL: https://github.com/apache/beam/pull/24009#discussion_r1024996169


##########
runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/PipelineTranslatorBatch.java:
##########
@@ -81,27 +80,13 @@ public class PipelineTranslatorBatch extends PipelineTranslator {
 
     TRANSFORM_TRANSLATORS.put(
         SplittableParDo.PrimitiveBoundedRead.class, new ReadSourceTranslatorBatch<>());
-
-    TRANSFORM_TRANSLATORS.put(
-        View.CreatePCollectionView.class, new CreatePCollectionViewTranslatorBatch<>());
-  }
-
-  public PipelineTranslatorBatch(SparkStructuredStreamingPipelineOptions options) {
-    translationContext = new TranslationContext(options);
   }
 
-  /** Returns a translator for the given node, if it is possible, otherwise null. */
+  /** Returns a {@link TransformTranslator} for the given {@link PTransform} if known. */
   @Override
   @Nullable
   protected <InT extends PInput, OutT extends POutput, TransformT extends PTransform<InT, OutT>>
-      TransformTranslator<InT, OutT, TransformT> getTransformTranslator(
-          @Nullable TransformT transform) {
-    // Root of the graph is null
-    if (transform == null) {
-      return null;
-    }
-    TransformTranslator<InT, OutT, TransformT> translator =
-        TRANSFORM_TRANSLATORS.get(transform.getClass());
-    return translator != null && translator.canTranslate(transform) ? translator : null;

Review Comment:
   exactly 👍 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] echauchot commented on pull request #24009: [Spark dataset runner] Cache datasets if used multiple times

Posted by GitBox <gi...@apache.org>.
echauchot commented on PR #24009:
URL: https://github.com/apache/beam/pull/24009#issuecomment-1313751222

   > This PR changes the PipelineTranslator to support caching datasets if they are used multiple times. This is an outstanding feature compared to the RDD runner and critical to run batch pipelines on large scale as it prevents the potentially costly re-evaluation of datasets (closes #24008). Additionally this PR is a preparation to fix / simplify handling of side inputs (see #24035).
   
   Why is it that this PR addresses 2 separate tickets when the guidelines state to address only one per PR (for isolation, diagnostic, and revert reasons) ? Can't the 2 issues be addressed separately ? They are too coupled ?
   
   > 
   > In our benchmarks, considering the insignificant scale of both nexmark and tpc-ds (1GB) tests, the impact of this is expected to be neutral or even negative.
   
   Sure, maybe we should setup a different CI job with TPCDS runs in higher scale in a different PR.
   
   > 
   > Pipeline translation is changed as follows:
   > 
   > 1. Detect if `streaming` mode is required.
   >    
   >    * As is, just removed unnecessary clutter.
   > 2. Identify datasets that are repeatedly used as input and should be cached.
   >    
   >    * New translation step, similar to the one done in the RDD runner.
   >    * Rather than just gathering translation state as `Map<PCollection, Dataset>`, this has turned into `Map<PCollection, TranslationResult>` to track the `Dataset` but also additional metadata such as dependencies (and broadcast variables once [[Bug]:  Improve handling of side inputs in Spark dataset runner #24035](https://github.com/apache/beam/issues/24035) is addressed)
   > 3. And finally, translate each primitive or composite `PTransform` that has a `TransformTranslator` and is supported (`TransformTranslator.canTranslate`) into its Spark correspondence. If a composite is not supported, it will be expanded further into its parts and translated then.
   >    
   >    * Logically as is, but using the extended state.
   >    * Also, this PR attempts to better distinguish between translation and evaluation. Instead of returning a translation context after the translation, this translation state is dropped as it is (mostly) not needed anymore. Only the necessary parts (leave datasets) 
   you mean **leaves** datasets, meaning at the very end of the dataset tree, right ?
   
   are returned as `EvaluationContext`, that can be used to trigger the evaluation of the pipeline.
   > 
   > Thank you for your contribution! Follow this checklist to help us incorporate your contribution quickly and easily:
   > 
   > * [ ]  [**Choose reviewer(s)**](https://beam.apache.org/contribute/#make-your-change) and mention them in a comment (`R: @username`).
   > * [ ]  Mention the appropriate issue in your description (for example: `addresses #123`), if applicable. This will automatically add a link to the pull request in the issue. If you would like the issue to automatically close on merging the pull request, comment `fixes #<ISSUE NUMBER>` instead.
   > * [ ]  Update `CHANGES.md` with noteworthy changes.
   > * [ ]  If this contribution is large, please file an Apache [Individual Contributor License Agreement](https://www.apache.org/licenses/icla.pdf).
   > 
   > See the [Contributor Guide](https://beam.apache.org/contribute) for more tips on [how to make review process smoother](https://beam.apache.org/contribute/get-started-contributing/#make-the-reviewers-job-easier).
   > 
   > To check the build health, please visit https://github.com/apache/beam/blob/master/.test-infra/BUILD_STATUS.md
   > 
   > ## GitHub Actions Tests Status (on master branch)
   > [![Build python source distribution and wheels](https://github.com/apache/beam/workflows/Build%20python%20source%20distribution%20and%20wheels/badge.svg?branch=master&event=schedule)](https://github.com/apache/beam/actions?query=workflow%3A%22Build+python+source+distribution+and+wheels%22+branch%3Amaster+event%3Aschedule) [![Python tests](https://github.com/apache/beam/workflows/Python%20tests/badge.svg?branch=master&event=schedule)](https://github.com/apache/beam/actions?query=workflow%3A%22Python+Tests%22+branch%3Amaster+event%3Aschedule) [![Java tests](https://github.com/apache/beam/workflows/Java%20Tests/badge.svg?branch=master&event=schedule)](https://github.com/apache/beam/actions?query=workflow%3A%22Java+Tests%22+branch%3Amaster+event%3Aschedule) [![Go tests](https://github.com/apache/beam/workflows/Go%20tests/badge.svg?branch=master&event=schedule)](https://github.com/apache/beam/actions?query=workflow%3A%22Go+tests%22+branch%3Amaster+event%3Aschedule)
   > 
   > See [CI.md](https://github.com/apache/beam/blob/master/CI.md) for more information about GitHub Actions CI.
   
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] mosche commented on pull request #24009: [Spark dataset runner] Cache datasets if used multiple times

Posted by GitBox <gi...@apache.org>.
mosche commented on PR #24009:
URL: https://github.com/apache/beam/pull/24009#issuecomment-1313860923

   > Why is it that this PR addresses 2 separate tickets when the guidelines state to address only one per PR (for isolation, diagnostic, and revert reasons) ? Can't the 2 issues be addressed separately ? They are too coupled ?
   
   @echauchot This PR addresses just one issue. But it's done in a way that #24035 can be fixed in a follow up without introducing massive changes yet again. #24035 is not addressed at all yet.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] echauchot commented on a diff in pull request #24009: [Spark dataset runner] Cache datasets if used multiple times

Posted by GitBox <gi...@apache.org>.
echauchot commented on code in PR #24009:
URL: https://github.com/apache/beam/pull/24009#discussion_r1026481858


##########
runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/TransformTranslator.java:
##########
@@ -118,69 +126,53 @@ public <T> PCollection<T> getOutput(TupleTag<T> tag) {
       return pc;
     }
 
-    public Map<TupleTag<?>, PCollection<?>> getOutputs() {
-      return transform.getOutputs();
-    }
-
     public AppliedPTransform<InT, OutT, PTransform<InT, OutT>> getCurrentTransform() {
       return transform;
     }
 
+    @Override
     public <T> Dataset<WindowedValue<T>> getDataset(PCollection<T> pCollection) {
-      return cxt.getDataset(pCollection);
+      return state.getDataset(pCollection);
     }
 
-    public <T> void putDataset(PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset) {
-      cxt.putDataset(pCollection, dataset);
+    @Override
+    public <T> void putDataset(
+        PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset, boolean noCache) {

Review Comment:
   yes and it even more difficult to read with noCache = false :smile: 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] mosche commented on a diff in pull request #24009: [Spark dataset runner] Cache datasets if used multiple times

Posted by GitBox <gi...@apache.org>.
mosche commented on code in PR #24009:
URL: https://github.com/apache/beam/pull/24009#discussion_r1026208479


##########
runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/EvaluationContext.java:
##########
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.spark.structuredstreaming.translation;
+
+import java.util.Collection;
+import java.util.concurrent.Callable;
+import javax.annotation.Nullable;
+import org.apache.beam.sdk.annotations.Internal;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Throwables;
+import org.apache.spark.api.java.function.ForeachFunction;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.execution.ExplainMode;
+import org.apache.spark.util.Utils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * The {@link EvaluationContext} is the result of a pipeline {@link PipelineTranslator#translate
+ * translation} and can be used to evaluate / run the pipeline.
+ *
+ * <p>However, in some cases pipeline translation involves the early evaluation of some parts of the
+ * pipeline. For example, this is necessary to materialize side-inputs. The {@link
+ * EvaluationContext} won't re-evaluate such datasets.
+ */
+@Internal
+public final class EvaluationContext {
+  private static final Logger LOG = LoggerFactory.getLogger(EvaluationContext.class);
+
+  interface NamedDataset<T> {
+    String name();
+
+    @Nullable
+    Dataset<WindowedValue<T>> dataset();
+  }
+
+  private final Collection<? extends NamedDataset<?>> leaveDatasets;
+  private final SparkSession session;
+
+  EvaluationContext(Collection<? extends NamedDataset<?>> leaveDatasets, SparkSession session) {
+    this.leaveDatasets = leaveDatasets;
+    this.session = session;
+  }
+
+  /** Trigger evaluation of all leave datasets. */
+  public void evaluate() {
+    for (NamedDataset<?> ds : leaveDatasets) {
+      final Dataset<?> dataset = ds.dataset();
+      if (dataset == null) {
+        continue;
+      }
+      if (LOG.isDebugEnabled()) {
+        ExplainMode explainMode = ExplainMode.fromString("simple");
+        String execPlan = dataset.queryExecution().explainString(explainMode);
+        LOG.debug("Evaluating dataset {}:\n{}", ds.name(), execPlan);
+      }
+      // force evaluation using a dummy foreach action
+      evaluate(ds.name(), () -> dataset.foreach(NOOP));
+    }
+  }
+
+  /**
+   * The purpose of this utility is to mark the evaluation of Spark actions, both during Pipeline
+   * translation, when evaluation is required, and when finally evaluating the pipeline.
+   */
+  public static void evaluate(String name, Runnable action) {
+    long startMs = System.currentTimeMillis();
+    try {
+      action.run();
+      LOG.info("Evaluated dataset {} in {}", name, durationSince(startMs));
+    } catch (RuntimeException e) {
+      LOG.error("Failed to evaluate dataset {}: {}", name, Throwables.getRootCause(e).getMessage());
+      throw new RuntimeException(e);
+    }
+  }
+
+  /**
+   * The purpose of this utility is to mark the evaluation of Spark actions, both during Pipeline
+   * translation, when evaluation is required, and when finally evaluating the pipeline.
+   */
+  public static <T> T evaluate(String name, Callable<T> action) {

Review Comment:
   Actually it is used, see https://github.com/apache/beam/pull/24009/commits/dd6d0781c7b0321999168586da058d8b66d0b138



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] mosche commented on a diff in pull request #24009: [Spark dataset runner] Cache datasets if used multiple times

Posted by GitBox <gi...@apache.org>.
mosche commented on code in PR #24009:
URL: https://github.com/apache/beam/pull/24009#discussion_r1026478459


##########
runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/EvaluationContext.java:
##########
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.spark.structuredstreaming.translation;
+
+import java.util.Collection;
+import java.util.concurrent.Callable;
+import javax.annotation.Nullable;
+import org.apache.beam.sdk.annotations.Internal;
+import org.apache.beam.sdk.util.WindowedValue;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Throwables;
+import org.apache.spark.api.java.function.ForeachFunction;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.execution.ExplainMode;
+import org.apache.spark.util.Utils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * The {@link EvaluationContext} is the result of a pipeline {@link PipelineTranslator#translate
+ * translation} and can be used to evaluate / run the pipeline.
+ *
+ * <p>However, in some cases pipeline translation involves the early evaluation of some parts of the
+ * pipeline. For example, this is necessary to materialize side-inputs. The {@link
+ * EvaluationContext} won't re-evaluate such datasets.
+ */
+@Internal
+public final class EvaluationContext {
+  private static final Logger LOG = LoggerFactory.getLogger(EvaluationContext.class);
+
+  interface NamedDataset<T> {
+    String name();
+
+    @Nullable
+    Dataset<WindowedValue<T>> dataset();
+  }
+
+  private final Collection<? extends NamedDataset<?>> leaveDatasets;
+  private final SparkSession session;
+
+  EvaluationContext(Collection<? extends NamedDataset<?>> leaveDatasets, SparkSession session) {
+    this.leaveDatasets = leaveDatasets;
+    this.session = session;
+  }
+
+  /** Trigger evaluation of all leave datasets. */
+  public void evaluate() {
+    for (NamedDataset<?> ds : leaveDatasets) {
+      final Dataset<?> dataset = ds.dataset();
+      if (dataset == null) {
+        continue;
+      }
+      if (LOG.isDebugEnabled()) {
+        ExplainMode explainMode = ExplainMode.fromString("simple");
+        String execPlan = dataset.queryExecution().explainString(explainMode);
+        LOG.debug("Evaluating dataset {}:\n{}", ds.name(), execPlan);
+      }
+      // force evaluation using a dummy foreach action
+      evaluate(ds.name(), () -> dataset.foreach(NOOP));
+    }
+  }
+
+  /**
+   * The purpose of this utility is to mark the evaluation of Spark actions, both during Pipeline
+   * translation, when evaluation is required, and when finally evaluating the pipeline.
+   */
+  public static void evaluate(String name, Runnable action) {
+    long startMs = System.currentTimeMillis();
+    try {
+      action.run();
+      LOG.info("Evaluated dataset {} in {}", name, durationSince(startMs));
+    } catch (RuntimeException e) {
+      LOG.error("Failed to evaluate dataset {}: {}", name, Throwables.getRootCause(e).getMessage());
+      throw new RuntimeException(e);
+    }
+  }
+
+  /**
+   * The purpose of this utility is to mark the evaluation of Spark actions, both during Pipeline
+   * translation, when evaluation is required, and when finally evaluating the pipeline.
+   */
+  public static <T> T evaluate(String name, Callable<T> action) {

Review Comment:
   It was used here https://github.com/apache/beam/commit/dd6d0781c7b0321999168586da058d8b66d0b138#diff-d7904f7a21f6d5b3aeccce61863bcca39db112696d972762deca5b4623c87da6L229-L230



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] echauchot commented on a diff in pull request #24009: [Spark dataset runner] Cache datasets if used multiple times

Posted by GitBox <gi...@apache.org>.
echauchot commented on code in PR #24009:
URL: https://github.com/apache/beam/pull/24009#discussion_r1026493733


##########
runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/PipelineTranslator.java:
##########
@@ -17,170 +17,336 @@
  */
 package org.apache.beam.runners.spark.structuredstreaming.translation;
 
+import static org.apache.beam.sdk.Pipeline.PipelineVisitor.CompositeBehavior.DO_NOT_ENTER_TRANSFORM;
+import static org.apache.beam.sdk.Pipeline.PipelineVisitor.CompositeBehavior.ENTER_TRANSFORM;
+import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull;
+import static org.apache.beam.sdk.values.PCollection.IsBounded.UNBOUNDED;
+
 import java.io.IOException;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+import javax.annotation.Nullable;
 import org.apache.beam.runners.core.construction.PTransformTranslation;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
+import org.apache.beam.runners.spark.SparkCommonPipelineOptions;
+import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderProvider;
 import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.Pipeline.PipelineVisitor;
+import org.apache.beam.sdk.annotations.Internal;
+import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.options.StreamingOptions;
 import org.apache.beam.sdk.runners.AppliedPTransform;
-import org.apache.beam.sdk.runners.TransformHierarchy;
+import org.apache.beam.sdk.runners.TransformHierarchy.Node;
 import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.View;
+import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PInput;
 import org.apache.beam.sdk.values.POutput;
 import org.apache.beam.sdk.values.PValue;
-import org.checkerframework.checker.nullness.qual.Nullable;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;
+import org.apache.spark.storage.StorageLevel;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 /**
- * {@link Pipeline.PipelineVisitor} that translates the Beam operators to their Spark counterparts.
- * It also does the pipeline preparation: mode detection, transforms replacement, classpath
- * preparation.
+ * The pipeline translator translates a Beam {@link Pipeline} into a Spark correspondence, that can
+ * then be evaluated.
+ *
+ * <p>The translation involves traversing the hierarchy of a pipeline multiple times:
+ *
+ * <ol>
+ *   <li>Detect if {@link StreamingOptions#setStreaming streaming} mode is required.
+ *   <li>Identify datasets that are repeatedly used as input and should be cached.
+ *   <li>And finally, translate each primitive or composite {@link PTransform} that is {@link
+ *       #getTransformTranslator known} and {@link TransformTranslator#canTranslate supported} into
+ *       its Spark correspondence. If a composite is not supported, it will be expanded further into
+ *       its parts and translated then.
+ * </ol>
  */
-@SuppressWarnings({
-  "nullness" // TODO(https://github.com/apache/beam/issues/20497)
-})
-public abstract class PipelineTranslator extends Pipeline.PipelineVisitor.Defaults {
+@Internal
+public abstract class PipelineTranslator {
   private static final Logger LOG = LoggerFactory.getLogger(PipelineTranslator.class);
-  protected TranslationContext translationContext;
 
-  // --------------------------------------------------------------------------------------------
-  //  Pipeline preparation methods
-  // --------------------------------------------------------------------------------------------
   public static void replaceTransforms(Pipeline pipeline, StreamingOptions options) {
     pipeline.replaceAll(SparkTransformOverrides.getDefaultOverrides(options.isStreaming()));
   }
 
   /**
-   * Visit the pipeline to determine the translation mode (batch/streaming) and update options
-   * accordingly.
+   * Analyse the pipeline to determine if we have to switch to streaming mode for the pipeline
+   * translation and update {@link StreamingOptions} accordingly.
    */
-  public static void detectTranslationMode(Pipeline pipeline, StreamingOptions options) {
-    TranslationModeDetector detector = new TranslationModeDetector();
+  public static void detectStreamingMode(Pipeline pipeline, StreamingOptions options) {
+    StreamingModeDetector detector = new StreamingModeDetector(options.isStreaming());
     pipeline.traverseTopologically(detector);
-    if (detector.getTranslationMode().equals(TranslationMode.STREAMING)) {
-      options.setStreaming(true);
+    options.setStreaming(detector.streaming);
+  }
+
+  /** Returns a {@link TransformTranslator} for the given {@link PTransform} if known. */
+  protected abstract @Nullable <
+          InT extends PInput, OutT extends POutput, TransformT extends PTransform<InT, OutT>>
+      TransformTranslator<InT, OutT, TransformT> getTransformTranslator(TransformT transform);
+
+  /**
+   * Translates a Beam pipeline into its Spark correspondence using the Spark SQL / Dataset API.
+   *
+   * <p>Note, in some cases this involves the early evaluation of some parts of the pipeline. For
+   * example, in order to use a side-input {@link org.apache.beam.sdk.values.PCollectionView
+   * PCollectionView} in a translation the corresponding Spark {@link
+   * org.apache.beam.runners.spark.translation.Dataset Dataset} might have to be collected and
+   * broadcasted to be able to continue with the translation.
+   *
+   * @return The result of the translation is an {@link EvaluationContext} that can trigger the
+   *     evaluation of the Spark pipeline.
+   */
+  public EvaluationContext translate(
+      Pipeline pipeline, SparkSession session, SparkCommonPipelineOptions options) {
+    LOG.debug("starting translation of the pipeline using {}", getClass().getName());
+    DependencyVisitor dependencies = new DependencyVisitor();
+    pipeline.traverseTopologically(dependencies);
+
+    TranslatingVisitor translator = new TranslatingVisitor(session, options, dependencies.results);
+    pipeline.traverseTopologically(translator);
+
+    return new EvaluationContext(translator.leaves, session);
+  }
+
+  /**
+   * The correspondence of a {@link PCollection} as result of translating a {@link PTransform}
+   * including additional metadata (such as name and dependents).
+   */
+  private static final class TranslationResult<T> implements EvaluationContext.NamedDataset<T> {
+    private final String name;
+    private @Nullable Dataset<WindowedValue<T>> dataset = null;
+    private final Set<PTransform<?, ?>> dependentTransforms = new HashSet<>();
+
+    private TranslationResult(PCollection<?> pCol) {
+      this.name = pCol.getName();
+    }
+
+    @Override
+    public String name() {
+      return name;
+    }
+
+    @Override
+    public @Nullable Dataset<WindowedValue<T>> dataset() {
+      return dataset;
     }
   }
 
-  /** The translation mode of the Beam Pipeline. */
-  private enum TranslationMode {
+  /** Shared, mutable state during the translation of a pipeline and omitted afterwards. */
+  interface TranslationState extends EncoderProvider {
+    <T> Dataset<WindowedValue<T>> getDataset(PCollection<T> pCollection);
+
+    <T> void putDataset(
+        PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset, boolean noCache);
 
-    /** Uses the batch mode. */
-    BATCH,
+    default <T> void putDataset(PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset) {
+      putDataset(pCollection, dataset, false);
+    }
 
-    /** Uses the streaming mode. */
-    STREAMING
+    SerializablePipelineOptions getSerializableOptions();
+
+    SparkSession getSparkSession();
   }
 
-  /** Traverses the Pipeline to determine the {@link TranslationMode} for this pipeline. */
-  private static class TranslationModeDetector extends Pipeline.PipelineVisitor.Defaults {
-    private static final Logger LOG = LoggerFactory.getLogger(TranslationModeDetector.class);
+  /**
+   * {@link PTransformVisitor} that translates supported {@link PTransform PTransforms} into their
+   * Spark correspondence.
+   *
+   * <p>Note, in some cases this involves the early evaluation of some parts of the pipeline. For
+   * example, in order to use a side-input {@link org.apache.beam.sdk.values.PCollectionView
+   * PCollectionView} in a translation the corresponding Spark {@link
+   * org.apache.beam.runners.spark.translation.Dataset Dataset} might have to be collected and
+   * broadcasted.
+   */
+  private class TranslatingVisitor extends PTransformVisitor implements TranslationState {
+    private final Map<PCollection<?>, TranslationResult<?>> translationResults;
+    private final Map<Coder<?>, ExpressionEncoder<?>> encoders;
+    private final SparkSession sparkSession;
+    private final SerializablePipelineOptions serializableOptions;
+    private final StorageLevel storageLevel;
+
+    private final Set<TranslationResult<?>> leaves;
+
+    public TranslatingVisitor(
+        SparkSession sparkSession,
+        SparkCommonPipelineOptions options,
+        Map<PCollection<?>, TranslationResult<?>> translationResults) {
+      this.sparkSession = sparkSession;
+      this.translationResults = translationResults;
+      this.serializableOptions = new SerializablePipelineOptions(options);
+      this.storageLevel = StorageLevel.fromString(options.getStorageLevel());
+      this.encoders = new HashMap<>();
+      this.leaves = new HashSet<>();
+    }
 
-    private TranslationMode translationMode;
+    @Override
+    <InT extends PInput, OutT extends POutput> void visit(
+        Node node,
+        PTransform<InT, OutT> transform,
+        TransformTranslator<InT, OutT, PTransform<InT, OutT>> translator) {
+
+      AppliedPTransform<InT, OutT, PTransform<InT, OutT>> appliedTransform =
+          (AppliedPTransform) node.toAppliedPTransform(getPipeline());
+      try {
+        LOG.info(
+            "Translating {}: {}",
+            node.isCompositeNode() ? "composite" : "primitive",
+            node.getFullName());
+        translator.translate(transform, appliedTransform, this);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
 
-    TranslationModeDetector(TranslationMode defaultMode) {
-      this.translationMode = defaultMode;
+    @Override
+    public <T> Encoder<T> encoderOf(Coder<T> coder, Factory<T> factory) {
+      return (Encoder<T>) encoders.computeIfAbsent(coder, (Factory) factory);
     }
 
-    TranslationModeDetector() {
-      this(TranslationMode.BATCH);
+    private <T> TranslationResult<T> getResult(PCollection<T> pCollection) {
+      return (TranslationResult<T>) checkStateNotNull(translationResults.get(pCollection));
     }
 
-    TranslationMode getTranslationMode() {
-      return translationMode;
+    @Override
+    public <T> Dataset<WindowedValue<T>> getDataset(PCollection<T> pCollection) {
+      return checkStateNotNull(getResult(pCollection).dataset);
     }
 
     @Override
-    public void visitValue(PValue value, TransformHierarchy.Node producer) {
-      if (translationMode.equals(TranslationMode.BATCH)) {
-        if (value instanceof PCollection
-            && ((PCollection) value).isBounded() == PCollection.IsBounded.UNBOUNDED) {
-          LOG.info(
-              "Found unbounded PCollection {}. Switching to streaming execution.", value.getName());
-          translationMode = TranslationMode.STREAMING;
+    public <T> void putDataset(
+        PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset, boolean noCache) {
+      TranslationResult<T> result = getResult(pCollection);
+      if (!noCache && result.dependentTransforms.size() > 1) {
+        LOG.info("Dataset {} will be cached.", result.name);
+        result.dataset = dataset.persist(storageLevel); // use NONE to disable
+      } else {
+        result.dataset = dataset;
+        if (result.dependentTransforms.isEmpty()) {
+          leaves.add(result);
         }
       }
     }
-  }
 
-  // --------------------------------------------------------------------------------------------
-  //  Pipeline utility methods
-  // --------------------------------------------------------------------------------------------
+    @Override
+    public SerializablePipelineOptions getSerializableOptions() {
+      return serializableOptions;
+    }
 
-  /** Get a {@link TransformTranslator} for the given {@link TransformHierarchy.Node}. */
-  protected abstract @Nullable <
-          InT extends PInput, OutT extends POutput, TransformT extends PTransform<InT, OutT>>
-      TransformTranslator<InT, OutT, TransformT> getTransformTranslator(
-          @Nullable TransformT transform);
-
-  /** Apply the given TransformTranslator to the given node. */
-  private <InT extends PInput, OutT extends POutput, TransformT extends PTransform<InT, OutT>>
-      void applyTransformTranslator(
-          TransformHierarchy.Node node,
-          TransformT transform,
-          TransformTranslator<InT, OutT, TransformT> transformTranslator) {
-    // create the applied PTransform on the translationContext
-    AppliedPTransform<InT, OutT, PTransform<InT, OutT>> appliedTransform =
-        (AppliedPTransform) node.toAppliedPTransform(getPipeline());
-    try {
-      transformTranslator.translate(transform, appliedTransform, translationContext);
-    } catch (IOException e) {
-      throw new RuntimeException(e);
+    @Override
+    public SparkSession getSparkSession() {
+      return sparkSession;
     }
   }
 
-  // --------------------------------------------------------------------------------------------
-  //  Pipeline visitor entry point
-  // --------------------------------------------------------------------------------------------
-
   /**
-   * Translates the pipeline by passing this class as a visitor.
+   * {@link PTransformVisitor} that analyses dependencies of supported {@link PTransform
+   * PTransforms} to help identify cache candidates.
    *
-   * @param pipeline The pipeline to be translated
+   * <p>The visitor may throw if a {@link PTransform} is observed that uses unsupported features.
    */
-  public void translate(Pipeline pipeline) {
-    LOG.debug("starting translation of the pipeline using {}", getClass().getName());
-    pipeline.traverseTopologically(this);
+  private class DependencyVisitor extends PTransformVisitor {
+    private final Map<PCollection<?>, TranslationResult<?>> results = new HashMap<>();
+
+    @Override
+    <InT extends PInput, OutT extends POutput> void visit(
+        Node node,
+        PTransform<InT, OutT> transform,
+        TransformTranslator<InT, OutT, PTransform<InT, OutT>> translator) {
+      for (PCollection<?> pOut : node.getOutputs().values()) {
+        results.put(pOut, new TranslationResult<>(pOut));
+        for (Map.Entry<TupleTag<?>, PCollection<?>> entry : node.getInputs().entrySet()) {
+          TranslationResult<?> input = checkStateNotNull(results.get(entry.getValue()));
+          input.dependentTransforms.add(transform);
+        }
+      }
+    }
   }
 
-  // --------------------------------------------------------------------------------------------
-  //  Pipeline Visitor Methods
-  // --------------------------------------------------------------------------------------------
+  /**
+   * An abstract {@link PipelineVisitor} that visits all translatable {@link PTransform} pipeline
+   * nodes of a pipeline with the respective {@link TransformTranslator}.
+   *
+   * <p>The visitor may throw if a {@link PTransform} is observed that uses unsupported features.
+   */
+  private abstract class PTransformVisitor extends PipelineVisitor.Defaults {
 
-  @Override
-  public CompositeBehavior enterCompositeTransform(TransformHierarchy.Node node) {
-    PTransform<PInput, POutput> transform = (PTransform<PInput, POutput>) node.getTransform();
-    TransformTranslator<PInput, POutput, PTransform<PInput, POutput>> transformTranslator =
-        getTransformTranslator(transform);
+    /** Visit the {@link PTransform} with its respective {@link TransformTranslator}. */
+    abstract <InT extends PInput, OutT extends POutput> void visit(
+        Node node,
+        PTransform<InT, OutT> transform,
+        TransformTranslator<InT, OutT, PTransform<InT, OutT>> translator);
 
-    if (transformTranslator != null) {
-      LOG.info("Translating composite: {}", node.getFullName());
-      applyTransformTranslator(node, transform, transformTranslator);
-      return CompositeBehavior.DO_NOT_ENTER_TRANSFORM;
-    } else {
-      return CompositeBehavior.ENTER_TRANSFORM;
+    @Override
+    public final CompositeBehavior enterCompositeTransform(Node node) {
+      PTransform<PInput, POutput> transform = (PTransform<PInput, POutput>) node.getTransform();
+      TransformTranslator<PInput, POutput, PTransform<PInput, POutput>> translator =
+          getTranslator(transform);
+      if (transform != null && translator != null) {
+        visit(node, transform, translator);
+        return DO_NOT_ENTER_TRANSFORM;
+      } else {
+        return ENTER_TRANSFORM;
+      }
+    }
+
+    @Override
+    public final void visitPrimitiveTransform(Node node) {
+      PTransform<PInput, POutput> transform = (PTransform<PInput, POutput>) node.getTransform();
+      if (transform == null || transform.getClass().equals(View.CreatePCollectionView.class)) {
+        return; // ignore, nothing to be translated here

Review Comment:
   ok that is what I thought. The comment you added in commit #8dce7a6f LGTM.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] echauchot commented on a diff in pull request #24009: [Spark dataset runner] Cache datasets if used multiple times

Posted by GitBox <gi...@apache.org>.
echauchot commented on code in PR #24009:
URL: https://github.com/apache/beam/pull/24009#discussion_r1026489847


##########
runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/PipelineTranslator.java:
##########
@@ -17,170 +17,336 @@
  */
 package org.apache.beam.runners.spark.structuredstreaming.translation;
 
+import static org.apache.beam.sdk.Pipeline.PipelineVisitor.CompositeBehavior.DO_NOT_ENTER_TRANSFORM;
+import static org.apache.beam.sdk.Pipeline.PipelineVisitor.CompositeBehavior.ENTER_TRANSFORM;
+import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull;
+import static org.apache.beam.sdk.values.PCollection.IsBounded.UNBOUNDED;
+
 import java.io.IOException;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+import javax.annotation.Nullable;
 import org.apache.beam.runners.core.construction.PTransformTranslation;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
+import org.apache.beam.runners.spark.SparkCommonPipelineOptions;
+import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderProvider;
 import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.Pipeline.PipelineVisitor;
+import org.apache.beam.sdk.annotations.Internal;
+import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.options.StreamingOptions;
 import org.apache.beam.sdk.runners.AppliedPTransform;
-import org.apache.beam.sdk.runners.TransformHierarchy;
+import org.apache.beam.sdk.runners.TransformHierarchy.Node;
 import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.View;
+import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PInput;
 import org.apache.beam.sdk.values.POutput;
 import org.apache.beam.sdk.values.PValue;
-import org.checkerframework.checker.nullness.qual.Nullable;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;
+import org.apache.spark.storage.StorageLevel;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 /**
- * {@link Pipeline.PipelineVisitor} that translates the Beam operators to their Spark counterparts.
- * It also does the pipeline preparation: mode detection, transforms replacement, classpath
- * preparation.
+ * The pipeline translator translates a Beam {@link Pipeline} into a Spark correspondence, that can
+ * then be evaluated.
+ *
+ * <p>The translation involves traversing the hierarchy of a pipeline multiple times:
+ *
+ * <ol>
+ *   <li>Detect if {@link StreamingOptions#setStreaming streaming} mode is required.
+ *   <li>Identify datasets that are repeatedly used as input and should be cached.
+ *   <li>And finally, translate each primitive or composite {@link PTransform} that is {@link
+ *       #getTransformTranslator known} and {@link TransformTranslator#canTranslate supported} into
+ *       its Spark correspondence. If a composite is not supported, it will be expanded further into
+ *       its parts and translated then.
+ * </ol>
  */
-@SuppressWarnings({
-  "nullness" // TODO(https://github.com/apache/beam/issues/20497)
-})
-public abstract class PipelineTranslator extends Pipeline.PipelineVisitor.Defaults {
+@Internal
+public abstract class PipelineTranslator {
   private static final Logger LOG = LoggerFactory.getLogger(PipelineTranslator.class);
-  protected TranslationContext translationContext;
 
-  // --------------------------------------------------------------------------------------------
-  //  Pipeline preparation methods
-  // --------------------------------------------------------------------------------------------
   public static void replaceTransforms(Pipeline pipeline, StreamingOptions options) {
     pipeline.replaceAll(SparkTransformOverrides.getDefaultOverrides(options.isStreaming()));
   }
 
   /**
-   * Visit the pipeline to determine the translation mode (batch/streaming) and update options
-   * accordingly.
+   * Analyse the pipeline to determine if we have to switch to streaming mode for the pipeline
+   * translation and update {@link StreamingOptions} accordingly.
    */
-  public static void detectTranslationMode(Pipeline pipeline, StreamingOptions options) {
-    TranslationModeDetector detector = new TranslationModeDetector();
+  public static void detectStreamingMode(Pipeline pipeline, StreamingOptions options) {
+    StreamingModeDetector detector = new StreamingModeDetector(options.isStreaming());
     pipeline.traverseTopologically(detector);
-    if (detector.getTranslationMode().equals(TranslationMode.STREAMING)) {
-      options.setStreaming(true);
+    options.setStreaming(detector.streaming);
+  }
+
+  /** Returns a {@link TransformTranslator} for the given {@link PTransform} if known. */
+  protected abstract @Nullable <
+          InT extends PInput, OutT extends POutput, TransformT extends PTransform<InT, OutT>>
+      TransformTranslator<InT, OutT, TransformT> getTransformTranslator(TransformT transform);
+
+  /**
+   * Translates a Beam pipeline into its Spark correspondence using the Spark SQL / Dataset API.
+   *
+   * <p>Note, in some cases this involves the early evaluation of some parts of the pipeline. For
+   * example, in order to use a side-input {@link org.apache.beam.sdk.values.PCollectionView
+   * PCollectionView} in a translation the corresponding Spark {@link
+   * org.apache.beam.runners.spark.translation.Dataset Dataset} might have to be collected and
+   * broadcasted to be able to continue with the translation.
+   *
+   * @return The result of the translation is an {@link EvaluationContext} that can trigger the
+   *     evaluation of the Spark pipeline.
+   */
+  public EvaluationContext translate(
+      Pipeline pipeline, SparkSession session, SparkCommonPipelineOptions options) {
+    LOG.debug("starting translation of the pipeline using {}", getClass().getName());
+    DependencyVisitor dependencies = new DependencyVisitor();
+    pipeline.traverseTopologically(dependencies);
+
+    TranslatingVisitor translator = new TranslatingVisitor(session, options, dependencies.results);
+    pipeline.traverseTopologically(translator);
+
+    return new EvaluationContext(translator.leaves, session);
+  }
+
+  /**
+   * The correspondence of a {@link PCollection} as result of translating a {@link PTransform}
+   * including additional metadata (such as name and dependents).
+   */
+  private static final class TranslationResult<T> implements EvaluationContext.NamedDataset<T> {
+    private final String name;
+    private @Nullable Dataset<WindowedValue<T>> dataset = null;
+    private final Set<PTransform<?, ?>> dependentTransforms = new HashSet<>();
+
+    private TranslationResult(PCollection<?> pCol) {
+      this.name = pCol.getName();
+    }
+
+    @Override
+    public String name() {
+      return name;
+    }
+
+    @Override
+    public @Nullable Dataset<WindowedValue<T>> dataset() {
+      return dataset;
     }
   }
 
-  /** The translation mode of the Beam Pipeline. */
-  private enum TranslationMode {
+  /** Shared, mutable state during the translation of a pipeline and omitted afterwards. */
+  interface TranslationState extends EncoderProvider {
+    <T> Dataset<WindowedValue<T>> getDataset(PCollection<T> pCollection);
+
+    <T> void putDataset(
+        PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset, boolean noCache);
 
-    /** Uses the batch mode. */
-    BATCH,
+    default <T> void putDataset(PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset) {
+      putDataset(pCollection, dataset, false);
+    }
 
-    /** Uses the streaming mode. */
-    STREAMING
+    SerializablePipelineOptions getSerializableOptions();
+
+    SparkSession getSparkSession();
   }
 
-  /** Traverses the Pipeline to determine the {@link TranslationMode} for this pipeline. */
-  private static class TranslationModeDetector extends Pipeline.PipelineVisitor.Defaults {
-    private static final Logger LOG = LoggerFactory.getLogger(TranslationModeDetector.class);
+  /**
+   * {@link PTransformVisitor} that translates supported {@link PTransform PTransforms} into their
+   * Spark correspondence.
+   *
+   * <p>Note, in some cases this involves the early evaluation of some parts of the pipeline. For
+   * example, in order to use a side-input {@link org.apache.beam.sdk.values.PCollectionView
+   * PCollectionView} in a translation the corresponding Spark {@link
+   * org.apache.beam.runners.spark.translation.Dataset Dataset} might have to be collected and
+   * broadcasted.
+   */
+  private class TranslatingVisitor extends PTransformVisitor implements TranslationState {
+    private final Map<PCollection<?>, TranslationResult<?>> translationResults;
+    private final Map<Coder<?>, ExpressionEncoder<?>> encoders;
+    private final SparkSession sparkSession;
+    private final SerializablePipelineOptions serializableOptions;
+    private final StorageLevel storageLevel;
+
+    private final Set<TranslationResult<?>> leaves;
+
+    public TranslatingVisitor(
+        SparkSession sparkSession,
+        SparkCommonPipelineOptions options,
+        Map<PCollection<?>, TranslationResult<?>> translationResults) {
+      this.sparkSession = sparkSession;
+      this.translationResults = translationResults;
+      this.serializableOptions = new SerializablePipelineOptions(options);
+      this.storageLevel = StorageLevel.fromString(options.getStorageLevel());
+      this.encoders = new HashMap<>();
+      this.leaves = new HashSet<>();
+    }
 
-    private TranslationMode translationMode;
+    @Override
+    <InT extends PInput, OutT extends POutput> void visit(
+        Node node,
+        PTransform<InT, OutT> transform,
+        TransformTranslator<InT, OutT, PTransform<InT, OutT>> translator) {
+
+      AppliedPTransform<InT, OutT, PTransform<InT, OutT>> appliedTransform =
+          (AppliedPTransform) node.toAppliedPTransform(getPipeline());
+      try {
+        LOG.info(
+            "Translating {}: {}",
+            node.isCompositeNode() ? "composite" : "primitive",
+            node.getFullName());
+        translator.translate(transform, appliedTransform, this);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
 
-    TranslationModeDetector(TranslationMode defaultMode) {
-      this.translationMode = defaultMode;
+    @Override
+    public <T> Encoder<T> encoderOf(Coder<T> coder, Factory<T> factory) {
+      return (Encoder<T>) encoders.computeIfAbsent(coder, (Factory) factory);
     }
 
-    TranslationModeDetector() {
-      this(TranslationMode.BATCH);
+    private <T> TranslationResult<T> getResult(PCollection<T> pCollection) {
+      return (TranslationResult<T>) checkStateNotNull(translationResults.get(pCollection));
     }
 
-    TranslationMode getTranslationMode() {
-      return translationMode;
+    @Override
+    public <T> Dataset<WindowedValue<T>> getDataset(PCollection<T> pCollection) {
+      return checkStateNotNull(getResult(pCollection).dataset);
     }
 
     @Override
-    public void visitValue(PValue value, TransformHierarchy.Node producer) {
-      if (translationMode.equals(TranslationMode.BATCH)) {
-        if (value instanceof PCollection
-            && ((PCollection) value).isBounded() == PCollection.IsBounded.UNBOUNDED) {
-          LOG.info(
-              "Found unbounded PCollection {}. Switching to streaming execution.", value.getName());
-          translationMode = TranslationMode.STREAMING;
+    public <T> void putDataset(
+        PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset, boolean noCache) {
+      TranslationResult<T> result = getResult(pCollection);
+      if (!noCache && result.dependentTransforms.size() > 1) {

Review Comment:
   you're right, it is not a `>=` my bad



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] aromanenko-dev commented on pull request #24009: [Spark dataset runner] Cache datasets if used multiple times

Posted by GitBox <gi...@apache.org>.
aromanenko-dev commented on PR #24009:
URL: https://github.com/apache/beam/pull/24009#issuecomment-1307344788

   @mosche Looks like a serious change - what are the main advantages of this?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] aromanenko-dev commented on pull request #24009: [Spark dataset runner] Cache datasets if used multiple times

Posted by GitBox <gi...@apache.org>.
aromanenko-dev commented on PR #24009:
URL: https://github.com/apache/beam/pull/24009#issuecomment-1307341587

   CC: @echauchot 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] mosche commented on pull request #24009: [Spark dataset runner] Cache datasets if used multiple times

Posted by GitBox <gi...@apache.org>.
mosche commented on PR #24009:
URL: https://github.com/apache/beam/pull/24009#issuecomment-1305916241

   R: @aromanenko-dev 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] mosche commented on pull request #24009: [Spark dataset runner] Cache datasets if used multiple times

Posted by GitBox <gi...@apache.org>.
mosche commented on PR #24009:
URL: https://github.com/apache/beam/pull/24009#issuecomment-1307598939

   > @mosche Looks like a serious change - what are the main advantages of this?
   
   I've updated the description of the PR, let me know if this isn't clear enough.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] echauchot commented on pull request #24009: [Spark dataset runner] Cache datasets if used multiple times

Posted by GitBox <gi...@apache.org>.
echauchot commented on PR #24009:
URL: https://github.com/apache/beam/pull/24009#issuecomment-1316807748

   > @mosche @aromanenko-dev as promised, I'm reviewing
   
   Almost done, should finish today.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] github-actions[bot] commented on pull request #24009: [Spark dataset runner] Cache datasets if used multiple times

Posted by GitBox <gi...@apache.org>.
github-actions[bot] commented on PR #24009:
URL: https://github.com/apache/beam/pull/24009#issuecomment-1305918524

   Stopping reviewer notifications for this pull request: review requested by someone other than the bot, ceding control


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] github-actions[bot] commented on pull request #24009: [Spark dataset runner] Cache datasets if used multiple times

Posted by GitBox <gi...@apache.org>.
github-actions[bot] commented on PR #24009:
URL: https://github.com/apache/beam/pull/24009#issuecomment-1305792581

   Checks are failing. Will not request review until checks are succeeding. If you'd like to override that behavior, comment `assign set of reviewers`


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] mosche commented on a diff in pull request #24009: [Spark dataset runner] Cache datasets if used multiple times

Posted by GitBox <gi...@apache.org>.
mosche commented on code in PR #24009:
URL: https://github.com/apache/beam/pull/24009#discussion_r1025008501


##########
runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderProvider.java:
##########
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.spark.structuredstreaming.translation.helpers;
+
+import static org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderHelpers.kvEncoder;
+
+import java.util.function.Function;
+import org.apache.beam.sdk.annotations.Internal;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.values.KV;
+import org.apache.spark.sql.Encoder;
+
+@Internal
+public interface EncoderProvider {
+  interface Factory<T> extends Function<Coder<T>, Encoder<T>> {
+    Factory<?> INSTANCE = EncoderHelpers::encoderFor;
+  }
+
+  <T> Encoder<T> encoderOf(Coder<T> coder, Factory<T> factory);
+
+  default <T> Encoder<T> encoderOf(Coder<T> coder) {
+    return coder instanceof KvCoder
+        ? (Encoder<T>) kvEncoderOf((KvCoder) coder)
+        : encoderOf(coder, encoderFactory());
+  }
+
+  default <K, V> Encoder<KV<K, V>> kvEncoderOf(KvCoder<K, V> coder) {
+    return encoderOf(coder, c -> kvEncoder(keyEncoderOf(coder), valueEncoderOf(coder)));
+  }
+
+  default <K, V> Encoder<K> keyEncoderOf(KvCoder<K, V> coder) {
+    return encoderOf(coder.getKeyCoder(), encoderFactory());
+  }
+
+  default <K, V> Encoder<V> valueEncoderOf(KvCoder<K, V> coder) {
+    return encoderOf(coder.getValueCoder(), encoderFactory());
+  }
+
+  default <T> Factory<T> encoderFactory() {
+    return (Factory<T>) Factory.INSTANCE;

Review Comment:
   This avoids having to repeatedly create instances of the default factory. `INSTANCE` certainly belongs into `Factory`, that's a very common pattern. I'll keep this as is.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] mosche commented on a diff in pull request #24009: [Spark dataset runner] Cache datasets if used multiple times

Posted by GitBox <gi...@apache.org>.
mosche commented on code in PR #24009:
URL: https://github.com/apache/beam/pull/24009#discussion_r1025028123


##########
runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/PipelineTranslator.java:
##########
@@ -17,170 +17,336 @@
  */
 package org.apache.beam.runners.spark.structuredstreaming.translation;
 
+import static org.apache.beam.sdk.Pipeline.PipelineVisitor.CompositeBehavior.DO_NOT_ENTER_TRANSFORM;
+import static org.apache.beam.sdk.Pipeline.PipelineVisitor.CompositeBehavior.ENTER_TRANSFORM;
+import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull;
+import static org.apache.beam.sdk.values.PCollection.IsBounded.UNBOUNDED;
+
 import java.io.IOException;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+import javax.annotation.Nullable;
 import org.apache.beam.runners.core.construction.PTransformTranslation;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
+import org.apache.beam.runners.spark.SparkCommonPipelineOptions;
+import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderProvider;
 import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.Pipeline.PipelineVisitor;
+import org.apache.beam.sdk.annotations.Internal;
+import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.options.StreamingOptions;
 import org.apache.beam.sdk.runners.AppliedPTransform;
-import org.apache.beam.sdk.runners.TransformHierarchy;
+import org.apache.beam.sdk.runners.TransformHierarchy.Node;
 import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.View;
+import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PInput;
 import org.apache.beam.sdk.values.POutput;
 import org.apache.beam.sdk.values.PValue;
-import org.checkerframework.checker.nullness.qual.Nullable;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;
+import org.apache.spark.storage.StorageLevel;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 /**
- * {@link Pipeline.PipelineVisitor} that translates the Beam operators to their Spark counterparts.
- * It also does the pipeline preparation: mode detection, transforms replacement, classpath
- * preparation.
+ * The pipeline translator translates a Beam {@link Pipeline} into a Spark correspondence, that can
+ * then be evaluated.
+ *
+ * <p>The translation involves traversing the hierarchy of a pipeline multiple times:
+ *
+ * <ol>
+ *   <li>Detect if {@link StreamingOptions#setStreaming streaming} mode is required.
+ *   <li>Identify datasets that are repeatedly used as input and should be cached.
+ *   <li>And finally, translate each primitive or composite {@link PTransform} that is {@link
+ *       #getTransformTranslator known} and {@link TransformTranslator#canTranslate supported} into
+ *       its Spark correspondence. If a composite is not supported, it will be expanded further into
+ *       its parts and translated then.
+ * </ol>
  */
-@SuppressWarnings({
-  "nullness" // TODO(https://github.com/apache/beam/issues/20497)
-})
-public abstract class PipelineTranslator extends Pipeline.PipelineVisitor.Defaults {
+@Internal
+public abstract class PipelineTranslator {
   private static final Logger LOG = LoggerFactory.getLogger(PipelineTranslator.class);
-  protected TranslationContext translationContext;
 
-  // --------------------------------------------------------------------------------------------
-  //  Pipeline preparation methods
-  // --------------------------------------------------------------------------------------------
   public static void replaceTransforms(Pipeline pipeline, StreamingOptions options) {
     pipeline.replaceAll(SparkTransformOverrides.getDefaultOverrides(options.isStreaming()));
   }
 
   /**
-   * Visit the pipeline to determine the translation mode (batch/streaming) and update options
-   * accordingly.
+   * Analyse the pipeline to determine if we have to switch to streaming mode for the pipeline
+   * translation and update {@link StreamingOptions} accordingly.
    */
-  public static void detectTranslationMode(Pipeline pipeline, StreamingOptions options) {
-    TranslationModeDetector detector = new TranslationModeDetector();
+  public static void detectStreamingMode(Pipeline pipeline, StreamingOptions options) {
+    StreamingModeDetector detector = new StreamingModeDetector(options.isStreaming());
     pipeline.traverseTopologically(detector);
-    if (detector.getTranslationMode().equals(TranslationMode.STREAMING)) {
-      options.setStreaming(true);
+    options.setStreaming(detector.streaming);
+  }
+
+  /** Returns a {@link TransformTranslator} for the given {@link PTransform} if known. */
+  protected abstract @Nullable <
+          InT extends PInput, OutT extends POutput, TransformT extends PTransform<InT, OutT>>
+      TransformTranslator<InT, OutT, TransformT> getTransformTranslator(TransformT transform);
+
+  /**
+   * Translates a Beam pipeline into its Spark correspondence using the Spark SQL / Dataset API.
+   *
+   * <p>Note, in some cases this involves the early evaluation of some parts of the pipeline. For
+   * example, in order to use a side-input {@link org.apache.beam.sdk.values.PCollectionView
+   * PCollectionView} in a translation the corresponding Spark {@link
+   * org.apache.beam.runners.spark.translation.Dataset Dataset} might have to be collected and
+   * broadcasted to be able to continue with the translation.
+   *
+   * @return The result of the translation is an {@link EvaluationContext} that can trigger the
+   *     evaluation of the Spark pipeline.
+   */
+  public EvaluationContext translate(
+      Pipeline pipeline, SparkSession session, SparkCommonPipelineOptions options) {
+    LOG.debug("starting translation of the pipeline using {}", getClass().getName());
+    DependencyVisitor dependencies = new DependencyVisitor();
+    pipeline.traverseTopologically(dependencies);
+
+    TranslatingVisitor translator = new TranslatingVisitor(session, options, dependencies.results);
+    pipeline.traverseTopologically(translator);
+
+    return new EvaluationContext(translator.leaves, session);
+  }
+
+  /**
+   * The correspondence of a {@link PCollection} as result of translating a {@link PTransform}
+   * including additional metadata (such as name and dependents).
+   */
+  private static final class TranslationResult<T> implements EvaluationContext.NamedDataset<T> {
+    private final String name;
+    private @Nullable Dataset<WindowedValue<T>> dataset = null;
+    private final Set<PTransform<?, ?>> dependentTransforms = new HashSet<>();
+
+    private TranslationResult(PCollection<?> pCol) {
+      this.name = pCol.getName();
+    }
+
+    @Override
+    public String name() {
+      return name;
+    }
+
+    @Override
+    public @Nullable Dataset<WindowedValue<T>> dataset() {
+      return dataset;
     }
   }
 
-  /** The translation mode of the Beam Pipeline. */
-  private enum TranslationMode {
+  /** Shared, mutable state during the translation of a pipeline and omitted afterwards. */
+  interface TranslationState extends EncoderProvider {
+    <T> Dataset<WindowedValue<T>> getDataset(PCollection<T> pCollection);
+
+    <T> void putDataset(
+        PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset, boolean noCache);
 
-    /** Uses the batch mode. */
-    BATCH,
+    default <T> void putDataset(PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset) {
+      putDataset(pCollection, dataset, false);
+    }
 
-    /** Uses the streaming mode. */
-    STREAMING
+    SerializablePipelineOptions getSerializableOptions();
+
+    SparkSession getSparkSession();
   }
 
-  /** Traverses the Pipeline to determine the {@link TranslationMode} for this pipeline. */
-  private static class TranslationModeDetector extends Pipeline.PipelineVisitor.Defaults {
-    private static final Logger LOG = LoggerFactory.getLogger(TranslationModeDetector.class);
+  /**
+   * {@link PTransformVisitor} that translates supported {@link PTransform PTransforms} into their
+   * Spark correspondence.
+   *
+   * <p>Note, in some cases this involves the early evaluation of some parts of the pipeline. For
+   * example, in order to use a side-input {@link org.apache.beam.sdk.values.PCollectionView
+   * PCollectionView} in a translation the corresponding Spark {@link
+   * org.apache.beam.runners.spark.translation.Dataset Dataset} might have to be collected and
+   * broadcasted.
+   */
+  private class TranslatingVisitor extends PTransformVisitor implements TranslationState {
+    private final Map<PCollection<?>, TranslationResult<?>> translationResults;
+    private final Map<Coder<?>, ExpressionEncoder<?>> encoders;
+    private final SparkSession sparkSession;
+    private final SerializablePipelineOptions serializableOptions;
+    private final StorageLevel storageLevel;
+
+    private final Set<TranslationResult<?>> leaves;
+
+    public TranslatingVisitor(
+        SparkSession sparkSession,
+        SparkCommonPipelineOptions options,
+        Map<PCollection<?>, TranslationResult<?>> translationResults) {
+      this.sparkSession = sparkSession;
+      this.translationResults = translationResults;
+      this.serializableOptions = new SerializablePipelineOptions(options);
+      this.storageLevel = StorageLevel.fromString(options.getStorageLevel());
+      this.encoders = new HashMap<>();
+      this.leaves = new HashSet<>();
+    }
 
-    private TranslationMode translationMode;
+    @Override
+    <InT extends PInput, OutT extends POutput> void visit(
+        Node node,
+        PTransform<InT, OutT> transform,
+        TransformTranslator<InT, OutT, PTransform<InT, OutT>> translator) {
+
+      AppliedPTransform<InT, OutT, PTransform<InT, OutT>> appliedTransform =
+          (AppliedPTransform) node.toAppliedPTransform(getPipeline());
+      try {
+        LOG.info(
+            "Translating {}: {}",
+            node.isCompositeNode() ? "composite" : "primitive",
+            node.getFullName());
+        translator.translate(transform, appliedTransform, this);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
 
-    TranslationModeDetector(TranslationMode defaultMode) {
-      this.translationMode = defaultMode;
+    @Override
+    public <T> Encoder<T> encoderOf(Coder<T> coder, Factory<T> factory) {
+      return (Encoder<T>) encoders.computeIfAbsent(coder, (Factory) factory);
     }
 
-    TranslationModeDetector() {
-      this(TranslationMode.BATCH);
+    private <T> TranslationResult<T> getResult(PCollection<T> pCollection) {
+      return (TranslationResult<T>) checkStateNotNull(translationResults.get(pCollection));
     }
 
-    TranslationMode getTranslationMode() {
-      return translationMode;
+    @Override
+    public <T> Dataset<WindowedValue<T>> getDataset(PCollection<T> pCollection) {
+      return checkStateNotNull(getResult(pCollection).dataset);
     }
 
     @Override
-    public void visitValue(PValue value, TransformHierarchy.Node producer) {
-      if (translationMode.equals(TranslationMode.BATCH)) {
-        if (value instanceof PCollection
-            && ((PCollection) value).isBounded() == PCollection.IsBounded.UNBOUNDED) {
-          LOG.info(
-              "Found unbounded PCollection {}. Switching to streaming execution.", value.getName());
-          translationMode = TranslationMode.STREAMING;
+    public <T> void putDataset(
+        PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset, boolean noCache) {
+      TranslationResult<T> result = getResult(pCollection);
+      if (!noCache && result.dependentTransforms.size() > 1) {
+        LOG.info("Dataset {} will be cached.", result.name);
+        result.dataset = dataset.persist(storageLevel); // use NONE to disable
+      } else {
+        result.dataset = dataset;
+        if (result.dependentTransforms.isEmpty()) {
+          leaves.add(result);
         }
       }
     }
-  }
 
-  // --------------------------------------------------------------------------------------------
-  //  Pipeline utility methods
-  // --------------------------------------------------------------------------------------------
+    @Override
+    public SerializablePipelineOptions getSerializableOptions() {
+      return serializableOptions;
+    }
 
-  /** Get a {@link TransformTranslator} for the given {@link TransformHierarchy.Node}. */
-  protected abstract @Nullable <
-          InT extends PInput, OutT extends POutput, TransformT extends PTransform<InT, OutT>>
-      TransformTranslator<InT, OutT, TransformT> getTransformTranslator(
-          @Nullable TransformT transform);
-
-  /** Apply the given TransformTranslator to the given node. */
-  private <InT extends PInput, OutT extends POutput, TransformT extends PTransform<InT, OutT>>
-      void applyTransformTranslator(
-          TransformHierarchy.Node node,
-          TransformT transform,
-          TransformTranslator<InT, OutT, TransformT> transformTranslator) {
-    // create the applied PTransform on the translationContext
-    AppliedPTransform<InT, OutT, PTransform<InT, OutT>> appliedTransform =
-        (AppliedPTransform) node.toAppliedPTransform(getPipeline());
-    try {
-      transformTranslator.translate(transform, appliedTransform, translationContext);
-    } catch (IOException e) {
-      throw new RuntimeException(e);
+    @Override
+    public SparkSession getSparkSession() {
+      return sparkSession;
     }
   }
 
-  // --------------------------------------------------------------------------------------------
-  //  Pipeline visitor entry point
-  // --------------------------------------------------------------------------------------------
-
   /**
-   * Translates the pipeline by passing this class as a visitor.
+   * {@link PTransformVisitor} that analyses dependencies of supported {@link PTransform
+   * PTransforms} to help identify cache candidates.
    *
-   * @param pipeline The pipeline to be translated
+   * <p>The visitor may throw if a {@link PTransform} is observed that uses unsupported features.
    */
-  public void translate(Pipeline pipeline) {
-    LOG.debug("starting translation of the pipeline using {}", getClass().getName());
-    pipeline.traverseTopologically(this);
+  private class DependencyVisitor extends PTransformVisitor {
+    private final Map<PCollection<?>, TranslationResult<?>> results = new HashMap<>();
+
+    @Override
+    <InT extends PInput, OutT extends POutput> void visit(
+        Node node,
+        PTransform<InT, OutT> transform,
+        TransformTranslator<InT, OutT, PTransform<InT, OutT>> translator) {
+      for (PCollection<?> pOut : node.getOutputs().values()) {
+        results.put(pOut, new TranslationResult<>(pOut));
+        for (Map.Entry<TupleTag<?>, PCollection<?>> entry : node.getInputs().entrySet()) {
+          TranslationResult<?> input = checkStateNotNull(results.get(entry.getValue()));
+          input.dependentTransforms.add(transform);
+        }
+      }
+    }

Review Comment:
   you're right, this isn't necessary anymore, thx for catching 👍  in an earlier version i was tracking downstream dependencies on the pCollection level rather than the transform level, which isn't actually necessary.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] mosche commented on a diff in pull request #24009: [Spark dataset runner] Cache datasets if used multiple times

Posted by GitBox <gi...@apache.org>.
mosche commented on code in PR #24009:
URL: https://github.com/apache/beam/pull/24009#discussion_r1025000237


##########
runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/PipelineTranslator.java:
##########
@@ -17,170 +17,336 @@
  */
 package org.apache.beam.runners.spark.structuredstreaming.translation;
 
+import static org.apache.beam.sdk.Pipeline.PipelineVisitor.CompositeBehavior.DO_NOT_ENTER_TRANSFORM;
+import static org.apache.beam.sdk.Pipeline.PipelineVisitor.CompositeBehavior.ENTER_TRANSFORM;
+import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull;
+import static org.apache.beam.sdk.values.PCollection.IsBounded.UNBOUNDED;
+
 import java.io.IOException;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+import javax.annotation.Nullable;
 import org.apache.beam.runners.core.construction.PTransformTranslation;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
+import org.apache.beam.runners.spark.SparkCommonPipelineOptions;
+import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderProvider;
 import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.Pipeline.PipelineVisitor;
+import org.apache.beam.sdk.annotations.Internal;
+import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.options.StreamingOptions;
 import org.apache.beam.sdk.runners.AppliedPTransform;
-import org.apache.beam.sdk.runners.TransformHierarchy;
+import org.apache.beam.sdk.runners.TransformHierarchy.Node;
 import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.View;
+import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PInput;
 import org.apache.beam.sdk.values.POutput;
 import org.apache.beam.sdk.values.PValue;
-import org.checkerframework.checker.nullness.qual.Nullable;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;
+import org.apache.spark.storage.StorageLevel;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 /**
- * {@link Pipeline.PipelineVisitor} that translates the Beam operators to their Spark counterparts.
- * It also does the pipeline preparation: mode detection, transforms replacement, classpath
- * preparation.
+ * The pipeline translator translates a Beam {@link Pipeline} into a Spark correspondence, that can
+ * then be evaluated.
+ *
+ * <p>The translation involves traversing the hierarchy of a pipeline multiple times:
+ *
+ * <ol>
+ *   <li>Detect if {@link StreamingOptions#setStreaming streaming} mode is required.
+ *   <li>Identify datasets that are repeatedly used as input and should be cached.
+ *   <li>And finally, translate each primitive or composite {@link PTransform} that is {@link
+ *       #getTransformTranslator known} and {@link TransformTranslator#canTranslate supported} into
+ *       its Spark correspondence. If a composite is not supported, it will be expanded further into
+ *       its parts and translated then.
+ * </ol>
  */
-@SuppressWarnings({
-  "nullness" // TODO(https://github.com/apache/beam/issues/20497)
-})
-public abstract class PipelineTranslator extends Pipeline.PipelineVisitor.Defaults {
+@Internal
+public abstract class PipelineTranslator {
   private static final Logger LOG = LoggerFactory.getLogger(PipelineTranslator.class);
-  protected TranslationContext translationContext;
 
-  // --------------------------------------------------------------------------------------------
-  //  Pipeline preparation methods
-  // --------------------------------------------------------------------------------------------
   public static void replaceTransforms(Pipeline pipeline, StreamingOptions options) {
     pipeline.replaceAll(SparkTransformOverrides.getDefaultOverrides(options.isStreaming()));
   }
 
   /**
-   * Visit the pipeline to determine the translation mode (batch/streaming) and update options
-   * accordingly.
+   * Analyse the pipeline to determine if we have to switch to streaming mode for the pipeline
+   * translation and update {@link StreamingOptions} accordingly.
    */
-  public static void detectTranslationMode(Pipeline pipeline, StreamingOptions options) {
-    TranslationModeDetector detector = new TranslationModeDetector();
+  public static void detectStreamingMode(Pipeline pipeline, StreamingOptions options) {
+    StreamingModeDetector detector = new StreamingModeDetector(options.isStreaming());
     pipeline.traverseTopologically(detector);
-    if (detector.getTranslationMode().equals(TranslationMode.STREAMING)) {
-      options.setStreaming(true);
+    options.setStreaming(detector.streaming);
+  }
+
+  /** Returns a {@link TransformTranslator} for the given {@link PTransform} if known. */
+  protected abstract @Nullable <
+          InT extends PInput, OutT extends POutput, TransformT extends PTransform<InT, OutT>>
+      TransformTranslator<InT, OutT, TransformT> getTransformTranslator(TransformT transform);
+
+  /**
+   * Translates a Beam pipeline into its Spark correspondence using the Spark SQL / Dataset API.
+   *
+   * <p>Note, in some cases this involves the early evaluation of some parts of the pipeline. For
+   * example, in order to use a side-input {@link org.apache.beam.sdk.values.PCollectionView
+   * PCollectionView} in a translation the corresponding Spark {@link
+   * org.apache.beam.runners.spark.translation.Dataset Dataset} might have to be collected and
+   * broadcasted to be able to continue with the translation.
+   *
+   * @return The result of the translation is an {@link EvaluationContext} that can trigger the
+   *     evaluation of the Spark pipeline.
+   */
+  public EvaluationContext translate(
+      Pipeline pipeline, SparkSession session, SparkCommonPipelineOptions options) {
+    LOG.debug("starting translation of the pipeline using {}", getClass().getName());
+    DependencyVisitor dependencies = new DependencyVisitor();
+    pipeline.traverseTopologically(dependencies);
+
+    TranslatingVisitor translator = new TranslatingVisitor(session, options, dependencies.results);
+    pipeline.traverseTopologically(translator);
+
+    return new EvaluationContext(translator.leaves, session);
+  }
+
+  /**
+   * The correspondence of a {@link PCollection} as result of translating a {@link PTransform}
+   * including additional metadata (such as name and dependents).
+   */
+  private static final class TranslationResult<T> implements EvaluationContext.NamedDataset<T> {
+    private final String name;
+    private @Nullable Dataset<WindowedValue<T>> dataset = null;
+    private final Set<PTransform<?, ?>> dependentTransforms = new HashSet<>();
+
+    private TranslationResult(PCollection<?> pCol) {
+      this.name = pCol.getName();
+    }
+
+    @Override
+    public String name() {
+      return name;
+    }
+
+    @Override
+    public @Nullable Dataset<WindowedValue<T>> dataset() {
+      return dataset;
     }
   }
 
-  /** The translation mode of the Beam Pipeline. */
-  private enum TranslationMode {
+  /** Shared, mutable state during the translation of a pipeline and omitted afterwards. */
+  interface TranslationState extends EncoderProvider {
+    <T> Dataset<WindowedValue<T>> getDataset(PCollection<T> pCollection);
+
+    <T> void putDataset(
+        PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset, boolean noCache);
 
-    /** Uses the batch mode. */
-    BATCH,
+    default <T> void putDataset(PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset) {
+      putDataset(pCollection, dataset, false);
+    }
 
-    /** Uses the streaming mode. */
-    STREAMING
+    SerializablePipelineOptions getSerializableOptions();
+
+    SparkSession getSparkSession();
   }
 
-  /** Traverses the Pipeline to determine the {@link TranslationMode} for this pipeline. */
-  private static class TranslationModeDetector extends Pipeline.PipelineVisitor.Defaults {
-    private static final Logger LOG = LoggerFactory.getLogger(TranslationModeDetector.class);
+  /**
+   * {@link PTransformVisitor} that translates supported {@link PTransform PTransforms} into their
+   * Spark correspondence.
+   *
+   * <p>Note, in some cases this involves the early evaluation of some parts of the pipeline. For
+   * example, in order to use a side-input {@link org.apache.beam.sdk.values.PCollectionView
+   * PCollectionView} in a translation the corresponding Spark {@link
+   * org.apache.beam.runners.spark.translation.Dataset Dataset} might have to be collected and
+   * broadcasted.
+   */
+  private class TranslatingVisitor extends PTransformVisitor implements TranslationState {
+    private final Map<PCollection<?>, TranslationResult<?>> translationResults;
+    private final Map<Coder<?>, ExpressionEncoder<?>> encoders;
+    private final SparkSession sparkSession;
+    private final SerializablePipelineOptions serializableOptions;
+    private final StorageLevel storageLevel;
+
+    private final Set<TranslationResult<?>> leaves;
+
+    public TranslatingVisitor(
+        SparkSession sparkSession,
+        SparkCommonPipelineOptions options,
+        Map<PCollection<?>, TranslationResult<?>> translationResults) {
+      this.sparkSession = sparkSession;
+      this.translationResults = translationResults;
+      this.serializableOptions = new SerializablePipelineOptions(options);
+      this.storageLevel = StorageLevel.fromString(options.getStorageLevel());
+      this.encoders = new HashMap<>();
+      this.leaves = new HashSet<>();
+    }
 
-    private TranslationMode translationMode;
+    @Override
+    <InT extends PInput, OutT extends POutput> void visit(
+        Node node,
+        PTransform<InT, OutT> transform,
+        TransformTranslator<InT, OutT, PTransform<InT, OutT>> translator) {
+
+      AppliedPTransform<InT, OutT, PTransform<InT, OutT>> appliedTransform =
+          (AppliedPTransform) node.toAppliedPTransform(getPipeline());
+      try {
+        LOG.info(
+            "Translating {}: {}",
+            node.isCompositeNode() ? "composite" : "primitive",
+            node.getFullName());
+        translator.translate(transform, appliedTransform, this);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
 
-    TranslationModeDetector(TranslationMode defaultMode) {
-      this.translationMode = defaultMode;
+    @Override
+    public <T> Encoder<T> encoderOf(Coder<T> coder, Factory<T> factory) {
+      return (Encoder<T>) encoders.computeIfAbsent(coder, (Factory) factory);
     }
 
-    TranslationModeDetector() {
-      this(TranslationMode.BATCH);
+    private <T> TranslationResult<T> getResult(PCollection<T> pCollection) {
+      return (TranslationResult<T>) checkStateNotNull(translationResults.get(pCollection));
     }
 
-    TranslationMode getTranslationMode() {
-      return translationMode;
+    @Override
+    public <T> Dataset<WindowedValue<T>> getDataset(PCollection<T> pCollection) {
+      return checkStateNotNull(getResult(pCollection).dataset);
     }
 
     @Override
-    public void visitValue(PValue value, TransformHierarchy.Node producer) {
-      if (translationMode.equals(TranslationMode.BATCH)) {
-        if (value instanceof PCollection
-            && ((PCollection) value).isBounded() == PCollection.IsBounded.UNBOUNDED) {
-          LOG.info(
-              "Found unbounded PCollection {}. Switching to streaming execution.", value.getName());
-          translationMode = TranslationMode.STREAMING;
+    public <T> void putDataset(
+        PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset, boolean noCache) {
+      TranslationResult<T> result = getResult(pCollection);
+      if (!noCache && result.dependentTransforms.size() > 1) {
+        LOG.info("Dataset {} will be cached.", result.name);
+        result.dataset = dataset.persist(storageLevel); // use NONE to disable
+      } else {
+        result.dataset = dataset;
+        if (result.dependentTransforms.isEmpty()) {
+          leaves.add(result);
         }
       }
     }
-  }
 
-  // --------------------------------------------------------------------------------------------
-  //  Pipeline utility methods
-  // --------------------------------------------------------------------------------------------
+    @Override
+    public SerializablePipelineOptions getSerializableOptions() {
+      return serializableOptions;
+    }
 
-  /** Get a {@link TransformTranslator} for the given {@link TransformHierarchy.Node}. */
-  protected abstract @Nullable <
-          InT extends PInput, OutT extends POutput, TransformT extends PTransform<InT, OutT>>
-      TransformTranslator<InT, OutT, TransformT> getTransformTranslator(
-          @Nullable TransformT transform);
-
-  /** Apply the given TransformTranslator to the given node. */
-  private <InT extends PInput, OutT extends POutput, TransformT extends PTransform<InT, OutT>>
-      void applyTransformTranslator(
-          TransformHierarchy.Node node,
-          TransformT transform,
-          TransformTranslator<InT, OutT, TransformT> transformTranslator) {
-    // create the applied PTransform on the translationContext
-    AppliedPTransform<InT, OutT, PTransform<InT, OutT>> appliedTransform =
-        (AppliedPTransform) node.toAppliedPTransform(getPipeline());
-    try {
-      transformTranslator.translate(transform, appliedTransform, translationContext);
-    } catch (IOException e) {
-      throw new RuntimeException(e);
+    @Override
+    public SparkSession getSparkSession() {
+      return sparkSession;
     }
   }
 
-  // --------------------------------------------------------------------------------------------
-  //  Pipeline visitor entry point
-  // --------------------------------------------------------------------------------------------
-
   /**
-   * Translates the pipeline by passing this class as a visitor.
+   * {@link PTransformVisitor} that analyses dependencies of supported {@link PTransform
+   * PTransforms} to help identify cache candidates.
    *
-   * @param pipeline The pipeline to be translated
+   * <p>The visitor may throw if a {@link PTransform} is observed that uses unsupported features.
    */
-  public void translate(Pipeline pipeline) {
-    LOG.debug("starting translation of the pipeline using {}", getClass().getName());
-    pipeline.traverseTopologically(this);
+  private class DependencyVisitor extends PTransformVisitor {
+    private final Map<PCollection<?>, TranslationResult<?>> results = new HashMap<>();
+
+    @Override
+    <InT extends PInput, OutT extends POutput> void visit(
+        Node node,
+        PTransform<InT, OutT> transform,
+        TransformTranslator<InT, OutT, PTransform<InT, OutT>> translator) {
+      for (PCollection<?> pOut : node.getOutputs().values()) {
+        results.put(pOut, new TranslationResult<>(pOut));
+        for (Map.Entry<TupleTag<?>, PCollection<?>> entry : node.getInputs().entrySet()) {
+          TranslationResult<?> input = checkStateNotNull(results.get(entry.getValue()));
+          input.dependentTransforms.add(transform);
+        }
+      }
+    }
   }
 
-  // --------------------------------------------------------------------------------------------
-  //  Pipeline Visitor Methods
-  // --------------------------------------------------------------------------------------------
+  /**
+   * An abstract {@link PipelineVisitor} that visits all translatable {@link PTransform} pipeline
+   * nodes of a pipeline with the respective {@link TransformTranslator}.
+   *
+   * <p>The visitor may throw if a {@link PTransform} is observed that uses unsupported features.
+   */
+  private abstract class PTransformVisitor extends PipelineVisitor.Defaults {
 
-  @Override
-  public CompositeBehavior enterCompositeTransform(TransformHierarchy.Node node) {
-    PTransform<PInput, POutput> transform = (PTransform<PInput, POutput>) node.getTransform();
-    TransformTranslator<PInput, POutput, PTransform<PInput, POutput>> transformTranslator =
-        getTransformTranslator(transform);
+    /** Visit the {@link PTransform} with its respective {@link TransformTranslator}. */
+    abstract <InT extends PInput, OutT extends POutput> void visit(
+        Node node,
+        PTransform<InT, OutT> transform,
+        TransformTranslator<InT, OutT, PTransform<InT, OutT>> translator);
 
-    if (transformTranslator != null) {
-      LOG.info("Translating composite: {}", node.getFullName());
-      applyTransformTranslator(node, transform, transformTranslator);
-      return CompositeBehavior.DO_NOT_ENTER_TRANSFORM;
-    } else {
-      return CompositeBehavior.ENTER_TRANSFORM;
+    @Override
+    public final CompositeBehavior enterCompositeTransform(Node node) {
+      PTransform<PInput, POutput> transform = (PTransform<PInput, POutput>) node.getTransform();
+      TransformTranslator<PInput, POutput, PTransform<PInput, POutput>> translator =
+          getTranslator(transform);
+      if (transform != null && translator != null) {
+        visit(node, transform, translator);
+        return DO_NOT_ENTER_TRANSFORM;
+      } else {
+        return ENTER_TRANSFORM;
+      }
+    }
+
+    @Override
+    public final void visitPrimitiveTransform(Node node) {
+      PTransform<PInput, POutput> transform = (PTransform<PInput, POutput>) node.getTransform();
+      if (transform == null || transform.getClass().equals(View.CreatePCollectionView.class)) {
+        return; // ignore, nothing to be translated here
+      }
+      TransformTranslator<PInput, POutput, PTransform<PInput, POutput>> translator =
+          getTranslator(transform);
+      if (translator == null) {
+        String urn = PTransformTranslation.urnForTransform(transform);
+        throw new UnsupportedOperationException("Transform " + urn + " is not supported.");
+      }
+      visit(node, transform, translator);
     }
-  }
 
-  @Override
-  public void visitPrimitiveTransform(TransformHierarchy.Node node) {
-    LOG.info("Translating primitive: {}", node.getFullName());
-    // get the transformation corresponding to the node we are
-    // currently visiting and translate it into its Spark alternative.
-    PTransform<PInput, POutput> transform = (PTransform<PInput, POutput>) node.getTransform();
-    TransformTranslator<PInput, POutput, PTransform<PInput, POutput>> transformTranslator =
-        getTransformTranslator(transform);
-
-    if (transformTranslator == null) {
-      String transformUrn = PTransformTranslation.urnForTransform(node.getTransform());
-      throw new UnsupportedOperationException(
-          "The transform " + transformUrn + " is currently not supported.");
-    }
-    applyTransformTranslator(node, transform, transformTranslator);
+    /** {@link TransformTranslator} for {@link PTransform} if translation is known and supported. */
+    private @Nullable TransformTranslator<PInput, POutput, PTransform<PInput, POutput>>

Review Comment:
   renamed to `getSupportedTranslator`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] mosche commented on a diff in pull request #24009: [Spark dataset runner] Cache datasets if used multiple times

Posted by GitBox <gi...@apache.org>.
mosche commented on code in PR #24009:
URL: https://github.com/apache/beam/pull/24009#discussion_r1026213378


##########
runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java:
##########
@@ -220,8 +225,9 @@ private SideInputBroadcast createBroadcastSideInputs(
       Coder<WindowedValue<?>> windowedValueCoder =
           (Coder<WindowedValue<?>>)
               (Coder<?>) WindowedValue.getFullCoder(pc.getCoder(), windowCoder);
-      Dataset<WindowedValue<?>> broadcastSet = context.getSideInputDataset(sideInput);
-      List<WindowedValue<?>> valuesList = broadcastSet.collectAsList();
+      Dataset<WindowedValue<?>> broadcastSet = context.getDataset((PCollection) pc);
+      List<WindowedValue<?>> valuesList =

Review Comment:
   This is primarily to log the evaluation / collect of the dataset. This can be surprising as it's happening a head of time during "translation time".



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] echauchot commented on a diff in pull request #24009: [Spark dataset runner] Cache datasets if used multiple times

Posted by GitBox <gi...@apache.org>.
echauchot commented on code in PR #24009:
URL: https://github.com/apache/beam/pull/24009#discussion_r1026474626


##########
runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/PipelineTranslatorBatch.java:
##########
@@ -81,27 +80,13 @@ public class PipelineTranslatorBatch extends PipelineTranslator {
 
     TRANSFORM_TRANSLATORS.put(
         SplittableParDo.PrimitiveBoundedRead.class, new ReadSourceTranslatorBatch<>());
-

Review Comment:
   :+1: 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] echauchot commented on a diff in pull request #24009: [Spark dataset runner] Cache datasets if used multiple times

Posted by GitBox <gi...@apache.org>.
echauchot commented on code in PR #24009:
URL: https://github.com/apache/beam/pull/24009#discussion_r1026480633


##########
runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/PipelineTranslator.java:
##########
@@ -17,170 +17,336 @@
  */
 package org.apache.beam.runners.spark.structuredstreaming.translation;
 
+import static org.apache.beam.sdk.Pipeline.PipelineVisitor.CompositeBehavior.DO_NOT_ENTER_TRANSFORM;
+import static org.apache.beam.sdk.Pipeline.PipelineVisitor.CompositeBehavior.ENTER_TRANSFORM;
+import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull;
+import static org.apache.beam.sdk.values.PCollection.IsBounded.UNBOUNDED;
+
 import java.io.IOException;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+import javax.annotation.Nullable;
 import org.apache.beam.runners.core.construction.PTransformTranslation;
+import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
+import org.apache.beam.runners.spark.SparkCommonPipelineOptions;
+import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.EncoderProvider;
 import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.Pipeline.PipelineVisitor;
+import org.apache.beam.sdk.annotations.Internal;
+import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.options.StreamingOptions;
 import org.apache.beam.sdk.runners.AppliedPTransform;
-import org.apache.beam.sdk.runners.TransformHierarchy;
+import org.apache.beam.sdk.runners.TransformHierarchy.Node;
 import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.View;
+import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PInput;
 import org.apache.beam.sdk.values.POutput;
 import org.apache.beam.sdk.values.PValue;
-import org.checkerframework.checker.nullness.qual.Nullable;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoder;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;
+import org.apache.spark.storage.StorageLevel;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 /**
- * {@link Pipeline.PipelineVisitor} that translates the Beam operators to their Spark counterparts.
- * It also does the pipeline preparation: mode detection, transforms replacement, classpath
- * preparation.
+ * The pipeline translator translates a Beam {@link Pipeline} into a Spark correspondence, that can
+ * then be evaluated.
+ *
+ * <p>The translation involves traversing the hierarchy of a pipeline multiple times:
+ *
+ * <ol>
+ *   <li>Detect if {@link StreamingOptions#setStreaming streaming} mode is required.
+ *   <li>Identify datasets that are repeatedly used as input and should be cached.
+ *   <li>And finally, translate each primitive or composite {@link PTransform} that is {@link
+ *       #getTransformTranslator known} and {@link TransformTranslator#canTranslate supported} into
+ *       its Spark correspondence. If a composite is not supported, it will be expanded further into
+ *       its parts and translated then.
+ * </ol>
  */
-@SuppressWarnings({
-  "nullness" // TODO(https://github.com/apache/beam/issues/20497)
-})
-public abstract class PipelineTranslator extends Pipeline.PipelineVisitor.Defaults {
+@Internal
+public abstract class PipelineTranslator {
   private static final Logger LOG = LoggerFactory.getLogger(PipelineTranslator.class);
-  protected TranslationContext translationContext;
 
-  // --------------------------------------------------------------------------------------------
-  //  Pipeline preparation methods
-  // --------------------------------------------------------------------------------------------
   public static void replaceTransforms(Pipeline pipeline, StreamingOptions options) {
     pipeline.replaceAll(SparkTransformOverrides.getDefaultOverrides(options.isStreaming()));
   }
 
   /**
-   * Visit the pipeline to determine the translation mode (batch/streaming) and update options
-   * accordingly.
+   * Analyse the pipeline to determine if we have to switch to streaming mode for the pipeline
+   * translation and update {@link StreamingOptions} accordingly.
    */
-  public static void detectTranslationMode(Pipeline pipeline, StreamingOptions options) {
-    TranslationModeDetector detector = new TranslationModeDetector();
+  public static void detectStreamingMode(Pipeline pipeline, StreamingOptions options) {
+    StreamingModeDetector detector = new StreamingModeDetector(options.isStreaming());
     pipeline.traverseTopologically(detector);
-    if (detector.getTranslationMode().equals(TranslationMode.STREAMING)) {
-      options.setStreaming(true);
+    options.setStreaming(detector.streaming);
+  }
+
+  /** Returns a {@link TransformTranslator} for the given {@link PTransform} if known. */
+  protected abstract @Nullable <
+          InT extends PInput, OutT extends POutput, TransformT extends PTransform<InT, OutT>>
+      TransformTranslator<InT, OutT, TransformT> getTransformTranslator(TransformT transform);
+
+  /**
+   * Translates a Beam pipeline into its Spark correspondence using the Spark SQL / Dataset API.
+   *
+   * <p>Note, in some cases this involves the early evaluation of some parts of the pipeline. For
+   * example, in order to use a side-input {@link org.apache.beam.sdk.values.PCollectionView
+   * PCollectionView} in a translation the corresponding Spark {@link
+   * org.apache.beam.runners.spark.translation.Dataset Dataset} might have to be collected and
+   * broadcasted to be able to continue with the translation.
+   *
+   * @return The result of the translation is an {@link EvaluationContext} that can trigger the
+   *     evaluation of the Spark pipeline.
+   */
+  public EvaluationContext translate(
+      Pipeline pipeline, SparkSession session, SparkCommonPipelineOptions options) {
+    LOG.debug("starting translation of the pipeline using {}", getClass().getName());
+    DependencyVisitor dependencies = new DependencyVisitor();
+    pipeline.traverseTopologically(dependencies);
+
+    TranslatingVisitor translator = new TranslatingVisitor(session, options, dependencies.results);
+    pipeline.traverseTopologically(translator);
+
+    return new EvaluationContext(translator.leaves, session);
+  }
+
+  /**
+   * The correspondence of a {@link PCollection} as result of translating a {@link PTransform}
+   * including additional metadata (such as name and dependents).
+   */
+  private static final class TranslationResult<T> implements EvaluationContext.NamedDataset<T> {
+    private final String name;
+    private @Nullable Dataset<WindowedValue<T>> dataset = null;
+    private final Set<PTransform<?, ?>> dependentTransforms = new HashSet<>();
+
+    private TranslationResult(PCollection<?> pCol) {
+      this.name = pCol.getName();
+    }
+
+    @Override
+    public String name() {
+      return name;
+    }
+
+    @Override
+    public @Nullable Dataset<WindowedValue<T>> dataset() {
+      return dataset;
     }
   }
 
-  /** The translation mode of the Beam Pipeline. */
-  private enum TranslationMode {
+  /** Shared, mutable state during the translation of a pipeline and omitted afterwards. */
+  interface TranslationState extends EncoderProvider {
+    <T> Dataset<WindowedValue<T>> getDataset(PCollection<T> pCollection);
+
+    <T> void putDataset(
+        PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset, boolean noCache);
 
-    /** Uses the batch mode. */
-    BATCH,
+    default <T> void putDataset(PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset) {
+      putDataset(pCollection, dataset, false);
+    }
 
-    /** Uses the streaming mode. */
-    STREAMING
+    SerializablePipelineOptions getSerializableOptions();
+
+    SparkSession getSparkSession();
   }
 
-  /** Traverses the Pipeline to determine the {@link TranslationMode} for this pipeline. */
-  private static class TranslationModeDetector extends Pipeline.PipelineVisitor.Defaults {
-    private static final Logger LOG = LoggerFactory.getLogger(TranslationModeDetector.class);
+  /**
+   * {@link PTransformVisitor} that translates supported {@link PTransform PTransforms} into their
+   * Spark correspondence.
+   *
+   * <p>Note, in some cases this involves the early evaluation of some parts of the pipeline. For
+   * example, in order to use a side-input {@link org.apache.beam.sdk.values.PCollectionView
+   * PCollectionView} in a translation the corresponding Spark {@link
+   * org.apache.beam.runners.spark.translation.Dataset Dataset} might have to be collected and
+   * broadcasted.
+   */
+  private class TranslatingVisitor extends PTransformVisitor implements TranslationState {
+    private final Map<PCollection<?>, TranslationResult<?>> translationResults;
+    private final Map<Coder<?>, ExpressionEncoder<?>> encoders;
+    private final SparkSession sparkSession;
+    private final SerializablePipelineOptions serializableOptions;
+    private final StorageLevel storageLevel;
+
+    private final Set<TranslationResult<?>> leaves;
+
+    public TranslatingVisitor(
+        SparkSession sparkSession,
+        SparkCommonPipelineOptions options,
+        Map<PCollection<?>, TranslationResult<?>> translationResults) {
+      this.sparkSession = sparkSession;
+      this.translationResults = translationResults;
+      this.serializableOptions = new SerializablePipelineOptions(options);
+      this.storageLevel = StorageLevel.fromString(options.getStorageLevel());
+      this.encoders = new HashMap<>();
+      this.leaves = new HashSet<>();
+    }
 
-    private TranslationMode translationMode;
+    @Override
+    <InT extends PInput, OutT extends POutput> void visit(
+        Node node,
+        PTransform<InT, OutT> transform,
+        TransformTranslator<InT, OutT, PTransform<InT, OutT>> translator) {
+
+      AppliedPTransform<InT, OutT, PTransform<InT, OutT>> appliedTransform =
+          (AppliedPTransform) node.toAppliedPTransform(getPipeline());
+      try {
+        LOG.info(
+            "Translating {}: {}",
+            node.isCompositeNode() ? "composite" : "primitive",
+            node.getFullName());
+        translator.translate(transform, appliedTransform, this);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
 
-    TranslationModeDetector(TranslationMode defaultMode) {
-      this.translationMode = defaultMode;
+    @Override
+    public <T> Encoder<T> encoderOf(Coder<T> coder, Factory<T> factory) {
+      return (Encoder<T>) encoders.computeIfAbsent(coder, (Factory) factory);
     }
 
-    TranslationModeDetector() {
-      this(TranslationMode.BATCH);
+    private <T> TranslationResult<T> getResult(PCollection<T> pCollection) {
+      return (TranslationResult<T>) checkStateNotNull(translationResults.get(pCollection));
     }
 
-    TranslationMode getTranslationMode() {
-      return translationMode;
+    @Override
+    public <T> Dataset<WindowedValue<T>> getDataset(PCollection<T> pCollection) {
+      return checkStateNotNull(getResult(pCollection).dataset);
     }
 
     @Override
-    public void visitValue(PValue value, TransformHierarchy.Node producer) {
-      if (translationMode.equals(TranslationMode.BATCH)) {
-        if (value instanceof PCollection
-            && ((PCollection) value).isBounded() == PCollection.IsBounded.UNBOUNDED) {
-          LOG.info(
-              "Found unbounded PCollection {}. Switching to streaming execution.", value.getName());
-          translationMode = TranslationMode.STREAMING;
+    public <T> void putDataset(
+        PCollection<T> pCollection, Dataset<WindowedValue<T>> dataset, boolean noCache) {
+      TranslationResult<T> result = getResult(pCollection);
+      if (!noCache && result.dependentTransforms.size() > 1) {
+        LOG.info("Dataset {} will be cached.", result.name);
+        result.dataset = dataset.persist(storageLevel); // use NONE to disable
+      } else {
+        result.dataset = dataset;
+        if (result.dependentTransforms.isEmpty()) {
+          leaves.add(result);
         }
       }
     }
-  }
 
-  // --------------------------------------------------------------------------------------------
-  //  Pipeline utility methods
-  // --------------------------------------------------------------------------------------------
+    @Override
+    public SerializablePipelineOptions getSerializableOptions() {
+      return serializableOptions;
+    }
 
-  /** Get a {@link TransformTranslator} for the given {@link TransformHierarchy.Node}. */
-  protected abstract @Nullable <
-          InT extends PInput, OutT extends POutput, TransformT extends PTransform<InT, OutT>>
-      TransformTranslator<InT, OutT, TransformT> getTransformTranslator(
-          @Nullable TransformT transform);
-
-  /** Apply the given TransformTranslator to the given node. */
-  private <InT extends PInput, OutT extends POutput, TransformT extends PTransform<InT, OutT>>
-      void applyTransformTranslator(
-          TransformHierarchy.Node node,
-          TransformT transform,
-          TransformTranslator<InT, OutT, TransformT> transformTranslator) {
-    // create the applied PTransform on the translationContext
-    AppliedPTransform<InT, OutT, PTransform<InT, OutT>> appliedTransform =
-        (AppliedPTransform) node.toAppliedPTransform(getPipeline());
-    try {
-      transformTranslator.translate(transform, appliedTransform, translationContext);
-    } catch (IOException e) {
-      throw new RuntimeException(e);
+    @Override
+    public SparkSession getSparkSession() {
+      return sparkSession;
     }
   }
 
-  // --------------------------------------------------------------------------------------------
-  //  Pipeline visitor entry point
-  // --------------------------------------------------------------------------------------------
-
   /**
-   * Translates the pipeline by passing this class as a visitor.
+   * {@link PTransformVisitor} that analyses dependencies of supported {@link PTransform
+   * PTransforms} to help identify cache candidates.
    *
-   * @param pipeline The pipeline to be translated
+   * <p>The visitor may throw if a {@link PTransform} is observed that uses unsupported features.
    */
-  public void translate(Pipeline pipeline) {
-    LOG.debug("starting translation of the pipeline using {}", getClass().getName());
-    pipeline.traverseTopologically(this);
+  private class DependencyVisitor extends PTransformVisitor {
+    private final Map<PCollection<?>, TranslationResult<?>> results = new HashMap<>();
+
+    @Override
+    <InT extends PInput, OutT extends POutput> void visit(
+        Node node,
+        PTransform<InT, OutT> transform,
+        TransformTranslator<InT, OutT, PTransform<InT, OutT>> translator) {
+      for (PCollection<?> pOut : node.getOutputs().values()) {
+        results.put(pOut, new TranslationResult<>(pOut));
+        for (Map.Entry<TupleTag<?>, PCollection<?>> entry : node.getInputs().entrySet()) {
+          TranslationResult<?> input = checkStateNotNull(results.get(entry.getValue()));
+          input.dependentTransforms.add(transform);
+        }
+      }
+    }
   }
 
-  // --------------------------------------------------------------------------------------------
-  //  Pipeline Visitor Methods
-  // --------------------------------------------------------------------------------------------
+  /**
+   * An abstract {@link PipelineVisitor} that visits all translatable {@link PTransform} pipeline
+   * nodes of a pipeline with the respective {@link TransformTranslator}.
+   *
+   * <p>The visitor may throw if a {@link PTransform} is observed that uses unsupported features.
+   */
+  private abstract class PTransformVisitor extends PipelineVisitor.Defaults {
 
-  @Override
-  public CompositeBehavior enterCompositeTransform(TransformHierarchy.Node node) {
-    PTransform<PInput, POutput> transform = (PTransform<PInput, POutput>) node.getTransform();
-    TransformTranslator<PInput, POutput, PTransform<PInput, POutput>> transformTranslator =
-        getTransformTranslator(transform);
+    /** Visit the {@link PTransform} with its respective {@link TransformTranslator}. */
+    abstract <InT extends PInput, OutT extends POutput> void visit(
+        Node node,
+        PTransform<InT, OutT> transform,
+        TransformTranslator<InT, OutT, PTransform<InT, OutT>> translator);
 
-    if (transformTranslator != null) {
-      LOG.info("Translating composite: {}", node.getFullName());
-      applyTransformTranslator(node, transform, transformTranslator);
-      return CompositeBehavior.DO_NOT_ENTER_TRANSFORM;
-    } else {
-      return CompositeBehavior.ENTER_TRANSFORM;
+    @Override
+    public final CompositeBehavior enterCompositeTransform(Node node) {
+      PTransform<PInput, POutput> transform = (PTransform<PInput, POutput>) node.getTransform();
+      TransformTranslator<PInput, POutput, PTransform<PInput, POutput>> translator =
+          getTranslator(transform);
+      if (transform != null && translator != null) {
+        visit(node, transform, translator);
+        return DO_NOT_ENTER_TRANSFORM;
+      } else {
+        return ENTER_TRANSFORM;
+      }
+    }
+
+    @Override
+    public final void visitPrimitiveTransform(Node node) {
+      PTransform<PInput, POutput> transform = (PTransform<PInput, POutput>) node.getTransform();
+      if (transform == null || transform.getClass().equals(View.CreatePCollectionView.class)) {
+        return; // ignore, nothing to be translated here
+      }
+      TransformTranslator<PInput, POutput, PTransform<PInput, POutput>> translator =
+          getTranslator(transform);
+      if (translator == null) {
+        String urn = PTransformTranslation.urnForTransform(transform);
+        throw new UnsupportedOperationException("Transform " + urn + " is not supported.");
+      }
+      visit(node, transform, translator);
     }
-  }
 
-  @Override
-  public void visitPrimitiveTransform(TransformHierarchy.Node node) {
-    LOG.info("Translating primitive: {}", node.getFullName());
-    // get the transformation corresponding to the node we are
-    // currently visiting and translate it into its Spark alternative.
-    PTransform<PInput, POutput> transform = (PTransform<PInput, POutput>) node.getTransform();
-    TransformTranslator<PInput, POutput, PTransform<PInput, POutput>> transformTranslator =
-        getTransformTranslator(transform);
-
-    if (transformTranslator == null) {
-      String transformUrn = PTransformTranslation.urnForTransform(node.getTransform());
-      throw new UnsupportedOperationException(
-          "The transform " + transformUrn + " is currently not supported.");
-    }
-    applyTransformTranslator(node, transform, transformTranslator);
+    /** {@link TransformTranslator} for {@link PTransform} if translation is known and supported. */
+    private @Nullable TransformTranslator<PInput, POutput, PTransform<PInput, POutput>>

Review Comment:
   yes good name.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [beam] echauchot commented on pull request #24009: [Spark dataset runner] Cache datasets if used multiple times

Posted by GitBox <gi...@apache.org>.
echauchot commented on PR #24009:
URL: https://github.com/apache/beam/pull/24009#issuecomment-1313881347

   > > Why is it that this PR addresses 2 separate tickets when the guidelines state to address only one per PR (for isolation, diagnostic, and revert reasons) ? Can't the 2 issues be addressed separately ? They are too coupled ?
   > 
   > @echauchot This PR addresses just one issue. But it's done in a way that #24035 can be fixed in a follow up without introducing massive changes yet again. #24035 is not addressed at all yet.
   
   ok perfect 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@beam.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org