You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@commons.apache.org by ah...@apache.org on 2019/10/14 16:28:12 UTC

[commons-rng] 03/03: Estimate time remaining when submitting tasks.

This is an automated email from the ASF dual-hosted git repository.

aherbert pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/commons-rng.git

commit d0d237d9442c8af9f9ef267badecd3deed4b33df
Author: aherbert <ah...@apache.org>
AuthorDate: Mon Oct 14 17:28:05 2019 +0100

    Estimate time remaining when submitting tasks.
---
 .../rng/examples/stress/StressTestCommand.java     | 218 ++++++++++++---------
 1 file changed, 121 insertions(+), 97 deletions(-)

diff --git a/commons-rng-examples/examples-stress/src/main/java/org/apache/commons/rng/examples/stress/StressTestCommand.java b/commons-rng-examples/examples-stress/src/main/java/org/apache/commons/rng/examples/stress/StressTestCommand.java
index 300b605..8780094 100644
--- a/commons-rng-examples/examples-stress/src/main/java/org/apache/commons/rng/examples/stress/StressTestCommand.java
+++ b/commons-rng-examples/examples-stress/src/main/java/org/apache/commons/rng/examples/stress/StressTestCommand.java
@@ -37,6 +37,7 @@ import java.time.Instant;
 import java.time.LocalDateTime;
 import java.time.ZoneId;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Date;
 import java.util.Formatter;
 import java.util.List;
