You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tez.apache.org by je...@apache.org on 2020/02/04 16:44:45 UTC
[tez] branch master updated: TEZ-4106. Add Exponential Smooth
RuntimeEstimator to the speculator
This is an automated email from the ASF dual-hosted git repository.
jeagles pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/tez.git
The following commit(s) were added to refs/heads/master by this push:
new 2736788 TEZ-4106. Add Exponential Smooth RuntimeEstimator to the speculator
2736788 is described below
commit 2736788f49f17bd269b7ea64d959bbaad95421d3
Author: Ahmed Hussein <ah...@apache.org>
AuthorDate: Tue Feb 4 10:44:29 2020 -0600
TEZ-4106. Add Exponential Smooth RuntimeEstimator to the speculator
Signed-off-by: Jonathan Eagles <je...@apache.org>
---
.../org/apache/tez/dag/api/TezConfiguration.java | 59 +++-
.../app/dag/speculation/legacy/DataStatistics.java | 25 +-
.../dag/speculation/legacy/LegacySpeculator.java | 39 ++-
.../SimpleExponentialTaskRuntimeEstimator.java | 194 ++++++++++++
.../dag/speculation/legacy/StartEndTimesBase.java | 42 +--
.../speculation/legacy/TaskRuntimeEstimator.java | 32 +-
.../forecast/SimpleExponentialSmoothing.java | 336 +++++++++++++++++++++
.../speculation/legacy/forecast/package-info.java | 20 ++
.../org/apache/tez/dag/app/TestSpeculation.java | 262 ++++++++++++++--
9 files changed, 926 insertions(+), 83 deletions(-)
diff --git a/tez-api/src/main/java/org/apache/tez/dag/api/TezConfiguration.java b/tez-api/src/main/java/org/apache/tez/dag/api/TezConfiguration.java
index f087e3a..58aecda 100644
--- a/tez-api/src/main/java/org/apache/tez/dag/api/TezConfiguration.java
+++ b/tez-api/src/main/java/org/apache/tez/dag/api/TezConfiguration.java
@@ -26,6 +26,7 @@ import java.util.List;
import java.util.Map;
import java.util.Set;
+import java.util.concurrent.TimeUnit;
import org.apache.hadoop.classification.InterfaceAudience.Private;
import org.apache.hadoop.classification.InterfaceAudience.Public;
import org.apache.hadoop.classification.InterfaceStability.Unstable;
@@ -531,14 +532,6 @@ public class TezConfiguration extends Configuration {
public static final boolean TEZ_AM_SPECULATION_ENABLED_DEFAULT = false;
/**
- * Class used to estimate task resource needs.
- */
- @ConfigurationScope(Scope.VERTEX)
- @ConfigurationProperty
- public static final String TEZ_AM_SPECULATION_ESTIMATOR_CLASS =
- TEZ_AM_PREFIX + "speculation.estimator.class";
-
- /**
* Float value. Specifies how many standard deviations away from the mean task execution time
* should be considered as an outlier/slow task.
*/
@@ -559,6 +552,10 @@ public class TezConfiguration extends Configuration {
TEZ_AM_PREFIX + "legacy.speculative.single.task.vertex.timeout";
public static final long TEZ_AM_LEGACY_SPECULATIVE_SINGLE_TASK_VERTEX_TIMEOUT_DEFAULT = -1;
+ @Private
+ public static final String TEZ_SPECULATOR_PREFIX = TEZ_AM_PREFIX + "speculator.";
+ @Private
+ public static final String TEZ_ESTIMATOR_PREFIX = TEZ_AM_PREFIX + "task.estimator.";
/**
* Long value. Specifies amount of time (in ms) that needs to elapse to do the next round of
* speculation if there is no task speculated in this round.
@@ -581,6 +578,52 @@ public class TezConfiguration extends Configuration {
TEZ_AM_PREFIX + "soonest.retry.after.speculate";
public static final long TEZ_AM_SOONEST_RETRY_AFTER_SPECULATE_DEFAULT = 1000L * 15L;
+ /** The class that should be used for speculative execution calculations. */
+ @ConfigurationScope(Scope.VERTEX)
+ @ConfigurationProperty
+ public static final String TEZ_AM_SPECULATOR_CLASS =
+ TEZ_SPECULATOR_PREFIX + "class";
+ /** The class that should be used for task runtime estimation. */
+ @ConfigurationScope(Scope.VERTEX)
+ @ConfigurationProperty
+ public static final String TEZ_AM_TASK_ESTIMATOR_CLASS =
+ TEZ_ESTIMATOR_PREFIX + "class";
+ /**
+ * Long value. Specifies amount of time (in ms) of the lambda value in the
+ * smoothing function of the task estimator
+ */
+ @Unstable
+ @ConfigurationScope(Scope.VERTEX)
+ @ConfigurationProperty(type="long")
+ public static final String TEZ_AM_ESTIMATOR_EXPONENTIAL_LAMBDA_MS =
+ TEZ_ESTIMATOR_PREFIX + "exponential.lambda.ms";
+ public static final long TEZ_AM_ESTIMATOR_EXPONENTIAL_LAMBDA_MS_DEFAULT =
+ TimeUnit.SECONDS.toMillis(120);
+
+ /**
+ * The window length in the simple exponential smoothing that considers the
+ * task attempt is stagnated.
+ */
+ @Unstable
+ @ConfigurationScope(Scope.VERTEX)
+ @ConfigurationProperty(type="long")
+ public static final String TEZ_AM_ESTIMATOR_EXPONENTIAL_STAGNATED_MS =
+ TEZ_ESTIMATOR_PREFIX + "exponential.stagnated.ms";
+ public static final long TEZ_AM_ESTIMATOR_EXPONENTIAL_STAGNATED_MS_DEFAULT =
+ TimeUnit.SECONDS.toMillis(360);
+
+ /**
+ * The number of initial readings that the estimator ignores before giving a
+ * prediction. At the beginning the smooth estimator won't be accurate in
+ * prediction
+ */
+ @Unstable
+ @ConfigurationScope(Scope.VERTEX)
+ @ConfigurationProperty(type="integer")
+ public static final String TEZ_AM_ESTIMATOR_EXPONENTIAL_SKIP_INITIALS =
+ TEZ_ESTIMATOR_PREFIX + "exponential.skip.initials";
+ public static final int TEZ_AM_ESTIMATOR_EXPONENTIAL_SKIP_INITIALS_DEFAULT = 24;
+
/**
* Double value. The max percent (0-1) of running tasks that can be speculatively re-executed at any time.
*/
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/speculation/legacy/DataStatistics.java b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/speculation/legacy/DataStatistics.java
index 7e6f1c2..bbfb950 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/speculation/legacy/DataStatistics.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/speculation/legacy/DataStatistics.java
@@ -21,6 +21,11 @@ package org.apache.tez.dag.app.dag.speculation.legacy;
import com.google.common.annotations.VisibleForTesting;
public class DataStatistics {
+ /**
+ * factor used to calculate confidence interval within 95%.
+ */
+ private static final double DEFAULT_CI_FACTOR = 1.96;
+
private int count = 0;
private double sum = 0;
private double sumSquares = 0;
@@ -79,8 +84,24 @@ public class DataStatistics {
return count;
}
+ /**
+ * calculates the mean value within 95% ConfidenceInterval. 1.96 is standard
+ * for 95%.
+ *
+ * @return the mean value adding 95% confidence interval.
+ */
+ public synchronized double meanCI() {
+ if (count <= 1) {
+ return 0.0;
+ }
+ double currMean = mean();
+ double currStd = std();
+ return currMean + (DEFAULT_CI_FACTOR * currStd / Math.sqrt(count));
+ }
+
public String toString() {
- return "DataStatistics: count is " + count + ", sum is " + sum +
- ", sumSquares is " + sumSquares + " mean is " + mean() + " std() is " + std();
+ return "DataStatistics: count is " + count + ", sum is " + sum
+ + ", sumSquares is " + sumSquares + " mean is " + mean()
+ + " std() is " + std() + ", meanCI() is " + meanCI();
}
}
\ No newline at end of file
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/speculation/legacy/LegacySpeculator.java b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/speculation/legacy/LegacySpeculator.java
index 23b057a..f21b819 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/speculation/legacy/LegacySpeculator.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/speculation/legacy/LegacySpeculator.java
@@ -33,7 +33,6 @@ import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import org.apache.hadoop.service.AbstractService;
-import org.apache.hadoop.service.ServiceOperations;
import org.apache.tez.common.ProgressHelper;
import org.apache.tez.dag.api.TezConfiguration;
import org.slf4j.Logger;
@@ -100,10 +99,9 @@ public class LegacySpeculator extends AbstractService {
private TaskRuntimeEstimator estimator;
private final long taskTimeout;
private final Clock clock;
- private long nextSpeculateTime = Long.MIN_VALUE;
private Thread speculationBackgroundThread = null;
private volatile boolean stopped = false;
- /* Allow the speculator to wait on a blockingQueue in case we use it for event notification */
+ /** Allow the speculator to wait on a blockingQueue in case we use it for event notification. */
private BlockingQueue<Object> scanControl = new LinkedBlockingQueue<Object>();
@VisibleForTesting
@@ -132,9 +130,8 @@ public class LegacySpeculator extends AbstractService {
static private TaskRuntimeEstimator getEstimator
(Configuration conf, Vertex vertex) {
TaskRuntimeEstimator estimator;
- // "tez.am.speculation.estimator.class"
Class<? extends TaskRuntimeEstimator> estimatorClass =
- conf.getClass(TezConfiguration.TEZ_AM_SPECULATION_ESTIMATOR_CLASS,
+ conf.getClass(TezConfiguration.TEZ_AM_TASK_ESTIMATOR_CLASS,
LegacyTaskRuntimeEstimator.class,
TaskRuntimeEstimator.class);
try {
@@ -236,6 +233,16 @@ public class LegacySpeculator extends AbstractService {
}
}
+ // This interface is intended to be used only for test cases.
+ public void scanForSpeculationsForTesting() {
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("We got asked to run a debug speculation scan.");
+ LOG.debug("There are {} speculative events stacked already.", scanControl.size());
+ }
+ scanControl.add(new Object());
+ Thread.yield();
+ }
+
public Runnable createThread() {
return new Runnable() {
@Override
@@ -267,8 +274,9 @@ public class LegacySpeculator extends AbstractService {
public void notifyAttemptStarted(TezTaskAttemptID taId, long timestamp) {
estimator.enrollAttempt(taId, timestamp);
}
-
- public void notifyAttemptStatusUpdate(TezTaskAttemptID taId, TaskAttemptState reportedState,
+
+ public void notifyAttemptStatusUpdate(TezTaskAttemptID taId,
+ TaskAttemptState reportedState,
long timestamp) {
statusUpdate(taId, reportedState, timestamp);
}
@@ -293,12 +301,10 @@ public class LegacySpeculator extends AbstractService {
estimator.updateAttempt(attemptID, reportedState, timestamp);
- //if (stateString.equals(TaskAttemptState.RUNNING.name())) {
if (reportedState == TaskAttemptState.RUNNING) {
runningTasks.putIfAbsent(taskID, Boolean.TRUE);
} else {
runningTasks.remove(taskID, Boolean.TRUE);
- //if (!stateString.equals(TaskAttemptState.STARTING.name())) {
if (reportedState == TaskAttemptState.STARTING) {
runningTaskAttemptStatistics.remove(attemptID);
}
@@ -356,7 +362,7 @@ public class LegacySpeculator extends AbstractService {
}
}
- TezTaskAttemptID runningTaskAttemptID = null;
+ TezTaskAttemptID runningTaskAttemptID;
int numberRunningAttempts = 0;
for (TaskAttempt taskAttempt : attempts.values()) {
@@ -387,7 +393,8 @@ public class LegacySpeculator extends AbstractService {
return ON_SCHEDULE;
}
} else {
- long estimatedRunTime = estimator.estimatedRuntime(runningTaskAttemptID);
+ long estimatedRunTime = estimator
+ .estimatedRuntime(runningTaskAttemptID);
long estimatedEndTime = estimatedRunTime + taskAttemptStartTime;
@@ -399,12 +406,15 @@ public class LegacySpeculator extends AbstractService {
runningTaskAttemptStatistics.get(runningTaskAttemptID);
if (data == null) {
runningTaskAttemptStatistics.put(runningTaskAttemptID,
- new TaskAttemptHistoryStatistics(estimatedRunTime, progress, now));
+ new TaskAttemptHistoryStatistics(estimatedRunTime, progress,
+ now));
} else {
if (estimatedRunTime == data.getEstimatedRunTime()
&& progress == data.getProgress()) {
// Previous stats are same as same stats
- if (data.notHeartbeatedInAWhile(now)) {
+ if (data.notHeartbeatedInAWhile(now)
+ || estimator
+ .hasStagnatedProgress(runningTaskAttemptID, now)) {
// Stats have stagnated for a while, simulate heart-beat.
// Now simulate the heart-beat
statusUpdate(taskAttempt.getID(), taskAttempt.getState(),
@@ -448,7 +458,8 @@ public class LegacySpeculator extends AbstractService {
// Add attempt to a given Task.
protected void addSpeculativeAttempt(TezTaskID taskID) {
- LOG.info("DefaultSpeculator.addSpeculativeAttempt -- we are speculating " + taskID);
+ LOG.info("DefaultSpeculator.addSpeculativeAttempt -- we are speculating "
+ + taskID);
vertex.scheduleSpeculativeTask(taskID);
mayHaveSpeculated.add(taskID);
}
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/speculation/legacy/SimpleExponentialTaskRuntimeEstimator.java b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/speculation/legacy/SimpleExponentialTaskRuntimeEstimator.java
new file mode 100644
index 0000000..b61f153
--- /dev/null
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/speculation/legacy/SimpleExponentialTaskRuntimeEstimator.java
@@ -0,0 +1,194 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.tez.dag.app.dag.speculation.legacy;
+
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentMap;
+import java.util.concurrent.atomic.AtomicReference;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.tez.dag.api.TezConfiguration;
+import org.apache.tez.dag.api.oldrecords.TaskAttemptState;
+import org.apache.tez.dag.app.dag.Task;
+import org.apache.tez.dag.app.dag.TaskAttempt;
+import org.apache.tez.dag.app.dag.Vertex;
+import org.apache.tez.dag.app.dag.speculation.legacy.forecast.SimpleExponentialSmoothing;
+import org.apache.tez.dag.records.TezTaskAttemptID;
+
+/**
+ * A task Runtime Estimator based on exponential smoothing.
+ */
+public class SimpleExponentialTaskRuntimeEstimator extends StartEndTimesBase {
+ /**
+ * The default value returned by the estimator when no records exist.
+ */
+ private static final long DEFAULT_ESTIMATE_RUNTIME = -1L;
+
+ /**
+ * Given a forecast of value 0.0, it is getting replaced by the default value
+ * to avoid division by 0.
+ */
+ private static final double DEFAULT_PROGRESS_VALUE = 1E-10;
+
+ /**
+ * Factor used to calculate the confidence interval.
+ */
+ private static final double CONFIDENCE_INTERVAL_FACTOR = 0.25;
+ /**
+ * Constant time used to calculate the smoothing exponential factor.
+ */
+ private long constTime;
+
+ /**
+ * Number of readings before we consider the estimate stable.
+ * Otherwise, the estimate will be skewed due to the initial estimate
+ */
+ private int skipCount;
+
+ /**
+ * Time window to automatically update the count of the skipCount. This is
+ * needed when a task stalls without any progress, causing the estimator to
+ * return -1 as an estimatedRuntime.
+ */
+ private long stagnatedWindow;
+
+ /**
+ * A map of TA Id to the statistic model of smooth exponential.
+ */
+ private final ConcurrentMap<TezTaskAttemptID,
+ AtomicReference<SimpleExponentialSmoothing>>
+ estimates = new ConcurrentHashMap<>();
+
+ private SimpleExponentialSmoothing getForecastEntry(
+ final TezTaskAttemptID attemptID) {
+ AtomicReference<SimpleExponentialSmoothing> entryRef = estimates
+ .get(attemptID);
+ if (entryRef == null) {
+ return null;
+ }
+ return entryRef.get();
+ }
+
+ private void incorporateReading(final TezTaskAttemptID attemptID,
+ final float newRawData, final long newTimeStamp) {
+ SimpleExponentialSmoothing foreCastEntry = getForecastEntry(attemptID);
+ if (foreCastEntry == null) {
+ Long tStartTime = startTimes.get(attemptID);
+ // skip if the startTime is not set yet
+ if (tStartTime == null) {
+ return;
+ }
+ estimates.putIfAbsent(attemptID,
+ new AtomicReference<>(SimpleExponentialSmoothing.createForecast(
+ constTime, skipCount, stagnatedWindow,
+ tStartTime - 1)));
+ incorporateReading(attemptID, newRawData, newTimeStamp);
+ return;
+ }
+ foreCastEntry.incorporateReading(newTimeStamp, newRawData);
+ }
+
+ @Override
+ public void contextualize(final Configuration conf, final Vertex vertex) {
+ super.contextualize(conf, vertex);
+
+ constTime
+ = conf.getLong(TezConfiguration.TEZ_AM_ESTIMATOR_EXPONENTIAL_LAMBDA_MS,
+ TezConfiguration.TEZ_AM_ESTIMATOR_EXPONENTIAL_LAMBDA_MS_DEFAULT);
+
+ stagnatedWindow = Math.max(2 * constTime, conf.getLong(
+ TezConfiguration.TEZ_AM_ESTIMATOR_EXPONENTIAL_STAGNATED_MS,
+ TezConfiguration.TEZ_AM_ESTIMATOR_EXPONENTIAL_STAGNATED_MS_DEFAULT));
+
+ skipCount = conf
+ .getInt(TezConfiguration.TEZ_AM_ESTIMATOR_EXPONENTIAL_SKIP_INITIALS,
+ TezConfiguration
+ .TEZ_AM_ESTIMATOR_EXPONENTIAL_SKIP_INITIALS_DEFAULT);
+ }
+
+ @Override
+ public long estimatedRuntime(final TezTaskAttemptID id) {
+ SimpleExponentialSmoothing foreCastEntry = getForecastEntry(id);
+ if (foreCastEntry == null) {
+ return DEFAULT_ESTIMATE_RUNTIME;
+ }
+ double remainingWork =
+ Math.max(0.0, Math.min(1.0, 1.0 - foreCastEntry.getRawData()));
+ double forecast =
+ Math.max(DEFAULT_PROGRESS_VALUE, foreCastEntry.getForecast());
+ long remainingTime = (long) (remainingWork / forecast);
+ long estimatedRuntime =
+ remainingTime + foreCastEntry.getTimeStamp() - foreCastEntry.getStartTime();
+ return estimatedRuntime;
+ }
+
+ @Override
+ public long newAttemptEstimatedRuntime() {
+ if (taskStatistics == null) {
+ return DEFAULT_ESTIMATE_RUNTIME;
+ }
+
+ double statsMeanCI = taskStatistics.meanCI();
+ double expectedVal =
+ statsMeanCI + Math.min(statsMeanCI * CONFIDENCE_INTERVAL_FACTOR,
+ taskStatistics.std() / 2);
+ return (long) (expectedVal);
+ }
+
+ @Override
+ public boolean hasStagnatedProgress(final TezTaskAttemptID id,
+ final long timeStamp) {
+ SimpleExponentialSmoothing foreCastEntry = getForecastEntry(id);
+ if (foreCastEntry == null) {
+ return false;
+ }
+ return foreCastEntry.isDataStagnated(timeStamp);
+ }
+
+ @Override
+ public long runtimeEstimateVariance(final TezTaskAttemptID id) {
+ SimpleExponentialSmoothing forecastEntry = getForecastEntry(id);
+ if (forecastEntry == null) {
+ return DEFAULT_ESTIMATE_RUNTIME;
+ }
+ double forecast = forecastEntry.getForecast();
+ if (forecastEntry.isDefaultForecast(forecast)) {
+ return DEFAULT_ESTIMATE_RUNTIME;
+ }
+ //TODO What is the best way to measure variance in runtime
+ return 0L;
+ }
+
+
+ @Override
+ public void updateAttempt(final TezTaskAttemptID attemptID,
+ final TaskAttemptState state,
+ final long timestamp) {
+ super.updateAttempt(attemptID, state, timestamp);
+ Task task = vertex.getTask(attemptID.getTaskID());
+ if (task == null) {
+ return;
+ }
+ TaskAttempt taskAttempt = task.getAttempt(attemptID);
+ if (taskAttempt == null) {
+ return;
+ }
+ float progress = taskAttempt.getProgress();
+ incorporateReading(attemptID, progress, timestamp);
+ }
+}
+
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/speculation/legacy/StartEndTimesBase.java b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/speculation/legacy/StartEndTimesBase.java
index d4d1a7f..3083986 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/speculation/legacy/StartEndTimesBase.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/speculation/legacy/StartEndTimesBase.java
@@ -35,13 +35,11 @@ import org.apache.tez.dag.records.TezTaskID;
/**
* Base class that uses the attempt runtime estimations from a derived class
* and uses it to determine outliers based on deviating beyond the mean
- * estimated runtime by some threshold
+ * estimated runtime by some threshold.
*/
abstract class StartEndTimesBase implements TaskRuntimeEstimator {
- static final float MINIMUM_COMPLETE_PROPORTION_TO_SPECULATE
- = 0.05F;
- static final int MINIMUM_COMPLETE_NUMBER_TO_SPECULATE
- = 1;
+ static final float MINIMUM_COMPLETE_PROPORTION_TO_SPECULATE = 0.05F;
+ static final int MINIMUM_COMPLETE_NUMBER_TO_SPECULATE = 1;
protected Vertex vertex;
@@ -50,56 +48,58 @@ abstract class StartEndTimesBase implements TaskRuntimeEstimator {
protected final DataStatistics taskStatistics = new DataStatistics();
- private float slowTaskRelativeTresholds;
+ private float slowTaskRelativeThresholds;
protected final Set<Task> doneTasks = new HashSet<Task>();
@Override
- public void enrollAttempt(TezTaskAttemptID id, long timestamp) {
+ public void enrollAttempt(final TezTaskAttemptID id, final long timestamp) {
startTimes.put(id, timestamp);
}
@Override
- public long attemptEnrolledTime(TezTaskAttemptID attemptID) {
+ public long attemptEnrolledTime(final TezTaskAttemptID attemptID) {
Long result = startTimes.get(attemptID);
return result == null ? Long.MAX_VALUE : result;
}
@Override
- public void contextualize(Configuration conf, Vertex vertex) {
- slowTaskRelativeTresholds = conf.getFloat(
+ public void contextualize(final Configuration conf, final Vertex vertexP) {
+ slowTaskRelativeThresholds = conf.getFloat(
TezConfiguration.TEZ_AM_LEGACY_SPECULATIVE_SLOWTASK_THRESHOLD, 1.0f);
- this.vertex = vertex;
+ this.vertex = vertexP;
}
- protected DataStatistics dataStatisticsForTask(TezTaskID taskID) {
+ protected DataStatistics dataStatisticsForTask(final TezTaskID taskID) {
return taskStatistics;
}
@Override
- public long thresholdRuntime(TezTaskID taskID) {
+ public long thresholdRuntime(final TezTaskID taskID) {
int completedTasks = vertex.getCompletedTasks();
int totalTasks = vertex.getTotalTasks();
-
+
if (completedTasks < MINIMUM_COMPLETE_NUMBER_TO_SPECULATE
- || (((float)completedTasks) / totalTasks)
- < MINIMUM_COMPLETE_PROPORTION_TO_SPECULATE ) {
+ || (((float) completedTasks) / totalTasks)
+ < MINIMUM_COMPLETE_PROPORTION_TO_SPECULATE) {
return Long.MAX_VALUE;
}
-
- long result = (long)taskStatistics.outlier(slowTaskRelativeTresholds);
+
+ long result = (long) taskStatistics.outlier(slowTaskRelativeThresholds);
return result;
}
@Override
public long newAttemptEstimatedRuntime() {
- return (long)taskStatistics.mean();
+ return (long) taskStatistics.mean();
}
@Override
- public void updateAttempt(TezTaskAttemptID attemptID, TaskAttemptState state, long timestamp) {
+ public void updateAttempt(final TezTaskAttemptID attemptID,
+ final TaskAttemptState state,
+ final long timestamp) {
Task task = vertex.getTask(attemptID.getTaskID());
@@ -109,7 +109,7 @@ abstract class StartEndTimesBase implements TaskRuntimeEstimator {
Long boxedStart = startTimes.get(attemptID);
long start = boxedStart == null ? Long.MIN_VALUE : boxedStart;
-
+
TaskAttempt taskAttempt = task.getAttempt(attemptID);
if (taskAttempt.getState() == TaskAttemptState.SUCCEEDED) {
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/speculation/legacy/TaskRuntimeEstimator.java b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/speculation/legacy/TaskRuntimeEstimator.java
index c8edd1e..4f747af 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/speculation/legacy/TaskRuntimeEstimator.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/speculation/legacy/TaskRuntimeEstimator.java
@@ -29,13 +29,14 @@ import org.apache.tez.dag.records.TezTaskID;
*
*/
public interface TaskRuntimeEstimator {
- public void enrollAttempt(TezTaskAttemptID id, long timestamp);
+ void enrollAttempt(TezTaskAttemptID id, long timestamp);
- public long attemptEnrolledTime(TezTaskAttemptID attemptID);
+ long attemptEnrolledTime(TezTaskAttemptID attemptID);
- public void updateAttempt(TezTaskAttemptID taId, TaskAttemptState reportedState, long timestamp);
+ void updateAttempt(TezTaskAttemptID taId,
+ TaskAttemptState reportedState, long timestamp);
- public void contextualize(Configuration conf, Vertex vertex);
+ void contextualize(Configuration conf, Vertex vertex);
/**
*
@@ -52,7 +53,7 @@ public interface TaskRuntimeEstimator {
* however long.
*
*/
- public long thresholdRuntime(TezTaskID id);
+ long thresholdRuntime(TezTaskID id);
/**
*
@@ -64,7 +65,7 @@ public interface TaskRuntimeEstimator {
* we don't have enough information yet to produce an estimate.
*
*/
- public long estimatedRuntime(TezTaskAttemptID id);
+ long estimatedRuntime(TezTaskAttemptID id);
/**
*
@@ -75,7 +76,7 @@ public interface TaskRuntimeEstimator {
* we don't have enough information yet to produce an estimate.
*
*/
- public long newAttemptEstimatedRuntime();
+ long newAttemptEstimatedRuntime();
/**
*
@@ -87,5 +88,20 @@ public interface TaskRuntimeEstimator {
* we don't have enough information yet to produce an estimate.
*
*/
- public long runtimeEstimateVariance(TezTaskAttemptID id);
+ long runtimeEstimateVariance(TezTaskAttemptID id);
+
+ /**
+ *
+ * Returns true if the estimator has no updates records for a threshold time
+ * window. This helps to identify task attempts that are stalled at the
+ * beginning of execution.
+ *
+ * @param id the {@link TezTaskAttemptID} of the attempt we are asking about
+ * @param timeStamp the time of the report we compare with
+ * @return true if the task attempt has no progress for a given time window
+ *
+ */
+ default boolean hasStagnatedProgress(TezTaskAttemptID id, long timeStamp) {
+ return false;
+ }
}
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/speculation/legacy/forecast/SimpleExponentialSmoothing.java b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/speculation/legacy/forecast/SimpleExponentialSmoothing.java
new file mode 100644
index 0000000..e7b7dcd
--- /dev/null
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/speculation/legacy/forecast/SimpleExponentialSmoothing.java
@@ -0,0 +1,336 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.dag.app.dag.speculation.legacy.forecast;
+
+import java.util.concurrent.atomic.AtomicReference;
+
+/**
+ * Implementation of the static model for Simple exponential smoothing.
+ */
+public class SimpleExponentialSmoothing {
+ private static final double DEFAULT_FORECAST = -1.0d;
+ private final int kMinimumReads;
+ private final long kStagnatedWindow;
+ private final long startTime;
+ private long timeConstant;
+
+ /**
+ * Holds reference to the current forecast record.
+ */
+ private AtomicReference<ForecastRecord> forecastRefEntry;
+
+ /**
+ * Create forecast simple exponential smoothing.
+ *
+ * @param timeConstant the time constant
+ * @param skipCnt the skip cnt
+ * @param stagnatedWindow the stagnated window
+ * @param timeStamp the time stamp
+ * @return the simple exponential smoothing
+ */
+ public static SimpleExponentialSmoothing createForecast(
+ final long timeConstant,
+ final int skipCnt, final long stagnatedWindow, final long timeStamp) {
+ return new SimpleExponentialSmoothing(timeConstant, skipCnt,
+ stagnatedWindow, timeStamp);
+ }
+
+ /**
+ * Instantiates a new Simple exponential smoothing.
+ *
+ * @param ktConstant the kt constant
+ * @param skipCnt the skip cnt
+ * @param stagnatedWindow the stagnated window
+ * @param timeStamp the time stamp
+ */
+ SimpleExponentialSmoothing(final long ktConstant, final int skipCnt,
+ final long stagnatedWindow, final long timeStamp) {
+ this.kMinimumReads = skipCnt;
+ this.kStagnatedWindow = stagnatedWindow;
+ this.timeConstant = ktConstant;
+ this.startTime = timeStamp;
+ this.forecastRefEntry = new AtomicReference<ForecastRecord>(null);
+ }
+
+ private class ForecastRecord {
+ private final double alpha;
+ private final long timeStamp;
+ private final double sample;
+ private final double rawData;
+ private double forecast;
+ private final double sseError;
+ private final long myIndex;
+ private ForecastRecord prevRec;
+
+ /**
+ * Instantiates a new Forecast record.
+ *
+ * @param currForecast the curr forecast
+ * @param currRawData the curr raw data
+ * @param currTimeStamp the curr time stamp
+ */
+ ForecastRecord(final double currForecast, final double currRawData,
+ final long currTimeStamp) {
+ this(0.0, currForecast, currRawData, currForecast, currTimeStamp, 0.0, 0);
+ }
+
+ /**
+ * Instantiates a new Forecast record.
+ *
+ * @param alphaVal the alpha val
+ * @param currSample the curr sample
+ * @param currRawData the curr raw data
+ * @param currForecast the curr forecast
+ * @param currTimeStamp the curr time stamp
+ * @param accError the acc error
+ * @param index the index
+ */
+ ForecastRecord(final double alphaVal, final double currSample,
+ final double currRawData,
+ final double currForecast, final long currTimeStamp,
+ final double accError,
+ final long index) {
+ this.timeStamp = currTimeStamp;
+ this.alpha = alphaVal;
+ this.sample = currSample;
+ this.forecast = currForecast;
+ this.rawData = currRawData;
+ this.sseError = accError;
+ this.myIndex = index;
+ }
+
+ private ForecastRecord createForecastRecord(final double alphaVal,
+ final double currSample,
+ final double currRawData,
+ final double currForecast, final long currTimeStamp,
+ final double accError,
+ final long index,
+ final ForecastRecord prev) {
+ ForecastRecord forecastRec =
+ new ForecastRecord(alphaVal, currSample, currRawData, currForecast,
+ currTimeStamp, accError, index);
+ forecastRec.prevRec = prev;
+ return forecastRec;
+ }
+
+ private double preProcessRawData(final double rData, final long newTime) {
+ return processRawData(this.rawData, this.timeStamp, rData, newTime);
+ }
+
+ /**
+ * Append forecast record.
+ *
+ * @param newTimeStamp the new time stamp
+ * @param rData the r data
+ * @return the forecast record
+ */
+ public ForecastRecord append(final long newTimeStamp, final double rData) {
+ if (this.timeStamp >= newTimeStamp
+ && Double.compare(this.rawData, rData) >= 0) {
+ // progress reported twice. Do nothing.
+ return this;
+ }
+ ForecastRecord refRecord = this;
+ if (newTimeStamp == this.timeStamp) {
+ // we need to restore old value if possible
+ if (this.prevRec != null) {
+ refRecord = this.prevRec;
+ }
+ }
+ double newSample = refRecord.preProcessRawData(rData, newTimeStamp);
+ long deltaTime = this.timeStamp - newTimeStamp;
+ if (refRecord.myIndex == kMinimumReads) {
+ timeConstant = Math.max(timeConstant, newTimeStamp - startTime);
+ }
+ double smoothFactor =
+ 1 - Math.exp(((double) deltaTime) / timeConstant);
+ double forecastVal =
+ smoothFactor * newSample + (1.0 - smoothFactor) * refRecord.forecast;
+ double newSSEError =
+ refRecord.sseError + Math.pow(newSample - refRecord.forecast, 2);
+ return refRecord
+ .createForecastRecord(smoothFactor, newSample, rData, forecastVal,
+ newTimeStamp, newSSEError, refRecord.myIndex + 1, refRecord);
+ }
+ }
+
+ /**
+ * checks if the task is hanging up.
+ *
+ * @param timeStamp current time of the scan.
+ * @return true if we have number of samples > kMinimumReads and the record
+ * timestamp has expired.
+ */
+ public boolean isDataStagnated(final long timeStamp) {
+ ForecastRecord rec = forecastRefEntry.get();
+ if (rec != null && rec.myIndex > kMinimumReads) {
+ return (rec.timeStamp + kStagnatedWindow) > timeStamp;
+ }
+ return false;
+ }
+
+ /**
+ * Process raw data double.
+ *
+ * @param oldRawData the old raw data
+ * @param oldTime the old time
+ * @param newRawData the new raw data
+ * @param newTime the new time
+ * @return the double
+ */
+ static double processRawData(final double oldRawData, final long oldTime,
+ final double newRawData, final long newTime) {
+ double rate = (newRawData - oldRawData) / (newTime - oldTime);
+ return rate;
+ }
+
+ /**
+ * Incorporate reading.
+ *
+ * @param timeStamp the time stamp
+ * @param currRawData the curr raw data
+ */
+ public void incorporateReading(final long timeStamp,
+ final double currRawData) {
+ ForecastRecord oldRec = forecastRefEntry.get();
+ if (oldRec == null) {
+ double oldForecast =
+ processRawData(0, startTime, currRawData, timeStamp);
+ forecastRefEntry.compareAndSet(null,
+ new ForecastRecord(oldForecast, 0.0d, startTime));
+ incorporateReading(timeStamp, currRawData);
+ return;
+ }
+ while (!forecastRefEntry.compareAndSet(oldRec, oldRec.append(timeStamp,
+ currRawData))) {
+ oldRec = forecastRefEntry.get();
+ }
+ }
+
+ /**
+ * Gets forecast.
+ *
+ * @return the forecast
+ */
+ public double getForecast() {
+ ForecastRecord rec = forecastRefEntry.get();
+ if (rec != null && rec.myIndex > kMinimumReads) {
+ return rec.forecast;
+ }
+ return DEFAULT_FORECAST;
+ }
+
+ /**
+ * Is default forecast boolean.
+ *
+ * @param value the value
+ * @return the boolean
+ */
+ public boolean isDefaultForecast(final double value) {
+ return value == DEFAULT_FORECAST;
+ }
+
+ /**
+ * Gets sse.
+ *
+ * @return the sse
+ */
+ public double getSSE() {
+ ForecastRecord rec = forecastRefEntry.get();
+ if (rec != null) {
+ return rec.sseError;
+ }
+ return DEFAULT_FORECAST;
+ }
+
+ /**
+ * Is error within bound boolean.
+ *
+ * @param bound the bound
+ * @return the boolean
+ */
+ public boolean isErrorWithinBound(final double bound) {
+ double squaredErr = getSSE();
+ if (squaredErr < 0) {
+ return false;
+ }
+ return bound > squaredErr;
+ }
+
+ /**
+ * Gets raw data.
+ *
+ * @return the raw data
+ */
+ public double getRawData() {
+ ForecastRecord rec = forecastRefEntry.get();
+ if (rec != null) {
+ return rec.rawData;
+ }
+ return DEFAULT_FORECAST;
+ }
+
+ /**
+ * Gets time stamp.
+ *
+ * @return the time stamp
+ */
+ public long getTimeStamp() {
+ ForecastRecord rec = forecastRefEntry.get();
+ if (rec != null) {
+ return rec.timeStamp;
+ }
+ return 0L;
+ }
+
+ /**
+ * Gets start time.
+ *
+ * @return the start time
+ */
+ public long getStartTime() {
+ return startTime;
+ }
+
+ /**
+ * Gets forecast ref entry.
+ *
+ * @return the forecast ref entry
+ */
+ public AtomicReference<ForecastRecord> getForecastRefEntry() {
+ return forecastRefEntry;
+ }
+
+ @Override
+ public String toString() {
+ String res = "NULL";
+ ForecastRecord rec = forecastRefEntry.get();
+ if (rec != null) {
+ StringBuilder strB = new StringBuilder("rec.index = ").append(rec.myIndex)
+ .append(", timeStamp t: ").append(rec.timeStamp)
+ .append(", forecast: ").append(rec.forecast).append(", sample: ")
+ .append(rec.sample).append(", raw: ").append(rec.rawData)
+ .append(", error: ").append(rec.sseError).append(", alpha: ")
+ .append(rec.alpha);
+ res = strB.toString();
+ }
+ return res;
+ }
+}
+
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/speculation/legacy/forecast/package-info.java b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/speculation/legacy/forecast/package-info.java
new file mode 100644
index 0000000..3ed8b6a
--- /dev/null
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/speculation/legacy/forecast/package-info.java
@@ -0,0 +1,20 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+@InterfaceAudience.Private
+package org.apache.tez.dag.app.dag.speculation.legacy.forecast;
+import org.apache.hadoop.classification.InterfaceAudience;
diff --git a/tez-dag/src/test/java/org/apache/tez/dag/app/TestSpeculation.java b/tez-dag/src/test/java/org/apache/tez/dag/app/TestSpeculation.java
index a81d4d3..b9a7c5a 100644
--- a/tez-dag/src/test/java/org/apache/tez/dag/app/TestSpeculation.java
+++ b/tez-dag/src/test/java/org/apache/tez/dag/app/TestSpeculation.java
@@ -19,10 +19,15 @@
package org.apache.tez.dag.app;
import java.io.IOException;
+import java.lang.annotation.Retention;
+import java.lang.annotation.RetentionPolicy;
+import java.util.Arrays;
+import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicInteger;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
@@ -45,39 +50,196 @@ import org.apache.tez.dag.app.dag.Task;
import org.apache.tez.dag.app.dag.TaskAttempt;
import org.apache.tez.dag.app.dag.impl.DAGImpl;
import org.apache.tez.dag.app.dag.speculation.legacy.LegacySpeculator;
+import org.apache.tez.dag.app.dag.speculation.legacy.LegacyTaskRuntimeEstimator;
+import org.apache.tez.dag.app.dag.speculation.legacy.SimpleExponentialTaskRuntimeEstimator;
+import org.apache.tez.dag.app.dag.speculation.legacy.TaskRuntimeEstimator;
import org.apache.tez.dag.library.vertexmanager.ShuffleVertexManager;
import org.apache.tez.dag.records.TaskAttemptTerminationCause;
import org.apache.tez.dag.records.TezTaskAttemptID;
import org.apache.tez.dag.records.TezTaskID;
import org.apache.tez.dag.records.TezVertexID;
+import org.junit.After;
import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
import org.junit.Test;
import com.google.common.base.Joiner;
+import org.junit.rules.TestRule;
+import org.junit.runner.Description;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.model.Statement;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+/**
+ * test speculation behavior given the list of estimator classes.
+ */
+@RunWith(Parameterized.class)
public class TestSpeculation {
- static Configuration defaultConf;
- static FileSystem localFs;
-
+ private final static Logger LOG = LoggerFactory.getLogger(TezConfiguration.class);
+
+ private static final String ASSERT_SPECULATIONS_COUNT_MSG =
+ "Number of attempts after Speculation should be two";
+ private static final String UNIT_EXCEPTION_MESSAGE =
+ "test timed out after";
+ private static final int ASSERT_SPECULATIONS_COUNT_RETRIES = 3;
+ private Configuration defaultConf;
+ private FileSystem localFs;
+
+ /**
+ * The Mock app.
+ */
MockDAGAppMaster mockApp;
+
+ /**
+ * The Mock launcher.
+ */
MockContainerLauncher mockLauncher;
-
- static {
+
+ /**
+ * The interface Retry.
+ */
+ @Retention(RetentionPolicy.RUNTIME)
+ public @interface Retry {}
+
+ /**
+ * The type Retry rule.
+ */
+ class RetryRule implements TestRule {
+
+ private AtomicInteger retryCount;
+
+ /**
+ * Instantiates a new Retry rule.
+ *
+ * @param retries the retries
+ */
+ RetryRule(int retries) {
+ super();
+ this.retryCount = new AtomicInteger(retries);
+ }
+
+ @Override
+ public Statement apply(final Statement base,
+ final Description description) {
+ return new Statement() {
+ @Override
+ public void evaluate() throws Throwable {
+ Throwable caughtThrowable = null;
+
+ while (retryCount.getAndDecrement() > 0) {
+ try {
+ base.evaluate();
+ return;
+ } catch (Throwable t) {
+ caughtThrowable = t;
+ if (retryCount.get() > 0 &&
+ description.getAnnotation(Retry.class) != null) {
+ if (!((t instanceof AssertionError && t.getMessage()
+ .contains(ASSERT_SPECULATIONS_COUNT_MSG))
+ || (t instanceof Exception && t.getMessage()
+ .contains(UNIT_EXCEPTION_MESSAGE)))) {
+ throw caughtThrowable;
+ }
+ LOG.warn("{} : Failed. Retries remaining: ",
+ description.getDisplayName(),
+ retryCount.toString());
+ } else {
+ throw caughtThrowable;
+ }
+ }
+ }
+ }
+ };
+ }
+ }
+
+ /**
+ * The Rule.
+ */
+ @Rule
+ public RetryRule rule = new RetryRule(ASSERT_SPECULATIONS_COUNT_RETRIES);
+
+ /**
+ * Sets default conf.
+ */
+ @Before
+ public void setDefaultConf() {
try {
defaultConf = new Configuration(false);
defaultConf.set("fs.defaultFS", "file:///");
defaultConf.setBoolean(TezConfiguration.TEZ_LOCAL_MODE, true);
defaultConf.setBoolean(TezConfiguration.TEZ_AM_SPECULATION_ENABLED, true);
- defaultConf.setFloat(ShuffleVertexManager.TEZ_SHUFFLE_VERTEX_MANAGER_MIN_SRC_FRACTION, 1);
- defaultConf.setFloat(ShuffleVertexManager.TEZ_SHUFFLE_VERTEX_MANAGER_MAX_SRC_FRACTION, 1);
+ defaultConf.setFloat(
+ ShuffleVertexManager.TEZ_SHUFFLE_VERTEX_MANAGER_MIN_SRC_FRACTION, 1);
+ defaultConf.setFloat(
+ ShuffleVertexManager.TEZ_SHUFFLE_VERTEX_MANAGER_MAX_SRC_FRACTION, 1);
localFs = FileSystem.getLocal(defaultConf);
- String stagingDir = "target" + Path.SEPARATOR + TestSpeculation.class.getName() + "-tmpDir";
+ String stagingDir =
+ "target" + Path.SEPARATOR + TestSpeculation.class.getName()
+ + "-tmpDir";
defaultConf.set(TezConfiguration.TEZ_AM_STAGING_DIR, stagingDir);
+ defaultConf.setClass(TezConfiguration.TEZ_AM_TASK_ESTIMATOR_CLASS,
+ estimatorClass,
+ TaskRuntimeEstimator.class);
+ defaultConf.setInt(TezConfiguration.TEZ_AM_MINIMUM_ALLOWED_SPECULATIVE_TASKS, 20);
+ defaultConf.setDouble(TezConfiguration.TEZ_AM_PROPORTION_TOTAL_TASKS_SPECULATABLE, 0.2);
+ defaultConf.setDouble(TezConfiguration.TEZ_AM_PROPORTION_RUNNING_TASKS_SPECULATABLE, 0.25);
+ defaultConf.setLong(TezConfiguration.TEZ_AM_SOONEST_RETRY_AFTER_NO_SPECULATE, 25);
+ defaultConf.setLong(TezConfiguration.TEZ_AM_SOONEST_RETRY_AFTER_SPECULATE, 50);
+ defaultConf.setInt(TezConfiguration.TEZ_AM_ESTIMATOR_EXPONENTIAL_SKIP_INITIALS, 2);
} catch (IOException e) {
throw new RuntimeException("init failure", e);
}
}
+ /**
+ * Tear down.
+ */
+ @After
+ public void tearDown() {
+ defaultConf = null;
+ try {
+ localFs.close();
+ mockLauncher.shutdown();
+ mockApp.close();
+ } catch (Exception e) {
+ e.printStackTrace();
+ }
+ }
+
+ /**
+ * Gets test parameters.
+ *
+ * @return the test parameters
+ */
+ @Parameterized.Parameters(name = "{index}: TaskEstimator(EstimatorClass {0})")
+ public static Collection<Object[]> getTestParameters() {
+ return Arrays.asList(new Object[][]{
+ {SimpleExponentialTaskRuntimeEstimator.class},
+ {LegacyTaskRuntimeEstimator.class}
+ });
+ }
+
+ private Class<? extends TaskRuntimeEstimator> estimatorClass;
+
+ /**
+ * Instantiates a new Test speculation.
+ *
+ * @param estimatorKlass the estimator klass
+ */
+ public TestSpeculation(Class<? extends TaskRuntimeEstimator> estimatorKlass) {
+ this.estimatorClass = estimatorKlass;
+ }
+
+ /**
+ * Create tez session mock tez client.
+ *
+ * @return the mock tez client
+ * @throws Exception the exception
+ */
MockTezClient createTezSession() throws Exception {
TezConfiguration tezconf = new TezConfiguration(defaultConf);
AtomicBoolean mockAppLauncherGoFlag = new AtomicBoolean(false);
@@ -87,8 +249,16 @@ public class TestSpeculation {
syncWithMockAppLauncher(false, mockAppLauncherGoFlag, tezClient);
return tezClient;
}
-
- void syncWithMockAppLauncher(boolean allowScheduling, AtomicBoolean mockAppLauncherGoFlag,
+
+ /**
+ * Sync with mock app launcher.
+ *
+ * @param allowScheduling the allow scheduling
+ * @param mockAppLauncherGoFlag the mock app launcher go flag
+ * @param tezClient the tez client
+ * @throws Exception the exception
+ */
+ void syncWithMockAppLauncher(boolean allowScheduling, AtomicBoolean mockAppLauncherGoFlag,
MockTezClient tezClient) throws Exception {
synchronized (mockAppLauncherGoFlag) {
while (!mockAppLauncherGoFlag.get()) {
@@ -101,6 +271,12 @@ public class TestSpeculation {
}
}
+ /**
+ * Test single task speculation.
+ *
+ * @throws Exception the exception
+ */
+ @Retry
@Test (timeout = 10000)
public void testSingleTaskSpeculation() throws Exception {
// Map<Timeout conf value, expected number of tasks>
@@ -126,9 +302,10 @@ public class TestSpeculation {
DAGImpl dagImpl = (DAGImpl) mockApp.getContext().getCurrentDAG();
TezVertexID vertexId = TezVertexID.getInstance(dagImpl.getID(), 0);
// original attempt is killed and speculative one is successful
- TezTaskAttemptID killedTaId = TezTaskAttemptID.getInstance(TezTaskID.getInstance(vertexId, 0), 0);
- TezTaskAttemptID successTaId = TezTaskAttemptID.getInstance(TezTaskID.getInstance(vertexId, 0), 1);
-
+ TezTaskAttemptID killedTaId =
+ TezTaskAttemptID.getInstance(TezTaskID.getInstance(vertexId, 0), 0);
+ TezTaskAttemptID successTaId =
+ TezTaskAttemptID.getInstance(TezTaskID.getInstance(vertexId, 0), 1);
Thread.sleep(200);
// cause speculation trigger
mockLauncher.setStatusUpdatesForTask(killedTaId, 100);
@@ -149,16 +326,16 @@ public class TestSpeculation {
}
}
+ /**
+ * Test basic speculation.
+ *
+ * @param withProgress the with progress
+ * @throws Exception the exception
+ */
public void testBasicSpeculation(boolean withProgress) throws Exception {
-
- defaultConf.setInt(TezConfiguration.TEZ_AM_MINIMUM_ALLOWED_SPECULATIVE_TASKS, 20);
- defaultConf.setDouble(TezConfiguration.TEZ_AM_PROPORTION_TOTAL_TASKS_SPECULATABLE, 0.2);
- defaultConf.setDouble(TezConfiguration.TEZ_AM_PROPORTION_RUNNING_TASKS_SPECULATABLE, 0.25);
- defaultConf.setLong(TezConfiguration.TEZ_AM_SOONEST_RETRY_AFTER_NO_SPECULATE, 25);
- defaultConf.setLong(TezConfiguration.TEZ_AM_SOONEST_RETRY_AFTER_SPECULATE, 50);
-
DAG dag = DAG.create("test");
- Vertex vA = Vertex.create("A", ProcessorDescriptor.create("Proc.class"), 5);
+ Vertex vA = Vertex.create("A",
+ ProcessorDescriptor.create("Proc.class"), 5);
dag.addVertex(vA);
MockTezClient tezClient = createTezSession();
@@ -166,8 +343,10 @@ public class TestSpeculation {
DAGImpl dagImpl = (DAGImpl) mockApp.getContext().getCurrentDAG();
TezVertexID vertexId = TezVertexID.getInstance(dagImpl.getID(), 0);
// original attempt is killed and speculative one is successful
- TezTaskAttemptID killedTaId = TezTaskAttemptID.getInstance(TezTaskID.getInstance(vertexId, 0), 0);
- TezTaskAttemptID successTaId = TezTaskAttemptID.getInstance(TezTaskID.getInstance(vertexId, 0), 1);
+ TezTaskAttemptID killedTaId =
+ TezTaskAttemptID.getInstance(TezTaskID.getInstance(vertexId, 0), 0);
+ TezTaskAttemptID successTaId =
+ TezTaskAttemptID.getInstance(TezTaskID.getInstance(vertexId, 0), 1);
mockLauncher.updateProgress(withProgress);
// cause speculation trigger
@@ -175,9 +354,11 @@ public class TestSpeculation {
mockLauncher.startScheduling(true);
dagClient.waitForCompletion();
- Assert.assertEquals(DAGStatus.State.SUCCEEDED, dagClient.getDAGStatus(null).getState());
+ Assert.assertEquals(DAGStatus.State.SUCCEEDED,
+ dagClient.getDAGStatus(null).getState());
Task task = dagImpl.getTask(killedTaId.getTaskID());
- Assert.assertEquals(2, task.getAttempts().size());
+ Assert.assertEquals(ASSERT_SPECULATIONS_COUNT_MSG, 2,
+ task.getAttempts().size());
Assert.assertEquals(successTaId, task.getSuccessfulAttempt().getID());
TaskAttempt killedAttempt = task.getAttempt(killedTaId);
Joiner.on(",").join(killedAttempt.getDiagnostics()).contains("Killed as speculative attempt");
@@ -204,18 +385,36 @@ public class TestSpeculation {
tezClient.stop();
}
-
+
+ /**
+ * Test basic speculation with progress.
+ *
+ * @throws Exception the exception
+ */
+ @Retry
@Test (timeout=10000)
public void testBasicSpeculationWithProgress() throws Exception {
testBasicSpeculation(true);
}
+ /**
+ * Test basic speculation without progress.
+ *
+ * @throws Exception the exception
+ */
+ @Retry
@Test (timeout=10000)
public void testBasicSpeculationWithoutProgress() throws Exception {
testBasicSpeculation(false);
}
- @Test (timeout=100000)
+ /**
+ * Test basic speculation per vertex conf.
+ *
+ * @throws Exception the exception
+ */
+ @Retry
+ @Test (timeout=10000)
public void testBasicSpeculationPerVertexConf() throws Exception {
DAG dag = DAG.create("test");
String vNameNoSpec = "A";
@@ -224,8 +423,6 @@ public class TestSpeculation {
Vertex vA = Vertex.create(vNameNoSpec, ProcessorDescriptor.create("Proc.class"), 5);
Vertex vB = Vertex.create(vNameSpec, ProcessorDescriptor.create("Proc.class"), 5);
vA.setConf(TezConfiguration.TEZ_AM_SPECULATION_ENABLED, "false");
- vB.setConf(TezConfiguration.TEZ_AM_SOONEST_RETRY_AFTER_NO_SPECULATE,
- speculatorSleepTime);
dag.addVertex(vA);
dag.addVertex(vB);
// min/max src fraction is set to 1. So vertices will run sequentially
@@ -273,6 +470,12 @@ public class TestSpeculation {
tezClient.stop();
}
+ /**
+ * Test basic speculation not useful.
+ *
+ * @throws Exception the exception
+ */
+ @Retry
@Test (timeout=10000)
public void testBasicSpeculationNotUseful() throws Exception {
DAG dag = DAG.create("test");
@@ -310,5 +513,4 @@ public class TestSpeculation {
.getValue());
tezClient.stop();
}
-
}