You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by tg...@apache.org on 2017/04/14 23:52:13 UTC

[1/3] beam git commit: This closes #2501

Repository: beam
Updated Branches:
  refs/heads/master 3c2b855f5 -> fdbadfc9c


This closes #2501


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/fdbadfc9
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/fdbadfc9
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/fdbadfc9

Branch: refs/heads/master
Commit: fdbadfc9cceecf645f2325777ab32f1cb3041953
Parents: 3c2b855 f3b4960
Author: Thomas Groh <tg...@google.com>
Authored: Fri Apr 14 16:52:03 2017 -0700
Committer: Thomas Groh <tg...@google.com>
Committed: Fri Apr 14 16:52:03 2017 -0700

----------------------------------------------------------------------
 .../apache/beam/runners/apex/ApexRunner.java    |  32 +++--
 .../DeduplicatedFlattenFactory.java             |  63 +++++----
 .../EmptyFlattenAsCreateFactory.java            |  20 ++-
 .../core/construction/PTransformMatchers.java   |   2 -
 .../construction/PTransformReplacements.java    |  69 ++++++++++
 .../core/construction/PrimitiveCreate.java      |  13 +-
 .../SingleInputOutputOverrideFactory.java       |   9 +-
 .../UnsupportedOverrideFactory.java             |  14 +-
 .../DeduplicatedFlattenFactoryTest.java         |  18 +--
 .../EmptyFlattenAsCreateFactoryTest.java        |  36 ++++-
 .../PTransformReplacementsTest.java             | 131 +++++++++++++++++++
 .../SingleInputOutputOverrideFactoryTest.java   |  31 ++---
 .../UnsupportedOverrideFactoryTest.java         |  11 +-
 ...ectGBKIntoKeyedWorkItemsOverrideFactory.java |  16 ++-
 .../direct/DirectGroupByKeyOverrideFactory.java |  14 +-
 .../direct/ParDoMultiOverrideFactory.java       |  22 ++--
 .../direct/TestStreamEvaluatorFactory.java      |  14 +-
 .../runners/direct/ViewOverrideFactory.java     |  18 +--
 .../direct/WriteWithShardingFactory.java        |  16 +--
 .../DirectGroupByKeyOverrideFactoryTest.java    |  12 +-
 .../direct/ParDoMultiOverrideFactoryTest.java   |  45 -------
 .../direct/TestStreamEvaluatorFactoryTest.java  |  12 --
 .../runners/direct/ViewOverrideFactoryTest.java |  42 ++++--
 .../direct/WriteWithShardingFactoryTest.java    |  23 ++--
 .../flink/FlinkStreamingPipelineTranslator.java |  56 ++++----
 .../dataflow/BatchStatefulParDoOverrides.java   |  42 +++---
 .../runners/dataflow/BatchViewOverrides.java    |  17 ++-
 .../beam/runners/dataflow/DataflowRunner.java   |  92 ++++++-------
 .../dataflow/PrimitiveParDoSingleFactory.java   |  15 ++-
 .../dataflow/ReshuffleOverrideFactory.java      |  12 +-
 .../dataflow/StreamingViewOverrides.java        |  14 +-
 .../PrimitiveParDoSingleFactoryTest.java        |  59 +++++++--
 .../beam/runners/spark/TestSparkRunner.java     |  14 +-
 .../main/java/org/apache/beam/sdk/Pipeline.java |  15 ++-
 .../sdk/runners/PTransformOverrideFactory.java  |  33 +++--
 .../beam/sdk/transforms/AppliedPTransform.java  |   5 +
 .../java/org/apache/beam/sdk/PipelineTest.java  |  33 ++---
 37 files changed, 675 insertions(+), 415 deletions(-)
----------------------------------------------------------------------



[3/3] beam git commit: Update Signature of PTransformOverrideFactory

Posted by tg...@apache.org.
Update Signature of PTransformOverrideFactory

This enables replacements to be reobtained with the entire transform
that is being replaced.

This is required when Side Inputs are part of the input of the
PTransform Application, as PTransforms are not applied to their side
inputs.


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/f3b49605
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/f3b49605
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/f3b49605

Branch: refs/heads/master
Commit: f3b496053d2596ee1b2de55f6da055b478a0d6d3
Parents: 3c2b855
Author: Thomas Groh <tg...@google.com>
Authored: Wed Mar 29 15:23:21 2017 -0700
Committer: Thomas Groh <tg...@google.com>
Committed: Fri Apr 14 16:52:03 2017 -0700

----------------------------------------------------------------------
 .../apache/beam/runners/apex/ApexRunner.java    |  32 +++--
 .../DeduplicatedFlattenFactory.java             |  63 +++++----
 .../EmptyFlattenAsCreateFactory.java            |  20 ++-
 .../core/construction/PTransformMatchers.java   |   2 -
 .../construction/PTransformReplacements.java    |  69 ++++++++++
 .../core/construction/PrimitiveCreate.java      |  13 +-
 .../SingleInputOutputOverrideFactory.java       |   9 +-
 .../UnsupportedOverrideFactory.java             |  14 +-
 .../DeduplicatedFlattenFactoryTest.java         |  18 +--
 .../EmptyFlattenAsCreateFactoryTest.java        |  36 ++++-
 .../PTransformReplacementsTest.java             | 131 +++++++++++++++++++
 .../SingleInputOutputOverrideFactoryTest.java   |  31 ++---
 .../UnsupportedOverrideFactoryTest.java         |  11 +-
 ...ectGBKIntoKeyedWorkItemsOverrideFactory.java |  16 ++-
 .../direct/DirectGroupByKeyOverrideFactory.java |  14 +-
 .../direct/ParDoMultiOverrideFactory.java       |  22 ++--
 .../direct/TestStreamEvaluatorFactory.java      |  14 +-
 .../runners/direct/ViewOverrideFactory.java     |  18 +--
 .../direct/WriteWithShardingFactory.java        |  16 +--
 .../DirectGroupByKeyOverrideFactoryTest.java    |  12 +-
 .../direct/ParDoMultiOverrideFactoryTest.java   |  45 -------
 .../direct/TestStreamEvaluatorFactoryTest.java  |  12 --
 .../runners/direct/ViewOverrideFactoryTest.java |  42 ++++--
 .../direct/WriteWithShardingFactoryTest.java    |  23 ++--
 .../flink/FlinkStreamingPipelineTranslator.java |  56 ++++----
 .../dataflow/BatchStatefulParDoOverrides.java   |  42 +++---
 .../runners/dataflow/BatchViewOverrides.java    |  17 ++-
 .../beam/runners/dataflow/DataflowRunner.java   |  92 ++++++-------
 .../dataflow/PrimitiveParDoSingleFactory.java   |  15 ++-
 .../dataflow/ReshuffleOverrideFactory.java      |  12 +-
 .../dataflow/StreamingViewOverrides.java        |  14 +-
 .../PrimitiveParDoSingleFactoryTest.java        |  59 +++++++--
 .../beam/runners/spark/TestSparkRunner.java     |  14 +-
 .../main/java/org/apache/beam/sdk/Pipeline.java |  15 ++-
 .../sdk/runners/PTransformOverrideFactory.java  |  33 +++--
 .../beam/sdk/transforms/AppliedPTransform.java  |   5 +
 .../java/org/apache/beam/sdk/PipelineTest.java  |  33 ++---
 37 files changed, 675 insertions(+), 415 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/apex/src/main/java/org/apache/beam/runners/apex/ApexRunner.java
----------------------------------------------------------------------
diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/ApexRunner.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/ApexRunner.java
index 1c99f8d..1c845c6 100644
--- a/runners/apex/src/main/java/org/apache/beam/runners/apex/ApexRunner.java
+++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/ApexRunner.java
@@ -39,6 +39,7 @@ import org.apache.apex.api.Launcher.AppHandle;
 import org.apache.apex.api.Launcher.LaunchMode;
 import org.apache.beam.runners.apex.translation.ApexPipelineTranslator;
 import org.apache.beam.runners.core.construction.PTransformMatchers;
+import org.apache.beam.runners.core.construction.PTransformReplacements;
 import org.apache.beam.runners.core.construction.PrimitiveCreate;
 import org.apache.beam.runners.core.construction.SingleInputOutputOverrideFactory;
 import org.apache.beam.sdk.Pipeline;
@@ -49,6 +50,7 @@ import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.options.PipelineOptionsValidator;
 import org.apache.beam.sdk.runners.PTransformOverride;
 import org.apache.beam.sdk.runners.PipelineRunner;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.Combine;
 import org.apache.beam.sdk.transforms.Combine.GloballyAsSingletonView;
 import org.apache.beam.sdk.transforms.Create;
