You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tez.apache.org by je...@apache.org on 2020/02/04 16:44:45 UTC

[tez] branch master updated: TEZ-4106. Add Exponential Smooth RuntimeEstimator to the speculator

This is an automated email from the ASF dual-hosted git repository.

jeagles pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/tez.git


The following commit(s) were added to refs/heads/master by this push:
     new 2736788  TEZ-4106. Add Exponential Smooth RuntimeEstimator to the speculator
2736788 is described below

commit 2736788f49f17bd269b7ea64d959bbaad95421d3
Author: Ahmed Hussein <ah...@apache.org>
AuthorDate: Tue Feb 4 10:44:29 2020 -0600

    TEZ-4106. Add Exponential Smooth RuntimeEstimator to the speculator
    
    Signed-off-by: Jonathan Eagles <je...@apache.org>
---
 .../org/apache/tez/dag/api/TezConfiguration.java   |  59 +++-
 .../app/dag/speculation/legacy/DataStatistics.java |  25 +-
 .../dag/speculation/legacy/LegacySpeculator.java   |  39 ++-
 .../SimpleExponentialTaskRuntimeEstimator.java     | 194 ++++++++++++
 .../dag/speculation/legacy/StartEndTimesBase.java  |  42 +--
 .../speculation/legacy/TaskRuntimeEstimator.java   |  32 +-
 .../forecast/SimpleExponentialSmoothing.java       | 336 +++++++++++++++++++++
 .../speculation/legacy/forecast/package-info.java  |  20 ++
 .../org/apache/tez/dag/app/TestSpeculation.java    | 262 ++++++++++++++--
 9 files changed, 926 insertions(+), 83 deletions(-)

diff --git a/tez-api/src/main/java/org/apache/tez/dag/api/TezConfiguration.java b/tez-api/src/main/java/org/apache/tez/dag/api/TezConfiguration.java
index f087e3a..58aecda 100644
--- a/tez-api/src/main/java/org/apache/tez/dag/api/TezConfiguration.java
+++ b/tez-api/src/main/java/org/apache/tez/dag/api/TezConfiguration.java
@@ -26,6 +26,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.Set;
 
+import java.util.concurrent.TimeUnit;
 import org.apache.hadoop.classification.InterfaceAudience.Private;
 import org.apache.hadoop.classification.InterfaceAudience.Public;
 import org.apache.hadoop.classification.InterfaceStability.Unstable;
