You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hive.apache.org by an...@apache.org on 2017/11/06 04:52:29 UTC

hive git commit: HIVE-17595 : Correct DAG for updating the last.repl.id for a database during bootstrap load (Anishek Agarwal reviewed by Daniel Dai)

Repository: hive
Updated Branches:
  refs/heads/master aaacda474 -> 06dc4e968


HIVE-17595 : Correct DAG for updating the last.repl.id for a database during bootstrap load (Anishek Agarwal reviewed by Daniel Dai)


Project: http://git-wip-us.apache.org/repos/asf/hive/repo
Commit: http://git-wip-us.apache.org/repos/asf/hive/commit/06dc4e96
Tree: http://git-wip-us.apache.org/repos/asf/hive/tree/06dc4e96
Diff: http://git-wip-us.apache.org/repos/asf/hive/diff/06dc4e96

Branch: refs/heads/master
Commit: 06dc4e968e691a5e4383be91a45dac0a62279b81
Parents: aaacda4
Author: Anishek Agarwal <an...@gmail.com>
Authored: Mon Nov 6 10:19:42 2017 +0530
Committer: Anishek Agarwal <an...@gmail.com>
Committed: Mon Nov 6 10:19:42 2017 +0530

----------------------------------------------------------------------
 .../apache/hadoop/hive/ql/exec/Utilities.java   | 59 ++++++++------
 .../repl/bootstrap/AddDependencyToLeaves.java   | 51 ++++++++++++
 .../ql/exec/repl/bootstrap/ReplLoadTask.java    | 42 ++++------
 .../exec/repl/bootstrap/load/LoadFunction.java  |  5 +-
 .../bootstrap/load/table/LoadPartitions.java    |  6 +-
 .../repl/bootstrap/load/table/LoadTable.java    |  7 +-
 .../hadoop/hive/ql/exec/util/DAGTraversal.java  | 55 +++++++++++++
 .../bootstrap/AddDependencyToLeavesTest.java    | 85 ++++++++++++++++++++
 .../hive/ql/exec/util/DAGTraversalTest.java     | 82 +++++++++++++++++++
 9 files changed, 334 insertions(+), 58 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/hive/blob/06dc4e96/ql/src/java/org/apache/hadoop/hive/ql/exec/Utilities.java
----------------------------------------------------------------------
diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/Utilities.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/Utilities.java
index 1be7eab..b78c930 100644
--- a/ql/src/java/org/apache/hadoop/hive/ql/exec/Utilities.java
+++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/Utilities.java
@@ -121,6 +121,7 @@ import org.apache.hadoop.hive.ql.exec.mr.MapRedTask;
 import org.apache.hadoop.hive.ql.exec.spark.SparkTask;
 import org.apache.hadoop.hive.ql.exec.tez.DagUtils;
 import org.apache.hadoop.hive.ql.exec.tez.TezTask;
+import org.apache.hadoop.hive.ql.exec.util.DAGTraversal;
 import org.apache.hadoop.hive.ql.exec.vector.VectorizedInputFormatInterface;
 import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatchCtx;
 import org.apache.hadoop.hive.ql.io.AcidUtils;
@@ -2569,41 +2570,49 @@ public final class Utilities {
   }
 
   public static List<TezTask> getTezTasks(List<Task<? extends Serializable>> tasks) {
-    return getTasks(tasks, TezTask.class);
+    return getTasks(tasks, new TaskFilterFunction<>(TezTask.class));
   }
 
   public static List<SparkTask> getSparkTasks(List<Task<? extends Serializable>> tasks) {
-    return getTasks(tasks, SparkTask.class);
+    return getTasks(tasks, new TaskFilterFunction<>(SparkTask.class));
   }
 
   public static List<ExecDriver> getMRTasks(List<Task<? extends Serializable>> tasks) {
-    return getTasks(tasks, ExecDriver.class);
+    return getTasks(tasks, new TaskFilterFunction<>(ExecDriver.class));
   }
 