@@ -342,7 +343,7 @@ class StressTestCommand implements Callable<Void> {
         final String basePath = fileOutputPrefix.getAbsolutePath();
         checkExistingOutputFiles(basePath, stressTestData);
 
-        final ProgressTracker progressTracker = new ProgressTracker(countTrials(stressTestData), taskCount);
+        final ProgressTracker progressTracker = new ProgressTracker(taskCount);
         final List<Runnable> tasks = createTasks(command, basePath, stressTestData, progressTracker);
 
         // Run tasks with parallel execution.
@@ -350,7 +351,7 @@ class StressTestCommand implements Callable<Void> {
 
         LogUtils.info("Running stress test ...");
         LogUtils.info("Shutdown by creating stop file: %s",  stopFile);
-        progressTracker.start();
+        progressTracker.setTotal(tasks.size());
         final List<Future<?>> taskList = submitTasks(service, tasks);
 
         // Wait for completion (ignoring return value).
@@ -451,7 +452,6 @@ class StressTestCommand implements Callable<Void> {
                     // Log the decision
                     LogUtils.info("%s existing output file: %s", outputMode, output);
                     if (outputMode == StressTestCommand.OutputMode.SKIP) {
-                        progressTracker.incrementProgress(0);
                         continue;
                     }
                 }
@@ -498,20 +498,6 @@ class StressTestCommand implements Callable<Void> {
     }
 
     /**
-     * Count the total number of trials.
-     *
-     * @param stressTestData List of generators to be tested.
-     * @return the count
-     */
-    private static int countTrials(Iterable<StressTestData> stressTestData) {
-        int count = 0;
-        for (final StressTestData testData : stressTestData) {
-            count += Math.max(0, testData.getTrials());
-        }
-        return count;
-    }
-
-    /**
      * Submit the tasks to the executor service.
      *
      * @param service The executor service.
@@ -526,80 +512,81 @@ class StressTestCommand implements Callable<Void> {
     }
 
     /**
-     * Class for reporting total progress to the console.
+     * Class for reporting total progress of tasks to the console.
+     *
+     * <p>This stores the start and end time of tasks to allow it to estimate the time remaining
+     * for all the tests.
      */
     static class ProgressTracker {
-        /** The reporting interval. */
-        private static final long REPORT_INTERVAL = 100;
         /** The total. */
-        private final int total;
-        /** The count. */
-        private int count;
-        /** The timestamp of the last progress report. */
-        private long timestamp;
+        private int total;
         /** The level of parallelisation. */
         private final int parallelTasks;
-        /** The total time of all completed tasks (in milliseconds). */
-        private long totalTime;
-        /** The number of tasks completed with a time (i.e. were not skipped). */
+        /** The task id. */
+        private int taskId;
+        /** The start time of tasks (in milliseconds from the epoch). */
+        private long[] startTimes;
+        /** The durations of all completed tasks (in milliseconds). This is sorted. */
+        private long[] sortedDurations;
+        /** The number of completed tasks. */
         private int completed;
-        /** The estimated time of arrival (in milliseconds from the epoch). */
-        private long eta;
 
         /**
-         * Create a new instance.
+         * Create a new instance. The total number of tasks must be initialised before use.
          *
-         * @param total The total progress.
          * @param parallelTasks The number of parallel tasks.
          */
-        ProgressTracker(int total, int parallelTasks) {
-            this.total = total;
+        ProgressTracker(int parallelTasks) {
             this.parallelTasks = parallelTasks;
         }
 
         /**
-         * Start the tracker. This will show progress as 0% complete.
+         * Sets the total number of tasks to track.
+         *
+         * @param total The total tasks.
          */
-        void start() {
-            showProgress();
+        void setTotal(int total) {
+            this.total = total;
+            startTimes = new long[total];
+            sortedDurations = new long[total];
         }
 
         /**
-         * Signal that a task has completed in a specified time.
+         * Submit a task for progress tracking. The task start time is recorded and the
+         * task is allocated an identifier.
          *
-         * @param taskTime The time for the task (milliseconds).
+         * @return the task Id
          */
-        void incrementProgress(long taskTime) {
+        int submitTask() {
+            int id;
             synchronized (this) {
-                count++;
-                // Used to compute the average task time
-                if (taskTime != 0) {
-                    totalTime += taskTime;
-                    completed++;
+                final long current = System.currentTimeMillis();
+                id = taskId++;
+                startTimes[id] = current;
+                final StringBuilder sb = createStringBuilderWithTimestamp(current);
+                try (Formatter formatter = new Formatter(sb)) {
+                    formatter.format(" (%.2f%%)", 100.0 * taskId / total);
+                    appendRemaining(sb);
+                    LogUtils.info(sb.toString());
                 }
-                showProgress();
             }
+            return id;
         }
 
         /**
-         * Show the progress. This will occur incrementally based on the current time
-         * or if the progress is complete.
+         * Signal that a task has completed. The task duration will be returned.
+         *
+         * @param id Task Id.
+         * @return the task time in milliseconds
          */
-        private void showProgress() {
+        long endTask(int id) {
             final long current = System.currentTimeMillis();
-            // Edge case. This handles 0 / 0 as 100%.
-            if (count >= total) {
-                final StringBuilder sb = createStringBuilderWithTimestamp(current);
-                LogUtils.info(sb.append(" (100%)").toString());
-            } else if (current - timestamp > REPORT_INTERVAL) {
-                timestamp = current;
-                final StringBuilder sb = createStringBuilderWithTimestamp(current);
-                try (Formatter formatter = new Formatter(sb)) {
-                    formatter.format(" (%.2f%%)", 100.0 * count / total);
-                    appendRemaining(sb);
-                    LogUtils.info(sb.toString());
-                }
+            final long duration = current - startTimes[id];
+            synchronized (this) {
+                sortedDurations[completed++] = duration;
+                Arrays.sort(sortedDurations, 0, completed);
             }
+            return duration;
         }
 
         /**
@@ -621,25 +608,20 @@ class StressTestCommand implements Callable<Void> {
             append00(sb, time.getHour()).append(':');
             append00(sb, time.getMinute()).append(':');
             append00(sb, time.getSecond());
-            return sb.append("] Progress ").append(count).append(" / ").append(total);
+            return sb.append("] Running ").append(taskId).append(" / ").append(total);
         }
 
         /**
-         * Compute an estimate of the time remaining and append to the progress. Updates the
-         * estimated time of arrival (ETA).
+         * Compute an estimate of the time remaining and append to the progress. Updates
+         * the estimated time of arrival (ETA).
          *
          * @param sb String Builder.
          * @return the string builder
          */
         private StringBuilder appendRemaining(StringBuilder sb) {
-            if (completed == 0) {
-                // No estimate possible.
-                return sb;
-            }
-
             final long millis = getRemainingTime();
             if (millis == 0) {
-                // This is an over-run of the ETA. Must be close to completion now.
+                // Unknown.
                 return sb;
             }
 
@@ -650,34 +632,76 @@ class StressTestCommand implements Callable<Void> {
         }
 
         /**
-         * Gets the remaining time (in milliseconds). Uses or updates the estimated time of
-         * arrival (ETA), depending on the estimation method.
+         * Gets the remaining time (in milliseconds).
          *
          * @return the remaining time
          */
         private long getRemainingTime() {
-            final int remainingTasks = total - count;
-
-            if (remainingTasks < parallelTasks) {
-                // No more tasks to submit so the last estimate was as good as we can make it.
-                // Return the difference between the ETA and the current timestamp.
-                return Math.max(0, eta - timestamp);
+            final long taskTime = getEstimatedTaskTime();
+            if (taskTime == 0) {
+                return 0;
             }
 
-            // Estimate time remaining using the average runtime per task
-            // multiplied by the number of parallel remaining tasks (rounded down).
-            // Parallel remaining is the number of batches required to execute the
-            // remaining tasks in parallel.
-            final long parallelRemaining = remainingTasks / parallelTasks;
-            final long millis = (totalTime * parallelRemaining) / completed;
-
-            // Update the ETA
-            eta = timestamp + millis;
+            // There is at least 1 task left.
+            // The remaining time is at least the length of the task estimate.
+            long millis = taskTime;
+
+            // If additional tasks must also be submitted then the time must include
+            // the estimated time for running tasks to finish before new submissions
+            // in the batch can be made.
+            //                   now
+            // s1 --------------->|
+            //      s2 -----------|-------->
+            //          s3 -------|------------>
+            //                    s4 -------------->
+            //
+
+            // Compute the number of additional tasks after this one to finish.
+            // E.g. 4 tasks left is 3 additional tasks.
+            final int additionalTasks = total - taskId;
+
+            // Assume parallel batch execution.
+            // E.g. 3 additional tasks with parallelisation 4 is 0 batches
+            int batches = additionalTasks / parallelTasks;
+            millis += batches * taskTime;
+
+            // Compute the expected end time of the final batch based on it starting when
+            // a currently running task ends.
+            // E.g. 3 remaining tasks requires the end time of the 3rd oldest running task.
+            int remainder = additionalTasks % parallelTasks;
+            if (remainder != 0) {
+                // The start times are sorted. This assumes the most recent start times are
+                // still running tasks.
+                // If this is wrong (more recently submitted tasks finished early) the result
+                // is the estimate is too high. This could be corrected by storing the tasks
+                // that have finished and finding the time corresponding the N'th oldest
+                // task that is still running.
+                final int id = taskId - 1;
+                // This should not index-out-of-bounds unless a task ends before the first
+                // set of parallel tasks has been submitted, i.e. during a dry-run.
+                // Guard with a minimum index of zero to get a valid start time.
+                final int nthOldest = Math.max(0, id - parallelTasks + remainder);
+                final long endTime = startTimes[nthOldest] + taskTime;
+                // Note: The current time is the most recent entry in the startTimes array.
+                millis += endTime - startTimes[id];
+            }
 
             return millis;
         }
 
         /**
+         * Gets the estimated task time.
+         *
+         * @return the estimated task time
+         */
+        private long getEstimatedTaskTime() {
+            // Use the median. This is less sensitive to outliers than the average.
+            // For example PractRand may fail very fast for bad generators and this
+            // will skew the average.
+            return sortedDurations[completed / 2];
+        }
+
+        /**
          * Append the milliseconds using {@code HH::mm:ss} format.
          *
          * @param sb String Builder.
@@ -785,27 +809,27 @@ class StressTestCommand implements Callable<Void> {
                 return;
             }
 
-            long nanoTime = 0;
             try {
                 printHeader();
 
                 Object exitValue;
+                long millis;
+                final int taskId = progressTracker.submitTask();
                 if (cmd.dryRun) {
-                    // Do not do anything
+                    // Do not do anything. Ignore the runtime.
                     exitValue = "N/A";
+                    progressTracker.endTask(taskId);
+                    millis = 0;
                 } else {
                     // Run the sub-process
-                    final long startTime = System.nanoTime();
                     exitValue = runSubProcess();
-                    nanoTime = System.nanoTime() - startTime;
+                    millis = progressTracker.endTask(taskId);
                 }
 
-                printFooter(nanoTime, exitValue);
+                printFooter(millis, exitValue);
 
             } catch (final IOException ex) {
                 throw new ApplicationException("Failed to run task: " + ex.getMessage(), ex);
-            } finally {
-                progressTracker.incrementProgress(TimeUnit.NANOSECONDS.toMillis(nanoTime));
             }
         }
 
@@ -890,12 +914,12 @@ class StressTestCommand implements Callable<Void> {
         /**
          * Prints the footer.
          *
-         * @param nanoTime Duration of the run.
+         * @param millis Duration of the run (in milliseconds).
          * @param exitValue The process exit value.
          * @throws IOException if there was a problem opening or writing to the
          * {@code output} file.
          */
-        private void printFooter(long nanoTime,
+        private void printFooter(long millis,
                                  Object exitValue) throws IOException {
             final StringBuilder sb = new StringBuilder(200);
             sb.append(C).append(N);
@@ -908,7 +932,7 @@ class StressTestCommand implements Callable<Void> {
                           .append(" (").append(bytesToString(bytesUsed)).append(')').append(N)
                 .append(C).append(N);
 
-            final double duration = nanoTime * 1e-9 / 60;
+            final double duration = millis * 1e-3 / 60;
             sb.append(C).append("Test duration: ").append(duration).append(" minutes").append(N)
                 .append(C).append(N);