@@ -531,14 +532,6 @@ public class TezConfiguration extends Configuration {
   public static final boolean TEZ_AM_SPECULATION_ENABLED_DEFAULT = false;
 
   /**
-   * Class used to estimate task resource needs.
-   */
-  @ConfigurationScope(Scope.VERTEX)
-  @ConfigurationProperty
-  public static final String TEZ_AM_SPECULATION_ESTIMATOR_CLASS =
-          TEZ_AM_PREFIX + "speculation.estimator.class";
-
-  /**
    * Float value. Specifies how many standard deviations away from the mean task execution time
    * should be considered as an outlier/slow task.
    */
@@ -559,6 +552,10 @@ public class TezConfiguration extends Configuration {
                                      TEZ_AM_PREFIX + "legacy.speculative.single.task.vertex.timeout";
   public static final long TEZ_AM_LEGACY_SPECULATIVE_SINGLE_TASK_VERTEX_TIMEOUT_DEFAULT = -1;
 
+  @Private
+  public static final String TEZ_SPECULATOR_PREFIX = TEZ_AM_PREFIX + "speculator.";
+  @Private
+  public static final String TEZ_ESTIMATOR_PREFIX = TEZ_AM_PREFIX + "task.estimator.";
   /**
    * Long value. Specifies amount of time (in ms) that needs to elapse to do the next round of
    * speculation if there is no task speculated in this round.
@@ -581,6 +578,52 @@ public class TezConfiguration extends Configuration {
           TEZ_AM_PREFIX + "soonest.retry.after.speculate";
   public static final long TEZ_AM_SOONEST_RETRY_AFTER_SPECULATE_DEFAULT = 1000L * 15L;
 
+  /** The class that should be used for speculative execution calculations. */
+  @ConfigurationScope(Scope.VERTEX)
+  @ConfigurationProperty
+  public static final String TEZ_AM_SPECULATOR_CLASS =
+      TEZ_SPECULATOR_PREFIX + "class";
+  /** The class that should be used for task runtime estimation. */
+  @ConfigurationScope(Scope.VERTEX)
+  @ConfigurationProperty
+  public static final String TEZ_AM_TASK_ESTIMATOR_CLASS =
+      TEZ_ESTIMATOR_PREFIX + "class";
+  /**
+   * Long value. Specifies amount of time (in ms) of the lambda value in the
+   * smoothing function of the task estimator
+   */
+  @Unstable
+  @ConfigurationScope(Scope.VERTEX)
+  @ConfigurationProperty(type="long")
+  public static final String TEZ_AM_ESTIMATOR_EXPONENTIAL_LAMBDA_MS =
+      TEZ_ESTIMATOR_PREFIX + "exponential.lambda.ms";
+  public static final long TEZ_AM_ESTIMATOR_EXPONENTIAL_LAMBDA_MS_DEFAULT =
+      TimeUnit.SECONDS.toMillis(120);
+
+  /**
+   * The window length in the simple exponential smoothing that considers the
+   * task attempt is stagnated.
+   */
+  @Unstable
+  @ConfigurationScope(Scope.VERTEX)
+  @ConfigurationProperty(type="long")
+  public static final String TEZ_AM_ESTIMATOR_EXPONENTIAL_STAGNATED_MS =
+      TEZ_ESTIMATOR_PREFIX + "exponential.stagnated.ms";
+  public static final long TEZ_AM_ESTIMATOR_EXPONENTIAL_STAGNATED_MS_DEFAULT =
+      TimeUnit.SECONDS.toMillis(360);
+
+  /**
+   * The number of initial readings that the estimator ignores before giving a
+   * prediction. At the beginning the smooth estimator won't be accurate in
+   * prediction
+   */
+  @Unstable
+  @ConfigurationScope(Scope.VERTEX)
+  @ConfigurationProperty(type="integer")
+  public static final String TEZ_AM_ESTIMATOR_EXPONENTIAL_SKIP_INITIALS =
+      TEZ_ESTIMATOR_PREFIX + "exponential.skip.initials";
+  public static final int TEZ_AM_ESTIMATOR_EXPONENTIAL_SKIP_INITIALS_DEFAULT = 24;
+
   /**
    * Double value. The max percent (0-1) of running tasks that can be speculatively re-executed at any time.
    */
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/speculation/legacy/DataStatistics.java b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/speculation/legacy/DataStatistics.java
index 7e6f1c2..bbfb950 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/speculation/legacy/DataStatistics.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/speculation/legacy/DataStatistics.java
@@ -21,6 +21,11 @@ package org.apache.tez.dag.app.dag.speculation.legacy;
 import com.google.common.annotations.VisibleForTesting;
 
 public class DataStatistics {
+  /**
+   * factor used to calculate confidence interval within 95%.
+   */
+  private static final double DEFAULT_CI_FACTOR = 1.96;
+
   private int count = 0;
   private double sum = 0;
   private double sumSquares = 0;
@@ -79,8 +84,24 @@ public class DataStatistics {
     return count;
   }
 
+  /**
+   * calculates the mean value within 95% ConfidenceInterval. 1.96 is standard
+   * for 95%.
+   *
+   * @return the mean value adding 95% confidence interval.
+   */
+  public synchronized double meanCI() {
+    if (count <= 1) {
+      return 0.0;
+    }
+    double currMean = mean();
+    double currStd = std();
+    return currMean + (DEFAULT_CI_FACTOR * currStd / Math.sqrt(count));
+  }
+
   public String toString() {
-    return "DataStatistics: count is " + count + ", sum is " + sum +
-    ", sumSquares is " + sumSquares + " mean is " + mean() + " std() is " + std();
+    return "DataStatistics: count is " + count + ", sum is " + sum
+        + ", sumSquares is " + sumSquares + " mean is " + mean()
+        + " std() is " + std() + ", meanCI() is " + meanCI();
   }
 }
\ No newline at end of file
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/speculation/legacy/LegacySpeculator.java b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/speculation/legacy/LegacySpeculator.java
index 23b057a..f21b819 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/speculation/legacy/LegacySpeculator.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/speculation/legacy/LegacySpeculator.java
@@ -33,7 +33,6 @@ import java.util.concurrent.TimeUnit;
 import java.util.concurrent.locks.ReadWriteLock;
 import java.util.concurrent.locks.ReentrantReadWriteLock;
 import org.apache.hadoop.service.AbstractService;
-import org.apache.hadoop.service.ServiceOperations;
 import org.apache.tez.common.ProgressHelper;
 import org.apache.tez.dag.api.TezConfiguration;
 import org.slf4j.Logger;
@@ -100,10 +99,9 @@ public class LegacySpeculator extends AbstractService {
   private TaskRuntimeEstimator estimator;
   private final long taskTimeout;
   private final Clock clock;
-  private long nextSpeculateTime = Long.MIN_VALUE;
   private Thread speculationBackgroundThread = null;
   private volatile boolean stopped = false;
-  /* Allow the speculator to wait on a blockingQueue in case we use it for event notification */
+  /** Allow the speculator to wait on a blockingQueue in case we use it for event notification. */
   private BlockingQueue<Object> scanControl = new LinkedBlockingQueue<Object>();
 
   @VisibleForTesting
@@ -132,9 +130,8 @@ public class LegacySpeculator extends AbstractService {
   static private TaskRuntimeEstimator getEstimator
       (Configuration conf, Vertex vertex) {
     TaskRuntimeEstimator estimator;
-    // "tez.am.speculation.estimator.class"
     Class<? extends TaskRuntimeEstimator> estimatorClass =
-        conf.getClass(TezConfiguration.TEZ_AM_SPECULATION_ESTIMATOR_CLASS,
+        conf.getClass(TezConfiguration.TEZ_AM_TASK_ESTIMATOR_CLASS,
             LegacyTaskRuntimeEstimator.class,
             TaskRuntimeEstimator.class);
     try {
@@ -236,6 +233,16 @@ public class LegacySpeculator extends AbstractService {
     }
   }
 
+  // This interface is intended to be used only for test cases.
+  public void scanForSpeculationsForTesting() {
+    if (LOG.isDebugEnabled()) {
+      LOG.debug("We got asked to run a debug speculation scan.");
+      LOG.debug("There are {} speculative events stacked already.", scanControl.size());
+    }
+    scanControl.add(new Object());
+    Thread.yield();
+  }
+
   public Runnable createThread() {
     return new Runnable() {
       @Override
@@ -267,8 +274,9 @@ public class LegacySpeculator extends AbstractService {
   public void notifyAttemptStarted(TezTaskAttemptID taId, long timestamp) {
     estimator.enrollAttempt(taId, timestamp);    
   }
-  
-  public void notifyAttemptStatusUpdate(TezTaskAttemptID taId, TaskAttemptState reportedState,
+
+  public void notifyAttemptStatusUpdate(TezTaskAttemptID taId,
+      TaskAttemptState reportedState,
       long timestamp) {
     statusUpdate(taId, reportedState, timestamp);
   }
@@ -293,12 +301,10 @@ public class LegacySpeculator extends AbstractService {
 
     estimator.updateAttempt(attemptID, reportedState, timestamp);
 
-    //if (stateString.equals(TaskAttemptState.RUNNING.name())) {
     if (reportedState == TaskAttemptState.RUNNING) {
       runningTasks.putIfAbsent(taskID, Boolean.TRUE);
     } else {
       runningTasks.remove(taskID, Boolean.TRUE);
-      //if (!stateString.equals(TaskAttemptState.STARTING.name())) {
       if (reportedState == TaskAttemptState.STARTING) {
         runningTaskAttemptStatistics.remove(attemptID);
       }
@@ -356,7 +362,7 @@ public class LegacySpeculator extends AbstractService {
       }
     }
 
-    TezTaskAttemptID runningTaskAttemptID = null;
+    TezTaskAttemptID runningTaskAttemptID;
     int numberRunningAttempts = 0;
 
     for (TaskAttempt taskAttempt : attempts.values()) {
@@ -387,7 +393,8 @@ public class LegacySpeculator extends AbstractService {
             return ON_SCHEDULE;
           }
         } else {
-          long estimatedRunTime = estimator.estimatedRuntime(runningTaskAttemptID);
+          long estimatedRunTime = estimator
+              .estimatedRuntime(runningTaskAttemptID);
 
           long estimatedEndTime = estimatedRunTime + taskAttemptStartTime;
 
@@ -399,12 +406,15 @@ public class LegacySpeculator extends AbstractService {
                   runningTaskAttemptStatistics.get(runningTaskAttemptID);
           if (data == null) {
             runningTaskAttemptStatistics.put(runningTaskAttemptID,
-                    new TaskAttemptHistoryStatistics(estimatedRunTime, progress, now));
+                new TaskAttemptHistoryStatistics(estimatedRunTime, progress,
+                    now));
           } else {
             if (estimatedRunTime == data.getEstimatedRunTime()
                     && progress == data.getProgress()) {
               // Previous stats are same as same stats
-              if (data.notHeartbeatedInAWhile(now)) {
+              if (data.notHeartbeatedInAWhile(now)
+                  || estimator
+                  .hasStagnatedProgress(runningTaskAttemptID, now)) {
                 // Stats have stagnated for a while, simulate heart-beat.
                 // Now simulate the heart-beat
                 statusUpdate(taskAttempt.getID(), taskAttempt.getState(),
@@ -448,7 +458,8 @@ public class LegacySpeculator extends AbstractService {
 
   // Add attempt to a given Task.
   protected void addSpeculativeAttempt(TezTaskID taskID) {
-    LOG.info("DefaultSpeculator.addSpeculativeAttempt -- we are speculating " + taskID);
+    LOG.info("DefaultSpeculator.addSpeculativeAttempt -- we are speculating "
+        + taskID);
     vertex.scheduleSpeculativeTask(taskID);
     mayHaveSpeculated.add(taskID);
   }
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/speculation/legacy/SimpleExponentialTaskRuntimeEstimator.java b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/speculation/legacy/SimpleExponentialTaskRuntimeEstimator.java
new file mode 100644
index 0000000..b61f153
--- /dev/null
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/speculation/legacy/SimpleExponentialTaskRuntimeEstimator.java
@@ -0,0 +1,194 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.tez.dag.app.dag.speculation.legacy;
+
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentMap;
+import java.util.concurrent.atomic.AtomicReference;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.tez.dag.api.TezConfiguration;
+import org.apache.tez.dag.api.oldrecords.TaskAttemptState;
+import org.apache.tez.dag.app.dag.Task;
+import org.apache.tez.dag.app.dag.TaskAttempt;
+import org.apache.tez.dag.app.dag.Vertex;
+import org.apache.tez.dag.app.dag.speculation.legacy.forecast.SimpleExponentialSmoothing;
+import org.apache.tez.dag.records.TezTaskAttemptID;
+
+/**
+ * A task Runtime Estimator based on exponential smoothing.
+ */
+public class SimpleExponentialTaskRuntimeEstimator extends StartEndTimesBase {
+  /**
+   * The default value returned by the estimator when no records exist.
+   */
+  private static final long DEFAULT_ESTIMATE_RUNTIME = -1L;
+
+  /**
+   * Given a forecast of value 0.0, it is getting replaced by the default value
+   * to avoid division by 0.
+   */
+  private static final double DEFAULT_PROGRESS_VALUE = 1E-10;
+
+  /**
+   * Factor used to calculate the confidence interval.
+   */
+  private static final double CONFIDENCE_INTERVAL_FACTOR = 0.25;
+  /**
+   * Constant time used to calculate the smoothing exponential factor.
+   */
+  private long constTime;
+
+  /**
+   * Number of readings before we consider the estimate stable.
+   * Otherwise, the estimate will be skewed due to the initial estimate
+   */
+  private int skipCount;
+
+  /**
+   * Time window to automatically update the count of the skipCount. This is
+   * needed when a task stalls without any progress, causing the estimator to
+   * return -1 as an estimatedRuntime.
+   */
+  private long stagnatedWindow;
+
+  /**
+   * A map of TA Id to the statistic model of smooth exponential.
+   */
+  private final ConcurrentMap<TezTaskAttemptID,
+      AtomicReference<SimpleExponentialSmoothing>>
+      estimates = new ConcurrentHashMap<>();
+
+  private SimpleExponentialSmoothing getForecastEntry(
+      final TezTaskAttemptID attemptID) {
+    AtomicReference<SimpleExponentialSmoothing> entryRef = estimates
+        .get(attemptID);
+    if (entryRef == null) {
+      return null;
+    }
+    return entryRef.get();
+  }
+
+  private void incorporateReading(final TezTaskAttemptID attemptID,
+      final float newRawData, final long newTimeStamp) {
+    SimpleExponentialSmoothing foreCastEntry = getForecastEntry(attemptID);
+    if (foreCastEntry == null) {
+      Long tStartTime = startTimes.get(attemptID);
+      // skip if the startTime is not set yet
+      if (tStartTime == null) {
+        return;
+      }
+      estimates.putIfAbsent(attemptID,
+          new AtomicReference<>(SimpleExponentialSmoothing.createForecast(
+              constTime, skipCount, stagnatedWindow,
+              tStartTime - 1)));
+      incorporateReading(attemptID, newRawData, newTimeStamp);
+      return;
+    }
+    foreCastEntry.incorporateReading(newTimeStamp, newRawData);
+  }
+
+  @Override
+  public void contextualize(final Configuration conf, final Vertex vertex) {
+    super.contextualize(conf, vertex);
+
+    constTime
+        = conf.getLong(TezConfiguration.TEZ_AM_ESTIMATOR_EXPONENTIAL_LAMBDA_MS,
+        TezConfiguration.TEZ_AM_ESTIMATOR_EXPONENTIAL_LAMBDA_MS_DEFAULT);
+
+    stagnatedWindow = Math.max(2 * constTime, conf.getLong(
+        TezConfiguration.TEZ_AM_ESTIMATOR_EXPONENTIAL_STAGNATED_MS,
+        TezConfiguration.TEZ_AM_ESTIMATOR_EXPONENTIAL_STAGNATED_MS_DEFAULT));
+
+    skipCount = conf
+        .getInt(TezConfiguration.TEZ_AM_ESTIMATOR_EXPONENTIAL_SKIP_INITIALS,
+            TezConfiguration
+                .TEZ_AM_ESTIMATOR_EXPONENTIAL_SKIP_INITIALS_DEFAULT);
+  }
+
+  @Override
+  public long estimatedRuntime(final TezTaskAttemptID id) {
+    SimpleExponentialSmoothing foreCastEntry = getForecastEntry(id);
+    if (foreCastEntry == null) {
+      return DEFAULT_ESTIMATE_RUNTIME;
+    }
+    double remainingWork =
+        Math.max(0.0, Math.min(1.0, 1.0 - foreCastEntry.getRawData()));
+    double forecast =
+        Math.max(DEFAULT_PROGRESS_VALUE, foreCastEntry.getForecast());
+    long remainingTime = (long) (remainingWork / forecast);
+    long estimatedRuntime =
+        remainingTime + foreCastEntry.getTimeStamp() - foreCastEntry.getStartTime();
+    return estimatedRuntime;
+  }
+
+  @Override
+  public long newAttemptEstimatedRuntime() {
+    if (taskStatistics == null) {
+      return DEFAULT_ESTIMATE_RUNTIME;
+    }
+
+    double statsMeanCI = taskStatistics.meanCI();
+    double expectedVal =
+        statsMeanCI + Math.min(statsMeanCI * CONFIDENCE_INTERVAL_FACTOR,
+            taskStatistics.std() / 2);
+    return (long) (expectedVal);
+  }
+
+  @Override
+  public boolean hasStagnatedProgress(final TezTaskAttemptID id,
+      final long timeStamp) {
+    SimpleExponentialSmoothing foreCastEntry = getForecastEntry(id);
+    if (foreCastEntry == null) {
+      return false;
+    }
+    return foreCastEntry.isDataStagnated(timeStamp);
+  }
+
+  @Override
+  public long runtimeEstimateVariance(final TezTaskAttemptID id) {
+    SimpleExponentialSmoothing forecastEntry = getForecastEntry(id);
+    if (forecastEntry == null) {
+      return DEFAULT_ESTIMATE_RUNTIME;
+    }
+    double forecast = forecastEntry.getForecast();
+    if (forecastEntry.isDefaultForecast(forecast)) {
+      return DEFAULT_ESTIMATE_RUNTIME;
+    }
+    //TODO What is the best way to measure variance in runtime
+    return 0L;
+  }
+
+
+  @Override
+  public void updateAttempt(final TezTaskAttemptID attemptID,
+      final TaskAttemptState state,
+      final long timestamp) {
+    super.updateAttempt(attemptID, state, timestamp);
+    Task task = vertex.getTask(attemptID.getTaskID());
+    if (task == null) {
+      return;
+    }
+    TaskAttempt taskAttempt = task.getAttempt(attemptID);
+    if (taskAttempt == null) {
+      return;
+    }
+    float progress = taskAttempt.getProgress();
+    incorporateReading(attemptID, progress, timestamp);
+  }
+}
+
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/speculation/legacy/StartEndTimesBase.java b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/speculation/legacy/StartEndTimesBase.java
index d4d1a7f..3083986 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/speculation/legacy/StartEndTimesBase.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/speculation/legacy/StartEndTimesBase.java
@@ -35,13 +35,11 @@ import org.apache.tez.dag.records.TezTaskID;
 /**
  * Base class that uses the attempt runtime estimations from a derived class
  * and uses it to determine outliers based on deviating beyond the mean
- * estimated runtime by some threshold
+ * estimated runtime by some threshold.
  */
 abstract class StartEndTimesBase implements TaskRuntimeEstimator {
-  static final float MINIMUM_COMPLETE_PROPORTION_TO_SPECULATE
-      = 0.05F;
-  static final int MINIMUM_COMPLETE_NUMBER_TO_SPECULATE
-      = 1;
+  static final float MINIMUM_COMPLETE_PROPORTION_TO_SPECULATE = 0.05F;
+  static final int MINIMUM_COMPLETE_NUMBER_TO_SPECULATE = 1;
 
   protected Vertex vertex;
 
@@ -50,56 +48,58 @@ abstract class StartEndTimesBase implements TaskRuntimeEstimator {
 
   protected final DataStatistics taskStatistics = new DataStatistics();
 
-  private float slowTaskRelativeTresholds;
+  private float slowTaskRelativeThresholds;
 
   protected final Set<Task> doneTasks = new HashSet<Task>();
 
   @Override
-  public void enrollAttempt(TezTaskAttemptID id, long timestamp) {
+  public void enrollAttempt(final TezTaskAttemptID id, final long timestamp) {
     startTimes.put(id, timestamp);
   }
 
   @Override
-  public long attemptEnrolledTime(TezTaskAttemptID attemptID) {
+  public long attemptEnrolledTime(final TezTaskAttemptID attemptID) {
     Long result = startTimes.get(attemptID);
 
     return result == null ? Long.MAX_VALUE : result;
   }
 
   @Override
-  public void contextualize(Configuration conf, Vertex vertex) {
-    slowTaskRelativeTresholds = conf.getFloat(
+  public void contextualize(final Configuration conf, final Vertex vertexP) {
+    slowTaskRelativeThresholds = conf.getFloat(
         TezConfiguration.TEZ_AM_LEGACY_SPECULATIVE_SLOWTASK_THRESHOLD, 1.0f);
-    this.vertex = vertex;
+    this.vertex = vertexP;
   }
 
-  protected DataStatistics dataStatisticsForTask(TezTaskID taskID) {
+  protected DataStatistics dataStatisticsForTask(final TezTaskID taskID) {
     return taskStatistics;
   }
 
   @Override
-  public long thresholdRuntime(TezTaskID taskID) {
+  public long thresholdRuntime(final TezTaskID taskID) {
     int completedTasks = vertex.getCompletedTasks();
 
     int totalTasks = vertex.getTotalTasks();
-    
+
     if (completedTasks < MINIMUM_COMPLETE_NUMBER_TO_SPECULATE
-        || (((float)completedTasks) / totalTasks)
-              < MINIMUM_COMPLETE_PROPORTION_TO_SPECULATE ) {
+        || (((float) completedTasks) / totalTasks)
+        < MINIMUM_COMPLETE_PROPORTION_TO_SPECULATE) {
       return Long.MAX_VALUE;
     }
-    
-    long result = (long)taskStatistics.outlier(slowTaskRelativeTresholds);
+
+    long result = (long) taskStatistics.outlier(slowTaskRelativeThresholds);
     return result;
   }
 
   @Override
   public long newAttemptEstimatedRuntime() {
-    return (long)taskStatistics.mean();
+    return (long) taskStatistics.mean();
   }
 
   @Override
-  public void updateAttempt(TezTaskAttemptID attemptID, TaskAttemptState state, long timestamp) {
+  public void updateAttempt(final TezTaskAttemptID attemptID,
+      final TaskAttemptState state,
+      final long timestamp) {
 
     Task task = vertex.getTask(attemptID.getTaskID());
 
@@ -109,7 +109,7 @@ abstract class StartEndTimesBase implements TaskRuntimeEstimator {
 
     Long boxedStart = startTimes.get(attemptID);
     long start = boxedStart == null ? Long.MIN_VALUE : boxedStart;
-    
+
     TaskAttempt taskAttempt = task.getAttempt(attemptID);
 
     if (taskAttempt.getState() == TaskAttemptState.SUCCEEDED) {
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/speculation/legacy/TaskRuntimeEstimator.java b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/speculation/legacy/TaskRuntimeEstimator.java
index c8edd1e..4f747af 100644
--- a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/speculation/legacy/TaskRuntimeEstimator.java
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/speculation/legacy/TaskRuntimeEstimator.java
@@ -29,13 +29,14 @@ import org.apache.tez.dag.records.TezTaskID;
  * 
  */
 public interface TaskRuntimeEstimator {
-  public void enrollAttempt(TezTaskAttemptID id, long timestamp);
+  void enrollAttempt(TezTaskAttemptID id, long timestamp);
 
-  public long attemptEnrolledTime(TezTaskAttemptID attemptID);
+  long attemptEnrolledTime(TezTaskAttemptID attemptID);
 
-  public void updateAttempt(TezTaskAttemptID taId, TaskAttemptState reportedState, long timestamp);
+  void updateAttempt(TezTaskAttemptID taId,
+      TaskAttemptState reportedState, long timestamp);
 
-  public void contextualize(Configuration conf, Vertex vertex);
+  void contextualize(Configuration conf, Vertex vertex);
 
   /**
    *
@@ -52,7 +53,7 @@ public interface TaskRuntimeEstimator {
    *         however long.
    *
    */
-  public long thresholdRuntime(TezTaskID id);
+  long thresholdRuntime(TezTaskID id);
 
   /**
    *
@@ -64,7 +65,7 @@ public interface TaskRuntimeEstimator {
    *         we don't have enough information yet to produce an estimate.
    *
    */
-  public long estimatedRuntime(TezTaskAttemptID id);
+  long estimatedRuntime(TezTaskAttemptID id);
 
   /**
    *
@@ -75,7 +76,7 @@ public interface TaskRuntimeEstimator {
    *         we don't have enough information yet to produce an estimate.
    *
    */
-  public long newAttemptEstimatedRuntime();
+  long newAttemptEstimatedRuntime();
 
   /**
    *
@@ -87,5 +88,20 @@ public interface TaskRuntimeEstimator {
    *         we don't have enough information yet to produce an estimate.
    *
    */
-  public long runtimeEstimateVariance(TezTaskAttemptID id);
+  long runtimeEstimateVariance(TezTaskAttemptID id);
+
+  /**
+   *
+   * Returns true if the estimator has no updates records for a threshold time
+   * window. This helps to identify task attempts that are stalled at the
+   * beginning of execution.
+   *
+   * @param id the {@link TezTaskAttemptID} of the attempt we are asking about
+   * @param timeStamp the time of the report we compare with
+   * @return true if the task attempt has no progress for a given time window
+   *
+   */
+  default boolean hasStagnatedProgress(TezTaskAttemptID id, long timeStamp) {
+    return false;
+  }
 }
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/speculation/legacy/forecast/SimpleExponentialSmoothing.java b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/speculation/legacy/forecast/SimpleExponentialSmoothing.java
new file mode 100644
index 0000000..e7b7dcd
--- /dev/null
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/speculation/legacy/forecast/SimpleExponentialSmoothing.java
@@ -0,0 +1,336 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.tez.dag.app.dag.speculation.legacy.forecast;
+
+import java.util.concurrent.atomic.AtomicReference;
+
+/**
+ * Implementation of the static model for Simple exponential smoothing.
+ */
+public class SimpleExponentialSmoothing {
+  private static final double DEFAULT_FORECAST = -1.0d;
+  private final int kMinimumReads;
+  private final long kStagnatedWindow;
+  private final long startTime;
+  private long timeConstant;
+
+  /**
+   * Holds reference to the current forecast record.
+   */
+  private AtomicReference<ForecastRecord> forecastRefEntry;
+
+  /**
+   * Create forecast simple exponential smoothing.
+   *
+   * @param timeConstant the time constant
+   * @param skipCnt the skip cnt
+   * @param stagnatedWindow the stagnated window
+   * @param timeStamp the time stamp
+   * @return the simple exponential smoothing
+   */
+  public static SimpleExponentialSmoothing createForecast(
+      final long timeConstant,
+      final int skipCnt, final long stagnatedWindow, final long timeStamp) {
+    return new SimpleExponentialSmoothing(timeConstant, skipCnt,
+        stagnatedWindow, timeStamp);
+  }
+
+  /**
+   * Instantiates a new Simple exponential smoothing.
+   *
+   * @param ktConstant the kt constant
+   * @param skipCnt the skip cnt
+   * @param stagnatedWindow the stagnated window
+   * @param timeStamp the time stamp
+   */
+  SimpleExponentialSmoothing(final long ktConstant, final int skipCnt,
+      final long stagnatedWindow, final long timeStamp) {
+    this.kMinimumReads = skipCnt;
+    this.kStagnatedWindow = stagnatedWindow;
+    this.timeConstant = ktConstant;
+    this.startTime = timeStamp;
+    this.forecastRefEntry = new AtomicReference<ForecastRecord>(null);
+  }
+
+  private class ForecastRecord {
+    private final double alpha;
+    private final long timeStamp;
+    private final double sample;
+    private final double rawData;
+    private double forecast;
+    private final double sseError;
+    private final long myIndex;
+    private ForecastRecord prevRec;
+
+    /**
+     * Instantiates a new Forecast record.
+     *
+     * @param currForecast the curr forecast
+     * @param currRawData the curr raw data
+     * @param currTimeStamp the curr time stamp
+     */
+    ForecastRecord(final double currForecast, final double currRawData,
+        final long currTimeStamp) {
+      this(0.0, currForecast, currRawData, currForecast, currTimeStamp, 0.0, 0);
+    }
+
+    /**
+     * Instantiates a new Forecast record.
+     *
+     * @param alphaVal the alpha val
+     * @param currSample the curr sample
+     * @param currRawData the curr raw data
+     * @param currForecast the curr forecast
+     * @param currTimeStamp the curr time stamp
+     * @param accError the acc error
+     * @param index the index
+     */
+    ForecastRecord(final double alphaVal, final double currSample,
+        final double currRawData,
+        final double currForecast, final long currTimeStamp,
+        final double accError,
+        final long index) {
+      this.timeStamp = currTimeStamp;
+      this.alpha = alphaVal;
+      this.sample = currSample;
+      this.forecast = currForecast;
+      this.rawData = currRawData;
+      this.sseError = accError;
+      this.myIndex = index;
+    }
+
+    private ForecastRecord createForecastRecord(final double alphaVal,
+        final double currSample,
+        final double currRawData,
+        final double currForecast, final long currTimeStamp,
+        final double accError,
+        final long index,
+        final ForecastRecord prev) {
+      ForecastRecord forecastRec =
+          new ForecastRecord(alphaVal, currSample, currRawData, currForecast,
+              currTimeStamp, accError, index);
+      forecastRec.prevRec = prev;
+      return forecastRec;
+    }
+
+    private double preProcessRawData(final double rData, final long newTime) {
+      return processRawData(this.rawData, this.timeStamp, rData, newTime);
+    }
+
+    /**
+     * Append forecast record.
+     *
+     * @param newTimeStamp the new time stamp
+     * @param rData the r data
+     * @return the forecast record
+     */
+    public ForecastRecord append(final long newTimeStamp, final double rData) {
+      if (this.timeStamp >= newTimeStamp
+          && Double.compare(this.rawData, rData) >= 0) {
+        // progress reported twice. Do nothing.
+        return this;
+      }
+      ForecastRecord refRecord = this;
+      if (newTimeStamp == this.timeStamp) {
+        // we need to restore old value if possible
+        if (this.prevRec != null) {
+          refRecord = this.prevRec;
+        }
+      }
+      double newSample = refRecord.preProcessRawData(rData, newTimeStamp);
+      long deltaTime = this.timeStamp - newTimeStamp;
+      if (refRecord.myIndex == kMinimumReads) {
+        timeConstant = Math.max(timeConstant, newTimeStamp - startTime);
+      }
+      double smoothFactor =
+          1 - Math.exp(((double) deltaTime) / timeConstant);
+      double forecastVal =
+          smoothFactor * newSample + (1.0 - smoothFactor) * refRecord.forecast;
+      double newSSEError =
+          refRecord.sseError + Math.pow(newSample - refRecord.forecast, 2);
+      return refRecord
+          .createForecastRecord(smoothFactor, newSample, rData, forecastVal,
+              newTimeStamp, newSSEError, refRecord.myIndex + 1, refRecord);
+    }
+  }
+
+  /**
+   * checks if the task is hanging up.
+   *
+   * @param timeStamp current time of the scan.
+   * @return true if we have number of samples > kMinimumReads and the record
+   * timestamp has expired.
+   */
+  public boolean isDataStagnated(final long timeStamp) {
+    ForecastRecord rec = forecastRefEntry.get();
+    if (rec != null && rec.myIndex > kMinimumReads) {
+      return (rec.timeStamp + kStagnatedWindow) > timeStamp;
+    }
+    return false;
+  }
+
+  /**
+   * Process raw data double.
+   *
+   * @param oldRawData the old raw data
+   * @param oldTime the old time
+   * @param newRawData the new raw data
+   * @param newTime the new time
+   * @return the double
+   */
+  static double processRawData(final double oldRawData, final long oldTime,
+      final double newRawData, final long newTime) {
+    double rate = (newRawData - oldRawData) / (newTime - oldTime);
+    return rate;
+  }
+
+  /**
+   * Incorporate reading.
+   *
+   * @param timeStamp the time stamp
+   * @param currRawData the curr raw data
+   */
+  public void incorporateReading(final long timeStamp,
+      final double currRawData) {
+    ForecastRecord oldRec = forecastRefEntry.get();
+    if (oldRec == null) {
+      double oldForecast =
+          processRawData(0, startTime, currRawData, timeStamp);
+      forecastRefEntry.compareAndSet(null,
+          new ForecastRecord(oldForecast, 0.0d, startTime));
+      incorporateReading(timeStamp, currRawData);
+      return;
+    }
+    while (!forecastRefEntry.compareAndSet(oldRec, oldRec.append(timeStamp,
+        currRawData))) {
+      oldRec = forecastRefEntry.get();
+    }
+  }
+
+  /**
+   * Gets forecast.
+   *
+   * @return the forecast
+   */
+  public double getForecast() {
+    ForecastRecord rec = forecastRefEntry.get();
+    if (rec != null && rec.myIndex > kMinimumReads) {
+      return rec.forecast;
+    }
+    return DEFAULT_FORECAST;
+  }
+
+  /**
+   * Is default forecast boolean.
+   *
+   * @param value the value
+   * @return the boolean
+   */
+  public boolean isDefaultForecast(final double value) {
+    return value == DEFAULT_FORECAST;
+  }
+
+  /**
+   * Gets sse.
+   *
+   * @return the sse
+   */
+  public double getSSE() {
+    ForecastRecord rec = forecastRefEntry.get();
+    if (rec != null) {
+      return rec.sseError;
+    }
+    return DEFAULT_FORECAST;
+  }
+
+  /**
+   * Is error within bound boolean.
+   *
+   * @param bound the bound
+   * @return the boolean
+   */
+  public boolean isErrorWithinBound(final double bound) {
+    double squaredErr = getSSE();
+    if (squaredErr < 0) {
+      return false;
+    }
+    return bound > squaredErr;
+  }
+
+  /**
+   * Gets raw data.
+   *
+   * @return the raw data
+   */
+  public double getRawData() {
+    ForecastRecord rec = forecastRefEntry.get();
+    if (rec != null) {
+      return rec.rawData;
+    }
+    return DEFAULT_FORECAST;
+  }
+
+  /**
+   * Gets time stamp.
+   *
+   * @return the time stamp
+   */
+  public long getTimeStamp() {
+    ForecastRecord rec = forecastRefEntry.get();
+    if (rec != null) {
+      return rec.timeStamp;
+    }
+    return 0L;
+  }
+
+  /**
+   * Gets start time.
+   *
+   * @return the start time
+   */
+  public long getStartTime() {
+    return startTime;
+  }
+
+  /**
+   * Gets forecast ref entry.
+   *
+   * @return the forecast ref entry
+   */
+  public AtomicReference<ForecastRecord> getForecastRefEntry() {
+    return forecastRefEntry;
+  }
+
+  @Override
+  public String toString() {
+    String res = "NULL";
+    ForecastRecord rec = forecastRefEntry.get();
+    if (rec != null) {
+      StringBuilder strB = new StringBuilder("rec.index = ").append(rec.myIndex)
+          .append(", timeStamp t: ").append(rec.timeStamp)
+          .append(", forecast: ").append(rec.forecast).append(", sample: ")
+          .append(rec.sample).append(", raw: ").append(rec.rawData)
+          .append(", error: ").append(rec.sseError).append(", alpha: ")
+          .append(rec.alpha);
+      res = strB.toString();
+    }
+    return res;
+  }
+}
+
diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/speculation/legacy/forecast/package-info.java b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/speculation/legacy/forecast/package-info.java
new file mode 100644
index 0000000..3ed8b6a
--- /dev/null
+++ b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/speculation/legacy/forecast/package-info.java
@@ -0,0 +1,20 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+@InterfaceAudience.Private
+package org.apache.tez.dag.app.dag.speculation.legacy.forecast;
+import org.apache.hadoop.classification.InterfaceAudience;
diff --git a/tez-dag/src/test/java/org/apache/tez/dag/app/TestSpeculation.java b/tez-dag/src/test/java/org/apache/tez/dag/app/TestSpeculation.java
index a81d4d3..b9a7c5a 100644
--- a/tez-dag/src/test/java/org/apache/tez/dag/app/TestSpeculation.java
+++ b/tez-dag/src/test/java/org/apache/tez/dag/app/TestSpeculation.java
@@ -19,10 +19,15 @@
 package org.apache.tez.dag.app;
 
 import java.io.IOException;
+import java.lang.annotation.Retention;
+import java.lang.annotation.RetentionPolicy;
+import java.util.Arrays;
+import java.util.Collection;
 import java.util.HashMap;
 import java.util.Map;
 import java.util.concurrent.atomic.AtomicBoolean;
 
+import java.util.concurrent.atomic.AtomicInteger;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.FileSystem;
 import org.apache.hadoop.fs.Path;
@@ -45,39 +50,196 @@ import org.apache.tez.dag.app.dag.Task;
 import org.apache.tez.dag.app.dag.TaskAttempt;
 import org.apache.tez.dag.app.dag.impl.DAGImpl;
 import org.apache.tez.dag.app.dag.speculation.legacy.LegacySpeculator;
+import org.apache.tez.dag.app.dag.speculation.legacy.LegacyTaskRuntimeEstimator;
+import org.apache.tez.dag.app.dag.speculation.legacy.SimpleExponentialTaskRuntimeEstimator;
+import org.apache.tez.dag.app.dag.speculation.legacy.TaskRuntimeEstimator;
 import org.apache.tez.dag.library.vertexmanager.ShuffleVertexManager;
 import org.apache.tez.dag.records.TaskAttemptTerminationCause;
 import org.apache.tez.dag.records.TezTaskAttemptID;
 import org.apache.tez.dag.records.TezTaskID;
 import org.apache.tez.dag.records.TezVertexID;
+import org.junit.After;
 import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
 import org.junit.Test;
 
 import com.google.common.base.Joiner;
+import org.junit.rules.TestRule;
+import org.junit.runner.Description;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.model.Statement;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
+/**
+ * test speculation behavior given the list of estimator classes.
+ */
+@RunWith(Parameterized.class)
 public class TestSpeculation {
-  static Configuration defaultConf;
-  static FileSystem localFs;
-  
+  private final static Logger LOG = LoggerFactory.getLogger(TezConfiguration.class);
+
+  private static final String ASSERT_SPECULATIONS_COUNT_MSG =
+      "Number of attempts after Speculation should be two";
+  private static final String UNIT_EXCEPTION_MESSAGE =
+      "test timed out after";
+  private static final int ASSERT_SPECULATIONS_COUNT_RETRIES = 3;
+  private Configuration defaultConf;
+  private FileSystem localFs;
+
+  /**
+   * The Mock app.
+   */
   MockDAGAppMaster mockApp;
+
+  /**
+   * The Mock launcher.
+   */
   MockContainerLauncher mockLauncher;
-  
-  static {
+
+  /**
+   * The interface Retry.
+   */
+  @Retention(RetentionPolicy.RUNTIME)
+  public @interface Retry {}
+
+  /**
+   * The type Retry rule.
+   */
+  class RetryRule implements TestRule {
+
+    private AtomicInteger retryCount;
+
+    /**
+     * Instantiates a new Retry rule.
+     *
+     * @param retries the retries
+     */
+    RetryRule(int retries) {
+      super();
+      this.retryCount = new AtomicInteger(retries);
+    }
+
+    @Override
+    public Statement apply(final Statement base,
+        final Description description) {
+      return new Statement() {
+        @Override
+        public void evaluate() throws Throwable {
+          Throwable caughtThrowable = null;
+
+          while (retryCount.getAndDecrement() > 0) {
+            try {
+              base.evaluate();
+              return;
+            } catch (Throwable t) {
+              caughtThrowable = t;
+              if (retryCount.get() > 0 &&
+                  description.getAnnotation(Retry.class) != null) {
+                if (!((t instanceof AssertionError && t.getMessage()
+                    .contains(ASSERT_SPECULATIONS_COUNT_MSG))
+                    || (t instanceof Exception && t.getMessage()
+                    .contains(UNIT_EXCEPTION_MESSAGE)))) {
+                  throw caughtThrowable;
+                }
+                LOG.warn("{} : Failed. Retries remaining: ",
+                    description.getDisplayName(),
+                    retryCount.toString());
+              } else {
+                throw caughtThrowable;
+              }
+            }
+          }
+        }
+      };
+    }
+  }
+
+  /**
+   * The Rule.
+   */
+  @Rule
+  public RetryRule rule = new RetryRule(ASSERT_SPECULATIONS_COUNT_RETRIES);
+
+  /**
+   * Sets default conf.
+   */
+  @Before
+  public void setDefaultConf() {
     try {
       defaultConf = new Configuration(false);
       defaultConf.set("fs.defaultFS", "file:///");
       defaultConf.setBoolean(TezConfiguration.TEZ_LOCAL_MODE, true);
       defaultConf.setBoolean(TezConfiguration.TEZ_AM_SPECULATION_ENABLED, true);
-      defaultConf.setFloat(ShuffleVertexManager.TEZ_SHUFFLE_VERTEX_MANAGER_MIN_SRC_FRACTION, 1);
-      defaultConf.setFloat(ShuffleVertexManager.TEZ_SHUFFLE_VERTEX_MANAGER_MAX_SRC_FRACTION, 1);
+      defaultConf.setFloat(
+          ShuffleVertexManager.TEZ_SHUFFLE_VERTEX_MANAGER_MIN_SRC_FRACTION, 1);
+      defaultConf.setFloat(
+          ShuffleVertexManager.TEZ_SHUFFLE_VERTEX_MANAGER_MAX_SRC_FRACTION, 1);
       localFs = FileSystem.getLocal(defaultConf);
-      String stagingDir = "target" + Path.SEPARATOR + TestSpeculation.class.getName() + "-tmpDir";
+      String stagingDir =
+          "target" + Path.SEPARATOR + TestSpeculation.class.getName()
+              + "-tmpDir";
       defaultConf.set(TezConfiguration.TEZ_AM_STAGING_DIR, stagingDir);
+      defaultConf.setClass(TezConfiguration.TEZ_AM_TASK_ESTIMATOR_CLASS,
+          estimatorClass,
+          TaskRuntimeEstimator.class);
+      defaultConf.setInt(TezConfiguration.TEZ_AM_MINIMUM_ALLOWED_SPECULATIVE_TASKS, 20);
+      defaultConf.setDouble(TezConfiguration.TEZ_AM_PROPORTION_TOTAL_TASKS_SPECULATABLE, 0.2);
+      defaultConf.setDouble(TezConfiguration.TEZ_AM_PROPORTION_RUNNING_TASKS_SPECULATABLE, 0.25);
+      defaultConf.setLong(TezConfiguration.TEZ_AM_SOONEST_RETRY_AFTER_NO_SPECULATE, 25);
+      defaultConf.setLong(TezConfiguration.TEZ_AM_SOONEST_RETRY_AFTER_SPECULATE, 50);
+      defaultConf.setInt(TezConfiguration.TEZ_AM_ESTIMATOR_EXPONENTIAL_SKIP_INITIALS, 2);
     } catch (IOException e) {
       throw new RuntimeException("init failure", e);
     }
   }
 
+  /**
+   * Tear down.
+   */
+  @After
+  public void tearDown() {
+    defaultConf = null;
+    try {
+      localFs.close();
+      mockLauncher.shutdown();
+      mockApp.close();
+    } catch (Exception e) {
+      e.printStackTrace();
+    }
+  }
+
+  /**
+   * Gets test parameters.
+   *
+   * @return the test parameters
+   */
+  @Parameterized.Parameters(name = "{index}: TaskEstimator(EstimatorClass {0})")
+  public static Collection<Object[]> getTestParameters() {
+    return Arrays.asList(new Object[][]{
+        {SimpleExponentialTaskRuntimeEstimator.class},
+        {LegacyTaskRuntimeEstimator.class}
+    });
+  }
+
+  private Class<? extends TaskRuntimeEstimator> estimatorClass;
+
+  /**
+   * Instantiates a new Test speculation.
+   *
+   * @param estimatorKlass the estimator klass
+   */
+  public TestSpeculation(Class<? extends TaskRuntimeEstimator>  estimatorKlass) {
+    this.estimatorClass = estimatorKlass;
+  }
+
+  /**
+   * Create tez session mock tez client.
+   *
+   * @return the mock tez client
+   * @throws Exception the exception
+   */
   MockTezClient createTezSession() throws Exception {
     TezConfiguration tezconf = new TezConfiguration(defaultConf);
     AtomicBoolean mockAppLauncherGoFlag = new AtomicBoolean(false);
@@ -87,8 +249,16 @@ public class TestSpeculation {
     syncWithMockAppLauncher(false, mockAppLauncherGoFlag, tezClient);
     return tezClient;
   }
-  
-  void syncWithMockAppLauncher(boolean allowScheduling, AtomicBoolean mockAppLauncherGoFlag, 
+
+  /**
+   * Sync with mock app launcher.
+   *
+   * @param allowScheduling the allow scheduling
+   * @param mockAppLauncherGoFlag the mock app launcher go flag
+   * @param tezClient the tez client
+   * @throws Exception the exception
+   */
+  void syncWithMockAppLauncher(boolean allowScheduling, AtomicBoolean mockAppLauncherGoFlag,
       MockTezClient tezClient) throws Exception {
     synchronized (mockAppLauncherGoFlag) {
       while (!mockAppLauncherGoFlag.get()) {
@@ -101,6 +271,12 @@ public class TestSpeculation {
     }     
   }
 
+  /**
+   * Test single task speculation.
+   *
+   * @throws Exception the exception
+   */
+  @Retry
   @Test (timeout = 10000)
   public void testSingleTaskSpeculation() throws Exception {
     // Map<Timeout conf value, expected number of tasks>
@@ -126,9 +302,10 @@ public class TestSpeculation {
       DAGImpl dagImpl = (DAGImpl) mockApp.getContext().getCurrentDAG();
       TezVertexID vertexId = TezVertexID.getInstance(dagImpl.getID(), 0);
       // original attempt is killed and speculative one is successful
-      TezTaskAttemptID killedTaId = TezTaskAttemptID.getInstance(TezTaskID.getInstance(vertexId, 0), 0);
-      TezTaskAttemptID successTaId = TezTaskAttemptID.getInstance(TezTaskID.getInstance(vertexId, 0), 1);
-
+      TezTaskAttemptID killedTaId =
+          TezTaskAttemptID.getInstance(TezTaskID.getInstance(vertexId, 0), 0);
+      TezTaskAttemptID successTaId =
+          TezTaskAttemptID.getInstance(TezTaskID.getInstance(vertexId, 0), 1);
       Thread.sleep(200);
       // cause speculation trigger
       mockLauncher.setStatusUpdatesForTask(killedTaId, 100);
@@ -149,16 +326,16 @@ public class TestSpeculation {
     }
   }
 
+  /**
+   * Test basic speculation.
+   *
+   * @param withProgress the with progress
+   * @throws Exception the exception
+   */
   public void testBasicSpeculation(boolean withProgress) throws Exception {
-
-    defaultConf.setInt(TezConfiguration.TEZ_AM_MINIMUM_ALLOWED_SPECULATIVE_TASKS, 20);
-    defaultConf.setDouble(TezConfiguration.TEZ_AM_PROPORTION_TOTAL_TASKS_SPECULATABLE, 0.2);
-    defaultConf.setDouble(TezConfiguration.TEZ_AM_PROPORTION_RUNNING_TASKS_SPECULATABLE, 0.25);
-    defaultConf.setLong(TezConfiguration.TEZ_AM_SOONEST_RETRY_AFTER_NO_SPECULATE, 25);
-    defaultConf.setLong(TezConfiguration.TEZ_AM_SOONEST_RETRY_AFTER_SPECULATE, 50);
-
     DAG dag = DAG.create("test");
-    Vertex vA = Vertex.create("A", ProcessorDescriptor.create("Proc.class"), 5);
+    Vertex vA = Vertex.create("A",
+        ProcessorDescriptor.create("Proc.class"), 5);
     dag.addVertex(vA);
 
     MockTezClient tezClient = createTezSession();
@@ -166,8 +343,10 @@ public class TestSpeculation {
     DAGImpl dagImpl = (DAGImpl) mockApp.getContext().getCurrentDAG();
     TezVertexID vertexId = TezVertexID.getInstance(dagImpl.getID(), 0);
     // original attempt is killed and speculative one is successful
-    TezTaskAttemptID killedTaId = TezTaskAttemptID.getInstance(TezTaskID.getInstance(vertexId, 0), 0);
-    TezTaskAttemptID successTaId = TezTaskAttemptID.getInstance(TezTaskID.getInstance(vertexId, 0), 1);
+    TezTaskAttemptID killedTaId =
+        TezTaskAttemptID.getInstance(TezTaskID.getInstance(vertexId, 0), 0);
+    TezTaskAttemptID successTaId =
+        TezTaskAttemptID.getInstance(TezTaskID.getInstance(vertexId, 0), 1);
 
     mockLauncher.updateProgress(withProgress);
     // cause speculation trigger
@@ -175,9 +354,11 @@ public class TestSpeculation {
 
     mockLauncher.startScheduling(true);
     dagClient.waitForCompletion();
-    Assert.assertEquals(DAGStatus.State.SUCCEEDED, dagClient.getDAGStatus(null).getState());
+    Assert.assertEquals(DAGStatus.State.SUCCEEDED,
+        dagClient.getDAGStatus(null).getState());
     Task task = dagImpl.getTask(killedTaId.getTaskID());
-    Assert.assertEquals(2, task.getAttempts().size());
+    Assert.assertEquals(ASSERT_SPECULATIONS_COUNT_MSG, 2,
+        task.getAttempts().size());
     Assert.assertEquals(successTaId, task.getSuccessfulAttempt().getID());
     TaskAttempt killedAttempt = task.getAttempt(killedTaId);
     Joiner.on(",").join(killedAttempt.getDiagnostics()).contains("Killed as speculative attempt");
@@ -204,18 +385,36 @@ public class TestSpeculation {
 
     tezClient.stop();
   }
-  
+
+  /**
+   * Test basic speculation with progress.
+   *
+   * @throws Exception the exception
+   */
+  @Retry
   @Test (timeout=10000)
   public void testBasicSpeculationWithProgress() throws Exception {
     testBasicSpeculation(true);
   }
 
+  /**
+   * Test basic speculation without progress.
+   *
+   * @throws Exception the exception
+   */
+  @Retry
   @Test (timeout=10000)
   public void testBasicSpeculationWithoutProgress() throws Exception {
     testBasicSpeculation(false);
   }
 
-  @Test (timeout=100000)
+  /**
+   * Test basic speculation per vertex conf.
+   *
+   * @throws Exception the exception
+   */
+  @Retry
+  @Test (timeout=10000)
   public void testBasicSpeculationPerVertexConf() throws Exception {
     DAG dag = DAG.create("test");
     String vNameNoSpec = "A";
@@ -224,8 +423,6 @@ public class TestSpeculation {
     Vertex vA = Vertex.create(vNameNoSpec, ProcessorDescriptor.create("Proc.class"), 5);
     Vertex vB = Vertex.create(vNameSpec, ProcessorDescriptor.create("Proc.class"), 5);
     vA.setConf(TezConfiguration.TEZ_AM_SPECULATION_ENABLED, "false");
-    vB.setConf(TezConfiguration.TEZ_AM_SOONEST_RETRY_AFTER_NO_SPECULATE,
-        speculatorSleepTime);
     dag.addVertex(vA);
     dag.addVertex(vB);
     // min/max src fraction is set to 1. So vertices will run sequentially
@@ -273,6 +470,12 @@ public class TestSpeculation {
     tezClient.stop();
   }
 
+  /**
+   * Test basic speculation not useful.
+   *
+   * @throws Exception the exception
+   */
+  @Retry
   @Test (timeout=10000)
   public void testBasicSpeculationNotUseful() throws Exception {
     DAG dag = DAG.create("test");
@@ -310,5 +513,4 @@ public class TestSpeculation {
         .getValue());
     tezClient.stop();
   }
-
 }