@@ -258,9 +260,15 @@ public class ApexRunner extends PipelineRunner<ApexRunnerResult> {
             PCollection<InputT>, PCollectionView<OutputT>,
             Combine.GloballyAsSingletonView<InputT, OutputT>> {
       @Override
-      public PTransform<PCollection<InputT>, PCollectionView<OutputT>> getReplacementTransform(
-          GloballyAsSingletonView<InputT, OutputT> transform) {
-        return new StreamingCombineGloballyAsSingletonView<>(transform);
+      public PTransformReplacement<PCollection<InputT>, PCollectionView<OutputT>>
+          getReplacementTransform(
+              AppliedPTransform<
+                      PCollection<InputT>, PCollectionView<OutputT>,
+                      GloballyAsSingletonView<InputT, OutputT>>
+                  transform) {
+        return PTransformReplacement.of(
+            PTransformReplacements.getSingletonMainInput(transform),
+            new StreamingCombineGloballyAsSingletonView<>(transform.getTransform()));
       }
     }
   }
@@ -321,9 +329,11 @@ public class ApexRunner extends PipelineRunner<ApexRunnerResult> {
         extends SingleInputOutputOverrideFactory<
             PCollection<T>, PCollectionView<T>, View.AsSingleton<T>> {
       @Override
-      public PTransform<PCollection<T>, PCollectionView<T>> getReplacementTransform(
-          AsSingleton<T> transform) {
-        return new StreamingViewAsSingleton<>(transform);
+      public PTransformReplacement<PCollection<T>, PCollectionView<T>> getReplacementTransform(
+          AppliedPTransform<PCollection<T>, PCollectionView<T>, AsSingleton<T>> transform) {
+        return PTransformReplacement.of(
+            PTransformReplacements.getSingletonMainInput(transform),
+            new StreamingViewAsSingleton<>(transform.getTransform()));
       }
     }
   }
@@ -352,9 +362,13 @@ public class ApexRunner extends PipelineRunner<ApexRunnerResult> {
         extends SingleInputOutputOverrideFactory<
             PCollection<T>, PCollectionView<Iterable<T>>, View.AsIterable<T>> {
       @Override
-      public PTransform<PCollection<T>, PCollectionView<Iterable<T>>> getReplacementTransform(
-          AsIterable<T> transform) {
-        return new StreamingViewAsIterable<>();
+      public PTransformReplacement<PCollection<T>, PCollectionView<Iterable<T>>>
+          getReplacementTransform(
+              AppliedPTransform<PCollection<T>, PCollectionView<Iterable<T>>, AsIterable<T>>
+                  transform) {
+        return PTransformReplacement.of(
+            PTransformReplacements.getSingletonMainInput(transform),
+            new StreamingViewAsIterable<T>());
       }
     }
   }

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/DeduplicatedFlattenFactory.java
----------------------------------------------------------------------
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/DeduplicatedFlattenFactory.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/DeduplicatedFlattenFactory.java
index c12c548..13e7593 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/DeduplicatedFlattenFactory.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/DeduplicatedFlattenFactory.java
@@ -18,10 +18,12 @@
 
 package org.apache.beam.runners.core.construction;
 
+import com.google.common.annotations.VisibleForTesting;
 import java.util.HashMap;
 import java.util.Map;
 import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.runners.PTransformOverrideFactory;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.Flatten;
 import org.apache.beam.sdk.transforms.Flatten.PCollections;
@@ -47,32 +49,11 @@ public class DeduplicatedFlattenFactory<T>
   private DeduplicatedFlattenFactory() {}
 
   @Override
-  public PTransform<PCollectionList<T>, PCollection<T>> getReplacementTransform(
-      PCollections<T> transform) {
-    return new PTransform<PCollectionList<T>, PCollection<T>>() {
-      @Override
-      public PCollection<T> expand(PCollectionList<T> input) {
-        Map<PCollection<T>, Integer> instances = new HashMap<>();
-        for (PCollection<T> pCollection : input.getAll()) {
-          int existing = instances.get(pCollection) == null ? 0 : instances.get(pCollection);
-          instances.put(pCollection, existing + 1);
-        }
-        PCollectionList<T> output = PCollectionList.empty(input.getPipeline());
-        for (Map.Entry<PCollection<T>, Integer> instanceEntry : instances.entrySet()) {
-          if (instanceEntry.getValue().equals(1)) {
-            output = output.and(instanceEntry.getKey());
-          } else {
-            String duplicationName = String.format("Multiply %s", instanceEntry.getKey().getName());
-            PCollection<T> duplicated =
-                instanceEntry
-                    .getKey()
-                    .apply(duplicationName, ParDo.of(new DuplicateFn<T>(instanceEntry.getValue())));
-            output = output.and(duplicated);
-          }
-        }
-        return output.apply(Flatten.<T>pCollections());
-      }
-    };
+  public PTransformReplacement<PCollectionList<T>, PCollection<T>> getReplacementTransform(
+      AppliedPTransform<PCollectionList<T>, PCollection<T>, PCollections<T>> transform) {
+    return PTransformReplacement.of(
+        getInput(transform.getInputs(), transform.getPipeline()),
+        new FlattenWithoutDuplicateInputs<T>());
   }
 
   /**
@@ -80,8 +61,7 @@ public class DeduplicatedFlattenFactory<T>
    *
    * <p>The input {@link PCollectionList} that is constructed will have the same values in the same
    */
-  @Override
-  public PCollectionList<T> getInput(Map<TupleTag<?>, PValue> inputs, Pipeline p) {
+  private PCollectionList<T> getInput(Map<TupleTag<?>, PValue> inputs, Pipeline p) {
     PCollectionList<T> pCollections = PCollectionList.empty(p);
     for (PValue input : inputs.values()) {
       PCollection<T> pcollection = (PCollection<T>) input;
@@ -96,6 +76,33 @@ public class DeduplicatedFlattenFactory<T>
     return ReplacementOutputs.singleton(outputs, newOutput);
   }
 
+  @VisibleForTesting
+  static class FlattenWithoutDuplicateInputs<T>
+      extends PTransform<PCollectionList<T>, PCollection<T>> {
+    @Override
+    public PCollection<T> expand(PCollectionList<T> input) {
+      Map<PCollection<T>, Integer> instances = new HashMap<>();
+      for (PCollection<T> pCollection : input.getAll()) {
+        int existing = instances.get(pCollection) == null ? 0 : instances.get(pCollection);
+        instances.put(pCollection, existing + 1);
+      }
+      PCollectionList<T> output = PCollectionList.empty(input.getPipeline());
+      for (Map.Entry<PCollection<T>, Integer> instanceEntry : instances.entrySet()) {
+        if (instanceEntry.getValue().equals(1)) {
+          output = output.and(instanceEntry.getKey());
+        } else {
+          String duplicationName = String.format("Multiply %s", instanceEntry.getKey().getName());
+          PCollection<T> duplicated =
+              instanceEntry
+                  .getKey()
+                  .apply(duplicationName, ParDo.of(new DuplicateFn<T>(instanceEntry.getValue())));
+          output = output.and(duplicated);
+        }
+      }
+      return output.apply(Flatten.<T>pCollections());
+    }
+  }
+
   private static class DuplicateFn<T> extends DoFn<T, T> {
     private final int numTimes;
 

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/EmptyFlattenAsCreateFactory.java
----------------------------------------------------------------------
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/EmptyFlattenAsCreateFactory.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/EmptyFlattenAsCreateFactory.java
index 936bc08..a6982d4 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/EmptyFlattenAsCreateFactory.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/EmptyFlattenAsCreateFactory.java
@@ -21,11 +21,12 @@ package org.apache.beam.runners.core.construction;
 import static com.google.common.base.Preconditions.checkArgument;
 
 import java.util.Map;
-import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.coders.VoidCoder;
 import org.apache.beam.sdk.runners.PTransformOverrideFactory;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.Flatten;
+import org.apache.beam.sdk.transforms.Flatten.PCollections;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PCollectionList;
@@ -49,20 +50,15 @@ public class EmptyFlattenAsCreateFactory<T>
   private EmptyFlattenAsCreateFactory() {}
 
   @Override
-  public PTransform<PCollectionList<T>, PCollection<T>> getReplacementTransform(
-      Flatten.PCollections<T> transform) {
-    return new CreateEmptyFromList<>();
-  }
-
-  @Override
-  public PCollectionList<T> getInput(
-      Map<TupleTag<?>, PValue> inputs, Pipeline p) {
+  public PTransformReplacement<PCollectionList<T>, PCollection<T>> getReplacementTransform(
+      AppliedPTransform<PCollectionList<T>, PCollection<T>, PCollections<T>> transform) {
     checkArgument(
-        inputs.isEmpty(),
+        transform.getInputs().isEmpty(),
         "Unexpected nonempty input %s for %s",
-        inputs,
+        transform.getInputs(),
         getClass().getSimpleName());
-    return PCollectionList.empty(p);
+    return PTransformReplacement.of(
+        PCollectionList.<T>empty(transform.getPipeline()), new CreateEmptyFromList<T>());
   }
 
   @Override

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformMatchers.java
----------------------------------------------------------------------
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformMatchers.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformMatchers.java
index 94ec38c..09946bc 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformMatchers.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformMatchers.java
@@ -52,8 +52,6 @@ public class PTransformMatchers {
   /**
    * Returns a {@link PTransformMatcher} that matches a {@link PTransform} if the class of the
    * {@link PTransform} is equal to the {@link Class} provided ot this matcher.
-   * @param clazz
-   * @return
    */
   public static PTransformMatcher classEqualTo(Class<? extends PTransform> clazz) {
     return new EqualClassPTransformMatcher(clazz);

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformReplacements.java
----------------------------------------------------------------------
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformReplacements.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformReplacements.java
new file mode 100644
index 0000000..72a3425
--- /dev/null
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformReplacements.java
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.core.construction;
+
+import static com.google.common.base.Preconditions.checkArgument;
+
+import java.util.Map;
+import java.util.Set;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
+import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PValue;
+import org.apache.beam.sdk.values.TupleTag;
+
+/**
+ */
+public class PTransformReplacements {
+  /**
+   * Gets the singleton input of an {@link AppliedPTransform}, ignoring any additional inputs
+   * returned by {@link PTransform#getAdditionalInputs()}.
+   */
+  public static <T> PCollection<T> getSingletonMainInput(
+      AppliedPTransform<? extends PCollection<? extends T>, ?, ?> application) {
+    return getSingletonMainInput(
+        application.getInputs(), application.getTransform().getAdditionalInputs().keySet());
+  }
+
+  private static <T> PCollection<T> getSingletonMainInput(
+      Map<TupleTag<?>, PValue> inputs, Set<TupleTag<?>> ignoredTags) {
+    PCollection<T> mainInput = null;
+    for (Map.Entry<TupleTag<?>, PValue> input : inputs.entrySet()) {
+      if (!ignoredTags.contains(input.getKey())) {
+        checkArgument(
+            mainInput == null,
+            "Got multiple inputs that are not additional inputs for a "
+                + "singleton main input: %s and %s",
+            mainInput,
+            input.getValue());
+        checkArgument(
+            input.getValue() instanceof PCollection,
+            "Unexpected input type %s",
+            input.getValue().getClass());
+        mainInput = (PCollection<T>) input.getValue();
+      }
+    }
+    checkArgument(
+        mainInput != null,
+        "No main input found in inputs: Inputs %s, Side Input tags %s",
+        inputs,
+        ignoredTags);
+    return mainInput;
+  }
+}

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PrimitiveCreate.java
----------------------------------------------------------------------
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PrimitiveCreate.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PrimitiveCreate.java
index 9335f3a..5a2140b 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PrimitiveCreate.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PrimitiveCreate.java
@@ -19,8 +19,8 @@
 package org.apache.beam.runners.core.construction;
 
 import java.util.Map;
-import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.runners.PTransformOverrideFactory;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.Create.Values;
 import org.apache.beam.sdk.transforms.PTransform;
@@ -57,13 +57,10 @@ public class PrimitiveCreate<T> extends PTransform<PBegin, PCollection<T>> {
   public static class Factory<T>
       implements PTransformOverrideFactory<PBegin, PCollection<T>, Values<T>> {
     @Override
-    public PTransform<PBegin, PCollection<T>> getReplacementTransform(Values<T> transform) {
-      return new PrimitiveCreate<>(transform);
-    }
-
-    @Override
-    public PBegin getInput(Map<TupleTag<?>, PValue> inputs, Pipeline p) {
-      return p.begin();
+    public PTransformReplacement<PBegin, PCollection<T>> getReplacementTransform(
+        AppliedPTransform<PBegin, PCollection<T>, Values<T>> transform) {
+      return PTransformReplacement.of(
+          transform.getPipeline().begin(), new PrimitiveCreate<T>(transform.getTransform()));
     }
 
     @Override

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SingleInputOutputOverrideFactory.java
----------------------------------------------------------------------
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SingleInputOutputOverrideFactory.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SingleInputOutputOverrideFactory.java
index 6d0d571..7a59c1c 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SingleInputOutputOverrideFactory.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SingleInputOutputOverrideFactory.java
@@ -18,9 +18,7 @@
 
 package org.apache.beam.runners.core.construction;
 
-import com.google.common.collect.Iterables;
 import java.util.Map;
-import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.runners.PTransformOverrideFactory;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.values.PValue;
@@ -28,7 +26,7 @@ import org.apache.beam.sdk.values.TupleTag;
 
 /**
  * A {@link PTransformOverrideFactory} which consumes from a {@link PValue} and produces a
- * {@link PValue}. {@link #getInput(Map, Pipeline)} and {@link #mapOutputs(Map, PValue)} are
+ * {@link PValue}. {@link #mapOutputs(Map, PValue)} is
  * implemented.
  */
 public abstract class SingleInputOutputOverrideFactory<
@@ -37,11 +35,6 @@ public abstract class SingleInputOutputOverrideFactory<
         TransformT extends PTransform<InputT, OutputT>>
     implements PTransformOverrideFactory<InputT, OutputT, TransformT> {
   @Override
-  public final InputT getInput(Map<TupleTag<?>, PValue> inputs, Pipeline p) {
-    return (InputT) Iterables.getOnlyElement(inputs.values());
-  }
-
-  @Override
   public final Map<PValue, ReplacementOutput> mapOutputs(
       Map<TupleTag<?>, PValue> outputs, OutputT newOutput) {
     return ReplacementOutputs.singleton(outputs, newOutput);

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/UnsupportedOverrideFactory.java
----------------------------------------------------------------------
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/UnsupportedOverrideFactory.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/UnsupportedOverrideFactory.java
index 7b9d704..efafa33 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/UnsupportedOverrideFactory.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/UnsupportedOverrideFactory.java
@@ -19,8 +19,8 @@
 package org.apache.beam.runners.core.construction;
 
 import java.util.Map;
-import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.runners.PTransformOverrideFactory;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.values.PInput;
 import org.apache.beam.sdk.values.POutput;
@@ -29,8 +29,8 @@ import org.apache.beam.sdk.values.TupleTag;
 
 /**
  * A {@link PTransformOverrideFactory} that throws an exception when a call to
- * {@link #getReplacementTransform(PTransform)} is made. This is for {@link PTransform PTransforms}
- * which are not supported by a runner.
+ * {@link #getReplacementTransform(AppliedPTransform)} is made. This is for
+ * {@link PTransform PTransforms} which are not supported by a runner.
  */
 public final class UnsupportedOverrideFactory<
         InputT extends PInput,
@@ -54,12 +54,8 @@ public final class UnsupportedOverrideFactory<
   }
 
   @Override
-  public PTransform<InputT, OutputT> getReplacementTransform(TransformT transform) {
-    throw new UnsupportedOperationException(message);
-  }
-
-  @Override
-  public InputT getInput(Map<TupleTag<?>, PValue> inputs, Pipeline p) {
+  public PTransformReplacement<InputT, OutputT> getReplacementTransform(
+      AppliedPTransform<InputT, OutputT, TransformT> transform) {
     throw new UnsupportedOperationException(message);
   }
 

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/DeduplicatedFlattenFactoryTest.java
----------------------------------------------------------------------
diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/DeduplicatedFlattenFactoryTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/DeduplicatedFlattenFactoryTest.java
index 14aa1e6..4e08c21 100644
--- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/DeduplicatedFlattenFactoryTest.java
+++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/DeduplicatedFlattenFactoryTest.java
@@ -22,6 +22,7 @@ import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.not;
 import static org.junit.Assert.assertThat;
 
+import org.apache.beam.runners.core.construction.DeduplicatedFlattenFactory.FlattenWithoutDuplicateInputs;
 import org.apache.beam.sdk.Pipeline.PipelineVisitor.Defaults;
 import org.apache.beam.sdk.runners.PTransformOverrideFactory.ReplacementOutput;
 import org.apache.beam.sdk.runners.TransformHierarchy;
@@ -56,7 +57,7 @@ public class DeduplicatedFlattenFactoryTest {
   @Test
   public void duplicatesInsertsMultipliers() {
     PTransform<PCollectionList<String>, PCollection<String>> replacement =
-        factory.getReplacementTransform(Flatten.<String>pCollections());
+        new DeduplicatedFlattenFactory.FlattenWithoutDuplicateInputs<>();
     final PCollectionList<String> inputList =
         PCollectionList.of(first).and(second).and(first).and(first);
     inputList.apply(replacement);
@@ -74,10 +75,10 @@ public class DeduplicatedFlattenFactoryTest {
   @Test
   @Category(NeedsRunner.class)
   public void testOverride() {
-    PTransform<PCollectionList<String>, PCollection<String>> replacement =
-        factory.getReplacementTransform(Flatten.<String>pCollections());
     final PCollectionList<String> inputList =
         PCollectionList.of(first).and(second).and(first).and(first);
+    PTransform<PCollectionList<String>, PCollection<String>> replacement =
+        new FlattenWithoutDuplicateInputs<>();
     PCollection<String> flattened = inputList.apply(replacement);
 
     PAssert.that(flattened).containsInAnyOrder("one", "two", "one", "one");
@@ -85,21 +86,12 @@ public class DeduplicatedFlattenFactoryTest {
   }
 
   @Test
-  public void inputReconstruction() {
-    final PCollectionList<String> inputList =
-        PCollectionList.of(first).and(second).and(first).and(first);
-
-    assertThat(factory.getInput(inputList.expand(), pipeline), equalTo(inputList));
-  }
-
-  @Test
   public void outputMapping() {
     final PCollectionList<String> inputList =
         PCollectionList.of(first).and(second).and(first).and(first);
     PCollection<String> original =
         inputList.apply(Flatten.<String>pCollections());
-    PCollection<String> replacement =
-        inputList.apply(factory.getReplacementTransform(Flatten.<String>pCollections()));
+    PCollection<String> replacement = inputList.apply(new FlattenWithoutDuplicateInputs<String>());
 
     assertThat(
         factory.mapOutputs(original.expand(), replacement),

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/EmptyFlattenAsCreateFactoryTest.java
----------------------------------------------------------------------
diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/EmptyFlattenAsCreateFactoryTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/EmptyFlattenAsCreateFactoryTest.java
index 90bbee7..ae2d0a9 100644
--- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/EmptyFlattenAsCreateFactoryTest.java
+++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/EmptyFlattenAsCreateFactoryTest.java
@@ -18,17 +18,20 @@
 
 package org.apache.beam.runners.core.construction;
 
-import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.emptyIterable;
 import static org.junit.Assert.assertThat;
 
 import java.util.Collections;
 import java.util.Map;
 import org.apache.beam.sdk.io.CountingInput;
+import org.apache.beam.sdk.runners.PTransformOverrideFactory.PTransformReplacement;
 import org.apache.beam.sdk.runners.PTransformOverrideFactory.ReplacementOutput;
 import org.apache.beam.sdk.testing.NeedsRunner;
 import org.apache.beam.sdk.testing.PAssert;
 import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.Flatten;
+import org.apache.beam.sdk.transforms.Flatten.PCollections;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PCollectionList;
 import org.apache.beam.sdk.values.PValue;
@@ -54,8 +57,15 @@ public class EmptyFlattenAsCreateFactoryTest {
 
   @Test
   public void getInputEmptySucceeds() {
-    assertThat(
-        factory.getInput(Collections.<TupleTag<?>, PValue>emptyMap(), pipeline).size(), equalTo(0));
+    PTransformReplacement<PCollectionList<Long>, PCollection<Long>> replacement =
+        factory.getReplacementTransform(
+            AppliedPTransform.<PCollectionList<Long>, PCollection<Long>, PCollections<Long>>of(
+                "nonEmptyInput",
+                Collections.<TupleTag<?>, PValue>emptyMap(),
+                Collections.<TupleTag<?>, PValue>emptyMap(),
+                Flatten.<Long>pCollections(),
+                pipeline));
+    assertThat(replacement.getInput().getAll(), emptyIterable());
   }
 
   @Test
@@ -66,7 +76,13 @@ public class EmptyFlattenAsCreateFactoryTest {
     thrown.expect(IllegalArgumentException.class);
     thrown.expectMessage(nonEmpty.expand().toString());
     thrown.expectMessage(EmptyFlattenAsCreateFactory.class.getSimpleName());
-    factory.getInput(nonEmpty.expand(), pipeline);
+    factory.getReplacementTransform(
+        AppliedPTransform.<PCollectionList<Long>, PCollection<Long>, Flatten.PCollections<Long>>of(
+            "nonEmptyInput",
+            nonEmpty.expand(),
+            Collections.<TupleTag<?>, PValue>emptyMap(),
+            Flatten.<Long>pCollections(),
+            pipeline));
   }
 
   @Test
@@ -89,7 +105,17 @@ public class EmptyFlattenAsCreateFactoryTest {
   public void testOverride() {
     PCollectionList<Long> empty = PCollectionList.empty(pipeline);
     PCollection<Long> emptyFlattened =
-        empty.apply(factory.getReplacementTransform(Flatten.<Long>pCollections()));
+        empty.apply(
+            factory
+                .getReplacementTransform(
+                    AppliedPTransform
+                        .<PCollectionList<Long>, PCollection<Long>, Flatten.PCollections<Long>>of(
+                            "nonEmptyInput",
+                            Collections.<TupleTag<?>, PValue>emptyMap(),
+                            Collections.<TupleTag<?>, PValue>emptyMap(),
+                            Flatten.<Long>pCollections(),
+                            pipeline))
+                .getTransform());
     PAssert.that(emptyFlattened).empty();
     pipeline.run();
   }

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PTransformReplacementsTest.java
----------------------------------------------------------------------
diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PTransformReplacementsTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PTransformReplacementsTest.java
new file mode 100644
index 0000000..b065617
--- /dev/null
+++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PTransformReplacementsTest.java
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.core.construction;
+
+import static org.hamcrest.Matchers.equalTo;
+import static org.junit.Assert.assertThat;
+
+import com.google.common.collect.ImmutableMap;
+import java.util.Collections;
+import org.apache.beam.sdk.io.CountingInput;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.View;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionView;
+import org.apache.beam.sdk.values.PValue;
+import org.apache.beam.sdk.values.TupleTag;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/**
+ * Tests for {@link PTransformReplacements}.
+ */
+@RunWith(JUnit4.class)
+public class PTransformReplacementsTest {
+  @Rule public TestPipeline pipeline = TestPipeline.create().enableAbandonedNodeEnforcement(false);
+  @Rule public ExpectedException thrown = ExpectedException.none();
+  private PCollection<Long> mainInput = pipeline.apply(CountingInput.unbounded());
+  private PCollectionView<String> sideInput =
+      pipeline.apply(Create.of("foo")).apply(View.<String>asSingleton());
+
+  private PCollection<Long> output = mainInput.apply(ParDo.of(new TestDoFn()));
+
+  @Test
+  public void getMainInputSingleOutputSingleInput() {
+    AppliedPTransform<PCollection<Long>, ?, ?> application =
+        AppliedPTransform.of(
+            "application",
+            Collections.<TupleTag<?>, PValue>singletonMap(new TupleTag<Long>(), mainInput),
+            Collections.<TupleTag<?>, PValue>singletonMap(new TupleTag<Long>(), output),
+            ParDo.of(new TestDoFn()),
+            pipeline);
+    PCollection<Long> input = PTransformReplacements.getSingletonMainInput(application);
+    assertThat(input, equalTo(mainInput));
+  }
+
+  @Test
+  public void getMainInputSingleOutputSideInputs() {
+    AppliedPTransform<PCollection<Long>, ?, ?> application =
+        AppliedPTransform.of(
+            "application",
+            ImmutableMap.<TupleTag<?>, PValue>builder()
+                .put(new TupleTag<Long>(), mainInput)
+                .put(sideInput.getTagInternal(), sideInput.getPCollection())
+                .build(),
+            Collections.<TupleTag<?>, PValue>singletonMap(new TupleTag<Long>(), output),
+            ParDo.of(new TestDoFn()).withSideInputs(sideInput),
+            pipeline);
+    PCollection<Long> input = PTransformReplacements.getSingletonMainInput(application);
+    assertThat(input, equalTo(mainInput));
+  }
+
+  @Test
+  public void getMainInputExtraMainInputsThrows() {
+    PCollection<Long> notInParDo = pipeline.apply("otherPCollection", Create.of(1L, 2L, 3L));
+    ImmutableMap<TupleTag<?>, PValue> inputs =
+        ImmutableMap.<TupleTag<?>, PValue>builder()
+            .putAll(mainInput.expand())
+            // Not represnted as an input
+            .put(new TupleTag<Long>(), notInParDo)
+            .put(sideInput.getTagInternal(), sideInput.getPCollection())
+            .build();
+    AppliedPTransform<PCollection<Long>, ?, ?> application =
+        AppliedPTransform.of(
+            "application",
+            inputs,
+            Collections.<TupleTag<?>, PValue>singletonMap(new TupleTag<Long>(), output),
+            ParDo.of(new TestDoFn()).withSideInputs(sideInput),
+            pipeline);
+    thrown.expect(IllegalArgumentException.class);
+    thrown.expectMessage("multiple inputs");
+    thrown.expectMessage("not additional inputs");
+    thrown.expectMessage(mainInput.toString());
+    thrown.expectMessage(notInParDo.toString());
+    PTransformReplacements.getSingletonMainInput(application);
+  }
+
+  @Test
+  public void getMainInputNoMainInputsThrows() {
+    ImmutableMap<TupleTag<?>, PValue> inputs =
+        ImmutableMap.<TupleTag<?>, PValue>builder()
+            .put(sideInput.getTagInternal(), sideInput.getPCollection())
+            .build();
+    AppliedPTransform<PCollection<Long>, ?, ?> application =
+        AppliedPTransform.of(
+            "application",
+            inputs,
+            Collections.<TupleTag<?>, PValue>singletonMap(new TupleTag<Long>(), output),
+            ParDo.of(new TestDoFn()).withSideInputs(sideInput),
+            pipeline);
+    thrown.expect(IllegalArgumentException.class);
+    thrown.expectMessage("No main input");
+    PTransformReplacements.getSingletonMainInput(application);
+  }
+
+  private static class TestDoFn extends DoFn<Long, Long> {
+    @ProcessElement public void process(ProcessContext context) {}
+  }
+}

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/SingleInputOutputOverrideFactoryTest.java
----------------------------------------------------------------------
diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/SingleInputOutputOverrideFactoryTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/SingleInputOutputOverrideFactoryTest.java
index 07352f5..acca5cd 100644
--- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/SingleInputOutputOverrideFactoryTest.java
+++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/SingleInputOutputOverrideFactoryTest.java
@@ -24,9 +24,9 @@ import java.io.Serializable;
 import java.util.Map;
 import org.apache.beam.sdk.runners.PTransformOverrideFactory.ReplacementOutput;
 import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.MapElements;
-import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.SimpleFunction;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PCollectionList;
@@ -55,9 +55,15 @@ public class SingleInputOutputOverrideFactoryTest implements Serializable {
               PCollection<? extends Integer>, PCollection<Integer>,
               MapElements<Integer, Integer>>() {
             @Override
-            public PTransform<PCollection<? extends Integer>, PCollection<Integer>>
-                getReplacementTransform(MapElements<Integer, Integer> transform) {
-              return transform;
+            public PTransformReplacement<PCollection<? extends Integer>, PCollection<Integer>>
+                getReplacementTransform(
+                    AppliedPTransform<
+                            PCollection<? extends Integer>, PCollection<Integer>,
+                            MapElements<Integer, Integer>>
+                        transform) {
+              return PTransformReplacement.of(
+                  PTransformReplacements.getSingletonMainInput(transform),
+                  transform.getTransform());
             }
           };
 
@@ -69,23 +75,6 @@ public class SingleInputOutputOverrideFactoryTest implements Serializable {
     };
 
   @Test
-  public void testGetInput() {
-    PCollection<Integer> input = pipeline.apply(Create.of(1, 2, 3));
-    assertThat(
-        factory.getInput(input.expand(), pipeline),
-        Matchers.<PCollection<? extends Integer>>equalTo(input));
-  }
-
-  @Test
-  public void testGetInputMultipleInputsFails() {
-    PCollection<Integer> input = pipeline.apply(Create.of(1, 2, 3));
-    PCollection<Integer> otherInput = pipeline.apply("OtherCreate", Create.of(1, 2, 3));
-
-    thrown.expect(IllegalArgumentException.class);
-    factory.getInput(PCollectionList.of(input).and(otherInput).expand(), pipeline);
-  }
-
-  @Test
   public void testMapOutputs() {
     PCollection<Integer> input = pipeline.apply(Create.of(1, 2, 3));
     PCollection<Integer> output = input.apply("Map", MapElements.via(fn));

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/UnsupportedOverrideFactoryTest.java
----------------------------------------------------------------------
diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/UnsupportedOverrideFactoryTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/UnsupportedOverrideFactoryTest.java
index 81ce00d..6d3b263 100644
--- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/UnsupportedOverrideFactoryTest.java
+++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/UnsupportedOverrideFactoryTest.java
@@ -19,9 +19,7 @@
 package org.apache.beam.runners.core.construction;
 
 import java.util.Collections;
-import org.apache.beam.sdk.coders.VoidCoder;
 import org.apache.beam.sdk.testing.TestPipeline;
-import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.values.PDone;
 import org.apache.beam.sdk.values.PValue;
 import org.apache.beam.sdk.values.TupleTag;
@@ -47,14 +45,7 @@ public class UnsupportedOverrideFactoryTest {
   public void getReplacementTransformThrows() {
     thrown.expect(UnsupportedOperationException.class);
     thrown.expectMessage(message);
-    factory.getReplacementTransform(Create.empty(VoidCoder.of()));
-  }
-
-  @Test
-  public void getInputThrows() {
-    thrown.expect(UnsupportedOperationException.class);
-    thrown.expectMessage(message);
-    factory.getInput(Collections.<TupleTag<?>, PValue>emptyMap(), pipeline);
+    factory.getReplacementTransform(null);
   }
 
   @Test

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGBKIntoKeyedWorkItemsOverrideFactory.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGBKIntoKeyedWorkItemsOverrideFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGBKIntoKeyedWorkItemsOverrideFactory.java
index bb90a6c..1120243 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGBKIntoKeyedWorkItemsOverrideFactory.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGBKIntoKeyedWorkItemsOverrideFactory.java
@@ -19,8 +19,9 @@ package org.apache.beam.runners.direct;
 
 import org.apache.beam.runners.core.KeyedWorkItem;
 import org.apache.beam.runners.core.SplittableParDo.GBKIntoKeyedWorkItems;
+import org.apache.beam.runners.core.construction.PTransformReplacements;
 import org.apache.beam.runners.core.construction.SingleInputOutputOverrideFactory;
-import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
 
@@ -33,8 +34,15 @@ class DirectGBKIntoKeyedWorkItemsOverrideFactory<KeyT, InputT>
         PCollection<KV<KeyT, InputT>>, PCollection<KeyedWorkItem<KeyT, InputT>>,
         GBKIntoKeyedWorkItems<KeyT, InputT>> {
   @Override
-  public PTransform<PCollection<KV<KeyT, InputT>>, PCollection<KeyedWorkItem<KeyT, InputT>>>
-      getReplacementTransform(GBKIntoKeyedWorkItems<KeyT, InputT> transform) {
-    return new DirectGroupByKey.DirectGroupByKeyOnly<>();
+  public PTransformReplacement<
+          PCollection<KV<KeyT, InputT>>, PCollection<KeyedWorkItem<KeyT, InputT>>>
+      getReplacementTransform(
+          AppliedPTransform<
+                  PCollection<KV<KeyT, InputT>>, PCollection<KeyedWorkItem<KeyT, InputT>>,
+                  GBKIntoKeyedWorkItems<KeyT, InputT>>
+              transform) {
+    return PTransformReplacement.of(
+        PTransformReplacements.getSingletonMainInput(transform),
+        new DirectGroupByKey.DirectGroupByKeyOnly<KeyT, InputT>());
   }
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGroupByKeyOverrideFactory.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGroupByKeyOverrideFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGroupByKeyOverrideFactory.java
index f3b718f..4eb0363 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGroupByKeyOverrideFactory.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGroupByKeyOverrideFactory.java
@@ -17,10 +17,11 @@
  */
 package org.apache.beam.runners.direct;
 
+import org.apache.beam.runners.core.construction.PTransformReplacements;
 import org.apache.beam.runners.core.construction.SingleInputOutputOverrideFactory;
 import org.apache.beam.sdk.runners.PTransformOverrideFactory;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.GroupByKey;
-import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
 
@@ -29,8 +30,13 @@ final class DirectGroupByKeyOverrideFactory<K, V>
     extends SingleInputOutputOverrideFactory<
         PCollection<KV<K, V>>, PCollection<KV<K, Iterable<V>>>, GroupByKey<K, V>> {
   @Override
-  public PTransform<PCollection<KV<K, V>>, PCollection<KV<K, Iterable<V>>>> getReplacementTransform(
-      GroupByKey<K, V> transform) {
-    return new DirectGroupByKey<>(transform);
+  public PTransformReplacement<PCollection<KV<K, V>>, PCollection<KV<K, Iterable<V>>>>
+      getReplacementTransform(
+          AppliedPTransform<
+                  PCollection<KV<K, V>>, PCollection<KV<K, Iterable<V>>>, GroupByKey<K, V>>
+              transform) {
+    return PTransformReplacement.of(
+        PTransformReplacements.getSingletonMainInput(transform),
+        new DirectGroupByKey<>(transform.getTransform()));
   }
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java
index 366777b..b08aa8e 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java
@@ -19,18 +19,18 @@ package org.apache.beam.runners.direct;
 
 import static com.google.common.base.Preconditions.checkState;
 
-import com.google.common.collect.Iterables;
 import java.util.Map;
 import org.apache.beam.runners.core.KeyedWorkItem;
 import org.apache.beam.runners.core.KeyedWorkItemCoder;
 import org.apache.beam.runners.core.KeyedWorkItems;
 import org.apache.beam.runners.core.SplittableParDo;
+import org.apache.beam.runners.core.construction.PTransformReplacements;
 import org.apache.beam.runners.core.construction.ReplacementOutputs;
-import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.coders.CannotProvideCoderException;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.runners.PTransformOverrideFactory;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.GroupByKey;
 import org.apache.beam.sdk.transforms.PTransform;
@@ -62,8 +62,18 @@ class ParDoMultiOverrideFactory<InputT, OutputT>
     implements PTransformOverrideFactory<
         PCollection<? extends InputT>, PCollectionTuple, MultiOutput<InputT, OutputT>> {
   @Override
+  public PTransformReplacement<PCollection<? extends InputT>, PCollectionTuple>
+      getReplacementTransform(
+          AppliedPTransform<
+                  PCollection<? extends InputT>, PCollectionTuple, MultiOutput<InputT, OutputT>>
+              transform) {
+    return PTransformReplacement.of(
+        PTransformReplacements.getSingletonMainInput(transform),
+        getReplacementTransform(transform.getTransform()));
+  }
+
   @SuppressWarnings("unchecked")
-  public PTransform<PCollection<? extends InputT>, PCollectionTuple> getReplacementTransform(
+  private PTransform<PCollection<? extends InputT>, PCollectionTuple> getReplacementTransform(
       MultiOutput<InputT, OutputT> transform) {
 
     DoFn<InputT, OutputT> fn = transform.getFn();
@@ -84,12 +94,6 @@ class ParDoMultiOverrideFactory<InputT, OutputT>
   }
 
   @Override
-  public PCollection<? extends InputT> getInput(
-      Map<TupleTag<?>, PValue> inputs, Pipeline p) {
-    return (PCollection<? extends InputT>) Iterables.getOnlyElement(inputs.values());
-  }
-
-  @Override
   public Map<PValue, ReplacementOutput> mapOutputs(
       Map<TupleTag<?>, PValue> outputs, PCollectionTuple newOutput) {
     return ReplacementOutputs.tagged(outputs, newOutput);

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactory.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactory.java
index 6e0a4fc..cba754e 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactory.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactory.java
@@ -31,7 +31,6 @@ import javax.annotation.Nullable;
 import org.apache.beam.runners.core.construction.ReplacementOutputs;
 import org.apache.beam.runners.direct.DirectRunner.CommittedBundle;
 import org.apache.beam.runners.direct.DirectRunner.UncommittedBundle;
-import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.runners.PTransformOverrideFactory;
 import org.apache.beam.sdk.testing.TestStream;
 import org.apache.beam.sdk.testing.TestStream.ElementEvent;
@@ -170,14 +169,11 @@ class TestStreamEvaluatorFactory implements TransformEvaluatorFactory {
     }
 
     @Override
-    public PTransform<PBegin, PCollection<T>> getReplacementTransform(
-        TestStream<T> transform) {
-      return new DirectTestStream<>(runner, transform);
-    }
-
-    @Override
-    public PBegin getInput(Map<TupleTag<?>, PValue> inputs, Pipeline p) {
-      return p.begin();
+    public PTransformReplacement<PBegin, PCollection<T>> getReplacementTransform(
+        AppliedPTransform<PBegin, PCollection<T>, TestStream<T>> transform) {
+      return PTransformReplacement.of(
+          transform.getPipeline().begin(),
+          new DirectTestStream<T>(runner, transform.getTransform()));
     }
 
     @Override

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ViewOverrideFactory.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ViewOverrideFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ViewOverrideFactory.java
index 52dc329..d4fd18f 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ViewOverrideFactory.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ViewOverrideFactory.java
@@ -18,14 +18,14 @@
 
 package org.apache.beam.runners.direct;
 
-import com.google.common.collect.Iterables;
 import java.util.Collections;
 import java.util.Map;
 import org.apache.beam.runners.core.construction.ForwardingPTransform;
-import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.runners.core.construction.PTransformReplacements;
 import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.coders.VoidCoder;
 import org.apache.beam.sdk.runners.PTransformOverrideFactory;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.GroupByKey;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.Values;
@@ -43,15 +43,15 @@ import org.apache.beam.sdk.values.TupleTag;
 class ViewOverrideFactory<ElemT, ViewT>
     implements PTransformOverrideFactory<
         PCollection<ElemT>, PCollectionView<ViewT>, CreatePCollectionView<ElemT, ViewT>> {
-  @Override
-  public PTransform<PCollection<ElemT>, PCollectionView<ViewT>> getReplacementTransform(
-      CreatePCollectionView<ElemT, ViewT> transform) {
-    return new GroupAndWriteView<>(transform);
-  }
 
   @Override
-  public PCollection<ElemT> getInput(Map<TupleTag<?>, PValue> inputs, Pipeline p) {
-    return (PCollection<ElemT>) Iterables.getOnlyElement(inputs.values());
+  public PTransformReplacement<PCollection<ElemT>, PCollectionView<ViewT>> getReplacementTransform(
+      AppliedPTransform<
+              PCollection<ElemT>, PCollectionView<ViewT>, CreatePCollectionView<ElemT, ViewT>>
+          transform) {
+    return PTransformReplacement.of(
+        PTransformReplacements.getSingletonMainInput(transform),
+        new GroupAndWriteView<>(transform.getTransform()));
   }
 
   @Override

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java
index b3f92ab..a23ab94 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java
@@ -21,14 +21,14 @@ package org.apache.beam.runners.direct;
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Supplier;
 import com.google.common.base.Suppliers;
-import com.google.common.collect.Iterables;
 import java.io.Serializable;
 import java.util.Collections;
 import java.util.Map;
 import java.util.concurrent.ThreadLocalRandom;
-import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.runners.core.construction.PTransformReplacements;
 import org.apache.beam.sdk.io.Write;
 import org.apache.beam.sdk.runners.PTransformOverrideFactory;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.Count;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.PTransform;
@@ -53,14 +53,12 @@ class WriteWithShardingFactory<InputT>
   @VisibleForTesting static final int MIN_SHARDS_FOR_LOG = 3;
 
   @Override
-  public PTransform<PCollection<InputT>, PDone> getReplacementTransform(
-      Write<InputT> transform) {
-    return transform.withSharding(new LogElementShardsWithDrift<InputT>());
-  }
+  public PTransformReplacement<PCollection<InputT>, PDone> getReplacementTransform(
+      AppliedPTransform<PCollection<InputT>, PDone, Write<InputT>> transform) {
 
-  @Override
-  public PCollection<InputT> getInput(Map<TupleTag<?>, PValue> inputs, Pipeline p) {
-    return (PCollection<InputT>) Iterables.getOnlyElement(inputs.values());
+    return PTransformReplacement.of(
+        PTransformReplacements.getSingletonMainInput(transform),
+        transform.getTransform().withSharding(new LogElementShardsWithDrift<InputT>()));
   }
 
   @Override

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectGroupByKeyOverrideFactoryTest.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectGroupByKeyOverrideFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectGroupByKeyOverrideFactoryTest.java
index c9fdda0..28fef4c 100644
--- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectGroupByKeyOverrideFactoryTest.java
+++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/DirectGroupByKeyOverrideFactoryTest.java
@@ -23,8 +23,11 @@ import static org.junit.Assert.assertThat;
 import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.coders.VarIntCoder;
+import org.apache.beam.sdk.runners.PTransformOverrideFactory.PTransformReplacement;
 import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.GroupByKey;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
 import org.hamcrest.Matchers;
@@ -45,7 +48,12 @@ public class DirectGroupByKeyOverrideFactoryTest {
         p.apply(
             Create.of(KV.of("foo", 1))
                 .withCoder(KvCoder.of(StringUtf8Coder.of(), VarIntCoder.of())));
-    PCollection<?> reconstructed = factory.getInput(input.expand(), p);
-    assertThat(reconstructed, Matchers.<PCollection<?>>equalTo(input));
+    PCollection<KV<String, Iterable<Integer>>> grouped =
+        input.apply(GroupByKey.<String, Integer>create());
+    AppliedPTransform<?, ?, ?> producer = DirectGraphs.getProducer(grouped);
+    PTransformReplacement<
+            PCollection<KV<String, Integer>>, PCollection<KV<String, Iterable<Integer>>>>
+        replacement = factory.getReplacementTransform((AppliedPTransform) producer);
+    assertThat(replacement.getInput(), Matchers.<PCollection<?>>equalTo(input));
   }
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactoryTest.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactoryTest.java
deleted file mode 100644
index 4bbf924..0000000
--- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactoryTest.java
+++ /dev/null
@@ -1,45 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.beam.runners.direct;
-
-import static org.junit.Assert.assertThat;
-
-import org.apache.beam.sdk.testing.TestPipeline;
-import org.apache.beam.sdk.transforms.Create;
-import org.apache.beam.sdk.values.PCollection;
-import org.hamcrest.Matchers;
-import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.junit.runners.JUnit4;
-
-/**
- * Tests for {@link ParDoMultiOverrideFactory}.
- */
-@RunWith(JUnit4.class)
-public class ParDoMultiOverrideFactoryTest {
-  private ParDoMultiOverrideFactory factory = new ParDoMultiOverrideFactory();
-
-  @Test
-  public void getInputSucceeds() {
-    TestPipeline p = TestPipeline.create();
-    PCollection<Integer> input = p.apply(Create.of(1, 2, 3));
-    PCollection<?> reconstructed = factory.getInput(input.expand(), p);
-    assertThat(reconstructed, Matchers.<PCollection<?>>equalTo(input));
-  }
-}

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/direct-java/src/test/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactoryTest.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactoryTest.java
index 0d909c2..b9c6e64 100644
--- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactoryTest.java
+++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactoryTest.java
@@ -27,22 +27,17 @@ import com.google.common.collect.Iterables;
 import java.util.Collection;
 import java.util.Collections;
 import org.apache.beam.runners.direct.DirectRunner.CommittedBundle;
-import org.apache.beam.runners.direct.TestStreamEvaluatorFactory.DirectTestStreamFactory;
 import org.apache.beam.runners.direct.TestStreamEvaluatorFactory.DirectTestStreamFactory.DirectTestStream;
 import org.apache.beam.runners.direct.TestStreamEvaluatorFactory.TestClock;
 import org.apache.beam.runners.direct.TestStreamEvaluatorFactory.TestStreamIndex;
-import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.coders.VarIntCoder;
 import org.apache.beam.sdk.testing.TestPipeline;
 import org.apache.beam.sdk.testing.TestStream;
 import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.util.WindowedValue;
-import org.apache.beam.sdk.values.PBegin;
 import org.apache.beam.sdk.values.PCollection;
-import org.apache.beam.sdk.values.PValue;
 import org.apache.beam.sdk.values.TimestampedValue;
-import org.apache.beam.sdk.values.TupleTag;
 import org.hamcrest.Matchers;
 import org.joda.time.Duration;
 import org.joda.time.Instant;
@@ -180,11 +175,4 @@ public class TestStreamEvaluatorFactoryTest {
     assertThat(fifthResult.getWatermarkHold(), equalTo(BoundedWindow.TIMESTAMP_MAX_VALUE));
     assertThat(fifthResult.getUnprocessedElements(), Matchers.emptyIterable());
   }
-
-  @Test
-  public void overrideFactoryGetInputSucceeds() {
-    DirectTestStreamFactory<?> factory = new DirectTestStreamFactory<>(runner);
-    PBegin begin = factory.getInput(Collections.<TupleTag<?>, PValue>emptyMap(), p);
-    assertThat(begin.getPipeline(), Matchers.<Pipeline>equalTo(p));
-  }
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ViewOverrideFactoryTest.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ViewOverrideFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ViewOverrideFactoryTest.java
index 258cb46..6875e1a 100644
--- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ViewOverrideFactoryTest.java
+++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ViewOverrideFactoryTest.java
@@ -30,12 +30,13 @@ import java.util.Set;
 import java.util.concurrent.atomic.AtomicBoolean;
 import org.apache.beam.runners.direct.ViewOverrideFactory.WriteView;
 import org.apache.beam.sdk.Pipeline.PipelineVisitor;
+import org.apache.beam.sdk.runners.PTransformOverrideFactory.PTransformReplacement;
 import org.apache.beam.sdk.runners.TransformHierarchy.Node;
 import org.apache.beam.sdk.testing.PAssert;
 import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.DoFn;
-import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.View.CreatePCollectionView;
 import org.apache.beam.sdk.util.PCollectionViews;
@@ -62,9 +63,20 @@ public class ViewOverrideFactoryTest implements Serializable {
     PCollection<Integer> ints = p.apply("CreateContents", Create.of(1, 2, 3));
     final PCollectionView<List<Integer>> view =
         PCollectionViews.listView(ints, WindowingStrategy.globalDefault(), ints.getCoder());
-    PTransform<PCollection<Integer>, PCollectionView<List<Integer>>> replacementTransform =
-        factory.getReplacementTransform(CreatePCollectionView.<Integer, List<Integer>>of(view));
-    PCollectionView<List<Integer>> afterReplacement = ints.apply(replacementTransform);
+    PTransformReplacement<PCollection<Integer>, PCollectionView<List<Integer>>>
+        replacementTransform =
+            factory.getReplacementTransform(
+                AppliedPTransform
+                    .<PCollection<Integer>, PCollectionView<List<Integer>>,
+                        CreatePCollectionView<Integer, List<Integer>>>
+                        of(
+                            "foo",
+                            ints.expand(),
+                            view.expand(),
+                            CreatePCollectionView.<Integer, List<Integer>>of(view),
+                            p));
+    PCollectionView<List<Integer>> afterReplacement =
+        ints.apply(replacementTransform.getTransform());
     assertThat(
         "The CreatePCollectionView replacement should return the same View",
         afterReplacement,
@@ -92,9 +104,18 @@ public class ViewOverrideFactoryTest implements Serializable {
     final PCollection<Integer> ints = p.apply("CreateContents", Create.of(1, 2, 3));
     final PCollectionView<List<Integer>> view =
         PCollectionViews.listView(ints, WindowingStrategy.globalDefault(), ints.getCoder());
-    PTransform<PCollection<Integer>, PCollectionView<List<Integer>>> replacement =
-        factory.getReplacementTransform(CreatePCollectionView.<Integer, List<Integer>>of(view));
-    ints.apply(replacement);
+    PTransformReplacement<PCollection<Integer>, PCollectionView<List<Integer>>> replacement =
+        factory.getReplacementTransform(
+            AppliedPTransform
+                .<PCollection<Integer>, PCollectionView<List<Integer>>,
+                    CreatePCollectionView<Integer, List<Integer>>>
+                    of(
+                        "foo",
+                        ints.expand(),
+                        view.expand(),
+                        CreatePCollectionView.<Integer, List<Integer>>of(view),
+                        p));
+    ints.apply(replacement.getTransform());
     final AtomicBoolean writeViewVisited = new AtomicBoolean();
     p.traverseTopologically(
         new PipelineVisitor.Defaults() {
@@ -114,11 +135,4 @@ public class ViewOverrideFactoryTest implements Serializable {
 
     assertThat(writeViewVisited.get(), is(true));
   }
-
-  @Test
-  public void overrideFactoryGetInputSucceeds() {
-    ViewOverrideFactory<String, String> factory = new ViewOverrideFactory<>();
-    PCollection<String> input = p.apply(Create.of("foo", "bar"));
-    assertThat(factory.getInput(input.expand(), p), equalTo(input));
-  }
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java
index 8720fd1..361850d 100644
--- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java
+++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java
@@ -38,11 +38,13 @@ import java.util.List;
 import java.util.UUID;
 import org.apache.beam.runners.direct.WriteWithShardingFactory.CalculateShardsFn;
 import org.apache.beam.sdk.coders.VarLongCoder;
+import org.apache.beam.sdk.coders.VoidCoder;
 import org.apache.beam.sdk.io.Sink;
 import org.apache.beam.sdk.io.TextIO;
 import org.apache.beam.sdk.io.Write;
 import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.DoFnTester;
@@ -52,7 +54,9 @@ import org.apache.beam.sdk.util.PCollectionViews;
 import org.apache.beam.sdk.util.WindowingStrategy;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PCollectionView;
-import org.hamcrest.Matchers;
+import org.apache.beam.sdk.values.PDone;
+import org.apache.beam.sdk.values.PValue;
+import org.apache.beam.sdk.values.TupleTag;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.TemporaryFolder;
@@ -118,7 +122,15 @@ public class WriteWithShardingFactoryTest {
   @Test
   public void withNoShardingSpecifiedReturnsNewTransform() {
     Write<Object> original = Write.to(new TestSink());
-    assertThat(factory.getReplacementTransform(original), not(equalTo((Object) original)));
+    PCollection<Object> objs = (PCollection) p.apply(Create.empty(VoidCoder.of()));
+
+    AppliedPTransform<PCollection<Object>, PDone, Write<Object>> originalApplication =
+        AppliedPTransform.of(
+            "write", objs.expand(), Collections.<TupleTag<?>, PValue>emptyMap(), original, p);
+
+    assertThat(
+        factory.getReplacementTransform(originalApplication).getTransform(),
+        not(equalTo((Object) original)));
   }
 
   @Test
@@ -195,13 +207,6 @@ public class WriteWithShardingFactoryTest {
     assertThat(shards, containsInAnyOrder(13));
   }
 
-  @Test
-  public void getInputSucceeds() {
-    PCollection<String> original = p.apply(Create.of("foo"));
-    PCollection<?> input = factory.getInput(original.expand(), p);
-    assertThat(input, Matchers.<PCollection<?>>equalTo(original));
-  }
-
   private static class TestSink extends Sink<Object> {
     @Override
     public void validate(PipelineOptions options) {}

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPipelineTranslator.java
----------------------------------------------------------------------
diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPipelineTranslator.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPipelineTranslator.java
index 70da2b3..0459ef7 100644
--- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPipelineTranslator.java
+++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPipelineTranslator.java
@@ -18,11 +18,11 @@
 package org.apache.beam.runners.flink;
 
 import com.google.common.collect.ImmutableList;
-import com.google.common.collect.Iterables;
 import java.util.List;
 import java.util.Map;
 import org.apache.beam.runners.core.SplittableParDo;
 import org.apache.beam.runners.core.construction.PTransformMatchers;
+import org.apache.beam.runners.core.construction.PTransformReplacements;
 import org.apache.beam.runners.core.construction.ReplacementOutputs;
 import org.apache.beam.runners.core.construction.SingleInputOutputOverrideFactory;
 import org.apache.beam.sdk.Pipeline;
@@ -30,9 +30,9 @@ import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.runners.PTransformOverride;
 import org.apache.beam.sdk.runners.PTransformOverrideFactory;
 import org.apache.beam.sdk.runners.TransformHierarchy;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.Combine;
 import org.apache.beam.sdk.transforms.PTransform;
-import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.ParDo.MultiOutput;
 import org.apache.beam.sdk.transforms.View;
 import org.apache.beam.sdk.util.InstanceBuilder;
@@ -221,46 +221,50 @@ class FlinkStreamingPipelineTranslator extends FlinkPipelineTranslator {
   }
 
   private static class ReflectiveOneToOneOverrideFactory<
-      InputT extends PValue,
-      OutputT extends PValue,
-      TransformT extends PTransform<InputT, OutputT>>
-      extends SingleInputOutputOverrideFactory<InputT, OutputT, TransformT> {
-    private final Class<PTransform<InputT, OutputT>> replacement;
+          InputT, OutputT, TransformT extends PTransform<PCollection<InputT>, PCollection<OutputT>>>
+      extends SingleInputOutputOverrideFactory<
+          PCollection<InputT>, PCollection<OutputT>, TransformT> {
+    private final Class<PTransform<PCollection<InputT>, PCollection<OutputT>>> replacement;
     private final FlinkRunner runner;
 
     private ReflectiveOneToOneOverrideFactory(
-        Class<PTransform<InputT, OutputT>> replacement, FlinkRunner runner) {
+        Class<PTransform<PCollection<InputT>, PCollection<OutputT>>> replacement,
+        FlinkRunner runner) {
       this.replacement = replacement;
       this.runner = runner;
     }
 
     @Override
-    public PTransform<InputT, OutputT> getReplacementTransform(TransformT transform) {
-      return InstanceBuilder.ofType(replacement)
-          .withArg(FlinkRunner.class, runner)
-          .withArg((Class<PTransform<InputT, OutputT>>) transform.getClass(), transform)
-          .build();
+    public PTransformReplacement<PCollection<InputT>, PCollection<OutputT>> getReplacementTransform(
+        AppliedPTransform<PCollection<InputT>, PCollection<OutputT>, TransformT> transform) {
+      return PTransformReplacement.of(
+          PTransformReplacements.getSingletonMainInput(transform),
+          InstanceBuilder.ofType(replacement)
+              .withArg(FlinkRunner.class, runner)
+              .withArg(
+                  (Class<PTransform<PCollection<InputT>, PCollection<OutputT>>>)
+                      transform.getTransform().getClass(),
+                  transform.getTransform())
+              .build());
     }
   }
 
   /**
-   * A {@link PTransformOverrideFactory} that overrides a
-   * <a href="https://s.apache.org/splittable-do-fn">Splittable DoFn</a> with
-   * {@link SplittableParDo}.
+   * A {@link PTransformOverrideFactory} that overrides a <a
+   * href="https://s.apache.org/splittable-do-fn">Splittable DoFn</a> with {@link SplittableParDo}.
    */
   static class SplittableParDoOverrideFactory<InputT, OutputT>
       implements PTransformOverrideFactory<
-            PCollection<? extends InputT>, PCollectionTuple, MultiOutput<InputT, OutputT>> {
+          PCollection<InputT>, PCollectionTuple, MultiOutput<InputT, OutputT>> {
     @Override
-    @SuppressWarnings("unchecked")
-    public PTransform<PCollection<? extends InputT>, PCollectionTuple> getReplacementTransform(
-        ParDo.MultiOutput<InputT, OutputT> transform) {
-      return new SplittableParDo(transform);
-    }
-
-    @Override
-    public PCollection<? extends InputT> getInput(Map<TupleTag<?>, PValue> inputs, Pipeline p) {
-      return (PCollection<? extends InputT>) Iterables.getOnlyElement(inputs.values());
+    public PTransformReplacement<PCollection<InputT>, PCollectionTuple>
+        getReplacementTransform(
+            AppliedPTransform<
+                    PCollection<InputT>, PCollectionTuple, MultiOutput<InputT, OutputT>>
+                transform) {
+      return PTransformReplacement.of(
+          PTransformReplacements.getSingletonMainInput(transform),
+          new SplittableParDo<>(transform.getTransform()));
     }
 
     @Override

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchStatefulParDoOverrides.java
----------------------------------------------------------------------
diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchStatefulParDoOverrides.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchStatefulParDoOverrides.java
index 73f3728..119c9c9 100644
--- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchStatefulParDoOverrides.java
+++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchStatefulParDoOverrides.java
@@ -19,19 +19,21 @@ package org.apache.beam.runners.dataflow;
 
 import static com.google.common.base.Preconditions.checkState;
 
-import com.google.common.collect.Iterables;
 import java.util.Map;
+import org.apache.beam.runners.core.construction.PTransformReplacements;
 import org.apache.beam.runners.core.construction.ReplacementOutputs;
 import org.apache.beam.runners.dataflow.BatchViewOverrides.GroupByKeyAndSortValuesOnly;
-import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.InstantCoder;
 import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.runners.PTransformOverrideFactory;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.GroupByKey;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.ParDo.MultiOutput;
+import org.apache.beam.sdk.transforms.ParDo.SingleOutput;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
 import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
@@ -85,15 +87,15 @@ public class BatchStatefulParDoOverrides {
           ParDo.SingleOutput<KV<K, InputT>, OutputT>> {
 
     @Override
-    @SuppressWarnings("unchecked")
-    public PTransform<PCollection<KV<K, InputT>>, PCollection<OutputT>> getReplacementTransform(
-        ParDo.SingleOutput<KV<K, InputT>, OutputT> originalParDo) {
-      return new StatefulSingleOutputParDo<>(originalParDo);
-    }
-
-    @Override
-    public PCollection<KV<K, InputT>> getInput(Map<TupleTag<?>, PValue> inputs, Pipeline p) {
-      return (PCollection<KV<K, InputT>>) Iterables.getOnlyElement(inputs.values());
+    public PTransformReplacement<PCollection<KV<K, InputT>>, PCollection<OutputT>>
+        getReplacementTransform(
+            AppliedPTransform<
+                    PCollection<KV<K, InputT>>, PCollection<OutputT>,
+                    SingleOutput<KV<K, InputT>, OutputT>>
+                transform) {
+      return PTransformReplacement.of(
+          PTransformReplacements.getSingletonMainInput(transform),
+          new StatefulSingleOutputParDo<>(transform.getTransform()));
     }
 
     @Override
@@ -108,15 +110,15 @@ public class BatchStatefulParDoOverrides {
           PCollection<KV<K, InputT>>, PCollectionTuple, ParDo.MultiOutput<KV<K, InputT>, OutputT>> {
 
     @Override
-    @SuppressWarnings("unchecked")
-    public PTransform<PCollection<KV<K, InputT>>, PCollectionTuple> getReplacementTransform(
-        ParDo.MultiOutput<KV<K, InputT>, OutputT> originalParDo) {
-      return new StatefulMultiOutputParDo<>(originalParDo);
-    }
-
-    @Override
-    public PCollection<KV<K, InputT>> getInput(Map<TupleTag<?>, PValue> inputs, Pipeline p) {
-      return (PCollection<KV<K, InputT>>) Iterables.getOnlyElement(inputs.values());
+    public PTransformReplacement<PCollection<KV<K, InputT>>, PCollectionTuple>
+        getReplacementTransform(
+            AppliedPTransform<
+                    PCollection<KV<K, InputT>>, PCollectionTuple,
+                    MultiOutput<KV<K, InputT>, OutputT>>
+                transform) {
+      return PTransformReplacement.of(
+          PTransformReplacements.getSingletonMainInput(transform),
+          new StatefulMultiOutputParDo<>(transform.getTransform()));
     }
 
     @Override

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchViewOverrides.java
----------------------------------------------------------------------
diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchViewOverrides.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchViewOverrides.java
index ead2712..1565fd1 100644
--- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchViewOverrides.java
+++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchViewOverrides.java
@@ -42,6 +42,7 @@ import java.util.HashMap;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
+import org.apache.beam.runners.core.construction.PTransformReplacements;
 import org.apache.beam.runners.core.construction.SingleInputOutputOverrideFactory;
 import org.apache.beam.runners.dataflow.internal.IsmFormat;
 import org.apache.beam.runners.dataflow.internal.IsmFormat.IsmRecord;
@@ -59,6 +60,7 @@ import org.apache.beam.sdk.coders.SerializableCoder;
 import org.apache.beam.sdk.coders.StandardCoder;
 import org.apache.beam.sdk.coders.VarIntCoder;
 import org.apache.beam.sdk.coders.VarLongCoder;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.Combine;
 import org.apache.beam.sdk.transforms.Combine.GloballyAsSingletonView;
 import org.apache.beam.sdk.transforms.CombineFnBase.GlobalCombineFn;
@@ -1404,10 +1406,17 @@ class BatchViewOverrides {
     }
 
     @Override
-    public PTransform<PCollection<ElemT>, PCollectionView<ViewT>> getReplacementTransform(
-        final GloballyAsSingletonView<ElemT, ViewT> transform) {
-      return new BatchCombineGloballyAsSingletonView<>(
-          runner, transform.getCombineFn(), transform.getFanout(), transform.getInsertDefault());
+    public PTransformReplacement<PCollection<ElemT>, PCollectionView<ViewT>>
+        getReplacementTransform(
+            AppliedPTransform<
+                    PCollection<ElemT>, PCollectionView<ViewT>,
+                    GloballyAsSingletonView<ElemT, ViewT>>
+                transform) {
+      GloballyAsSingletonView<ElemT, ViewT> combine = transform.getTransform();
+      return PTransformReplacement.of(
+          PTransformReplacements.getSingletonMainInput(transform),
+          new BatchCombineGloballyAsSingletonView<>(
+              runner, combine.getCombineFn(), combine.getFanout(), combine.getInsertDefault()));
     }
 
     private static class BatchCombineGloballyAsSingletonView<ElemT, ViewT>


[2/3] beam git commit: Update Signature of PTransformOverrideFactory

Posted by tg...@apache.org.
http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
----------------------------------------------------------------------
diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
index 684dc14..4eec6b8 100644
--- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
+++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
@@ -61,6 +61,7 @@ import java.util.TreeSet;
 import org.apache.beam.runners.core.construction.DeduplicatedFlattenFactory;
 import org.apache.beam.runners.core.construction.EmptyFlattenAsCreateFactory;
 import org.apache.beam.runners.core.construction.PTransformMatchers;
+import org.apache.beam.runners.core.construction.PTransformReplacements;
 import org.apache.beam.runners.core.construction.ReplacementOutputs;
 import org.apache.beam.runners.core.construction.SingleInputOutputOverrideFactory;
 import org.apache.beam.runners.core.construction.UnboundedReadFromBoundedSource;
@@ -96,6 +97,7 @@ import org.apache.beam.sdk.runners.PipelineRunner;
 import org.apache.beam.sdk.runners.TransformHierarchy;
 import org.apache.beam.sdk.runners.TransformHierarchy.Node;
 import org.apache.beam.sdk.transforms.Aggregator;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.Combine;
 import org.apache.beam.sdk.transforms.Combine.GroupedValues;
 import org.apache.beam.sdk.transforms.DoFn;
@@ -390,25 +392,29 @@ public class DataflowRunner extends PipelineRunner<DataflowPipelineJob> {
   }
 
   private static class ReflectiveOneToOneOverrideFactory<
-          InputT extends PValue,
-          OutputT extends PValue,
-          TransformT extends PTransform<InputT, OutputT>>
-      extends SingleInputOutputOverrideFactory<InputT, OutputT, TransformT> {
-    private final Class<PTransform<InputT, OutputT>> replacement;
+          InputT, OutputT, TransformT extends PTransform<PCollection<InputT>, PCollection<OutputT>>>
+      extends SingleInputOutputOverrideFactory<
+          PCollection<InputT>, PCollection<OutputT>, TransformT> {
+    private final Class<PTransform<PCollection<InputT>, PCollection<OutputT>>> replacement;
     private final DataflowRunner runner;
 
     private ReflectiveOneToOneOverrideFactory(
-        Class<PTransform<InputT, OutputT>> replacement, DataflowRunner runner) {
+        Class<PTransform<PCollection<InputT>, PCollection<OutputT>>> replacement,
+        DataflowRunner runner) {
       this.replacement = replacement;
       this.runner = runner;
     }
 
     @Override
-    public PTransform<InputT, OutputT> getReplacementTransform(TransformT transform) {
-      return InstanceBuilder.ofType(replacement)
-          .withArg(DataflowRunner.class, runner)
-          .withArg((Class<PTransform<InputT, OutputT>>) transform.getClass(), transform)
-          .build();
+    public PTransformReplacement<PCollection<InputT>, PCollection<OutputT>> getReplacementTransform(
+        AppliedPTransform<PCollection<InputT>, PCollection<OutputT>, TransformT> transform) {
+      PTransform<PCollection<InputT>, PCollection<OutputT>> rep =
+          InstanceBuilder.ofType(replacement)
+              .withArg(DataflowRunner.class, runner)
+              .withArg(
+                  (Class<TransformT>) transform.getTransform().getClass(), transform.getTransform())
+              .build();
+      return PTransformReplacement.of(PTransformReplacements.getSingletonMainInput(transform), rep);
     }
   }
 
@@ -423,19 +429,18 @@ public class DataflowRunner extends PipelineRunner<DataflowPipelineJob> {
       this.replacement = replacement;
       this.runner = runner;
     }
-    @Override
-    public PTransform<PBegin, PCollection<T>> getReplacementTransform(
-        PTransform<PInput, PCollection<T>> transform) {
-      return InstanceBuilder.ofType(replacement)
-          .withArg(DataflowRunner.class, runner)
-          .withArg(
-              (Class<? super PTransform<PInput, PCollection<T>>>) transform.getClass(), transform)
-          .build();
-    }
 
     @Override
-    public PBegin getInput(Map<TupleTag<?>, PValue> inputs, Pipeline p) {
-      return p.begin();
+    public PTransformReplacement<PBegin, PCollection<T>> getReplacementTransform(
+        AppliedPTransform<PBegin, PCollection<T>, PTransform<PInput, PCollection<T>>> transform) {
+      PTransform<PInput, PCollection<T>> original = transform.getTransform();
+      return PTransformReplacement.of(
+          transform.getPipeline().begin(),
+          InstanceBuilder.ofType(replacement)
+              .withArg(DataflowRunner.class, runner)
+              .withArg(
+                  (Class<? super PTransform<PInput, PCollection<T>>>) original.getClass(), original)
+              .build());
     }
 
     @Override
@@ -805,13 +810,11 @@ public class DataflowRunner extends PipelineRunner<DataflowPipelineJob> {
     }
 
     @Override
-    public PTransform<PCollection<T>, PDone> getReplacementTransform(Write<T> transform) {
-      return new BatchWrite<>(runner, transform);
-    }
-
-    @Override
-    public PCollection<T> getInput(Map<TupleTag<?>, PValue> inputs, Pipeline p) {
-      return (PCollection<T>) Iterables.getOnlyElement(inputs.values());
+    public PTransformReplacement<PCollection<T>, PDone> getReplacementTransform(
+        AppliedPTransform<PCollection<T>, PDone, Write<T>> transform) {
+      return PTransformReplacement.of(
+          PTransformReplacements.getSingletonMainInput(transform),
+          new BatchWrite<>(runner, transform.getTransform()));
     }
 
     @Override
@@ -1295,15 +1298,15 @@ public class DataflowRunner extends PipelineRunner<DataflowPipelineJob> {
           PCollection<KV<K, Iterable<InputT>>>, PCollection<KV<K, OutputT>>,
           Combine.GroupedValues<K, InputT, OutputT>> {
     @Override
-    public PTransform<PCollection<KV<K, Iterable<InputT>>>, PCollection<KV<K, OutputT>>>
-        getReplacementTransform(GroupedValues<K, InputT, OutputT> transform) {
-      return new CombineGroupedValues<>(transform);
-    }
-
-    @Override
-    public PCollection<KV<K, Iterable<InputT>>> getInput(
-        Map<TupleTag<?>, PValue> inputs, Pipeline p) {
-      return (PCollection<KV<K, Iterable<InputT>>>) Iterables.getOnlyElement(inputs.values());
+    public PTransformReplacement<PCollection<KV<K, Iterable<InputT>>>, PCollection<KV<K, OutputT>>>
+        getReplacementTransform(
+            AppliedPTransform<
+                    PCollection<KV<K, Iterable<InputT>>>, PCollection<KV<K, OutputT>>,
+                    GroupedValues<K, InputT, OutputT>>
+                transform) {
+      return PTransformReplacement.of(
+          PTransformReplacements.getSingletonMainInput(transform),
+          new CombineGroupedValues<>(transform.getTransform()));
     }
 
     @Override
@@ -1322,14 +1325,11 @@ public class DataflowRunner extends PipelineRunner<DataflowPipelineJob> {
     }
 
     @Override
-    public PTransform<PCollection<T>, PDone> getReplacementTransform(
-        PubsubUnboundedSink<T> transform) {
-      return new StreamingPubsubIOWrite<>(runner, transform);
-    }
-
-    @Override
-    public PCollection<T> getInput(Map<TupleTag<?>, PValue> inputs, Pipeline p) {
-      return (PCollection<T>) Iterables.getOnlyElement(inputs.values());
+    public PTransformReplacement<PCollection<T>, PDone> getReplacementTransform(
+        AppliedPTransform<PCollection<T>, PDone, PubsubUnboundedSink<T>> transform) {
+      return PTransformReplacement.of(
+          PTransformReplacements.getSingletonMainInput(transform),
+          new StreamingPubsubIOWrite<>(runner, transform.getTransform()));
     }
 
     @Override

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactory.java
----------------------------------------------------------------------
diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactory.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactory.java
index db50cc2..2e50cb5 100644
--- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactory.java
+++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactory.java
@@ -20,12 +20,15 @@ package org.apache.beam.runners.dataflow;
 
 import java.util.List;
 import org.apache.beam.runners.core.construction.ForwardingPTransform;
+import org.apache.beam.runners.core.construction.PTransformReplacements;
 import org.apache.beam.runners.core.construction.SingleInputOutputOverrideFactory;
 import org.apache.beam.sdk.common.runner.v1.RunnerApi.DisplayData;
 import org.apache.beam.sdk.runners.PTransformOverrideFactory;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.ParDo.SingleOutput;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PCollectionView;
 
@@ -38,9 +41,15 @@ public class PrimitiveParDoSingleFactory<InputT, OutputT>
     extends SingleInputOutputOverrideFactory<
         PCollection<? extends InputT>, PCollection<OutputT>, ParDo.SingleOutput<InputT, OutputT>> {
   @Override
-  public PTransform<PCollection<? extends InputT>, PCollection<OutputT>> getReplacementTransform(
-      ParDo.SingleOutput<InputT, OutputT> transform) {
-    return new ParDoSingle<>(transform);
+  public PTransformReplacement<PCollection<? extends InputT>, PCollection<OutputT>>
+      getReplacementTransform(
+          AppliedPTransform<
+                  PCollection<? extends InputT>, PCollection<OutputT>,
+                  SingleOutput<InputT, OutputT>>
+              transform) {
+    return PTransformReplacement.of(
+        PTransformReplacements.getSingletonMainInput(transform),
+        new ParDoSingle<>(transform.getTransform()));
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/ReshuffleOverrideFactory.java
----------------------------------------------------------------------
diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/ReshuffleOverrideFactory.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/ReshuffleOverrideFactory.java
index 2e6455d..aa9d9f8 100644
--- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/ReshuffleOverrideFactory.java
+++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/ReshuffleOverrideFactory.java
@@ -18,8 +18,10 @@
 
 package org.apache.beam.runners.dataflow;
 
+import org.apache.beam.runners.core.construction.PTransformReplacements;
 import org.apache.beam.runners.core.construction.SingleInputOutputOverrideFactory;
 import org.apache.beam.sdk.runners.PTransformOverrideFactory;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.GroupByKey;
 import org.apache.beam.sdk.transforms.PTransform;
@@ -43,9 +45,13 @@ class ReshuffleOverrideFactory<K, V>
     extends SingleInputOutputOverrideFactory<
         PCollection<KV<K, V>>, PCollection<KV<K, V>>, Reshuffle<K, V>> {
   @Override
-  public PTransform<PCollection<KV<K, V>>, PCollection<KV<K, V>>> getReplacementTransform(
-      Reshuffle<K, V> transform) {
-    return new ReshuffleWithOnlyTrigger<>();
+  public PTransformReplacement<PCollection<KV<K, V>>, PCollection<KV<K, V>>>
+      getReplacementTransform(
+          AppliedPTransform<PCollection<KV<K, V>>, PCollection<KV<K, V>>, Reshuffle<K, V>>
+              transform) {
+    return PTransformReplacement.of(
+        PTransformReplacements.getSingletonMainInput(transform),
+        new ReshuffleWithOnlyTrigger<K, V>());
   }
 
   private static class ReshuffleWithOnlyTrigger<K, V>

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/StreamingViewOverrides.java
----------------------------------------------------------------------
diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/StreamingViewOverrides.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/StreamingViewOverrides.java
index c407517..eb385de 100644
--- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/StreamingViewOverrides.java
+++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/StreamingViewOverrides.java
@@ -20,11 +20,13 @@ package org.apache.beam.runners.dataflow;
 
 import java.util.ArrayList;
 import java.util.List;
+import org.apache.beam.runners.core.construction.PTransformReplacements;
 import org.apache.beam.runners.core.construction.SingleInputOutputOverrideFactory;
 import org.apache.beam.runners.dataflow.DataflowRunner.StreamingPCollectionViewWriterFn;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.CoderRegistry;
 import org.apache.beam.sdk.coders.ListCoder;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.Combine;
 import org.apache.beam.sdk.transforms.Combine.CombineFn;
 import org.apache.beam.sdk.transforms.PTransform;
@@ -42,9 +44,15 @@ class StreamingViewOverrides {
       extends SingleInputOutputOverrideFactory<
           PCollection<ElemT>, PCollectionView<ViewT>, CreatePCollectionView<ElemT, ViewT>> {
     @Override
-    public PTransform<PCollection<ElemT>, PCollectionView<ViewT>> getReplacementTransform(
-        final CreatePCollectionView<ElemT, ViewT> transform) {
-      return new StreamingCreatePCollectionView<>(transform.getView());
+    public PTransformReplacement<PCollection<ElemT>, PCollectionView<ViewT>>
+        getReplacementTransform(
+            AppliedPTransform<
+                    PCollection<ElemT>, PCollectionView<ViewT>, CreatePCollectionView<ElemT, ViewT>>
+                transform) {
+      StreamingCreatePCollectionView<ElemT, ViewT> streamingView =
+          new StreamingCreatePCollectionView<>(transform.getTransform().getView());
+      return PTransformReplacement.of(
+          PTransformReplacements.getSingletonMainInput(transform), streamingView);
     }
 
     private static class StreamingCreatePCollectionView<ElemT, ViewT>

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactoryTest.java
----------------------------------------------------------------------
diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactoryTest.java b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactoryTest.java
index bff46ea..e320036 100644
--- a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactoryTest.java
+++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactoryTest.java
@@ -27,10 +27,11 @@ import java.io.Serializable;
 import java.util.List;
 import org.apache.beam.runners.dataflow.PrimitiveParDoSingleFactory.ParDoSingle;
 import org.apache.beam.sdk.coders.VarIntCoder;
+import org.apache.beam.sdk.runners.PTransformOverrideFactory.PTransformReplacement;
 import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.DoFn;
-import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.Sum;
 import org.apache.beam.sdk.transforms.View;
@@ -64,17 +65,27 @@ public class PrimitiveParDoSingleFactoryTest implements Serializable {
   public void getReplacementTransformPopulateDisplayData() {
     ParDo.SingleOutput<Integer, Long> originalTransform = ParDo.of(new ToLongFn());
     DisplayData originalDisplayData = DisplayData.from(originalTransform);
-
-    PTransform<PCollection<? extends Integer>, PCollection<Long>> replacement =
-        factory.getReplacementTransform(originalTransform);
-    DisplayData replacementDisplayData = DisplayData.from(replacement);
+    PCollection<? extends Integer> input = pipeline.apply(Create.of(1, 2, 3));
+    AppliedPTransform<
+        PCollection<? extends Integer>, PCollection<Long>, ParDo.SingleOutput<Integer, Long>>
+        application =
+        AppliedPTransform.of(
+            "original",
+            input.expand(),
+            input.apply(originalTransform).expand(),
+            originalTransform,
+            pipeline);
+
+    PTransformReplacement<PCollection<? extends Integer>, PCollection<Long>> replacement =
+        factory.getReplacementTransform(application);
+    DisplayData replacementDisplayData = DisplayData.from(replacement.getTransform());
 
     assertThat(replacementDisplayData, equalTo(originalDisplayData));
 
     DisplayData primitiveDisplayData =
         Iterables.getOnlyElement(
             DisplayDataEvaluator.create()
-                .displayDataForPrimitiveTransforms(replacement, VarIntCoder.of()));
+                .displayDataForPrimitiveTransforms(replacement.getTransform(), VarIntCoder.of()));
     assertThat(primitiveDisplayData, equalTo(replacementDisplayData));
   }
 
@@ -91,9 +102,21 @@ public class PrimitiveParDoSingleFactoryTest implements Serializable {
     ParDo.SingleOutput<Integer, Long> originalTransform =
         ParDo.of(new ToLongFn()).withSideInputs(sideLong, sideStrings);
 
-    PTransform<PCollection<? extends Integer>, PCollection<Long>> replacementTransform =
-        factory.getReplacementTransform(originalTransform);
-    ParDoSingle<Integer, Long> parDoSingle = (ParDoSingle<Integer, Long>) replacementTransform;
+    PCollection<? extends Integer> input = pipeline.apply(Create.of(1, 2, 3));
+    AppliedPTransform<
+        PCollection<? extends Integer>, PCollection<Long>, ParDo.SingleOutput<Integer, Long>>
+        application =
+        AppliedPTransform.of(
+            "original",
+            input.expand(),
+            input.apply(originalTransform).expand(),
+            originalTransform,
+            pipeline);
+
+    PTransformReplacement<PCollection<? extends Integer>, PCollection<Long>> replacementTransform =
+        factory.getReplacementTransform(application);
+    ParDoSingle<Integer, Long> parDoSingle =
+        (ParDoSingle<Integer, Long>) replacementTransform.getTransform();
     assertThat(parDoSingle.getSideInputs(), containsInAnyOrder(sideStrings, sideLong));
   }
 
@@ -101,9 +124,21 @@ public class PrimitiveParDoSingleFactoryTest implements Serializable {
   public void getReplacementTransformGetFn() {
     DoFn<Integer, Long> originalFn = new ToLongFn();
     ParDo.SingleOutput<Integer, Long> originalTransform = ParDo.of(originalFn);
-    PTransform<PCollection<? extends Integer>, PCollection<Long>> replacementTransform =
-        factory.getReplacementTransform(originalTransform);
-    ParDoSingle<Integer, Long> parDoSingle = (ParDoSingle<Integer, Long>) replacementTransform;
+    PCollection<? extends Integer> input = pipeline.apply(Create.of(1, 2, 3));
+    AppliedPTransform<
+            PCollection<? extends Integer>, PCollection<Long>, ParDo.SingleOutput<Integer, Long>>
+        application =
+            AppliedPTransform.of(
+                "original",
+                input.expand(),
+                input.apply(originalTransform).expand(),
+                originalTransform,
+                pipeline);
+
+    PTransformReplacement<PCollection<? extends Integer>, PCollection<Long>> replacementTransform =
+        factory.getReplacementTransform(application);
+    ParDoSingle<Integer, Long> parDoSingle =
+        (ParDoSingle<Integer, Long>) replacementTransform.getTransform();
 
     assertThat(parDoSingle.getFn(), equalTo(originalTransform.getFn()));
     assertThat(parDoSingle.getFn(), equalTo(originalFn));

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java
----------------------------------------------------------------------
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java
index aacb942..61fcaa9 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/TestSparkRunner.java
@@ -46,6 +46,7 @@ import org.apache.beam.sdk.runners.PTransformOverride;
 import org.apache.beam.sdk.runners.PTransformOverrideFactory;
 import org.apache.beam.sdk.runners.PipelineRunner;
 import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.util.ValueWithRecordId;
@@ -244,14 +245,11 @@ public final class TestSparkRunner extends PipelineRunner<SparkPipelineResult> {
         implements PTransformOverrideFactory<
             PBegin, PCollection<T>, BoundedReadFromUnboundedSource<T>> {
       @Override
-      public PTransform<PBegin, PCollection<T>> getReplacementTransform(
-          BoundedReadFromUnboundedSource<T> transform) {
-        return new AdaptedBoundedAsUnbounded<>(transform);
-      }
-
-      @Override
-      public PBegin getInput(Map<TupleTag<?>, PValue> inputs, Pipeline p) {
-        return p.begin();
+      public PTransformReplacement<PBegin, PCollection<T>> getReplacementTransform(
+          AppliedPTransform<PBegin, PCollection<T>, BoundedReadFromUnboundedSource<T>> transform) {
+        return PTransformReplacement.of(
+            transform.getPipeline().begin(),
+            new AdaptedBoundedAsUnbounded<T>(transform.getTransform()));
       }
 
       @Override

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/sdks/java/core/src/main/java/org/apache/beam/sdk/Pipeline.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/Pipeline.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/Pipeline.java
index 791166e..1ff4c30 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/Pipeline.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/Pipeline.java
@@ -33,11 +33,13 @@ import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.options.PipelineOptionsFactory;
 import org.apache.beam.sdk.runners.PTransformOverride;
 import org.apache.beam.sdk.runners.PTransformOverrideFactory;
+import org.apache.beam.sdk.runners.PTransformOverrideFactory.PTransformReplacement;
 import org.apache.beam.sdk.runners.PTransformOverrideFactory.ReplacementOutput;
 import org.apache.beam.sdk.runners.PipelineRunner;
 import org.apache.beam.sdk.runners.TransformHierarchy;
 import org.apache.beam.sdk.runners.TransformHierarchy.Node;
 import org.apache.beam.sdk.transforms.Aggregator;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.util.UserCodeException;
@@ -497,17 +499,18 @@ public class Pipeline {
       void applyReplacement(
           Node original,
           PTransformOverrideFactory<InputT, OutputT, TransformT> replacementFactory) {
-    PTransform<InputT, OutputT> replacement =
-        replacementFactory.getReplacementTransform((TransformT) original.getTransform());
-    if (replacement == original.getTransform()) {
+    PTransformReplacement<InputT, OutputT> replacement =
+        replacementFactory.getReplacementTransform(
+            (AppliedPTransform<InputT, OutputT, TransformT>) original.toAppliedPTransform());
+    if (replacement.getTransform() == original.getTransform()) {
       return;
     }
-    InputT originalInput = replacementFactory.getInput(original.getInputs(), this);
+    InputT originalInput = replacement.getInput();
 
     LOG.debug("Replacing {} with {}", original, replacement);
-    transforms.replaceNode(original, originalInput, replacement);
+    transforms.replaceNode(original, originalInput, replacement.getTransform());
     try {
-      OutputT newOutput = replacement.expand(originalInput);
+      OutputT newOutput = replacement.getTransform().expand(originalInput);
       Map<PValue, ReplacementOutput> originalToReplacement =
           replacementFactory.mapOutputs(original.getOutputs(), newOutput);
       // Ensure the internal TransformHierarchy data structures are consistent.

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/PTransformOverrideFactory.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/PTransformOverrideFactory.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/PTransformOverrideFactory.java
index 57cba50..786c61c 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/PTransformOverrideFactory.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/runners/PTransformOverrideFactory.java
@@ -21,9 +21,9 @@ package org.apache.beam.sdk.runners;
 
 import com.google.auto.value.AutoValue;
 import java.util.Map;
-import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.annotations.Experimental;
 import org.apache.beam.sdk.annotations.Experimental.Kind;
+import org.apache.beam.sdk.transforms.AppliedPTransform;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.values.PInput;
 import org.apache.beam.sdk.values.POutput;
@@ -41,14 +41,11 @@ public interface PTransformOverrideFactory<
     OutputT extends POutput,
     TransformT extends PTransform<? super InputT, OutputT>> {
   /**
-   * Returns a {@link PTransform} that produces equivalent output to the provided transform.
+   * Returns a {@link PTransform} that produces equivalent output to the provided {@link
+   * AppliedPTransform transform}.
    */
-  PTransform<InputT, OutputT> getReplacementTransform(TransformT transform);
-
-  /**
-   * Returns the composite type that replacement transforms consumed from an equivalent expansion.
-   */
-  InputT getInput(Map<TupleTag<?>, PValue> inputs, Pipeline p);
+  PTransformReplacement<InputT, OutputT> getReplacementTransform(
+      AppliedPTransform<InputT, OutputT, TransformT> transform);
 
   /**
    * Returns a {@link Map} from the expanded values in {@code newOutput} to the values produced by
@@ -56,7 +53,25 @@ public interface PTransformOverrideFactory<
    */
   Map<PValue, ReplacementOutput> mapOutputs(Map<TupleTag<?>, PValue> outputs, OutputT newOutput);
 
-  /** A mapping between original {@link TaggedPValue} outputs and their replacements. */
+  /**
+   * A {@link PTransform} that replaces an {@link AppliedPTransform}, and the input required to
+   * do so. The input must be constructed from the expanded form, as the transform may not have
+   * originally been applied within this process or from within a Java SDK.
+   */
+  @AutoValue
+  abstract class PTransformReplacement<InputT extends PInput, OutputT extends POutput> {
+    public static <InputT extends PInput, OutputT extends POutput>
+        PTransformReplacement<InputT, OutputT> of(
+            InputT input, PTransform<InputT, OutputT> transform) {
+      return new AutoValue_PTransformOverrideFactory_PTransformReplacement(input, transform);
+    }
+    public abstract InputT getInput();
+    public abstract PTransform<InputT, OutputT> getTransform();
+  }
+
+  /**
+   * A mapping between original {@link TaggedPValue} outputs and their replacements.
+   */
   @AutoValue
   abstract class ReplacementOutput {
     public static ReplacementOutput of(TaggedPValue original, TaggedPValue replacement) {

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AppliedPTransform.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AppliedPTransform.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AppliedPTransform.java
index 8d99a62..bdb61b8 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AppliedPTransform.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/AppliedPTransform.java
@@ -31,6 +31,11 @@ import org.apache.beam.sdk.values.TupleTag;
  *
  * <p>For internal use.
  *
+ * <p>Inputs and outputs are stored in their expanded forms, as the condensed form of a composite
+ * {@link PInput} or {@link POutput} is a language-specific concept, and {@link AppliedPTransform}
+ * represents a possibly cross-language transform for which no appropriate composite type exists
+ * in the Java SDK.
+ *
  * @param <InputT>     transform input type
  * @param <OutputT>    transform output type
  * @param <TransformT> transform type

http://git-wip-us.apache.org/repos/asf/beam/blob/f3b49605/sdks/java/core/src/test/java/org/apache/beam/sdk/PipelineTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/PipelineTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/PipelineTest.java
index 6ce016d..75cabf2 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/PipelineTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/PipelineTest.java
@@ -406,16 +406,10 @@ public class PipelineTest {
     class ReplacementOverrideFactory
         implements PTransformOverrideFactory<
             PCollection<String>, PCollection<Long>, OriginalTransform> {
-
       @Override
-      public PTransform<PCollection<String>, PCollection<Long>> getReplacementTransform(
-          OriginalTransform transform) {
-        return new ReplacementTransform();
-      }
-
-      @Override
-      public PCollection<String> getInput(Map<TupleTag<?>, PValue> inputs, Pipeline p) {
-        return originalInput;
+      public PTransformReplacement<PCollection<String>, PCollection<Long>> getReplacementTransform(
+          AppliedPTransform<PCollection<String>, PCollection<Long>, OriginalTransform> transform) {
+        return PTransformReplacement.of(originalInput, new ReplacementTransform());
       }
 
       @Override
@@ -464,14 +458,9 @@ public class PipelineTest {
   static class BoundedCountingInputOverride
       implements PTransformOverrideFactory<PBegin, PCollection<Long>, BoundedCountingInput> {
     @Override
-    public PTransform<PBegin, PCollection<Long>> getReplacementTransform(
-        BoundedCountingInput transform) {
-      return Create.of(0L);
-    }
-
-    @Override
-    public PBegin getInput(Map<TupleTag<?>, PValue> inputs, Pipeline p) {
-      return p.begin();
+    public PTransformReplacement<PBegin, PCollection<Long>> getReplacementTransform(
+        AppliedPTransform<PBegin, PCollection<Long>, BoundedCountingInput> transform) {
+      return PTransformReplacement.of(transform.getPipeline().begin(), Create.of(0L));
     }
 
     @Override
@@ -489,15 +478,11 @@ public class PipelineTest {
   }
   static class UnboundedCountingInputOverride
       implements PTransformOverrideFactory<PBegin, PCollection<Long>, UnboundedCountingInput> {
-    @Override
-    public PTransform<PBegin, PCollection<Long>> getReplacementTransform(
-        UnboundedCountingInput transform) {
-      return CountingInput.upTo(100L);
-    }
 
     @Override
-    public PBegin getInput(Map<TupleTag<?>, PValue> inputs, Pipeline p) {
-      return p.begin();
+    public PTransformReplacement<PBegin, PCollection<Long>> getReplacementTransform(
+        AppliedPTransform<PBegin, PCollection<Long>, UnboundedCountingInput> transform) {
+      return PTransformReplacement.of(transform.getPipeline().begin(), CountingInput.upTo(100L));
     }
 
     @Override