You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hive.apache.org by an...@apache.org on 2017/11/06 04:52:29 UTC
hive git commit: HIVE-17595 : Correct DAG for updating the
last.repl.id for a database during bootstrap load (Anishek Agarwal reviewed
by Daniel Dai)
Repository: hive
Updated Branches:
refs/heads/master aaacda474 -> 06dc4e968
HIVE-17595 : Correct DAG for updating the last.repl.id for a database during bootstrap load (Anishek Agarwal reviewed by Daniel Dai)
Project: http://git-wip-us.apache.org/repos/asf/hive/repo
Commit: http://git-wip-us.apache.org/repos/asf/hive/commit/06dc4e96
Tree: http://git-wip-us.apache.org/repos/asf/hive/tree/06dc4e96
Diff: http://git-wip-us.apache.org/repos/asf/hive/diff/06dc4e96
Branch: refs/heads/master
Commit: 06dc4e968e691a5e4383be91a45dac0a62279b81
Parents: aaacda4
Author: Anishek Agarwal <an...@gmail.com>
Authored: Mon Nov 6 10:19:42 2017 +0530
Committer: Anishek Agarwal <an...@gmail.com>
Committed: Mon Nov 6 10:19:42 2017 +0530
----------------------------------------------------------------------
.../apache/hadoop/hive/ql/exec/Utilities.java | 59 ++++++++------
.../repl/bootstrap/AddDependencyToLeaves.java | 51 ++++++++++++
.../ql/exec/repl/bootstrap/ReplLoadTask.java | 42 ++++------
.../exec/repl/bootstrap/load/LoadFunction.java | 5 +-
.../bootstrap/load/table/LoadPartitions.java | 6 +-
.../repl/bootstrap/load/table/LoadTable.java | 7 +-
.../hadoop/hive/ql/exec/util/DAGTraversal.java | 55 +++++++++++++
.../bootstrap/AddDependencyToLeavesTest.java | 85 ++++++++++++++++++++
.../hive/ql/exec/util/DAGTraversalTest.java | 82 +++++++++++++++++++
9 files changed, 334 insertions(+), 58 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/hive/blob/06dc4e96/ql/src/java/org/apache/hadoop/hive/ql/exec/Utilities.java
----------------------------------------------------------------------
diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/Utilities.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/Utilities.java
index 1be7eab..b78c930 100644
--- a/ql/src/java/org/apache/hadoop/hive/ql/exec/Utilities.java
+++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/Utilities.java
@@ -121,6 +121,7 @@ import org.apache.hadoop.hive.ql.exec.mr.MapRedTask;
import org.apache.hadoop.hive.ql.exec.spark.SparkTask;
import org.apache.hadoop.hive.ql.exec.tez.DagUtils;
import org.apache.hadoop.hive.ql.exec.tez.TezTask;
+import org.apache.hadoop.hive.ql.exec.util.DAGTraversal;
import org.apache.hadoop.hive.ql.exec.vector.VectorizedInputFormatInterface;
import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatchCtx;
import org.apache.hadoop.hive.ql.io.AcidUtils;
@@ -2569,41 +2570,49 @@ public final class Utilities {
}
public static List<TezTask> getTezTasks(List<Task<? extends Serializable>> tasks) {
- return getTasks(tasks, TezTask.class);
+ return getTasks(tasks, new TaskFilterFunction<>(TezTask.class));
}
public static List<SparkTask> getSparkTasks(List<Task<? extends Serializable>> tasks) {
- return getTasks(tasks, SparkTask.class);
+ return getTasks(tasks, new TaskFilterFunction<>(SparkTask.class));
}
public static List<ExecDriver> getMRTasks(List<Task<? extends Serializable>> tasks) {
- return getTasks(tasks, ExecDriver.class);
+ return getTasks(tasks, new TaskFilterFunction<>(ExecDriver.class));
}
- @SuppressWarnings("unchecked")
- public static <T> List<T> getTasks(List<Task<? extends Serializable>> tasks, Class<T> requiredType) {
- List<T> typeSpecificTasks = new ArrayList<>();
- if (tasks != null) {
- Set<Task<? extends Serializable>> visited = new HashSet<>();
- while (!tasks.isEmpty()) {
- List<Task<? extends Serializable>> childTasks = new ArrayList<>();
- for (Task<? extends Serializable> task : tasks) {
- if (visited.contains(task)) {
- continue;
- }
- if (requiredType.isInstance(task) && !typeSpecificTasks.contains(task)) {
- typeSpecificTasks.add((T) task);
- }
- if (task.getDependentTasks() != null) {
- childTasks.addAll(task.getDependentTasks());
- }
- visited.add(task);
- }
- // start recursion
- tasks = childTasks;
+ static class TaskFilterFunction<T> implements DAGTraversal.Function {
+ private Set<Task<? extends Serializable>> visited = new HashSet<>();
+ private Class<T> requiredType;
+ private List<T> typeSpecificTasks = new ArrayList<>();
+
+ TaskFilterFunction(Class<T> requiredType) {
+ this.requiredType = requiredType;
+ }
+
+ @Override
+ public void process(Task<? extends Serializable> task) {
+ if (requiredType.isInstance(task) && !typeSpecificTasks.contains(task)) {
+ typeSpecificTasks.add((T) task);
}
+ visited.add(task);
+ }
+
+ List<T> getTasks() {
+ return typeSpecificTasks;
}
- return typeSpecificTasks;
+
+ @Override
+ public boolean skipProcessing(Task<? extends Serializable> task) {
+ return visited.contains(task);
+ }
+ }
+
+ @SuppressWarnings("unchecked")
+ private static <T> List<T> getTasks(List<Task<? extends Serializable>> tasks,
+ TaskFilterFunction<T> function) {
+ DAGTraversal.traverse(tasks, function);
+ return function.getTasks();
}
/**
http://git-wip-us.apache.org/repos/asf/hive/blob/06dc4e96/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/AddDependencyToLeaves.java
----------------------------------------------------------------------
diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/AddDependencyToLeaves.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/AddDependencyToLeaves.java
new file mode 100644
index 0000000..cf838e1
--- /dev/null
+++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/AddDependencyToLeaves.java
@@ -0,0 +1,51 @@
+/*
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+
+package org.apache.hadoop.hive.ql.exec.repl.bootstrap;
+
+import org.apache.hadoop.hive.ql.exec.Task;
+import org.apache.hadoop.hive.ql.exec.util.DAGTraversal;
+
+import java.io.Serializable;
+import java.util.Collections;
+import java.util.List;
+
+public class AddDependencyToLeaves implements DAGTraversal.Function {
+ private List<Task<? extends Serializable>> postDependencyCollectionTasks;
+
+ AddDependencyToLeaves(List<Task<? extends Serializable>> postDependencyCollectionTasks) {
+ this.postDependencyCollectionTasks = postDependencyCollectionTasks;
+ }
+
+ public AddDependencyToLeaves(Task<? extends Serializable> postDependencyTask) {
+ this(Collections.singletonList(postDependencyTask));
+ }
+
+
+ @Override
+ public void process(Task<? extends Serializable> task) {
+ if (task.getChildTasks() == null) {
+ postDependencyCollectionTasks.forEach(task::addDependentTask);
+ }
+ }
+
+ @Override
+ public boolean skipProcessing(Task<? extends Serializable> task) {
+ return postDependencyCollectionTasks.contains(task);
+ }
+}
http://git-wip-us.apache.org/repos/asf/hive/blob/06dc4e96/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/ReplLoadTask.java
----------------------------------------------------------------------
diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/ReplLoadTask.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/ReplLoadTask.java
index bf5c819..bfbec45 100644
--- a/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/ReplLoadTask.java
+++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/ReplLoadTask.java
@@ -39,12 +39,14 @@ import org.apache.hadoop.hive.ql.exec.repl.bootstrap.load.table.LoadPartitions;
import org.apache.hadoop.hive.ql.exec.repl.bootstrap.load.table.LoadTable;
import org.apache.hadoop.hive.ql.exec.repl.bootstrap.load.table.TableContext;
import org.apache.hadoop.hive.ql.exec.repl.bootstrap.load.util.Context;
+import org.apache.hadoop.hive.ql.exec.util.DAGTraversal;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.parse.repl.ReplLogger;
import org.apache.hadoop.hive.ql.plan.api.StageType;
import java.io.Serializable;
import java.util.ArrayList;
+import java.util.Collections;
import java.util.List;
import static org.apache.hadoop.hive.ql.exec.repl.bootstrap.load.LoadDatabase.AlterDatabase;
@@ -225,22 +227,28 @@ public class ReplLoadTask extends Task<ReplLoadWork> implements Serializable {
return 0;
}
- private Task<? extends Serializable> createEndReplLogTask(Context context, Scope scope,
+ private void createEndReplLogTask(Context context, Scope scope,
ReplLogger replLogger) throws SemanticException {
Database dbInMetadata = work.databaseEvent(context.hiveConf).dbInMetadata(work.dbNameToLoadIn);
ReplStateLogWork replLogWork = new ReplStateLogWork(replLogger, dbInMetadata.getParameters());
Task<ReplStateLogWork> replLogTask = TaskFactory.get(replLogWork, conf);
- if (null == scope.rootTasks) {
+ if (scope.rootTasks.isEmpty()) {
scope.rootTasks.add(replLogTask);
} else {
- dependency(scope.rootTasks, replLogTask);
+ DAGTraversal.traverse(scope.rootTasks,
+ new AddDependencyToLeaves(Collections.singletonList(replLogTask)));
}
- return replLogTask;
}
/**
* There was a database update done before and we want to make sure we update the last repl
* id on this database as we are now going to switch to processing a new database.
+ *
+ * This has to be last task in the graph since if there are intermediate tasks and the last.repl.id
+ * is a root level task then in the execution phase the root level tasks will get executed first,
+ * however if any of the child tasks of the bootstrap load failed then even though the bootstrap has failed
+ * the last repl status of the target database will return a valid value, which will not represent
+ * the state of the database.
*/
private TaskTracker updateDatabaseLastReplID(int maxTasks, Context context, Scope scope)
throws SemanticException {
@@ -251,7 +259,10 @@ public class ReplLoadTask extends Task<ReplLoadWork> implements Serializable {
TaskTracker taskTracker =
new AlterDatabase(context, work.databaseEvent(context.hiveConf), work.dbNameToLoadIn,
new TaskTracker(maxTasks)).tasks();
- scope.rootTasks.addAll(taskTracker.tasks());
+
+ AddDependencyToLeaves function = new AddDependencyToLeaves(taskTracker.tasks());
+ DAGTraversal.traverse(scope.rootTasks, function);
+
return taskTracker;
}
@@ -288,27 +299,8 @@ public class ReplLoadTask extends Task<ReplLoadWork> implements Serializable {
*/
if (shouldCreateAnotherLoadTask) {
Task<ReplLoadWork> loadTask = TaskFactory.get(work, conf, true);
- dependency(rootTasks, loadTask);
- }
- }
-
- /**
- * add the dependency to the leaf node
- */
- public static boolean dependency(List<Task<? extends Serializable>> tasks, Task<?> tailTask) {
- if (tasks == null || tasks.isEmpty()) {
- return true;
- }
- for (Task<? extends Serializable> task : tasks) {
- if (task == tailTask) {
- continue;
- }
- boolean leafNode = dependency(task.getChildTasks(), tailTask);
- if (leafNode) {
- task.addDependentTask(tailTask);
- }
+ DAGTraversal.traverse(rootTasks, new AddDependencyToLeaves(loadTask));
}
- return false;
}
@Override
http://git-wip-us.apache.org/repos/asf/hive/blob/06dc4e96/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/load/LoadFunction.java
----------------------------------------------------------------------
diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/load/LoadFunction.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/load/LoadFunction.java
index 8852a60..ef4ed4d 100644
--- a/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/load/LoadFunction.java
+++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/load/LoadFunction.java
@@ -22,12 +22,13 @@ import org.apache.hadoop.hive.ql.ErrorMsg;
import org.apache.hadoop.hive.ql.exec.Task;
import org.apache.hadoop.hive.ql.exec.TaskFactory;
import org.apache.hadoop.hive.ql.exec.repl.ReplStateLogWork;
+import org.apache.hadoop.hive.ql.exec.repl.bootstrap.AddDependencyToLeaves;
import org.apache.hadoop.hive.ql.exec.repl.bootstrap.events.FunctionEvent;
import org.apache.hadoop.hive.ql.exec.repl.bootstrap.load.util.Context;
+import org.apache.hadoop.hive.ql.exec.util.DAGTraversal;
import org.apache.hadoop.hive.ql.parse.EximUtil;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.parse.repl.ReplLogger;
-import org.apache.hadoop.hive.ql.exec.repl.bootstrap.ReplLoadTask;
import org.apache.hadoop.hive.ql.parse.repl.load.message.CreateFunctionHandler;
import org.apache.hadoop.hive.ql.parse.repl.load.message.MessageHandler;
import org.slf4j.Logger;
@@ -61,7 +62,7 @@ public class LoadFunction {
String functionName) {
ReplStateLogWork replLogWork = new ReplStateLogWork(replLogger, functionName);
Task<ReplStateLogWork> replLogTask = TaskFactory.get(replLogWork, context.hiveConf);
- ReplLoadTask.dependency(functionTasks, replLogTask);
+ DAGTraversal.traverse(functionTasks, new AddDependencyToLeaves(replLogTask));
}
public TaskTracker tasks() throws IOException, SemanticException {
http://git-wip-us.apache.org/repos/asf/hive/blob/06dc4e96/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/load/table/LoadPartitions.java
----------------------------------------------------------------------
diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/load/table/LoadPartitions.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/load/table/LoadPartitions.java
index 0360816..262225f 100644
--- a/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/load/table/LoadPartitions.java
+++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/load/table/LoadPartitions.java
@@ -27,12 +27,13 @@ import org.apache.hadoop.hive.ql.exec.Task;
import org.apache.hadoop.hive.ql.exec.TaskFactory;
import org.apache.hadoop.hive.ql.exec.Utilities;
import org.apache.hadoop.hive.ql.exec.repl.ReplStateLogWork;
-import org.apache.hadoop.hive.ql.exec.repl.bootstrap.ReplLoadTask;
+import org.apache.hadoop.hive.ql.exec.repl.bootstrap.AddDependencyToLeaves;
import org.apache.hadoop.hive.ql.exec.repl.bootstrap.events.TableEvent;
import org.apache.hadoop.hive.ql.exec.repl.bootstrap.load.ReplicationState;
import org.apache.hadoop.hive.ql.exec.repl.bootstrap.load.TaskTracker;
import org.apache.hadoop.hive.ql.exec.repl.bootstrap.load.util.Context;
import org.apache.hadoop.hive.ql.exec.repl.bootstrap.load.util.PathUtils;
+import org.apache.hadoop.hive.ql.exec.util.DAGTraversal;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.metadata.Partition;
import org.apache.hadoop.hive.ql.metadata.Table;
@@ -47,7 +48,6 @@ import org.apache.hadoop.hive.ql.plan.LoadTableDesc;
import org.apache.hadoop.hive.ql.plan.LoadTableDesc.LoadFileType;
import org.apache.hadoop.hive.ql.plan.MoveWork;
import org.apache.hadoop.hive.ql.session.SessionState;
-import org.mortbay.jetty.servlet.AbstractSessionManager;
import org.datanucleus.util.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -118,7 +118,7 @@ public class LoadPartitions {
if (tracker.tasks().isEmpty()) {
tracker.addTask(replLogTask);
} else {
- ReplLoadTask.dependency(tracker.tasks(), replLogTask);
+ DAGTraversal.traverse(tracker.tasks(), new AddDependencyToLeaves(replLogTask));
List<Task<? extends Serializable>> visited = new ArrayList<>();
tracker.updateTaskCount(replLogTask, visited);
http://git-wip-us.apache.org/repos/asf/hive/blob/06dc4e96/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/load/table/LoadTable.java
----------------------------------------------------------------------
diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/load/table/LoadTable.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/load/table/LoadTable.java
index 766a9a9..bb1f4e5 100644
--- a/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/load/table/LoadTable.java
+++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/repl/bootstrap/load/table/LoadTable.java
@@ -28,11 +28,12 @@ import org.apache.hadoop.hive.ql.exec.Task;
import org.apache.hadoop.hive.ql.exec.TaskFactory;
import org.apache.hadoop.hive.ql.exec.Utilities;
import org.apache.hadoop.hive.ql.exec.repl.ReplStateLogWork;
-import org.apache.hadoop.hive.ql.exec.repl.bootstrap.ReplLoadTask;
+import org.apache.hadoop.hive.ql.exec.repl.bootstrap.AddDependencyToLeaves;
import org.apache.hadoop.hive.ql.exec.repl.bootstrap.events.TableEvent;
import org.apache.hadoop.hive.ql.exec.repl.bootstrap.load.TaskTracker;
import org.apache.hadoop.hive.ql.exec.repl.bootstrap.load.util.Context;
import org.apache.hadoop.hive.ql.exec.repl.bootstrap.load.util.PathUtils;
+import org.apache.hadoop.hive.ql.exec.util.DAGTraversal;
import org.apache.hadoop.hive.ql.metadata.Table;
import org.apache.hadoop.hive.ql.parse.EximUtil;
import org.apache.hadoop.hive.ql.parse.ImportSemanticAnalyzer;
@@ -78,12 +79,12 @@ public class LoadTable {
private void createTableReplLogTask(String tableName, TableType tableType) throws SemanticException {
ReplStateLogWork replLogWork = new ReplStateLogWork(replLogger,tableName, tableType);
Task<ReplStateLogWork> replLogTask = TaskFactory.get(replLogWork, context.hiveConf);
- ReplLoadTask.dependency(tracker.tasks(), replLogTask);
+ DAGTraversal.traverse(tracker.tasks(), new AddDependencyToLeaves(replLogTask));
if (tracker.tasks().isEmpty()) {
tracker.addTask(replLogTask);
} else {
- ReplLoadTask.dependency(tracker.tasks(), replLogTask);
+ DAGTraversal.traverse(tracker.tasks(), new AddDependencyToLeaves(replLogTask));
List<Task<? extends Serializable>> visited = new ArrayList<>();
tracker.updateTaskCount(replLogTask, visited);
http://git-wip-us.apache.org/repos/asf/hive/blob/06dc4e96/ql/src/java/org/apache/hadoop/hive/ql/exec/util/DAGTraversal.java
----------------------------------------------------------------------
diff --git a/ql/src/java/org/apache/hadoop/hive/ql/exec/util/DAGTraversal.java b/ql/src/java/org/apache/hadoop/hive/ql/exec/util/DAGTraversal.java
new file mode 100644
index 0000000..1e436ba
--- /dev/null
+++ b/ql/src/java/org/apache/hadoop/hive/ql/exec/util/DAGTraversal.java
@@ -0,0 +1,55 @@
+/*
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+
+package org.apache.hadoop.hive.ql.exec.util;
+
+import org.apache.hadoop.hive.ql.exec.Task;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * The dag traversal done here is written to be not recursion based as large DAG's will lead to
+ * stack overflow's, hence iteration based.
+ */
+public class DAGTraversal {
+ public static void traverse(List<Task<? extends Serializable>> tasks, Function function) {
+ List<Task<? extends Serializable>> listOfTasks = new ArrayList<>(tasks);
+ while (!listOfTasks.isEmpty()) {
+ List<Task<? extends Serializable>> children = new ArrayList<>();
+ for (Task<? extends Serializable> task : listOfTasks) {
+ // skip processing has to be done first before continuing
+ if (function.skipProcessing(task)) {
+ continue;
+ }
+ if (task.getDependentTasks() != null) {
+ children.addAll(task.getDependentTasks());
+ }
+ function.process(task);
+ }
+ listOfTasks = children;
+ }
+ }
+
+ public interface Function {
+ void process(Task<? extends Serializable> task);
+
+ boolean skipProcessing(Task<? extends Serializable> task);
+ }
+}
http://git-wip-us.apache.org/repos/asf/hive/blob/06dc4e96/ql/src/test/org/apache/hadoop/hive/ql/exec/repl/bootstrap/AddDependencyToLeavesTest.java
----------------------------------------------------------------------
diff --git a/ql/src/test/org/apache/hadoop/hive/ql/exec/repl/bootstrap/AddDependencyToLeavesTest.java b/ql/src/test/org/apache/hadoop/hive/ql/exec/repl/bootstrap/AddDependencyToLeavesTest.java
new file mode 100644
index 0000000..a807483
--- /dev/null
+++ b/ql/src/test/org/apache/hadoop/hive/ql/exec/repl/bootstrap/AddDependencyToLeavesTest.java
@@ -0,0 +1,85 @@
+/*
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+
+package org.apache.hadoop.hive.ql.exec.repl.bootstrap;
+
+import org.apache.hadoop.hive.conf.HiveConf;
+import org.apache.hadoop.hive.ql.exec.Task;
+import org.apache.hadoop.hive.ql.exec.TaskFactory;
+import org.apache.hadoop.hive.ql.exec.util.DAGTraversal;
+import org.apache.hadoop.hive.ql.plan.DependencyCollectionWork;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mock;
+import org.mockito.runners.MockitoJUnitRunner;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+@RunWith(MockitoJUnitRunner.class)
+public class AddDependencyToLeavesTest {
+ @Mock
+ private HiveConf hiveConf;
+
+ @Test
+ public void shouldNotSkipIntermediateDependencyCollectionTasks() {
+ Task<DependencyCollectionWork> collectionWorkTaskOne =
+ TaskFactory.get(new DependencyCollectionWork(), hiveConf);
+ Task<DependencyCollectionWork> collectionWorkTaskTwo =
+ TaskFactory.get(new DependencyCollectionWork(), hiveConf);
+ Task<DependencyCollectionWork> collectionWorkTaskThree =
+ TaskFactory.get(new DependencyCollectionWork(), hiveConf);
+
+ @SuppressWarnings("unchecked") Task<? extends Serializable> rootTask = mock(Task.class);
+ when(rootTask.getDependentTasks())
+ .thenReturn(
+ Arrays.asList(collectionWorkTaskOne, collectionWorkTaskTwo, collectionWorkTaskThree));
+ @SuppressWarnings("unchecked") List<Task<? extends Serializable>> tasksPostCurrentGraph =
+ Arrays.asList(mock(Task.class), mock(Task.class));
+
+ DAGTraversal.traverse(Collections.singletonList(rootTask),
+ new AddDependencyToLeaves(tasksPostCurrentGraph));
+
+ List<Task<? extends Serializable>> dependentTasksForOne =
+ collectionWorkTaskOne.getDependentTasks();
+ List<Task<? extends Serializable>> dependentTasksForTwo =
+ collectionWorkTaskTwo.getDependentTasks();
+ List<Task<? extends Serializable>> dependentTasksForThree =
+ collectionWorkTaskThree.getDependentTasks();
+
+ assertEquals(dependentTasksForOne.size(), 2);
+ assertEquals(dependentTasksForTwo.size(), 2);
+ assertEquals(dependentTasksForThree.size(), 2);
+ assertTrue(tasksPostCurrentGraph.containsAll(dependentTasksForOne));
+ assertTrue(tasksPostCurrentGraph.containsAll(dependentTasksForTwo));
+ assertTrue(tasksPostCurrentGraph.containsAll(dependentTasksForThree));
+
+// assertTrue(dependentTasksForOne.iterator().next() instanceof DependencyCollectionTask);
+// assertTrue(dependentTasksForTwo.iterator().next() instanceof DependencyCollectionTask);
+// assertTrue(dependentTasksForThree.iterator().next() instanceof DependencyCollectionTask);
+ }
+
+
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/hive/blob/06dc4e96/ql/src/test/org/apache/hadoop/hive/ql/exec/util/DAGTraversalTest.java
----------------------------------------------------------------------
diff --git a/ql/src/test/org/apache/hadoop/hive/ql/exec/util/DAGTraversalTest.java b/ql/src/test/org/apache/hadoop/hive/ql/exec/util/DAGTraversalTest.java
new file mode 100644
index 0000000..4bce6bc
--- /dev/null
+++ b/ql/src/test/org/apache/hadoop/hive/ql/exec/util/DAGTraversalTest.java
@@ -0,0 +1,82 @@
+/*
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ */
+
+package org.apache.hadoop.hive.ql.exec.util;
+
+import org.apache.hadoop.hive.ql.exec.Task;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.runners.MockitoJUnitRunner;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.Collections;
+
+import static org.junit.Assert.assertEquals;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+@RunWith(MockitoJUnitRunner.class)
+public class DAGTraversalTest {
+
+ static class CountLeafFunction implements DAGTraversal.Function {
+ int count = 0;
+
+ @Override
+ public void process(Task<? extends Serializable> task) {
+ if (task.getDependentTasks() == null || task.getDependentTasks().isEmpty()) {
+ count++;
+ }
+ }
+
+ @Override
+ public boolean skipProcessing(Task<? extends Serializable> task) {
+ return false;
+ }
+ }
+
+ @Test
+ public void shouldCountNumberOfLeafNodesCorrectly() {
+ Task<? extends Serializable> taskWith5NodeTree = linearTree(5);
+ Task<? extends Serializable> taskWith1NodeTree = linearTree(1);
+ Task<? extends Serializable> taskWith3NodeTree = linearTree(3);
+ @SuppressWarnings("unchecked") Task<? extends Serializable> rootTask = mock(Task.class);
+ when(rootTask.getDependentTasks())
+ .thenReturn(Arrays.asList(taskWith1NodeTree, taskWith3NodeTree, taskWith5NodeTree));
+
+ CountLeafFunction function = new CountLeafFunction();
+ DAGTraversal.traverse(Collections.singletonList(rootTask), function);
+ assertEquals(3, function.count);
+ }
+
+ private Task<? extends Serializable> linearTree(int numOfNodes) {
+ Task<? extends Serializable> current = null, head = null;
+ for (int i = 0; i < numOfNodes; i++) {
+ @SuppressWarnings("unchecked") Task<? extends Serializable> task = mock(Task.class);
+ if (current != null) {
+ when(current.getDependentTasks()).thenReturn(Collections.singletonList(task));
+ }
+ if (head == null) {
+ head = task;
+ }
+ current = task;
+ }
+ return head;
+ }
+
+}
\ No newline at end of file