You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tez.apache.org by mi...@apache.org on 2016/11/07 22:49:07 UTC

tez git commit: TEZ-3465. Support broadcast edge into cartesian product vertex and forbid other edges. (Zhiyuan Yang via mingma)

Repository: tez
Updated Branches:
  refs/heads/master ad68f7358 -> b4c949c9c


TEZ-3465. Support broadcast edge into cartesian product vertex and forbid other edges. (Zhiyuan Yang via mingma)


Project: http://git-wip-us.apache.org/repos/asf/tez/repo
Commit: http://git-wip-us.apache.org/repos/asf/tez/commit/b4c949c9
Tree: http://git-wip-us.apache.org/repos/asf/tez/tree/b4c949c9
Diff: http://git-wip-us.apache.org/repos/asf/tez/diff/b4c949c9

Branch: refs/heads/master
Commit: b4c949c9cbdcfe2c1bb3e7ffcc635f281beb9889
Parents: ad68f73
Author: Ming Ma <mi...@twitter.com>
Authored: Mon Nov 7 14:48:52 2016 -0800
Committer: Ming Ma <mi...@twitter.com>
Committed: Mon Nov 7 14:48:52 2016 -0800

----------------------------------------------------------------------
 CHANGES.txt                                     |   1 +
 .../apache/tez/examples/CartesianProduct.java   |  92 +++++---
 .../CartesianProductConfig.java                 |  12 +-
 .../CartesianProductVertexManager.java          |  41 +++-
 ...artesianProductVertexManagerPartitioned.java |  80 ++++---
 ...tesianProductVertexManagerUnpartitioned.java | 175 +++++++++------
 .../TestCartesianProductCombination.java        |   2 +-
 ...tCartesianProductEdgeManagerPartitioned.java |   2 +-
 .../TestCartesianProductVertexManager.java      | 125 +++++++++--
 ...artesianProductVertexManagerPartitioned.java | 214 +++++++++----------
 ...tesianProductVertexManagerUnpartitioned.java | 100 +++++----
 11 files changed, 524 insertions(+), 320 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/tez/blob/b4c949c9/CHANGES.txt
----------------------------------------------------------------------
diff --git a/CHANGES.txt b/CHANGES.txt
index 033291a..ecfe935 100644
--- a/CHANGES.txt
+++ b/CHANGES.txt
@@ -7,6 +7,7 @@ INCOMPATIBLE CHANGES
 
 ALL CHANGES:
 
+  TEZ-3465. Support broadcast edge into cartesian product vertex and forbid other edges.
   TEZ-3493. DAG submit timeout cannot be set to a month
   TEZ-3505. Move license to the file header for TezBytesWritableSerialization
   TEZ-3486. COMBINE_OUTPUT_RECORDS/COMBINE_INPUT_RECORDS are not correct

http://git-wip-us.apache.org/repos/asf/tez/blob/b4c949c9/tez-examples/src/main/java/org/apache/tez/examples/CartesianProduct.java
----------------------------------------------------------------------
diff --git a/tez-examples/src/main/java/org/apache/tez/examples/CartesianProduct.java b/tez-examples/src/main/java/org/apache/tez/examples/CartesianProduct.java
index 9f3d490..84367f8 100644
--- a/tez-examples/src/main/java/org/apache/tez/examples/CartesianProduct.java
+++ b/tez-examples/src/main/java/org/apache/tez/examples/CartesianProduct.java
@@ -51,20 +51,28 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import java.io.IOException;
-import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.HashMap;
 import java.util.HashSet;