-  @SuppressWarnings("unchecked")
-  public static <T> List<T> getTasks(List<Task<? extends Serializable>> tasks, Class<T> requiredType) {
-    List<T> typeSpecificTasks = new ArrayList<>();
-    if (tasks != null) {
-      Set<Task<? extends Serializable>> visited = new HashSet<>();
-      while (!tasks.isEmpty()) {
-        List<Task<? extends Serializable>> childTasks = new ArrayList<>();
-        for (Task<? extends Serializable> task : tasks) {
-          if (visited.contains(task)) {
-            continue;
-          }
-          if (requiredType.isInstance(task) && !typeSpecificTasks.contains(task)) {
-            typeSpecificTasks.add((T) task);
-          }
-          if (task.getDependentTasks() != null) {
-            childTasks.addAll(task.getDependentTasks());
-          }
-          visited.add(task);
-        }
-        // start recursion
-        tasks = childTasks;
+  static class TaskFilterFunction<T> implements DAGTraversal.Function {
+    private Set<Task<? extends Serializable>> visited = new HashSet<>();
+    private Class<T> requiredType;
+    private List<T> typeSpecificTasks = new ArrayList<>();
+
+    TaskFilterFunction(Class<T> requiredType) {
+      this.requiredType = requiredType;
+    }
+
+    @Override
+    public void process(Task<? extends Serializable> task) {
+      if (requiredType.isInstance(task) && !typeSpecificTasks.contains(task)) {
+        typeSpecificTasks.add((T) task);
       }
+      visited.add(task);
+    }
+
+    List<T> getTasks() {
+      return typeSpecificTasks;
     }
-    return typeSpecificTasks;
+
+    @Override
+    public boolean skipProcessing(Task<? extends Serializable> task) {
+      return visited.contains(task);
+    }
+  }
+
+  @SuppressWarnings("unchecked")
+  private static <T> List<T> getTasks(List<Task<? extends Serializable>> tasks,
+      TaskFilterFunction<T> function) {
+    DAGTraversal.traverse(tasks, function);
+    return function.getTasks();
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/hive/blob/06dc4e96/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/AddDependencyToLeaves.java
----------------------------------------------------------------------
diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/AddDependencyToLeaves.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/AddDependencyToLeaves.java
new file mode 100644
index 0000000..cf838e1
--- /dev/null
+++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/AddDependencyToLeaves.java
@@ -0,0 +1,51 @@
+/*
+  Licensed to the Apache Software Foundation (ASF) under one
+  or more contributor license agreements.  See the NOTICE file
+  distributed with this work for additional information
+  regarding copyright ownership.  The ASF licenses this file
+  to you under the Apache License, Version 2.0 (the
+  "License"); you may not use this file except in compliance
+  with the License.  You may obtain a copy of the License at
+
+      http://www.apache.org/licenses/LICENSE-2.0
+
+  Unless required by applicable law or agreed to in writing, software
+  distributed under the License is distributed on an "AS IS" BASIS,
+  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+  See the License for the specific language governing permissions and
+  limitations under the License.
+ */
+
+package org.apache.hadoop.hive.ql.exec.repl.bootstrap;
+
+import org.apache.hadoop.hive.ql.exec.Task;
+import org.apache.hadoop.hive.ql.exec.util.DAGTraversal;
+
+import java.io.Serializable;
+import java.util.Collections;
+import java.util.List;
+
+public class AddDependencyToLeaves implements DAGTraversal.Function {
+  private List<Task<? extends Serializable>> postDependencyCollectionTasks;
+
+  AddDependencyToLeaves(List<Task<? extends Serializable>> postDependencyCollectionTasks) {
+    this.postDependencyCollectionTasks = postDependencyCollectionTasks;
+  }
+
+  public AddDependencyToLeaves(Task<? extends Serializable> postDependencyTask) {
+    this(Collections.singletonList(postDependencyTask));
+  }
+
+
+  @Override
+  public void process(Task<? extends Serializable> task) {
+    if (task.getChildTasks() == null) {
+      postDependencyCollectionTasks.forEach(task::addDependentTask);
+    }
+  }
+
+  @Override
+  public boolean skipProcessing(Task<? extends Serializable> task) {
+    return postDependencyCollectionTasks.contains(task);
+  }
+}

http://git-wip-us.apache.org/repos/asf/hive/blob/06dc4e96/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/ReplLoadTask.java
----------------------------------------------------------------------
diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/ReplLoadTask.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/ReplLoadTask.java
index bf5c819..bfbec45 100644
--- a/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/ReplLoadTask.java
+++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/ReplLoadTask.java
@@ -39,12 +39,14 @@ import org.apache.hadoop.hive.ql.exec.repl.bootstrap.load.table.LoadPartitions;
 import org.apache.hadoop.hive.ql.exec.repl.bootstrap.load.table.LoadTable;
 import org.apache.hadoop.hive.ql.exec.repl.bootstrap.load.table.TableContext;
 import org.apache.hadoop.hive.ql.exec.repl.bootstrap.load.util.Context;
+import org.apache.hadoop.hive.ql.exec.util.DAGTraversal;
 import org.apache.hadoop.hive.ql.parse.SemanticException;
 import org.apache.hadoop.hive.ql.parse.repl.ReplLogger;
 import org.apache.hadoop.hive.ql.plan.api.StageType;
 
 import java.io.Serializable;
 import java.util.ArrayList;
+import java.util.Collections;
 import java.util.List;
 
 import static org.apache.hadoop.hive.ql.exec.repl.bootstrap.load.LoadDatabase.AlterDatabase;
@@ -225,22 +227,28 @@ public class ReplLoadTask extends Task<ReplLoadWork> implements Serializable {
     return 0;
   }
 
-  private Task<? extends Serializable> createEndReplLogTask(Context context, Scope scope,
+  private void createEndReplLogTask(Context context, Scope scope,
                                                   ReplLogger replLogger) throws SemanticException {
     Database dbInMetadata = work.databaseEvent(context.hiveConf).dbInMetadata(work.dbNameToLoadIn);
     ReplStateLogWork replLogWork = new ReplStateLogWork(replLogger, dbInMetadata.getParameters());
     Task<ReplStateLogWork> replLogTask = TaskFactory.get(replLogWork, conf);
-    if (null == scope.rootTasks) {
+    if (scope.rootTasks.isEmpty()) {
       scope.rootTasks.add(replLogTask);
     } else {
-      dependency(scope.rootTasks, replLogTask);
+      DAGTraversal.traverse(scope.rootTasks,
+          new AddDependencyToLeaves(Collections.singletonList(replLogTask)));
     }
-    return replLogTask;
   }
 
   /**
    * There was a database update done before and we want to make sure we update the last repl
    * id on this database as we are now going to switch to processing a new database.
+   *
+   * This has to be last task in the graph since if there are intermediate tasks and the last.repl.id
+   * is a root level task then in the execution phase the root level tasks will get executed first,
+   * however if any of the child tasks of the bootstrap load failed then even though the bootstrap has failed
+   * the last repl status of the target database will return a valid value, which will not represent
+   * the state of the database.
    */
   private TaskTracker updateDatabaseLastReplID(int maxTasks, Context context, Scope scope)
       throws SemanticException {
@@ -251,7 +259,10 @@ public class ReplLoadTask extends Task<ReplLoadWork> implements Serializable {
     TaskTracker taskTracker =
         new AlterDatabase(context, work.databaseEvent(context.hiveConf), work.dbNameToLoadIn,
             new TaskTracker(maxTasks)).tasks();
-    scope.rootTasks.addAll(taskTracker.tasks());
+
+    AddDependencyToLeaves function = new AddDependencyToLeaves(taskTracker.tasks());
+    DAGTraversal.traverse(scope.rootTasks, function);
+
     return taskTracker;
   }
 
@@ -288,27 +299,8 @@ public class ReplLoadTask extends Task<ReplLoadWork> implements Serializable {
    */
     if (shouldCreateAnotherLoadTask) {
       Task<ReplLoadWork> loadTask = TaskFactory.get(work, conf, true);
-      dependency(rootTasks, loadTask);
-    }
-  }
-
-  /**
-   * add the dependency to the leaf node
-   */
-  public static boolean dependency(List<Task<? extends Serializable>> tasks, Task<?> tailTask) {
-    if (tasks == null || tasks.isEmpty()) {
-      return true;
-    }
-    for (Task<? extends Serializable> task : tasks) {
-      if (task == tailTask) {
-        continue;
-      }
-      boolean leafNode = dependency(task.getChildTasks(), tailTask);
-      if (leafNode) {
-        task.addDependentTask(tailTask);
-      }
+      DAGTraversal.traverse(rootTasks, new AddDependencyToLeaves(loadTask));
     }
-    return false;
   }
 
   @Override

http://git-wip-us.apache.org/repos/asf/hive/blob/06dc4e96/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/load/LoadFunction.java
----------------------------------------------------------------------
diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/load/LoadFunction.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/load/LoadFunction.java
index 8852a60..ef4ed4d 100644
--- a/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/load/LoadFunction.java
+++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/load/LoadFunction.java
@@ -22,12 +22,13 @@ import org.apache.hadoop.hive.ql.ErrorMsg;
 import org.apache.hadoop.hive.ql.exec.Task;
 import org.apache.hadoop.hive.ql.exec.TaskFactory;
 import org.apache.hadoop.hive.ql.exec.repl.ReplStateLogWork;
+import org.apache.hadoop.hive.ql.exec.repl.bootstrap.AddDependencyToLeaves;
 import org.apache.hadoop.hive.ql.exec.repl.bootstrap.events.FunctionEvent;
 import org.apache.hadoop.hive.ql.exec.repl.bootstrap.load.util.Context;
+import org.apache.hadoop.hive.ql.exec.util.DAGTraversal;
 import org.apache.hadoop.hive.ql.parse.EximUtil;
 import org.apache.hadoop.hive.ql.parse.SemanticException;
 import org.apache.hadoop.hive.ql.parse.repl.ReplLogger;
-import org.apache.hadoop.hive.ql.exec.repl.bootstrap.ReplLoadTask;
 import org.apache.hadoop.hive.ql.parse.repl.load.message.CreateFunctionHandler;
 import org.apache.hadoop.hive.ql.parse.repl.load.message.MessageHandler;
 import org.slf4j.Logger;
@@ -61,7 +62,7 @@ public class LoadFunction {
                                          String functionName) {
     ReplStateLogWork replLogWork = new ReplStateLogWork(replLogger, functionName);
     Task<ReplStateLogWork> replLogTask = TaskFactory.get(replLogWork, context.hiveConf);
-    ReplLoadTask.dependency(functionTasks, replLogTask);
+    DAGTraversal.traverse(functionTasks, new AddDependencyToLeaves(replLogTask));
   }
 
   public TaskTracker tasks() throws IOException, SemanticException {

http://git-wip-us.apache.org/repos/asf/hive/blob/06dc4e96/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/load/table/LoadPartitions.java
----------------------------------------------------------------------
diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/load/table/LoadPartitions.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/load/table/LoadPartitions.java
index 0360816..262225f 100644
--- a/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/load/table/LoadPartitions.java
+++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/load/table/LoadPartitions.java
@@ -27,12 +27,13 @@ import org.apache.hadoop.hive.ql.exec.Task;
 import org.apache.hadoop.hive.ql.exec.TaskFactory;
 import org.apache.hadoop.hive.ql.exec.Utilities;
 import org.apache.hadoop.hive.ql.exec.repl.ReplStateLogWork;
-import org.apache.hadoop.hive.ql.exec.repl.bootstrap.ReplLoadTask;
+import org.apache.hadoop.hive.ql.exec.repl.bootstrap.AddDependencyToLeaves;
 import org.apache.hadoop.hive.ql.exec.repl.bootstrap.events.TableEvent;
 import org.apache.hadoop.hive.ql.exec.repl.bootstrap.load.ReplicationState;
 import org.apache.hadoop.hive.ql.exec.repl.bootstrap.load.TaskTracker;
 import org.apache.hadoop.hive.ql.exec.repl.bootstrap.load.util.Context;
 import org.apache.hadoop.hive.ql.exec.repl.bootstrap.load.util.PathUtils;
+import org.apache.hadoop.hive.ql.exec.util.DAGTraversal;
 import org.apache.hadoop.hive.ql.metadata.HiveException;
 import org.apache.hadoop.hive.ql.metadata.Partition;
 import org.apache.hadoop.hive.ql.metadata.Table;
@@ -47,7 +48,6 @@ import org.apache.hadoop.hive.ql.plan.LoadTableDesc;
 import org.apache.hadoop.hive.ql.plan.LoadTableDesc.LoadFileType;
 import org.apache.hadoop.hive.ql.plan.MoveWork;
 import org.apache.hadoop.hive.ql.session.SessionState;
-import org.mortbay.jetty.servlet.AbstractSessionManager;
 import org.datanucleus.util.StringUtils;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -118,7 +118,7 @@ public class LoadPartitions {
     if (tracker.tasks().isEmpty()) {
       tracker.addTask(replLogTask);
     } else {
-      ReplLoadTask.dependency(tracker.tasks(), replLogTask);
+      DAGTraversal.traverse(tracker.tasks(), new AddDependencyToLeaves(replLogTask));
 
       List<Task<? extends Serializable>> visited = new ArrayList<>();
       tracker.updateTaskCount(replLogTask, visited);

http://git-wip-us.apache.org/repos/asf/hive/blob/06dc4e96/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/load/table/LoadTable.java
----------------------------------------------------------------------
diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/load/table/LoadTable.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/load/table/LoadTable.java
index 766a9a9..bb1f4e5 100644
--- a/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/load/table/LoadTable.java
+++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/load/table/LoadTable.java
@@ -28,11 +28,12 @@ import org.apache.hadoop.hive.ql.exec.Task;
 import org.apache.hadoop.hive.ql.exec.TaskFactory;
 import org.apache.hadoop.hive.ql.exec.Utilities;
 import org.apache.hadoop.hive.ql.exec.repl.ReplStateLogWork;
-import org.apache.hadoop.hive.ql.exec.repl.bootstrap.ReplLoadTask;
+import org.apache.hadoop.hive.ql.exec.repl.bootstrap.AddDependencyToLeaves;
 import org.apache.hadoop.hive.ql.exec.repl.bootstrap.events.TableEvent;
 import org.apache.hadoop.hive.ql.exec.repl.bootstrap.load.TaskTracker;
 import org.apache.hadoop.hive.ql.exec.repl.bootstrap.load.util.Context;
 import org.apache.hadoop.hive.ql.exec.repl.bootstrap.load.util.PathUtils;
+import org.apache.hadoop.hive.ql.exec.util.DAGTraversal;
 import org.apache.hadoop.hive.ql.metadata.Table;
 import org.apache.hadoop.hive.ql.parse.EximUtil;
 import org.apache.hadoop.hive.ql.parse.ImportSemanticAnalyzer;
@@ -78,12 +79,12 @@ public class LoadTable {
   private void createTableReplLogTask(String tableName, TableType tableType) throws SemanticException {
     ReplStateLogWork replLogWork = new ReplStateLogWork(replLogger,tableName, tableType);
     Task<ReplStateLogWork> replLogTask = TaskFactory.get(replLogWork, context.hiveConf);
-    ReplLoadTask.dependency(tracker.tasks(), replLogTask);
+    DAGTraversal.traverse(tracker.tasks(), new AddDependencyToLeaves(replLogTask));
 
     if (tracker.tasks().isEmpty()) {
       tracker.addTask(replLogTask);
     } else {
-      ReplLoadTask.dependency(tracker.tasks(), replLogTask);
+      DAGTraversal.traverse(tracker.tasks(), new AddDependencyToLeaves(replLogTask));
 
       List<Task<? extends Serializable>> visited = new ArrayList<>();
       tracker.updateTaskCount(replLogTask, visited);

http://git-wip-us.apache.org/repos/asf/hive/blob/06dc4e96/ql/src/java/org/apache/hadoop/hive/ql/exec/util/DAGTraversal.java
----------------------------------------------------------------------
diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/util/DAGTraversal.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/util/DAGTraversal.java
new file mode 100644
index 0000000..1e436ba
--- /dev/null
+++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/util/DAGTraversal.java
@@ -0,0 +1,55 @@
+/*
+  Licensed to the Apache Software Foundation (ASF) under one
+  or more contributor license agreements.  See the NOTICE file
+  distributed with this work for additional information
+  regarding copyright ownership.  The ASF licenses this file
+  to you under the Apache License, Version 2.0 (the
+  "License"); you may not use this file except in compliance
+  with the License.  You may obtain a copy of the License at
+
+      http://www.apache.org/licenses/LICENSE-2.0
+
+  Unless required by applicable law or agreed to in writing, software
+  distributed under the License is distributed on an "AS IS" BASIS,
+  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+  See the License for the specific language governing permissions and
+  limitations under the License.
+ */
+
+package org.apache.hadoop.hive.ql.exec.util;
+
+import org.apache.hadoop.hive.ql.exec.Task;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * The dag traversal done here is written to be not recursion based as large DAG's will lead to
+ * stack overflow's, hence iteration based.
+ */
+public class DAGTraversal {
+  public static void traverse(List<Task<? extends Serializable>> tasks, Function function) {
+    List<Task<? extends Serializable>> listOfTasks = new ArrayList<>(tasks);
+    while (!listOfTasks.isEmpty()) {
+      List<Task<? extends Serializable>> children = new ArrayList<>();
+      for (Task<? extends Serializable> task : listOfTasks) {
+        // skip processing has to be done first before continuing
+        if (function.skipProcessing(task)) {
+          continue;
+        }
+        if (task.getDependentTasks() != null) {
+          children.addAll(task.getDependentTasks());
+        }
+        function.process(task);
+      }
+      listOfTasks = children;
+    }
+  }
+
+  public interface Function {
+    void process(Task<? extends Serializable> task);
+
+    boolean skipProcessing(Task<? extends Serializable> task);
+  }
+}

http://git-wip-us.apache.org/repos/asf/hive/blob/06dc4e96/ql/src/test/org/apache/hadoop/hive/ql/exec/repl/bootstrap/AddDependencyToLeavesTest.java
----------------------------------------------------------------------
diff --git a/ql/src/test/org/apache/hadoop/hive/ql/exec/repl/bootstrap/AddDependencyToLeavesTest.java b/ql/src/test/org/apache/hadoop/hive/ql/exec/repl/bootstrap/AddDependencyToLeavesTest.java
new file mode 100644
index 0000000..a807483
--- /dev/null
+++ b/ql/src/test/org/apache/hadoop/hive/ql/exec/repl/bootstrap/AddDependencyToLeavesTest.java
@@ -0,0 +1,85 @@
+/*
+  Licensed to the Apache Software Foundation (ASF) under one
+  or more contributor license agreements.  See the NOTICE file
+  distributed with this work for additional information
+  regarding copyright ownership.  The ASF licenses this file
+  to you under the Apache License, Version 2.0 (the
+  "License"); you may not use this file except in compliance
+  with the License.  You may obtain a copy of the License at
+
+      http://www.apache.org/licenses/LICENSE-2.0
+
+  Unless required by applicable law or agreed to in writing, software
+  distributed under the License is distributed on an "AS IS" BASIS,
+  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+  See the License for the specific language governing permissions and
+  limitations under the License.
+ */
+
+package org.apache.hadoop.hive.ql.exec.repl.bootstrap;
+
+import org.apache.hadoop.hive.conf.HiveConf;
+import org.apache.hadoop.hive.ql.exec.Task;
+import org.apache.hadoop.hive.ql.exec.TaskFactory;
+import org.apache.hadoop.hive.ql.exec.util.DAGTraversal;
+import org.apache.hadoop.hive.ql.plan.DependencyCollectionWork;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mock;
+import org.mockito.runners.MockitoJUnitRunner;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+@RunWith(MockitoJUnitRunner.class)
+public class AddDependencyToLeavesTest {
+  @Mock
+  private HiveConf hiveConf;
+
+  @Test
+  public void shouldNotSkipIntermediateDependencyCollectionTasks() {
+    Task<DependencyCollectionWork> collectionWorkTaskOne =
+        TaskFactory.get(new DependencyCollectionWork(), hiveConf);
+    Task<DependencyCollectionWork> collectionWorkTaskTwo =
+        TaskFactory.get(new DependencyCollectionWork(), hiveConf);
+    Task<DependencyCollectionWork> collectionWorkTaskThree =
+        TaskFactory.get(new DependencyCollectionWork(), hiveConf);
+
+    @SuppressWarnings("unchecked") Task<? extends Serializable> rootTask = mock(Task.class);
+    when(rootTask.getDependentTasks())
+        .thenReturn(
+            Arrays.asList(collectionWorkTaskOne, collectionWorkTaskTwo, collectionWorkTaskThree));
+    @SuppressWarnings("unchecked") List<Task<? extends Serializable>> tasksPostCurrentGraph =
+        Arrays.asList(mock(Task.class), mock(Task.class));
+
+    DAGTraversal.traverse(Collections.singletonList(rootTask),
+        new AddDependencyToLeaves(tasksPostCurrentGraph));
+
+    List<Task<? extends Serializable>> dependentTasksForOne =
+        collectionWorkTaskOne.getDependentTasks();
+    List<Task<? extends Serializable>> dependentTasksForTwo =
+        collectionWorkTaskTwo.getDependentTasks();
+    List<Task<? extends Serializable>> dependentTasksForThree =
+        collectionWorkTaskThree.getDependentTasks();
+
+    assertEquals(dependentTasksForOne.size(), 2);
+    assertEquals(dependentTasksForTwo.size(), 2);
+    assertEquals(dependentTasksForThree.size(), 2);
+    assertTrue(tasksPostCurrentGraph.containsAll(dependentTasksForOne));
+    assertTrue(tasksPostCurrentGraph.containsAll(dependentTasksForTwo));
+    assertTrue(tasksPostCurrentGraph.containsAll(dependentTasksForThree));
+
+//    assertTrue(dependentTasksForOne.iterator().next() instanceof DependencyCollectionTask);
+//    assertTrue(dependentTasksForTwo.iterator().next() instanceof DependencyCollectionTask);
+//    assertTrue(dependentTasksForThree.iterator().next() instanceof DependencyCollectionTask);
+  }
+
+
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/hive/blob/06dc4e96/ql/src/test/org/apache/hadoop/hive/ql/exec/util/DAGTraversalTest.java
----------------------------------------------------------------------
diff --git a/ql/src/test/org/apache/hadoop/hive/ql/exec/util/DAGTraversalTest.java b/ql/src/test/org/apache/hadoop/hive/ql/exec/util/DAGTraversalTest.java
new file mode 100644
index 0000000..4bce6bc
--- /dev/null
+++ b/ql/src/test/org/apache/hadoop/hive/ql/exec/util/DAGTraversalTest.java
@@ -0,0 +1,82 @@
+/*
+  Licensed to the Apache Software Foundation (ASF) under one
+  or more contributor license agreements.  See the NOTICE file
+  distributed with this work for additional information
+  regarding copyright ownership.  The ASF licenses this file
+  to you under the Apache License, Version 2.0 (the
+  "License"); you may not use this file except in compliance
+  with the License.  You may obtain a copy of the License at
+
+      http://www.apache.org/licenses/LICENSE-2.0
+
+  Unless required by applicable law or agreed to in writing, software
+  distributed under the License is distributed on an "AS IS" BASIS,
+  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+  See the License for the specific language governing permissions and
+  limitations under the License.
+ */
+
+package org.apache.hadoop.hive.ql.exec.util;
+
+import org.apache.hadoop.hive.ql.exec.Task;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.runners.MockitoJUnitRunner;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.Collections;
+
+import static org.junit.Assert.assertEquals;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+@RunWith(MockitoJUnitRunner.class)
+public class DAGTraversalTest {
+
+  static class CountLeafFunction implements DAGTraversal.Function {
+    int count = 0;
+
+    @Override
+    public void process(Task<? extends Serializable> task) {
+      if (task.getDependentTasks() == null || task.getDependentTasks().isEmpty()) {
+        count++;
+      }
+    }
+
+    @Override
+    public boolean skipProcessing(Task<? extends Serializable> task) {
+      return false;
+    }
+  }
+
+  @Test
+  public void shouldCountNumberOfLeafNodesCorrectly() {
+    Task<? extends Serializable> taskWith5NodeTree = linearTree(5);
+    Task<? extends Serializable> taskWith1NodeTree = linearTree(1);
+    Task<? extends Serializable> taskWith3NodeTree = linearTree(3);
+    @SuppressWarnings("unchecked") Task<? extends Serializable> rootTask = mock(Task.class);
+    when(rootTask.getDependentTasks())
+        .thenReturn(Arrays.asList(taskWith1NodeTree, taskWith3NodeTree, taskWith5NodeTree));
+
+    CountLeafFunction function = new CountLeafFunction();
+    DAGTraversal.traverse(Collections.singletonList(rootTask), function);
+    assertEquals(3, function.count);
+  }
+
+  private Task<? extends Serializable> linearTree(int numOfNodes) {
+    Task<? extends Serializable> current = null, head = null;
+    for (int i = 0; i < numOfNodes; i++) {
+      @SuppressWarnings("unchecked") Task<? extends Serializable> task = mock(Task.class);
+      if (current != null) {
+        when(current.getDependentTasks()).thenReturn(Collections.singletonList(task));
+      }
+      if (head == null) {
+        head = task;
+      }
+      current = task;
+    }
+    return head;
+  }
+
+}
\ No newline at end of file