-import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import java.util.StringTokenizer;
 
 /**
- * This job has three vertices: two Tokenizers and one JoinProcessor. Each Tokenizer handles one
- * input directory and generates tokens. CustomPartitioner separates tokens into 2 partitions
- * according to the parity of token's first char. Then JoinProcessor does cartesian product of
- * partitioned token sets.
+ * This DAG does cartesian product of two text inputs and then filters results according to the
+ * third text input.
+ *
+ * V1    V2    V3
+ *  \     |    /
+ * CP\  CP|   / Broadcast
+ *    \   |  /
+ *    Vertex 4
+ *
+ * Vertex 1~3 are tokenizers and each of them tokenizes input from one directory. In partitioned
+ * case, CustomPartitioner separates tokens into 2 partitions according to the parity of token's
+ * first char. Vertex 4 does cartesian product of input from vertex1 and vertex2, and generates
+ * KV pairs where keys are vertex 1 tokens and values are vertex 2 tokens. Then vertex 4 outputs KV
+ * pairs whose keys appears in vertex 3 tokens.
  */
 public class CartesianProduct extends TezExampleBase {
   private static final String INPUT = "Input1";
@@ -72,11 +80,12 @@ public class CartesianProduct extends TezExampleBase {
   private static final String VERTEX1 = "Vertex1";
   private static final String VERTEX2 = "Vertex2";
   private static final String VERTEX3 = "Vertex3";
+  private static final String VERTEX4 = "Vertex4";
   private static final String PARTITIONED = "-partitioned";
   private static final String UNPARTITIONED = "-unpartitioned";
   private static final Logger LOG = LoggerFactory.getLogger(CartesianProduct.class);
   private static final int numPartition = 2;
-  private static final String[] sourceVertices = new String[] {VERTEX1, VERTEX2};
+  private static final String[] cpSources = new String[] {VERTEX1, VERTEX2};
 
   public static class TokenProcessor extends SimpleProcessor {
     public TokenProcessor(ProcessorContext context) {
@@ -88,7 +97,7 @@ public class CartesianProduct extends TezExampleBase {
       Preconditions.checkArgument(getInputs().size() == 1);
       Preconditions.checkArgument(getOutputs().size() == 1);
       KeyValueReader kvReader = (KeyValueReader) getInputs().get(INPUT).getReader();
-      KeyValueWriter kvWriter = (KeyValueWriter) getOutputs().get(VERTEX3).getWriter();
+      KeyValueWriter kvWriter = (KeyValueWriter) getOutputs().get(VERTEX4).getWriter();
       while (kvReader.next()) {
         StringTokenizer itr = new StringTokenizer(kvReader.getCurrentValue().toString());
         while (itr.hasMoreTokens()) {
@@ -108,16 +117,23 @@ public class CartesianProduct extends TezExampleBase {
       KeyValueWriter kvWriter = (KeyValueWriter) getOutputs().get(OUTPUT).getWriter();
       KeyValueReader kvReader1 = (KeyValueReader) getInputs().get(VERTEX1).getReader();
       KeyValueReader kvReader2 = (KeyValueReader) getInputs().get(VERTEX2).getReader();
-      Set<String> rightSet = new HashSet<>();
+      KeyValueReader kvReader3 = (KeyValueReader) getInputs().get(VERTEX3).getReader();
+      Set<String> v2TokenSet = new HashSet<>();
+      Set<String> v3TokenSet = new HashSet<>();
 
       while (kvReader2.next()) {
-        rightSet.add(kvReader2.getCurrentKey().toString());
+        v2TokenSet.add(kvReader2.getCurrentKey().toString());
+      }
+      while (kvReader3.next()) {
+        v3TokenSet.add(kvReader3.getCurrentKey().toString());
       }
 
       while (kvReader1.next()) {
         String left = kvReader1.getCurrentKey().toString();
-        for (String right : rightSet) {
-          kvWriter.write(left, right);
+        if (v3TokenSet.contains(left)) {
+          for (String right : v2TokenSet) {
+            kvWriter.write(left, right);
+          }
         }
       }
     }
@@ -131,7 +147,8 @@ public class CartesianProduct extends TezExampleBase {
   }
 
   private DAG createDAG(TezConfiguration tezConf, String inputPath1, String inputPath2,
-                        String outputPath, boolean isPartitioned) throws IOException {
+                        String inputPath3, String outputPath, boolean isPartitioned)
+    throws IOException {
     Vertex v1 = Vertex.create(VERTEX1, ProcessorDescriptor.create(TokenProcessor.class.getName()));
     // turn off groupSplit so that each input file incurs one task
     v1.addDataSource(INPUT,
@@ -141,54 +158,65 @@ public class CartesianProduct extends TezExampleBase {
     v2.addDataSource(INPUT,
       MRInput.createConfigBuilder(new Configuration(tezConf), TextInputFormat.class, inputPath2)
               .groupSplits(false).build());
+    Vertex v3 = Vertex.create(VERTEX3, ProcessorDescriptor.create(TokenProcessor.class.getName()));
+    v3.addDataSource(INPUT,
+      MRInput.createConfigBuilder(new Configuration(tezConf), TextInputFormat.class, inputPath3)
+        .groupSplits(false).build());
     CartesianProductConfig cartesianProductConfig;
     if (isPartitioned) {
       Map<String, Integer> vertexPartitionMap = new HashMap<>();
-      for (String vertex : sourceVertices) {
+      for (String vertex : cpSources) {
         vertexPartitionMap.put(vertex, numPartition);
       }
       cartesianProductConfig = new CartesianProductConfig(vertexPartitionMap);
     } else {
-      cartesianProductConfig = new CartesianProductConfig(Arrays.asList(sourceVertices));
+      cartesianProductConfig = new CartesianProductConfig(Arrays.asList(cpSources));
     }
     UserPayload userPayload = cartesianProductConfig.toUserPayload(tezConf);
-    Vertex v3 = Vertex.create(VERTEX3, ProcessorDescriptor.create(JoinProcessor.class.getName()));
-    v3.addDataSink(OUTPUT,
+    Vertex v4 = Vertex.create(VERTEX4, ProcessorDescriptor.create(JoinProcessor.class.getName()));
+    v4.addDataSink(OUTPUT,
       MROutput.createConfigBuilder(new Configuration(tezConf), TextOutputFormat.class, outputPath)
               .build());
-    v3.setVertexManagerPlugin(
+    v4.setVertexManagerPlugin(
       VertexManagerPluginDescriptor.create(CartesianProductVertexManager.class.getName())
                                    .setUserPayload(userPayload));
 
-    DAG dag = DAG.create("CrossProduct").addVertex(v1).addVertex(v2).addVertex(v3);
-    EdgeManagerPluginDescriptor edgeManagerDescriptor =
+    EdgeManagerPluginDescriptor cpEdgeManager =
       EdgeManagerPluginDescriptor.create(CartesianProductEdgeManager.class.getName());
-    edgeManagerDescriptor.setUserPayload(userPayload);
-    EdgeProperty edgeProperty;
+    cpEdgeManager.setUserPayload(userPayload);
+    EdgeProperty cpEdgeProperty;
     if (isPartitioned) {
-      UnorderedPartitionedKVEdgeConfig edgeConf =
-        UnorderedPartitionedKVEdgeConfig.newBuilder(Text.class.getName(), IntWritable.class.getName(),
-          CustomPartitioner.class.getName()).build();
-      edgeProperty = edgeConf.createDefaultCustomEdgeProperty(edgeManagerDescriptor);
+      UnorderedPartitionedKVEdgeConfig cpEdgeConf =
+        UnorderedPartitionedKVEdgeConfig.newBuilder(Text.class.getName(),
+          IntWritable.class.getName(), CustomPartitioner.class.getName()).build();
+      cpEdgeProperty = cpEdgeConf.createDefaultCustomEdgeProperty(cpEdgeManager);
     } else {
       UnorderedKVEdgeConfig edgeConf =
         UnorderedKVEdgeConfig.newBuilder(Text.class.getName(), IntWritable.class.getName()).build();
-      edgeProperty = edgeConf.createDefaultCustomEdgeProperty(edgeManagerDescriptor);
+      cpEdgeProperty = edgeConf.createDefaultCustomEdgeProperty(cpEdgeManager);
     }
-    dag.addEdge(Edge.create(v1, v3, edgeProperty)).addEdge(Edge.create(v2, v3, edgeProperty));
 
-    return dag;
+    EdgeProperty broadcastEdgeProperty;
+    UnorderedKVEdgeConfig broadcastEdgeConf =
+      UnorderedKVEdgeConfig.newBuilder(Text.class.getName(), IntWritable.class.getName()).build();
+    broadcastEdgeProperty = broadcastEdgeConf.createDefaultBroadcastEdgeProperty();
+
+    return DAG.create("CartesianProduct")
+      .addVertex(v1).addVertex(v2).addVertex(v3).addVertex(v4)
+      .addEdge(Edge.create(v1, v4, cpEdgeProperty))
+      .addEdge(Edge.create(v2, v4, cpEdgeProperty))
+      .addEdge(Edge.create(v3, v4, broadcastEdgeProperty));
   }
 
   @Override
   protected void printUsage() {
     System.err.println("Usage: args: ["+PARTITIONED + "|" + UNPARTITIONED
-      + " <input_dir1> <input_dir2> <output_dir>");
+      + " <input_dir1> <input_dir2> <input_dir3> <output_dir>");
   }
 
   @Override
   protected int validateArgs(String[] otherArgs) {
-    return (otherArgs.length != 4 || (!otherArgs[0].equals(PARTITIONED)
+    return (otherArgs.length != 5 || (!otherArgs[0].equals(PARTITIONED)
       && !otherArgs[0].equals(UNPARTITIONED))) ? -1 : 0;
   }
 
@@ -196,7 +224,7 @@ public class CartesianProduct extends TezExampleBase {
   protected int runJob(String[] args, TezConfiguration tezConf,
       TezClient tezClient) throws Exception {
     DAG dag = createDAG(tezConf, args[1], args[2],
-        args[3], args[0].equals(PARTITIONED));
+        args[3], args[4], args[0].equals(PARTITIONED));
     return runDag(dag, isCountersLog(), LOG);
   }
 

http://git-wip-us.apache.org/repos/asf/tez/blob/b4c949c9/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductConfig.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductConfig.java b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductConfig.java
index b682182..a7a3940 100644
--- a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductConfig.java
+++ b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductConfig.java
@@ -201,17 +201,17 @@ public class CartesianProductConfig {
     }
 
     builder.setMinFraction(
-      CartesianProductVertexManager.TEZ_CAERESIAN_PRODUCT_SLOW_START_MIN_FRACTION_DEFAULT);
+      CartesianProductVertexManager.TEZ_CARTESIAN_PRODUCT_SLOW_START_MIN_FRACTION_DEFAULT);
     builder.setMaxFraction(
-      CartesianProductVertexManager.TEZ_CAERESIAN_PRODUCT_SLOW_START_MAX_FRACTION_DEFAULT);
+      CartesianProductVertexManager.TEZ_CARTESIAN_PRODUCT_SLOW_START_MAX_FRACTION_DEFAULT);
 
     if (conf != null) {
       builder.setMinFraction(conf.getFloat(
-        CartesianProductVertexManager.TEZ_CAERESIAN_PRODUCT_SLOW_START_MIN_FRACTION,
-        CartesianProductVertexManager.TEZ_CAERESIAN_PRODUCT_SLOW_START_MIN_FRACTION_DEFAULT));
+        CartesianProductVertexManager.TEZ_CARTESIAN_PRODUCT_SLOW_START_MIN_FRACTION,
+        CartesianProductVertexManager.TEZ_CARTESIAN_PRODUCT_SLOW_START_MIN_FRACTION_DEFAULT));
       builder.setMaxFraction(conf.getFloat(
-        CartesianProductVertexManager.TEZ_CAERESIAN_PRODUCT_SLOW_START_MAX_FRACTION,
-        CartesianProductVertexManager.TEZ_CAERESIAN_PRODUCT_SLOW_START_MAX_FRACTION_DEFAULT));
+        CartesianProductVertexManager.TEZ_CARTESIAN_PRODUCT_SLOW_START_MAX_FRACTION,
+        CartesianProductVertexManager.TEZ_CARTESIAN_PRODUCT_SLOW_START_MAX_FRACTION_DEFAULT));
     }
     Preconditions.checkArgument(builder.getMinFraction() <= builder.getMaxFraction(),
       "min fraction(" + builder.getMinFraction() + ") should be less than max fraction(" +

http://git-wip-us.apache.org/repos/asf/tez/blob/b4c949c9/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductVertexManager.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductVertexManager.java b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductVertexManager.java
index 659d3b7..83caac2 100644
--- a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductVertexManager.java
+++ b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductVertexManager.java
@@ -19,6 +19,7 @@ package org.apache.tez.runtime.library.cartesianproduct;
 
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Preconditions;
+import org.apache.tez.dag.api.EdgeManagerPluginDescriptor;
 import org.apache.tez.dag.api.EdgeProperty;
 import org.apache.tez.dag.api.InputDescriptor;
 import org.apache.tez.dag.api.TezException;
@@ -29,29 +30,36 @@ import org.apache.tez.runtime.api.Event;
 import org.apache.tez.runtime.api.TaskAttemptIdentifier;
 import org.apache.tez.runtime.api.events.VertexManagerEvent;
 
-import java.util.Arrays;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
 
+import static org.apache.tez.dag.api.EdgeProperty.DataMovementType.BROADCAST;
+import static org.apache.tez.dag.api.EdgeProperty.DataMovementType.CUSTOM;
+
 /**
  * This VM wrap a real vertex manager implementation object. It choose whether it's partitioned or
  * unpartitioned implementation according to the config. All method invocations are actually
  * redirected to real implementation.
+ *
+ * Predefined parallelism isn't allowed for cartesian product vertex. Parallellism has to be
+ * determined by vertex manager.
  */
 public class CartesianProductVertexManager extends VertexManagerPlugin {
-  public static final String TEZ_CAERESIAN_PRODUCT_SLOW_START_MIN_FRACTION =
+  public static final String TEZ_CARTESIAN_PRODUCT_SLOW_START_MIN_FRACTION =
     "tez.cartesian-product.min-src-fraction";
-  public static final float TEZ_CAERESIAN_PRODUCT_SLOW_START_MIN_FRACTION_DEFAULT = 0.25f;
-  public static final String TEZ_CAERESIAN_PRODUCT_SLOW_START_MAX_FRACTION =
+  public static final float TEZ_CARTESIAN_PRODUCT_SLOW_START_MIN_FRACTION_DEFAULT = 0.25f;
+  public static final String TEZ_CARTESIAN_PRODUCT_SLOW_START_MAX_FRACTION =
     "tez.cartesian-product.min-src-fraction";
-  public static final float TEZ_CAERESIAN_PRODUCT_SLOW_START_MAX_FRACTION_DEFAULT = 0.75f;
+  public static final float TEZ_CARTESIAN_PRODUCT_SLOW_START_MAX_FRACTION_DEFAULT = 0.75f;
 
   private CartesianProductVertexManagerReal vertexManagerReal = null;
 
   public CartesianProductVertexManager(VertexManagerPluginContext context) {
     super(context);
+    Preconditions.checkArgument(context.getVertexNumTasks(context.getVertexName()) == -1,
+      "Vertex with CartesianProductVertexManager cannot use pre-defined parallelism");
   }
 
   @Override
@@ -65,16 +73,27 @@ public class CartesianProductVertexManager extends VertexManagerPlugin {
     sourceVerticesConfig.addAll(config.getSourceVertices());
 
     for (Map.Entry<String, EdgeProperty> entry : edgePropertyMap.entrySet()) {
-      if (entry.getValue().getEdgeManagerDescriptor().getClassName()
-        .equals(CartesianProductEdgeManager.class.getName())) {
-        Preconditions.checkArgument(sourceVerticesDAG.contains(entry.getKey()),
-          entry.getKey() + " has CartesianProductEdgeManager but isn't in " +
+      String vertex = entry.getKey();
+      EdgeProperty edgeProperty = entry.getValue();
+      EdgeManagerPluginDescriptor empDescriptor = edgeProperty.getEdgeManagerDescriptor();
+      if (empDescriptor != null
+        && empDescriptor.getClassName().equals(CartesianProductEdgeManager.class.getName())) {
+        Preconditions.checkArgument(sourceVerticesConfig.contains(vertex),
+          vertex + " has CartesianProductEdgeManager but isn't in " +
             "CartesianProductVertexManagerConfig");
       } else {
-        Preconditions.checkArgument(!sourceVerticesDAG.contains(entry.getKey()),
-          entry.getKey() + " has no CartesianProductEdgeManager but is in " +
+        Preconditions.checkArgument(!sourceVerticesConfig.contains(vertex),
+          vertex + " has no CartesianProductEdgeManager but is in " +
             "CartesianProductVertexManagerConfig");
       }
+
+      if (edgeProperty.getDataMovementType() == CUSTOM) {
+        Preconditions.checkArgument(sourceVerticesConfig.contains(vertex),
+          "Only broadcast and cartesian product edges are allowed in cartesian product vertex");
+      } else {
+        Preconditions.checkArgument(edgeProperty.getDataMovementType() == BROADCAST,
+          "Only broadcast and cartesian product edges are allowed in cartesian product vertex");
+      }
     }
 
     for (String vertex : sourceVerticesConfig) {

http://git-wip-us.apache.org/repos/asf/tez/blob/b4c949c9/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductVertexManagerPartitioned.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductVertexManagerPartitioned.java b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductVertexManagerPartitioned.java
index af2abae..38ec1b1 100644
--- a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductVertexManagerPartitioned.java
+++ b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductVertexManagerPartitioned.java
@@ -17,7 +17,6 @@
  */
 package org.apache.tez.runtime.library.cartesianproduct;
 
-import com.google.common.base.Preconditions;
 import com.google.common.primitives.Ints;
 import org.apache.tez.common.ReflectionUtils;
 import org.apache.tez.dag.api.TezReflectionException;
@@ -48,7 +47,9 @@ class CartesianProductVertexManagerPartitioned extends CartesianProductVertexMan
   private int parallelism = 0;
   private boolean vertexStarted = false;
   private boolean vertexReconfigured = false;
-  private int numSourceVertexConfigured = 0;
+  private int numCPSrcNotInConfiguredState = 0;
+  private int numBroadcastSrcNotInRunningState = 0;
+
   private CartesianProductFilter filter;
   private Map<String, BitSet> sourceTaskCompleted = new HashMap<>();
   private int numFinishedSrcTasks = 0;
@@ -78,33 +79,18 @@ class CartesianProductVertexManagerPartitioned extends CartesianProductVertexMan
     for (String sourceVertex : sourceVertices) {
       sourceTaskCompleted.put(sourceVertex, new BitSet());
     }
-    for (String vertex : sourceVertices) {
-      getContext().registerForVertexStateUpdates(vertex, EnumSet.of(VertexState.CONFIGURED));
+    for (String vertex : getContext().getInputVertexEdgeProperties().keySet()) {
+      if (sourceVertices.indexOf(vertex) != -1) {
+        getContext().registerForVertexStateUpdates(vertex, EnumSet.of(VertexState.CONFIGURED));
+        numCPSrcNotInConfiguredState++;
+      } else {
+        getContext().registerForVertexStateUpdates(vertex, EnumSet.of(VertexState.RUNNING));
+        numBroadcastSrcNotInRunningState++;
+      }
     }
     getContext().vertexReconfigurationPlanned();
   }
 
-  private void reconfigureVertex() throws IOException {
-    // try all combinations, check against filter and get final parallelism
-    Map<String, Integer> vertexPartitionMap = new HashMap<>();
-
-    CartesianProductCombination combination =
-      new CartesianProductCombination(Ints.toArray(config.getNumPartitions()));
-    combination.firstTask();
-    do {
-      for (int i = 0; i < sourceVertices.size(); i++) {
-        vertexPartitionMap.put(sourceVertices.get(i), combination.getCombination().get(i));
-      }
-      if (filter == null || filter.isValidCombination(vertexPartitionMap)) {
-        parallelism++;
-      }
-    } while (combination.nextTask());
-    // no need to reconfigure EM because EM already has all necessary information via config object
-    getContext().reconfigureVertex(parallelism, null, null);
-    vertexReconfigured = true;
-    getContext().doneReconfiguringVertex();
-  }
-
   @Override
   public synchronized void onVertexStarted(List<TaskAttemptIdentifier> completions)
     throws Exception {
@@ -120,12 +106,17 @@ class CartesianProductVertexManagerPartitioned extends CartesianProductVertexMan
 
   @Override
   public synchronized void onVertexStateUpdated(VertexStateUpdate stateUpdate) throws IOException{
-    Preconditions.checkArgument(stateUpdate.getVertexState() == VertexState.CONFIGURED);
-    if (!vertexReconfigured) {
-      reconfigureVertex();
+    VertexState state = stateUpdate.getVertexState();
+
+    if (state == VertexState.CONFIGURED) {
+      if (!vertexReconfigured) {
+        reconfigureVertex();
+      }
+      numCPSrcNotInConfiguredState--;
+      totalNumSrcTasks += getContext().getVertexNumTasks(stateUpdate.getVertexName());
+    } else if (state == VertexState.RUNNING){
+      numBroadcastSrcNotInRunningState--;
     }
-    numSourceVertexConfigured++;
-    totalNumSrcTasks += getContext().getVertexNumTasks(stateUpdate.getVertexName());
     // try schedule because there may be no more vertex start and source completions
     tryScheduleTask();
   }
@@ -134,6 +125,11 @@ class CartesianProductVertexManagerPartitioned extends CartesianProductVertexMan
   public synchronized void onSourceTaskCompleted(TaskAttemptIdentifier attempt) throws Exception {
     int taskId = attempt.getTaskIdentifier().getIdentifier();
     String vertex = attempt.getTaskIdentifier().getVertexIdentifier().getName();
+
+    if (!sourceTaskCompleted.containsKey(vertex)) {
+      return;
+    }
+
     BitSet bitSet = this.sourceTaskCompleted.get(vertex);
     if (!bitSet.get(taskId)) {
       bitSet.set(taskId);
@@ -142,13 +138,33 @@ class CartesianProductVertexManagerPartitioned extends CartesianProductVertexMan
     }
   }
 
+  private void reconfigureVertex() throws IOException {
+    // try all combinations, check against filter and get final parallelism
+    Map<String, Integer> vertexPartitionMap = new HashMap<>();
+
+    CartesianProductCombination combination =
+      new CartesianProductCombination(Ints.toArray(config.getNumPartitions()));
+    combination.firstTask();
+    do {
+      for (int i = 0; i < sourceVertices.size(); i++) {
+        vertexPartitionMap.put(sourceVertices.get(i), combination.getCombination().get(i));
+      }
+      if (filter == null || filter.isValidCombination(vertexPartitionMap)) {
+        parallelism++;
+      }
+    } while (combination.nextTask());
+    // no need to reconfigure EM because EM already has all necessary information via config object
+    getContext().reconfigureVertex(parallelism, null, null);
+    vertexReconfigured = true;
+    getContext().doneReconfiguringVertex();
+  }
+
   /**
    * schedule task as the ascending order of id. Slow start has same behavior as ShuffleVertexManager
    */
   private void tryScheduleTask() {
     // only schedule task when vertex is already started and all source vertices are configured
-    if (!vertexStarted
-      || numSourceVertexConfigured != sourceVertices.size()) {
+    if (!vertexStarted || numCPSrcNotInConfiguredState > 0 || numBroadcastSrcNotInRunningState > 0) {
       return;
     }
     // determine the destination task with largest id to schedule

http://git-wip-us.apache.org/repos/asf/tez/blob/b4c949c9/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductVertexManagerUnpartitioned.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductVertexManagerUnpartitioned.java b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductVertexManagerUnpartitioned.java
index af7d15e..5114293 100644
--- a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductVertexManagerUnpartitioned.java
+++ b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/cartesianproduct/CartesianProductVertexManagerUnpartitioned.java
@@ -17,7 +17,6 @@
  */
 package org.apache.tez.runtime.library.cartesianproduct;
 
-import com.google.common.base.Preconditions;
 import com.google.common.primitives.Ints;
 import org.apache.tez.dag.api.EdgeManagerPluginDescriptor;
 import org.apache.tez.dag.api.EdgeProperty;
@@ -27,32 +26,36 @@ import org.apache.tez.dag.api.VertexManagerPluginContext.ScheduleTaskRequest;
 import org.apache.tez.dag.api.event.VertexState;
 import org.apache.tez.dag.api.event.VertexStateUpdate;
 import org.apache.tez.runtime.api.TaskAttemptIdentifier;
+import org.roaringbitmap.RoaringBitmap;
 
 import java.io.IOException;
 import java.nio.ByteBuffer;
 import java.util.ArrayList;
-import java.util.BitSet;
 import java.util.EnumSet;
 import java.util.HashMap;
+import java.util.Iterator;
 import java.util.LinkedList;
 import java.util.List;
 import java.util.Map;
 import java.util.Queue;
 
+import static org.apache.tez.dag.api.EdgeProperty.DataMovementType.CUSTOM;
 import static org.apache.tez.runtime.library.cartesianproduct.CartesianProductUserPayload.CartesianProductConfigProto;
 
 class CartesianProductVertexManagerUnpartitioned extends CartesianProductVertexManagerReal {
   List<String> sourceVertices;
   private int parallelism = 1;
-  private boolean vertexStarted = false;
   private boolean vertexReconfigured = false;
-  private int numSourceVertexConfigured = 0;
+  private boolean vertexStarted = false;
+  private boolean vertexStartSchedule = false;
+  private int numCPSrcNotInConfigureState = 0;
+  private int numBroadcastSrcNotInRunningState = 0;
   private int[] numTasks;
-  private Queue<TaskAttemptIdentifier> pendingCompletedSrcTask = new LinkedList<>();
-  private Map<String, BitSet> sourceTaskCompleted = new HashMap<>();
-  private BitSet scheduledTasks = new BitSet();
+
+  private Queue<TaskAttemptIdentifier> completedSrcTaskToProcess = new LinkedList<>();
+  private Map<String, RoaringBitmap> sourceTaskCompleted = new HashMap<>();
+  private RoaringBitmap scheduledTasks = new RoaringBitmap();
   private CartesianProductConfig config;
-  private int numSrcHasCompletedTask = 0;
 
   public CartesianProductVertexManagerUnpartitioned(VertexManagerPluginContext context) {
     super(context);
@@ -62,24 +65,97 @@ class CartesianProductVertexManagerUnpartitioned extends CartesianProductVertexM
   public void initialize(CartesianProductVertexManagerConfig config) throws Exception {
     sourceVertices = config.getSourceVertices();
     numTasks = new int[sourceVertices.size()];
-    for (String vertex : sourceVertices) {
-      sourceTaskCompleted.put(vertex, new BitSet());
-    }
-    for (String vertex : sourceVertices) {
-      getContext().registerForVertexStateUpdates(vertex, EnumSet.of(VertexState.CONFIGURED));
+
+    for (String vertex : getContext().getInputVertexEdgeProperties().keySet()) {
+      if (sourceVertices.indexOf(vertex) != -1) {
+        sourceTaskCompleted.put(vertex, new RoaringBitmap());
+        getContext().registerForVertexStateUpdates(vertex, EnumSet.of(VertexState.CONFIGURED));
+        numCPSrcNotInConfigureState++;
+      } else {
+        getContext().registerForVertexStateUpdates(vertex, EnumSet.of(VertexState.RUNNING));
+        numBroadcastSrcNotInRunningState++;
+      }
     }
     this.config = config;
     getContext().vertexReconfigurationPlanned();
   }
 
-  private void reconfigureVertex() throws IOException {
+  @Override
+  public synchronized void onVertexStarted(List<TaskAttemptIdentifier> completions)
+    throws Exception {
+    vertexStarted = true;
+    if (completions != null) {
+      for (TaskAttemptIdentifier attempt : completions) {
+        addCompletedSrcTaskToProcess(attempt);
+      }
+    }
+    tryScheduleTasks();
+  }
+
+  @Override
+  public synchronized void onVertexStateUpdated(VertexStateUpdate stateUpdate) throws IOException {
+    String vertex = stateUpdate.getVertexName();
+    VertexState state = stateUpdate.getVertexState();
+
+    if (state == VertexState.CONFIGURED) {
+      numTasks[sourceVertices.indexOf(vertex)] = getContext().getVertexNumTasks(vertex);
+      numCPSrcNotInConfigureState--;
+    } else if (state == VertexState.RUNNING) {
+      numBroadcastSrcNotInRunningState--;
+    }
+    tryScheduleTasks();
+  }
+
+  @Override
+  public synchronized void onSourceTaskCompleted(TaskAttemptIdentifier attempt) throws Exception {
+    addCompletedSrcTaskToProcess(attempt);
+    tryScheduleTasks();
+  }
+
+  private void addCompletedSrcTaskToProcess(TaskAttemptIdentifier attempt) {
+    int taskId = attempt.getTaskIdentifier().getIdentifier();
+    String vertex = attempt.getTaskIdentifier().getVertexIdentifier().getName();
+    if (sourceVertices.indexOf(vertex) == -1) {
+      return;
+    }
+    if (sourceTaskCompleted.get(vertex).contains(taskId)) {
+      return;
+    }
+    sourceTaskCompleted.get(vertex).add(taskId);
+    completedSrcTaskToProcess.add(attempt);
+  }
+
+  private boolean tryStartSchedule() {
+    if (!vertexReconfigured || !vertexStarted || numBroadcastSrcNotInRunningState > 0) {
+      return false;
+    }
+    for (RoaringBitmap bitmap: sourceTaskCompleted.values()) {
+      if (bitmap.isEmpty()) {
+        return false;
+      }
+    }
+    vertexStartSchedule = true;
+    return true;
+  }
+
+  private boolean tryReconfigure() throws IOException {
+    if (numCPSrcNotInConfigureState > 0) {
+      return false;
+    }
+
     for (int numTask : numTasks) {
       parallelism *= numTask;
     }
 
     UserPayload payload = null;
     Map<String, EdgeProperty> edgeProperties = getContext().getInputVertexEdgeProperties();
-    for (EdgeProperty edgeProperty : edgeProperties.values()) {
+    Iterator<Map.Entry<String,EdgeProperty>> iter = edgeProperties.entrySet().iterator();
+    while (iter.hasNext()) {
+      EdgeProperty edgeProperty = iter.next().getValue();
+      if (edgeProperty.getDataMovementType() != CUSTOM) {
+        iter.remove();
+        continue;
+      }
       EdgeManagerPluginDescriptor descriptor = edgeProperty.getEdgeManagerDescriptor();
       if (payload == null) {
         CartesianProductConfigProto.Builder builder = CartesianProductConfigProto.newBuilder();
@@ -92,83 +168,42 @@ class CartesianProductVertexManagerUnpartitioned extends CartesianProductVertexM
     getContext().reconfigureVertex(parallelism, null, edgeProperties);
     vertexReconfigured = true;
     getContext().doneReconfiguringVertex();
+    return true;
   }
 
-  @Override
-  public synchronized void onVertexStarted(List<TaskAttemptIdentifier> completions)
-    throws Exception {
-    vertexStarted = true;
-    // if vertex is already reconfigured, we can handle pending completions immediately
-    // otherwise we have to wait until vertex is reconfigured
-    if (vertexReconfigured) {
-      Preconditions.checkArgument(pendingCompletedSrcTask.size() == 0,
-        "Unexpected pending source completion on vertex start after vertex reconfiguration");
-      for (TaskAttemptIdentifier taId : completions) {
-        handleCompletedSrcTask(taId);
-      }
-    } else {
-      pendingCompletedSrcTask.addAll(completions);
+  private void tryScheduleTasks() throws IOException {
+    if (!vertexReconfigured && !tryReconfigure()) {
+      return;
     }
-  }
-
-  @Override
-  public synchronized void onVertexStateUpdated(VertexStateUpdate stateUpdate) throws IOException {
-    Preconditions.checkArgument(stateUpdate.getVertexState() == VertexState.CONFIGURED);
-    String vertex = stateUpdate.getVertexName();
-    numTasks[sourceVertices.indexOf(vertex)] = getContext().getVertexNumTasks(vertex);
-    // reconfigure vertex when all source vertices are CONFIGURED
-    if (++numSourceVertexConfigured == sourceVertices.size()) {
-      reconfigureVertex();
-      // handle pending source completions when vertex is started and reconfigured
-      if (vertexStarted) {
-        while (!pendingCompletedSrcTask.isEmpty()) {
-          handleCompletedSrcTask(pendingCompletedSrcTask.poll());
-        }
-      }
+    if (!vertexStartSchedule && !tryStartSchedule()) {
+      return;
     }
-  }
 
-  @Override
-  public synchronized void onSourceTaskCompleted(TaskAttemptIdentifier attempt) throws Exception {
-    if (numSourceVertexConfigured < sourceVertices.size()) {
-      pendingCompletedSrcTask.add(attempt);
-      return;
+    while (!completedSrcTaskToProcess.isEmpty()) {
+      scheduledTasksDependOnCompletion(completedSrcTaskToProcess.poll());
     }
-    Preconditions.checkArgument(pendingCompletedSrcTask.size() == 0,
-      "Unexpected pending src completion on source task completed after vertex reconfiguration");
-    handleCompletedSrcTask(attempt);
   }
 
-  private void handleCompletedSrcTask(TaskAttemptIdentifier attempt) {
+  private void scheduledTasksDependOnCompletion(TaskAttemptIdentifier attempt) {
     int taskId = attempt.getTaskIdentifier().getIdentifier();
     String vertex = attempt.getTaskIdentifier().getVertexIdentifier().getName();
-    if (sourceTaskCompleted.get(vertex).get(taskId)) {
-      return;
-    }
-
-    if (sourceTaskCompleted.get(vertex).isEmpty()) {
-      numSrcHasCompletedTask++;
-    }
-    sourceTaskCompleted.get(vertex).set(taskId);
-    if (numSrcHasCompletedTask != sourceVertices.size()) {
-      return;
-    }
 
     List<ScheduleTaskRequest> requests = new ArrayList<>();
-    CartesianProductCombination combination = new CartesianProductCombination(numTasks, sourceVertices.indexOf(vertex));
+    CartesianProductCombination combination =
+      new CartesianProductCombination(numTasks, sourceVertices.indexOf(vertex));
     combination.firstTaskWithFixedPartition(taskId);
     do {
       List<Integer> list = combination.getCombination();
       boolean readyToSchedule = true;
       for (int i = 0; i < list.size(); i++) {
-        if (!sourceTaskCompleted.get(sourceVertices.get(i)).get(list.get(i))) {
+        if (!sourceTaskCompleted.get(sourceVertices.get(i)).contains(list.get(i))) {
           readyToSchedule = false;
           break;
         }
       }
-      if (readyToSchedule && !scheduledTasks.get(combination.getTaskId())) {
+      if (readyToSchedule && !scheduledTasks.contains(combination.getTaskId())) {
         requests.add(ScheduleTaskRequest.create(combination.getTaskId(), null));
-        scheduledTasks.set(combination.getTaskId());
+        scheduledTasks.add(combination.getTaskId());
       }
     } while (combination.nextTaskWithFixedPartition());
     if (!requests.isEmpty()) {

http://git-wip-us.apache.org/repos/asf/tez/blob/b4c949c9/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductCombination.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductCombination.java b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductCombination.java
index 0d6a928..4a2827a 100644
--- a/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductCombination.java
+++ b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductCombination.java
@@ -98,7 +98,7 @@ public class TestCartesianProductCombination {
     assertFalse(combination.nextTask());
   }
 
-  @Test//(timeout = 5000)
+  @Test(timeout = 5000)
   public void testFromTaskId() {
     for (int i = 0; i < 6; i++) {
       List<Integer> list = CartesianProductCombination.fromTaskId(new int[]{2,3}, i)

http://git-wip-us.apache.org/repos/asf/tez/blob/b4c949c9/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductEdgeManagerPartitioned.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductEdgeManagerPartitioned.java b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductEdgeManagerPartitioned.java
index 2e8697d..8710c55 100644
--- a/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductEdgeManagerPartitioned.java
+++ b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductEdgeManagerPartitioned.java
@@ -135,7 +135,7 @@ public class TestCartesianProductEdgeManagerPartitioned {
    * Vertex v0 has 2 tasks which generate 3 partitions
    * Vertex v1 has 3 tasks which generate 4 partitions
    */
-  @Test//(timeout = 5000)
+  @Test(timeout = 5000)
   public void testTwoWayWithFilter() throws Exception {
     ByteBuffer buffer = ByteBuffer.allocate(2);
     buffer.putChar('>');

http://git-wip-us.apache.org/repos/asf/tez/blob/b4c949c9/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductVertexManager.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductVertexManager.java b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductVertexManager.java
index 755c578..f3a5851 100644
--- a/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductVertexManager.java
+++ b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductVertexManager.java
@@ -19,37 +19,72 @@ package org.apache.tez.runtime.library.cartesianproduct;
 
 import org.apache.tez.dag.api.EdgeManagerPluginDescriptor;
 import org.apache.tez.dag.api.EdgeProperty;
+import org.apache.tez.dag.api.EdgeProperty.DataMovementType;
 import org.apache.tez.dag.api.TezConfiguration;
+import org.apache.tez.dag.api.UserPayload;
 import org.apache.tez.dag.api.VertexManagerPluginContext;
+import org.junit.Before;
 import org.junit.Test;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
 
 import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 
+import static org.apache.tez.dag.api.EdgeProperty.DataMovementType.BROADCAST;
+import static org.apache.tez.dag.api.EdgeProperty.DataMovementType.CUSTOM;
 import static org.junit.Assert.assertTrue;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
 
 public class TestCartesianProductVertexManager {
+  private CartesianProductVertexManager vertexManager;
+  private VertexManagerPluginContext context;
+  private String vertexName = "cp";
+  private TezConfiguration conf;
+  private CartesianProductConfig config;
+  private Map<String, EdgeProperty> edgePropertyMap;
+  private EdgeProperty cpEdge = EdgeProperty.create(EdgeManagerPluginDescriptor.create(
+    CartesianProductEdgeManager.class.getName()), null, null, null, null);
+  private EdgeProperty customEdge = EdgeProperty.create(EdgeManagerPluginDescriptor.create(
+    "OTHER_EDGE"), null, null, null, null);
+  private EdgeProperty broadcastEdge =
+    EdgeProperty.create(DataMovementType.BROADCAST, null, null, null, null);
+
+  @Before
+  public void setup() {
+    context = mock(VertexManagerPluginContext.class);
+    conf = new TezConfiguration();
+    edgePropertyMap = new HashMap<>();
+    edgePropertyMap.put("v0", cpEdge);
+    edgePropertyMap.put("v1", cpEdge);
+    when(context.getVertexName()).thenReturn(vertexName);
+    when(context.getVertexNumTasks(vertexName)).thenReturn(-1);
+    when(context.getInputVertexEdgeProperties()).thenReturn(edgePropertyMap);
+    when(context.getUserPayload()).thenAnswer(new Answer<UserPayload>() {
+      @Override
+      public UserPayload answer(InvocationOnMock invocation) throws Throwable {
+        return config.toUserPayload(conf);
+      }
+    });
+    vertexManager = new CartesianProductVertexManager(context);
+  }
+
   @Test(timeout = 5000)
-  public void testInitialize() throws Exception {
-    VertexManagerPluginContext context = mock(VertexManagerPluginContext.class);
-    CartesianProductVertexManager vertexManager = new CartesianProductVertexManager(context);
-    TezConfiguration conf = new TezConfiguration();
+  public void testRejectPredefinedParallelism() throws Exception {
+    when(context.getVertexNumTasks(vertexName)).thenReturn(10);
+    try {
+      vertexManager = new CartesianProductVertexManager(context);
+      assertTrue(false);
+    } catch (Exception ignored){}
+  }
 
+  @Test(timeout = 5000)
+  public void testChooseRealVertexManager() throws Exception {
     // partitioned case
-    CartesianProductConfig config =
-      new CartesianProductConfig(new int[]{2,3}, new String[]{"v0", "v1"}, null);
-    when(context.getUserPayload()).thenReturn(config.toUserPayload(conf));
-    EdgeProperty edgeProperty =
-      EdgeProperty.create(EdgeManagerPluginDescriptor.create(
-        CartesianProductEdgeManager.class.getName()), null, null, null, null);
-    Map<String, EdgeProperty> edgePropertyMap = new HashMap<>();
-    edgePropertyMap.put("v0", edgeProperty);
-    edgePropertyMap.put("v1", edgeProperty);
-    when(context.getInputVertexEdgeProperties()).thenReturn(edgePropertyMap);
+    config = new CartesianProductConfig(new int[]{2, 3}, new String[]{"v0", "v1"}, null);
     vertexManager.initialize();
     assertTrue(vertexManager.getVertexManagerReal()
       instanceof CartesianProductVertexManagerPartitioned);
@@ -59,9 +94,69 @@ public class TestCartesianProductVertexManager {
     sourceVertices.add("v0");
     sourceVertices.add("v1");
     config = new CartesianProductConfig(sourceVertices);
-    when(context.getUserPayload()).thenReturn(config.toUserPayload(conf));
     vertexManager.initialize();
     assertTrue(vertexManager.getVertexManagerReal()
       instanceof CartesianProductVertexManagerUnpartitioned);
   }
+
+  @Test(timeout = 5000)
+  public void testCheckDAGConfigConsistent() throws Exception {
+    // positive case
+    edgePropertyMap.put("v2", broadcastEdge);
+    config = new CartesianProductConfig(new int[]{2, 3}, new String[]{"v0", "v1"}, null);
+    vertexManager.initialize();
+
+    // cartesian product edge in dag but not in config
+    edgePropertyMap.put("v2", cpEdge);
+    try {
+      vertexManager.initialize();
+      assertTrue(false);
+    } catch (Exception ignored) {}
+
+    // non-cartesian-product edge in dag but in config
+    edgePropertyMap.put("v2", broadcastEdge);
+    config = new CartesianProductConfig(new int[]{2, 3, 4}, new String[]{"v0", "v1", "v2"}, null);
+    try {
+      vertexManager.initialize();
+      assertTrue(false);
+    } catch (Exception ignored) {}
+
+    edgePropertyMap.put("v2", customEdge);
+    try {
+      vertexManager.initialize();
+      assertTrue(false);
+    } catch (Exception ignored) {}
+
+    // edge in config but not in dag
+    edgePropertyMap.remove("v2");
+    try {
+      vertexManager.initialize();
+      assertTrue(false);
+    } catch (Exception ignored) {}
+  }
+
+  @Test(timeout = 5000)
+  public void testOtherEdgeType() throws Exception {
+    // forbid other custom edge
+    edgePropertyMap.put("v2", customEdge);
+    config = new CartesianProductConfig(new int[]{2, 3}, new String[]{"v0", "v1"}, null);
+    try {
+      vertexManager.initialize();
+      assertTrue(false);
+    } catch (Exception ignored) {}
+
+    // broadcast edge should be allowed and other non-custom edge shouldn't be allowed
+    for (DataMovementType type : DataMovementType.values()) {
+      if (type == CUSTOM) {
+        continue;
+      }
+      edgePropertyMap.put("v2", EdgeProperty.create(type, null, null, null, null));
+      try {
+        vertexManager.initialize();
+        assertTrue(type == BROADCAST);
+      } catch (Exception e) {
+        assertTrue(type != BROADCAST);
+      }
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/tez/blob/b4c949c9/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductVertexManagerPartitioned.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductVertexManagerPartitioned.java b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductVertexManagerPartitioned.java
index 9aca647..99067f1 100644
--- a/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductVertexManagerPartitioned.java
+++ b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductVertexManagerPartitioned.java
@@ -17,10 +17,9 @@
  */
 package org.apache.tez.runtime.library.cartesianproduct;
 
-import org.apache.hadoop.yarn.api.records.ApplicationId;
 import org.apache.tez.dag.api.EdgeManagerPluginDescriptor;
 import org.apache.tez.dag.api.EdgeProperty;
-import org.apache.tez.dag.api.TezConfiguration;
+import org.apache.tez.dag.api.TezReflectionException;
 import org.apache.tez.dag.api.UserPayload;
 import org.apache.tez.dag.api.VertexLocationHint;
 import org.apache.tez.dag.api.VertexManagerPluginContext;
@@ -38,18 +37,17 @@ import org.junit.Test;
 import org.mockito.ArgumentCaptor;
 import org.mockito.Captor;
 import org.mockito.Matchers;
-import org.mockito.Mockito;
 import org.mockito.MockitoAnnotations;
 
 import java.util.ArrayList;
 import java.util.HashMap;
-import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
-import java.util.Set;
 
+import static org.apache.tez.dag.api.EdgeProperty.DataMovementType.BROADCAST;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNull;
+import static org.mockito.Matchers.eq;
 import static org.mockito.Matchers.isNull;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.never;
@@ -62,43 +60,53 @@ public class TestCartesianProductVertexManagerPartitioned {
   private ArgumentCaptor<Map<String, EdgeProperty>> edgePropertiesCaptor;
   @Captor
   private ArgumentCaptor<List<ScheduleTaskRequest>> scheduleTaskRequestCaptor;
-  private TezConfiguration conf = new TezConfiguration();
+  private CartesianProductVertexManagerPartitioned vertexManager;
+  private VertexManagerPluginContext context;
+  private List<TaskAttemptIdentifier> allCompletions;
 
   @Before
-  public void init() {
-    MockitoAnnotations.initMocks(this);
+  public void setup() throws TezReflectionException {
+    setupWithConfig(
+      new CartesianProductVertexManagerConfig(true, new String[]{"v0","v1"}, new int[] {2, 2},
+        CartesianProductVertexManager.TEZ_CARTESIAN_PRODUCT_SLOW_START_MIN_FRACTION_DEFAULT,
+        CartesianProductVertexManager.TEZ_CARTESIAN_PRODUCT_SLOW_START_MAX_FRACTION_DEFAULT, null));
   }
 
-  public static class TestFilter extends CartesianProductFilter {
-    public TestFilter(UserPayload payload) {
-      super(payload);
-    }
-
-    @Override
-    public boolean isValidCombination(Map<String, Integer> vertexPartitionMap) {
-      return vertexPartitionMap.get("v0") > vertexPartitionMap.get("v1");
+  private void setupWithConfig(CartesianProductVertexManagerConfig config)
+    throws TezReflectionException {
+    MockitoAnnotations.initMocks(this);
+    context = mock(VertexManagerPluginContext.class);
+    vertexManager = new CartesianProductVertexManagerPartitioned(context);
+    Map<String, EdgeProperty> edgePropertyMap = new HashMap<>();
+    edgePropertyMap.put("v0", EdgeProperty.create(EdgeManagerPluginDescriptor.create(
+      CartesianProductEdgeManager.class.getName()), null, null, null, null));
+    edgePropertyMap.put("v1", EdgeProperty.create(EdgeManagerPluginDescriptor.create(
+      CartesianProductEdgeManager.class.getName()), null, null, null, null));
+    edgePropertyMap.put("v2", EdgeProperty.create(BROADCAST, null, null, null, null));
+    when(context.getInputVertexEdgeProperties()).thenReturn(edgePropertyMap);
+    when(context.getVertexNumTasks(eq("v0"))).thenReturn(4);
+    when(context.getVertexNumTasks(eq("v1"))).thenReturn(4);
+    when(context.getVertexNumTasks(eq("v2"))).thenReturn(4);
+    vertexManager.initialize(config);
+
+    allCompletions = new ArrayList<>();
+    for (int i = 0; i < 3; i++) {
+      for (int j = 0; j < 4; j++) {
+        allCompletions.add(new TaskAttemptIdentifierImpl("dag", "v" + i,
+          TezTaskAttemptID.getInstance(TezTaskID.getInstance(TezVertexID.getInstance(
+            TezDAGID.getInstance("0", 0, 0), i), j), 0)));
+      }
     }
   }
 
-  private void testReconfigureVertexHelper(CartesianProductConfig config, int parallelism)
+  private void testReconfigureVertexHelper(CartesianProductVertexManagerConfig config,
+                                           int parallelism)
     throws Exception {
-    VertexManagerPluginContext mockContext = mock(VertexManagerPluginContext.class);
-    when(mockContext.getUserPayload()).thenReturn(config.toUserPayload(conf));
-
-    EdgeProperty edgeProperty =
-      EdgeProperty.create(EdgeManagerPluginDescriptor.create(
-        CartesianProductEdgeManager.class.getName()), null, null, null, null);
-    Map<String, EdgeProperty> inputEdgeProperties = new HashMap<>();
-    for (String vertex : config.getSourceVertices()) {
-      inputEdgeProperties.put(vertex, edgeProperty);
-    }
-    when(mockContext.getInputVertexEdgeProperties()).thenReturn(inputEdgeProperties);
-    CartesianProductVertexManager vertexManager = new CartesianProductVertexManager(mockContext);
-    vertexManager.initialize();
+    setupWithConfig(config);
     ArgumentCaptor<Integer> parallelismCaptor = ArgumentCaptor.forClass(Integer.class);
 
     vertexManager.onVertexStateUpdated(new VertexStateUpdate("v0", VertexState.CONFIGURED));
-    verify(mockContext, times(1)).reconfigureVertex(parallelismCaptor.capture(),
+    verify(context, times(1)).reconfigureVertex(parallelismCaptor.capture(),
       isNull(VertexLocationHint.class), edgePropertiesCaptor.capture());
     assertEquals((int)parallelismCaptor.getValue(), parallelism);
     assertNull(edgePropertiesCaptor.getValue());
@@ -107,124 +115,98 @@ public class TestCartesianProductVertexManagerPartitioned {
   @Test(timeout = 5000)
   public void testReconfigureVertex() throws Exception {
     testReconfigureVertexHelper(
-      new CartesianProductConfig(new int[]{5,5}, new String[]{"v0", "v1"},
-        new CartesianProductFilterDescriptor(TestFilter.class.getName())), 10);
+      new CartesianProductVertexManagerConfig(true, new String[]{"v0", "v1"}, new int[] {5, 5}, 0,
+        0, new CartesianProductFilterDescriptor(TestFilter.class.getName())), 10);
     testReconfigureVertexHelper(
-      new CartesianProductConfig(new int[]{5,5}, new String[]{"v0", "v1"}, null), 25);
+      new CartesianProductVertexManagerConfig(true, new String[]{"v0", "v1"}, new int[] {5, 5}, 0,
+        0, null), 25);
   }
 
   @Test(timeout = 5000)
   public void testScheduling() throws Exception {
-    CartesianProductConfig config = new CartesianProductConfig(new int[]{2,2},
-      new String[]{"v0", "v1"}, null);
-    VertexManagerPluginContext mockContext = mock(VertexManagerPluginContext.class);
-    when(mockContext.getUserPayload()).thenReturn(config.toUserPayload(conf));
-    Set<String> inputVertices = new HashSet<String>();
-    inputVertices.add("v0");
-    inputVertices.add("v1");
-    when(mockContext.getVertexInputNames()).thenReturn(inputVertices);
-    when(mockContext.getVertexNumTasks("v0")).thenReturn(4);
-    when(mockContext.getVertexNumTasks("v1")).thenReturn(4);
-    EdgeProperty edgeProperty =
-      EdgeProperty.create(EdgeManagerPluginDescriptor.create(
-        CartesianProductEdgeManager.class.getName()), null, null, null, null);
-    Map<String, EdgeProperty> inputEdgeProperties = new HashMap<String, EdgeProperty>();
-    inputEdgeProperties.put("v0", edgeProperty);
-    inputEdgeProperties.put("v1", edgeProperty);
-    when(mockContext.getInputVertexEdgeProperties()).thenReturn(inputEdgeProperties);
-    CartesianProductVertexManager vertexManager = new CartesianProductVertexManager(mockContext);
-    vertexManager.initialize();
-
-    vertexManager.onVertexStarted(new ArrayList<TaskAttemptIdentifier>());
+    vertexManager.onVertexStarted(null);
     vertexManager.onVertexStateUpdated(new VertexStateUpdate("v0", VertexState.CONFIGURED));
     vertexManager.onVertexStateUpdated(new VertexStateUpdate("v1", VertexState.CONFIGURED));
 
 
-    TaskAttemptIdentifier taId = mock(TaskAttemptIdentifier.class, Mockito.RETURNS_DEEP_STUBS);
-    when(taId.getTaskIdentifier().getVertexIdentifier().getName()).thenReturn("v0", "v0", "v1",
-      "v1", "v0", "v0", "v1", "v1");
-    when(taId.getTaskIdentifier().getIdentifier()).thenReturn(0, 1, 0, 1, 2, 3, 2, 3);
-
-    for (int i = 0; i < 2; i++) {
-      vertexManager.onSourceTaskCompleted(taId);
-      verify(mockContext, never()).scheduleTasks(Matchers.<List<ScheduleTaskRequest>>any());
-    }
+    vertexManager.onSourceTaskCompleted(allCompletions.get(0));
+    vertexManager.onSourceTaskCompleted(allCompletions.get(1));
+    verify(context, never()).scheduleTasks(Matchers.<List<ScheduleTaskRequest>>any());
 
     List<ScheduleTaskRequest> scheduleTaskRequests;
+    vertexManager.onSourceTaskCompleted(allCompletions.get(2));
+    // shouldn't start schedule because broadcast src is not in RUNNING state
+    verify(context, never()).scheduleTasks(Matchers.<List<ScheduleTaskRequest>>any());
 
-    vertexManager.onSourceTaskCompleted(taId);
-    verify(mockContext, times(1)).scheduleTasks(scheduleTaskRequestCaptor.capture());
+    vertexManager.onVertexStateUpdated(new VertexStateUpdate("v2", VertexState.RUNNING));
+    verify(context, times(1)).scheduleTasks(scheduleTaskRequestCaptor.capture());
     scheduleTaskRequests = scheduleTaskRequestCaptor.getValue();
     assertEquals(1, scheduleTaskRequests.size());
     assertEquals(0, scheduleTaskRequests.get(0).getTaskIndex());
 
-    vertexManager.onSourceTaskCompleted(taId);
-    verify(mockContext, times(2)).scheduleTasks(scheduleTaskRequestCaptor.capture());
-    scheduleTaskRequests = scheduleTaskRequestCaptor.getValue();
-    assertEquals(1, scheduleTaskRequests.size());
-    assertEquals(1, scheduleTaskRequests.get(0).getTaskIndex());
-
-    vertexManager.onSourceTaskCompleted(taId);
-    verify(mockContext, times(3)).scheduleTasks(scheduleTaskRequestCaptor.capture());
-    scheduleTaskRequests = scheduleTaskRequestCaptor.getValue();
-    assertEquals(1, scheduleTaskRequests.size());
-    assertEquals(2, scheduleTaskRequests.get(0).getTaskIndex());
+    // completion from broadcast src shouldn't matter
+    vertexManager.onSourceTaskCompleted(allCompletions.get(8));
+    verify(context, times(1)).scheduleTasks(scheduleTaskRequestCaptor.capture());
 
-    vertexManager.onSourceTaskCompleted(taId);
-    verify(mockContext, times(4)).scheduleTasks(scheduleTaskRequestCaptor.capture());
-    scheduleTaskRequests = scheduleTaskRequestCaptor.getValue();
-    assertEquals(1, scheduleTaskRequests.size());
-    assertEquals(3, scheduleTaskRequests.get(0).getTaskIndex());
+    for (int i = 3; i < 6; i++) {
+      vertexManager.onSourceTaskCompleted(allCompletions.get(i));
+      verify(context, times(i-1)).scheduleTasks(scheduleTaskRequestCaptor.capture());
+      scheduleTaskRequests = scheduleTaskRequestCaptor.getValue();
+      assertEquals(1, scheduleTaskRequests.size());
+      assertEquals(i-2, scheduleTaskRequests.get(0).getTaskIndex());
+    }
 
-    for (int i = 0; i < 2; i++) {
-      vertexManager.onSourceTaskCompleted(taId);
-      verify(mockContext, times(4)).scheduleTasks(Matchers.<List<ScheduleTaskRequest>>any());
+    for (int i = 6; i < 8; i++) {
+      vertexManager.onSourceTaskCompleted(allCompletions.get(i));
+      verify(context, times(4)).scheduleTasks(Matchers.<List<ScheduleTaskRequest>>any());
     }
   }
 
   @Test(timeout = 5000)
-  public void testVertexStartWithCompletion() throws Exception {
-    CartesianProductConfig config = new CartesianProductConfig(new int[]{2,2},
-      new String[]{"v0", "v1"}, null);
-    VertexManagerPluginContext mockContext = mock(VertexManagerPluginContext.class);
-    when(mockContext.getUserPayload()).thenReturn(config.toUserPayload(conf));
-    Set<String> inputVertices = new HashSet<String>();
-    inputVertices.add("v0");
-    inputVertices.add("v1");
-    when(mockContext.getVertexInputNames()).thenReturn(inputVertices);
-    when(mockContext.getVertexNumTasks("v0")).thenReturn(4);
-    when(mockContext.getVertexNumTasks("v1")).thenReturn(4);
-    EdgeProperty edgeProperty =
-      EdgeProperty.create(EdgeManagerPluginDescriptor.create(
-        CartesianProductEdgeManager.class.getName()), null, null, null, null);
-    Map<String, EdgeProperty> inputEdgeProperties = new HashMap<String, EdgeProperty>();
-    inputEdgeProperties.put("v0", edgeProperty);
-    inputEdgeProperties.put("v1", edgeProperty);
-    when(mockContext.getInputVertexEdgeProperties()).thenReturn(inputEdgeProperties);
-    CartesianProductVertexManager vertexManager = new CartesianProductVertexManager(mockContext);
-    vertexManager.initialize();
+  public void testOnVertexStartWithBroadcastRunning() throws Exception {
+    testOnVertexStartHelper(true);
+  }
+
+  @Test(timeout = 5000)
+  public void testOnVertexStartWithoutBroadcastRunning() throws Exception {
+    testOnVertexStartHelper(false);
+  }
 
+  private void testOnVertexStartHelper(boolean broadcastRunning) throws Exception {
     vertexManager.onVertexStateUpdated(new VertexStateUpdate("v0", VertexState.CONFIGURED));
     vertexManager.onVertexStateUpdated(new VertexStateUpdate("v1", VertexState.CONFIGURED));
+    if (broadcastRunning) {
+      vertexManager.onVertexStateUpdated(new VertexStateUpdate("v2", VertexState.RUNNING));
+    }
 
     List<TaskAttemptIdentifier> completions = new ArrayList<>();
-    TezDAGID dagId = TezDAGID.getInstance(ApplicationId.newInstance(0, 0), 0);
-    TezVertexID v0Id = TezVertexID.getInstance(dagId, 0);
-    TezVertexID v1Id = TezVertexID.getInstance(dagId, 1);
-
-    completions.add(new TaskAttemptIdentifierImpl("dag", "v0",
-      TezTaskAttemptID.getInstance(TezTaskID.getInstance(v0Id, 0), 0)));
-    completions.add(new TaskAttemptIdentifierImpl("dag", "v0",
-      TezTaskAttemptID.getInstance(TezTaskID.getInstance(v0Id, 1), 0)));
-    completions.add(new TaskAttemptIdentifierImpl("dag", "v1",
-      TezTaskAttemptID.getInstance(TezTaskID.getInstance(v1Id, 0), 0)));
+    completions.add(allCompletions.get(0));
+    completions.add(allCompletions.get(1));
+    completions.add(allCompletions.get(4));
+    completions.add(allCompletions.get(8));
 
     vertexManager.onVertexStarted(completions);
 
+    if (!broadcastRunning) {
+      verify(context, never()).scheduleTasks(Matchers.<List<ScheduleTaskRequest>>any());
+      vertexManager.onVertexStateUpdated(new VertexStateUpdate("v2", VertexState.RUNNING));
+    }
+
     List<ScheduleTaskRequest> scheduleTaskRequests;
-    verify(mockContext, times(1)).scheduleTasks(scheduleTaskRequestCaptor.capture());
+    verify(context, times(1)).scheduleTasks(scheduleTaskRequestCaptor.capture());
     scheduleTaskRequests = scheduleTaskRequestCaptor.getValue();
     assertEquals(1, scheduleTaskRequests.size());
     assertEquals(0, scheduleTaskRequests.get(0).getTaskIndex());
   }
+
+  public static class TestFilter extends CartesianProductFilter {
+    public TestFilter(UserPayload payload) {
+      super(payload);
+    }
+
+    @Override
+    public boolean isValidCombination(Map<String, Integer> vertexPartitionMap) {
+      return vertexPartitionMap.get("v0") > vertexPartitionMap.get("v1");
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/tez/blob/b4c949c9/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductVertexManagerUnpartitioned.java
----------------------------------------------------------------------
diff --git a/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductVertexManagerUnpartitioned.java b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductVertexManagerUnpartitioned.java
index f76de96..dfe2830 100644
--- a/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductVertexManagerUnpartitioned.java
+++ b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/cartesianproduct/TestCartesianProductVertexManagerUnpartitioned.java
@@ -17,6 +17,7 @@
  */
 package org.apache.tez.runtime.library.cartesianproduct;
 
+import org.apache.tez.dag.api.EdgeManagerPluginDescriptor;
 import org.apache.tez.dag.api.EdgeProperty;
 import org.apache.tez.dag.api.UserPayload;
 import org.apache.tez.dag.api.VertexLocationHint;
@@ -38,11 +39,14 @@ import org.mockito.Matchers;
 import org.mockito.MockitoAnnotations;
 
 import java.util.ArrayList;
+import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 
+import static org.apache.tez.dag.api.EdgeProperty.DataMovementType.BROADCAST;
 import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNotNull;
 import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.anyInt;
@@ -69,8 +73,17 @@ public class TestCartesianProductVertexManagerUnpartitioned {
     MockitoAnnotations.initMocks(this);
     context = mock(VertexManagerPluginContext.class);
     vertexManager = new CartesianProductVertexManagerUnpartitioned(context);
+
+    Map<String, EdgeProperty> edgePropertyMap = new HashMap<>();
+    edgePropertyMap.put("v0", EdgeProperty.create(EdgeManagerPluginDescriptor.create(
+        CartesianProductEdgeManager.class.getName()), null, null, null, null));
+    edgePropertyMap.put("v1", EdgeProperty.create(EdgeManagerPluginDescriptor.create(
+      CartesianProductEdgeManager.class.getName()), null, null, null, null));
+    edgePropertyMap.put("v2", EdgeProperty.create(BROADCAST, null, null, null, null));
+    when(context.getInputVertexEdgeProperties()).thenReturn(edgePropertyMap);
     when(context.getVertexNumTasks(eq("v0"))).thenReturn(2);
     when(context.getVertexNumTasks(eq("v1"))).thenReturn(3);
+    when(context.getVertexNumTasks(eq("v2"))).thenReturn(5);
 
     CartesianProductVertexManagerConfig config =
       new CartesianProductVertexManagerConfig(false, new String[]{"v0","v1"}, null, 0, 0, null);
@@ -81,16 +94,19 @@ public class TestCartesianProductVertexManagerUnpartitioned {
         TezDAGID.getInstance("0", 0, 0), 0), 0), 0)));
     allCompletions.add(new TaskAttemptIdentifierImpl("dag", "v0",
       TezTaskAttemptID.getInstance(TezTaskID.getInstance(TezVertexID.getInstance(
-        TezDAGID.getInstance("0", 0, 0), 0), 0), 1)));
+        TezDAGID.getInstance("0", 0, 0), 0), 1), 0)));
     allCompletions.add(new TaskAttemptIdentifierImpl("dag", "v1",
       TezTaskAttemptID.getInstance(TezTaskID.getInstance(TezVertexID.getInstance(
         TezDAGID.getInstance("0", 0, 0), 1), 0), 0)));
     allCompletions.add(new TaskAttemptIdentifierImpl("dag", "v1",
       TezTaskAttemptID.getInstance(TezTaskID.getInstance(TezVertexID.getInstance(
-        TezDAGID.getInstance("0", 0, 0), 1), 0), 1)));
+        TezDAGID.getInstance("0", 0, 0), 1), 1), 0)));
     allCompletions.add(new TaskAttemptIdentifierImpl("dag", "v1",
       TezTaskAttemptID.getInstance(TezTaskID.getInstance(TezVertexID.getInstance(
-        TezDAGID.getInstance("0", 0, 0), 1), 0), 2)));
+        TezDAGID.getInstance("0", 0, 0), 1), 2), 0)));
+    allCompletions.add(new TaskAttemptIdentifierImpl("dag", "v2",
+      TezTaskAttemptID.getInstance(TezTaskID.getInstance(TezVertexID.getInstance(
+        TezDAGID.getInstance("0", 0, 0), 3), 0), 0)));
   }
 
   @Test(timeout = 5000)
@@ -104,6 +120,7 @@ public class TestCartesianProductVertexManagerUnpartitioned {
       isNull(VertexLocationHint.class), edgePropertiesCaptor.capture());
     assertEquals(6, (int)parallelismCaptor.getValue());
     Map<String, EdgeProperty> edgeProperties = edgePropertiesCaptor.getValue();
+    assertFalse(edgeProperties.containsKey("v2"));
     for (EdgeProperty edgeProperty : edgeProperties.values()) {
       UserPayload payload = edgeProperty.getEdgeManagerDescriptor().getUserPayload();
       CartesianProductEdgeManagerConfig newConfig =
@@ -113,47 +130,54 @@ public class TestCartesianProductVertexManagerUnpartitioned {
   }
 
   @Test(timeout = 5000)
-  public void testCompletionAfterReconfigured() throws Exception {
-    vertexManager.onVertexStarted(new ArrayList<TaskAttemptIdentifier>());
+  public void testOnSourceTaskComplete() throws Exception {
     vertexManager.onVertexStateUpdated(new VertexStateUpdate("v0", VertexState.CONFIGURED));
     vertexManager.onVertexStateUpdated(new VertexStateUpdate("v1", VertexState.CONFIGURED));
+    vertexManager.onVertexStarted(null);
     verify(context, never()).scheduleTasks(Matchers.<List<ScheduleTaskRequest>>any());
     vertexManager.onSourceTaskCompleted(allCompletions.get(0));
     verify(context, never()).scheduleTasks(Matchers.<List<ScheduleTaskRequest>>any());
     vertexManager.onSourceTaskCompleted(allCompletions.get(2));
+    // cannot start schedule because broadcast vertex isn't in RUNNING state
+    verify(context, never()).scheduleTasks(Matchers.<List<ScheduleTaskRequest>>any());
+
+    vertexManager.onVertexStateUpdated(new VertexStateUpdate("v2", VertexState.RUNNING));
     verify(context, times(1)).scheduleTasks(scheduleTaskRequestCaptor.capture());
     List<ScheduleTaskRequest> requests = scheduleTaskRequestCaptor.getValue();
     assertNotNull(requests);
     assertEquals(1, requests.size());
     assertEquals(0, requests.get(0).getTaskIndex());
-  }
 
-  @Test(timeout = 5000)
-  public void testCompletionBeforeReconfigured() throws Exception {
-    vertexManager.onVertexStarted(new ArrayList<TaskAttemptIdentifier>());
-    vertexManager.onSourceTaskCompleted(allCompletions.get(0));
-    verify(context, never()).scheduleTasks(Matchers.<List<ScheduleTaskRequest>>any());
-    vertexManager.onSourceTaskCompleted(allCompletions.get(2));
-    verify(context, never()).scheduleTasks(Matchers.<List<ScheduleTaskRequest>>any());
-    vertexManager.onVertexStateUpdated(new VertexStateUpdate("v0", VertexState.CONFIGURED));
-    verify(context, never()).scheduleTasks(Matchers.<List<ScheduleTaskRequest>>any());
-    vertexManager.onVertexStateUpdated(new VertexStateUpdate("v1", VertexState.CONFIGURED));
+    // v2 completion shouldn't matter
+    vertexManager.onSourceTaskCompleted(allCompletions.get(5));
     verify(context, times(1)).scheduleTasks(scheduleTaskRequestCaptor.capture());
-    List<ScheduleTaskRequest> requests = scheduleTaskRequestCaptor.getValue();
+
+    vertexManager.onSourceTaskCompleted(allCompletions.get(3));
+    verify(context, times(2)).scheduleTasks(scheduleTaskRequestCaptor.capture());
+    requests = scheduleTaskRequestCaptor.getValue();
     assertNotNull(requests);
     assertEquals(1, requests.size());
-    assertEquals(0, requests.get(0).getTaskIndex());
+    assertEquals(1, requests.get(0).getTaskIndex());
   }
 
-  @Test(timeout = 5000)
-  public void testStartAfterReconfigured() throws Exception {
+  private void testOnVertexStartHelper(boolean broadcastRunning) throws Exception {
     vertexManager.onVertexStateUpdated(new VertexStateUpdate("v0", VertexState.CONFIGURED));
     vertexManager.onVertexStateUpdated(new VertexStateUpdate("v1", VertexState.CONFIGURED));
+    if (broadcastRunning) {
+      vertexManager.onVertexStateUpdated(new VertexStateUpdate("v2", VertexState.RUNNING));
+    }
+
+    List<TaskAttemptIdentifier> completions = new ArrayList<>();
+    completions.add(allCompletions.get(0));
+    completions.add(allCompletions.get(2));
+    completions.add(allCompletions.get(5));
+    vertexManager.onVertexStarted(completions);
+
+    if (!broadcastRunning) {
+      verify(context, never()).scheduleTasks(Matchers.<List<ScheduleTaskRequest>>any());
+      vertexManager.onVertexStateUpdated(new VertexStateUpdate("v2", VertexState.RUNNING));
+    }
 
-    List<TaskAttemptIdentifier> completion = new ArrayList<>();
-    completion.add(allCompletions.get(0));
-    completion.add(allCompletions.get(2));
-    vertexManager.onVertexStarted(completion);
     verify(context, times(1)).scheduleTasks(scheduleTaskRequestCaptor.capture());
     List<ScheduleTaskRequest> requests = scheduleTaskRequestCaptor.getValue();
     assertNotNull(requests);
@@ -162,9 +186,14 @@ public class TestCartesianProductVertexManagerUnpartitioned {
   }
 
   @Test(timeout = 5000)
-  public void testStartBeforeReconfigured() throws Exception {
-    vertexManager.onVertexStarted(allCompletions);
-    verify(context, never()).scheduleTasks(Matchers.<List<ScheduleTaskRequest>>any());
+  public void testOnVertexStartWithBroadcastRunning() throws Exception {
+    testOnVertexStartHelper(true);
+  }
+
+  @Test(timeout = 5000)
+  public void testOnVertexStartWithoutBroadcastRunning() throws Exception {
+    testOnVertexStartHelper(false);
+
   }
 
   @Test(timeout = 5000)
@@ -176,18 +205,17 @@ public class TestCartesianProductVertexManagerUnpartitioned {
 
     CartesianProductVertexManagerConfig config =
       new CartesianProductVertexManagerConfig(false, new String[]{"v0","v1"}, null, 0, 0, null);
-    vertexManager.initialize(config);
-    allCompletions = new ArrayList<>();
-    allCompletions.add(new TaskAttemptIdentifierImpl("dag", "v0",
-      TezTaskAttemptID.getInstance(TezTaskID.getInstance(TezVertexID.getInstance(
-        TezDAGID.getInstance("0", 0, 0), 0), 0), 0)));
-    allCompletions.add(new TaskAttemptIdentifierImpl("dag", "v0",
-      TezTaskAttemptID.getInstance(TezTaskID.getInstance(TezVertexID.getInstance(
-        TezDAGID.getInstance("0", 0, 0), 0), 0), 1)));
+    Map<String, EdgeProperty> edgePropertyMap = new HashMap<>();
+    edgePropertyMap.put("v0", EdgeProperty.create(EdgeManagerPluginDescriptor.create(
+      CartesianProductEdgeManager.class.getName()), null, null, null, null));
+    edgePropertyMap.put("v1", EdgeProperty.create(EdgeManagerPluginDescriptor.create(
+      CartesianProductEdgeManager.class.getName()), null, null, null, null));
+    when(context.getInputVertexEdgeProperties()).thenReturn(edgePropertyMap);
 
-    vertexManager.onVertexStarted(new ArrayList<TaskAttemptIdentifier>());
+    vertexManager.initialize(config);
     vertexManager.onVertexStateUpdated(new VertexStateUpdate("v0", VertexState.CONFIGURED));
     vertexManager.onVertexStateUpdated(new VertexStateUpdate("v1", VertexState.CONFIGURED));
+    vertexManager.onVertexStarted(null);
     vertexManager.onSourceTaskCompleted(allCompletions.get(0));
     vertexManager.onSourceTaskCompleted(allCompletions.get(1));
   }