You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by mx...@apache.org on 2015/07/21 16:35:30 UTC

[2/2] flink git commit: [FLINK-2371] improve AccumulatorLiveITCase

[FLINK-2371] improve AccumulatorLiveITCase

Instead of using Thread.sleep() to synchronize the checks of the
accumulator values, we rely on message passing here to synchronize the
task process.

Therefore, we let the task process signal to the task manager that it
has updated its accumulator values. The task manager lets the job
manager know and sends out the heartbeat which contains the
accumulators. When the job manager receives the accumulators and has
been notified previously, it sends a message to the subscribed test case
with the current accumulators.

This assures that all processes are always synchronized correctly and we
can verify the live accumulator results correctly.

In the course of rewriting the test, I had to change two things in the
implementation:

a) User accumulators are now immediately serialized as well. Otherwise,
Akka does not serialize in local one VM setups and passes the live
accumulator map through.

b) The asynchronous update of the accumulators is disabled for
tests (via the dispatcher flag of the TestingCluster). This was
necessary because we cannot guarantee when the Future for updating the
accumulators is executed. In real setups this is neglectable.

This closes #925.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/0f589aad
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/0f589aad
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/0f589aad

Branch: refs/heads/master
Commit: 0f589aad899ca2ee230ff21655a38842928005f1
Parents: 0f45f2b
Author: Maximilian Michels <mx...@apache.org>
Authored: Mon Jul 20 11:55:11 2015 +0200
Committer: Maximilian Michels <mx...@apache.org>
Committed: Tue Jul 21 16:31:51 2015 +0200

----------------------------------------------------------------------
 .../accumulators/AccumulatorSnapshot.java       |  10 +-
 .../runtime/executiongraph/ExecutionGraph.java  |   5 +-
 .../io/network/api/writer/RecordWriter.java     |  17 --
 .../flink/runtime/jobmanager/JobManager.scala   |  28 +-
 .../flink/runtime/taskmanager/TaskManager.scala |   4 +-
 .../testingUtils/TestingJobManager.scala        |  53 +++-
 .../TestingJobManagerMessages.scala             |  16 +-
 .../testingUtils/TestingTaskManager.scala       |  17 +-
 .../TestingTaskManagerMessages.scala            |   9 +-
 .../accumulators/AccumulatorLiveITCase.java     | 274 ++++++++++++-------
 10 files changed, 277 insertions(+), 156 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/0f589aad/flink-runtime/src/main/java/org/apache/flink/runtime/accumulators/AccumulatorSnapshot.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/accumulators/AccumulatorSnapshot.java b/flink-runtime/src/main/java/org/apache/flink/runtime/accumulators/AccumulatorSnapshot.java
index a6288f0..0f1911d 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/accumulators/AccumulatorSnapshot.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/accumulators/AccumulatorSnapshot.java
@@ -40,9 +40,9 @@ public class AccumulatorSnapshot implements Serializable {
 	private final ExecutionAttemptID executionAttemptID;
 
 	/**
-	 * Flink internal accumulators which can be serialized using the system class loader.
+	 * Flink internal accumulators which can be deserialized using the system class loader.
 	 */
-	private final Map<AccumulatorRegistry.Metric, Accumulator<?, ?>> flinkAccumulators;
+	private final SerializedValue<Map<AccumulatorRegistry.Metric, Accumulator<?, ?>>> flinkAccumulators;
 
 	/**
 	 * Serialized user accumulators which may require the custom user class loader.
@@ -54,7 +54,7 @@ public class AccumulatorSnapshot implements Serializable {
 							Map<String, Accumulator<?, ?>> userAccumulators) throws IOException {
 		this.jobID = jobID;
 		this.executionAttemptID = executionAttemptID;
-		this.flinkAccumulators = flinkAccumulators;
+		this.flinkAccumulators = new SerializedValue<Map<AccumulatorRegistry.Metric, Accumulator<?, ?>>>(flinkAccumulators);
 		this.userAccumulators = new SerializedValue<Map<String, Accumulator<?, ?>>>(userAccumulators);
 	}
 
@@ -70,8 +70,8 @@ public class AccumulatorSnapshot implements Serializable {
 	 * Gets the Flink (internal) accumulators values.
 	 * @return the serialized map
 	 */
-	public Map<AccumulatorRegistry.Metric, Accumulator<?, ?>> getFlinkAccumulators() {
-		return flinkAccumulators;
+	public Map<AccumulatorRegistry.Metric, Accumulator<?, ?>> deserializeFlinkAccumulators() throws IOException, ClassNotFoundException {
+		return flinkAccumulators.deserializeValue(ClassLoader.getSystemClassLoader());
 	}
 
 	/**

http://git-wip-us.apache.org/repos/asf/flink/blob/0f589aad/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java
index 6d2262c..70dc89c 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java
@@ -141,9 +141,10 @@ public class ExecutionGraph implements Serializable {
 	 * @param accumulatorSnapshot The serialized flink and user-defined accumulators
 	 */
 	public void updateAccumulators(AccumulatorSnapshot accumulatorSnapshot) {
-		Map<AccumulatorRegistry.Metric, Accumulator<?, ?>> flinkAccumulators = accumulatorSnapshot.getFlinkAccumulators();
+		Map<AccumulatorRegistry.Metric, Accumulator<?, ?>> flinkAccumulators;
 		Map<String, Accumulator<?, ?>> userAccumulators;
 		try {
+			flinkAccumulators = accumulatorSnapshot.deserializeFlinkAccumulators();
 			userAccumulators = accumulatorSnapshot.deserializeUserAccumulators(userClassLoader);
 
 			ExecutionAttemptID execID = accumulatorSnapshot.getExecutionAttemptID();
@@ -889,7 +890,7 @@ public class ExecutionGraph implements Serializable {
 					Map<String, Accumulator<?, ?>> userAccumulators = null;
 					try {
 						AccumulatorSnapshot accumulators = state.getAccumulators();
-						flinkAccumulators = accumulators.getFlinkAccumulators();
+						flinkAccumulators = accumulators.deserializeFlinkAccumulators();
 						userAccumulators = accumulators.deserializeUserAccumulators(userClassLoader);
 					} catch (Exception e) {
 						// Exceptions would be thrown in the future here

http://git-wip-us.apache.org/repos/asf/flink/blob/0f589aad/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/RecordWriter.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/RecordWriter.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/RecordWriter.java
index 5bc705d..885c316 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/RecordWriter.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/RecordWriter.java
@@ -50,12 +50,6 @@ public class RecordWriter<T extends IOReadableWritable> {
 
 	private final int numChannels;
 
-	/**
-	 * Counter for the number of records emitted and for the number of bytes written.
-	 * @param counter
-	 */
-	private AccumulatorRegistry.Reporter reporter;
-
 	/** {@link RecordSerializer} per outgoing channel */
 	private final RecordSerializer<T>[] serializers;
 
@@ -88,7 +82,6 @@ public class RecordWriter<T extends IOReadableWritable> {
 
 			synchronized (serializer) {
 				SerializationResult result = serializer.addRecord(record);
-
 				while (result.isFullBuffer()) {
 					Buffer buffer = serializer.getCurrentBuffer();
 
@@ -98,18 +91,8 @@ public class RecordWriter<T extends IOReadableWritable> {
 					}
 
 					buffer = writer.getBufferProvider().requestBufferBlocking();
-					if (reporter != null) {
-						// increase the number of written bytes by the memory segment's size
-						reporter.reportNumBytesOut(buffer.getSize());
-					}
-
 					result = serializer.setNextBuffer(buffer);
 				}
-
-				if(reporter != null) {
-					// count number of emitted records
-					reporter.reportNumRecordsOut(1);
-				}
 			}
 		}
 	}

http://git-wip-us.apache.org/repos/asf/flink/blob/0f589aad/flink-runtime/src/main/scala/org/apache/flink/runtime/jobmanager/JobManager.scala
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/scala/org/apache/flink/runtime/jobmanager/JobManager.scala b/flink-runtime/src/main/scala/org/apache/flink/runtime/jobmanager/JobManager.scala
index 8823041..5cf69ec 100644
--- a/flink-runtime/src/main/scala/org/apache/flink/runtime/jobmanager/JobManager.scala
+++ b/flink-runtime/src/main/scala/org/apache/flink/runtime/jobmanager/JobManager.scala
@@ -28,7 +28,7 @@ import grizzled.slf4j.Logger
 import org.apache.flink.api.common.{ExecutionConfig, JobID}
 import org.apache.flink.configuration.{ConfigConstants, Configuration, GlobalConfiguration}
 import org.apache.flink.core.io.InputSplitAssigner
-import org.apache.flink.runtime.accumulators.StringifiedAccumulatorResult
+import org.apache.flink.runtime.accumulators.{AccumulatorSnapshot, StringifiedAccumulatorResult}
 import org.apache.flink.runtime.blob.BlobServer
 import org.apache.flink.runtime.client._
 import org.apache.flink.runtime.executiongraph.{ExecutionGraph, ExecutionJobVertex}
@@ -404,15 +404,7 @@ class JobManager(
       log.debug(s"Received hearbeat message from $instanceID.")
 
       Future {
-        accumulators foreach {
-          case accumulators =>
-              currentJobs.get(accumulators.getJobID) match {
-                case Some((jobGraph, jobInfo)) =>
-                  jobGraph.updateAccumulators(accumulators)
-                case None =>
-                  // ignore accumulator values for old job
-              }
-        }
+        updateAccumulators(accumulators)
       }(context.dispatcher)
 
       instanceManager.reportHeartBeat(instanceID, metricsReport)
@@ -770,6 +762,22 @@ class JobManager(
         log.error(s"Could not properly unregister job $jobID form the library cache.", t)
     }
   }
+
+  /**
+   * Updates the accumulators reported from a task manager via the Heartbeat message.
+   * @param accumulators list of accumulator snapshots
+   */
+  private def updateAccumulators(accumulators : Seq[AccumulatorSnapshot]) = {
+    accumulators foreach {
+      case accumulatorEvent =>
+        currentJobs.get(accumulatorEvent.getJobID) match {
+          case Some((jobGraph, jobInfo)) =>
+            jobGraph.updateAccumulators(accumulatorEvent)
+          case None =>
+          // ignore accumulator values for old job
+        }
+    }
+  }
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/flink/blob/0f589aad/flink-runtime/src/main/scala/org/apache/flink/runtime/taskmanager/TaskManager.scala
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/scala/org/apache/flink/runtime/taskmanager/TaskManager.scala b/flink-runtime/src/main/scala/org/apache/flink/runtime/taskmanager/TaskManager.scala
index f07fa0c..d78a594 100644
--- a/flink-runtime/src/main/scala/org/apache/flink/runtime/taskmanager/TaskManager.scala
+++ b/flink-runtime/src/main/scala/org/apache/flink/runtime/taskmanager/TaskManager.scala
@@ -155,7 +155,7 @@ extends Actor with ActorLogMessages with ActorSynchronousLogging {
 
   private var blobService: Option[BlobService] = None
   private var libraryCacheManager: Option[LibraryCacheManager] = None
-  private var currentJobManager: Option[ActorRef] = None
+  protected var currentJobManager: Option[ActorRef] = None
 
   private var instanceID: InstanceID = null
 
@@ -936,7 +936,7 @@ extends Actor with ActorLogMessages with ActorSynchronousLogging {
    * Sends a heartbeat message to the JobManager (if connected) with the current
    * metrics report.
    */
-  private def sendHeartbeatToJobManager(): Unit = {
+  protected def sendHeartbeatToJobManager(): Unit = {
     try {
       log.debug("Sending heartbeat to JobManager")
       val metricsReport: Array[Byte] = metricRegistryMapper.writeValueAsBytes(metricRegistry)

http://git-wip-us.apache.org/repos/asf/flink/blob/0f589aad/flink-runtime/src/test/scala/org/apache/flink/runtime/testingUtils/TestingJobManager.scala
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/scala/org/apache/flink/runtime/testingUtils/TestingJobManager.scala b/flink-runtime/src/test/scala/org/apache/flink/runtime/testingUtils/TestingJobManager.scala
index 6d316ca..ca5927d 100644
--- a/flink-runtime/src/test/scala/org/apache/flink/runtime/testingUtils/TestingJobManager.scala
+++ b/flink-runtime/src/test/scala/org/apache/flink/runtime/testingUtils/TestingJobManager.scala
@@ -27,8 +27,10 @@ import org.apache.flink.runtime.jobgraph.JobStatus
 import org.apache.flink.runtime.jobmanager.{JobManager, MemoryArchivist}
 import org.apache.flink.runtime.messages.ExecutionGraphMessages.JobStatusChanged
 import org.apache.flink.runtime.messages.Messages.Disconnect
+import org.apache.flink.runtime.messages.TaskManagerMessages.Heartbeat
 import org.apache.flink.runtime.testingUtils.TestingJobManagerMessages._
 import org.apache.flink.runtime.testingUtils.TestingMessages.DisableDisconnect
+import org.apache.flink.runtime.testingUtils.TestingTaskManagerMessages.AccumulatorsChanged
 
 import scala.collection.convert.WrapAsScala
 import scala.concurrent.Future
@@ -55,6 +57,8 @@ trait TestingJobManager extends ActorLogMessages with WrapAsScala {
   val waitForJobStatus = scala.collection.mutable.HashMap[JobID,
     collection.mutable.HashMap[JobStatus, Set[ActorRef]]]()
 
+  val waitForAccumulatorUpdate = scala.collection.mutable.HashMap[JobID, (Boolean, Set[ActorRef])]()
+
   var disconnectDisabled = false
 
   abstract override def receiveWithLogMessages: Receive = {
@@ -130,6 +134,46 @@ trait TestingJobManager extends ActorLogMessages with WrapAsScala {
         }
       }
 
+    case NotifyWhenAccumulatorChange(jobID) =>
+
+      val (updated, registered) = waitForAccumulatorUpdate.
+        getOrElse(jobID, (false, Set[ActorRef]()))
+      waitForAccumulatorUpdate += jobID -> (updated, registered + sender)
+      sender ! true
+
+    /**
+     * Notification from the task manager that changed accumulator are transferred on next
+     * Hearbeat. We need to keep this state to notify the listeners on next Heartbeat report.
+     */
+    case AccumulatorsChanged(jobID: JobID) =>
+      waitForAccumulatorUpdate.get(jobID) match {
+        case Some((updated, registered)) =>
+          waitForAccumulatorUpdate.put(jobID, (true, registered))
+        case None =>
+      }
+
+    /**
+     * Disabled async processing of accumulator values and send accumulators to the listeners if
+     * we previously received an [[AccumulatorsChanged]] message.
+     */
+    case msg : Heartbeat =>
+      super.receiveWithLogMessages(msg)
+
+      waitForAccumulatorUpdate foreach {
+        case (jobID, (updated, actors)) if updated =>
+          currentJobs.get(jobID) match {
+            case Some((graph, jobInfo)) =>
+              val flinkAccumulators = graph.getFlinkAccumulators
+              val userAccumulators = graph.aggregateUserAccumulators
+              actors foreach {
+                actor => actor ! UpdatedAccumulators(jobID, flinkAccumulators, userAccumulators)
+              }
+            case None =>
+          }
+          waitForAccumulatorUpdate.put(jobID, (false, actors))
+        case _ =>
+      }
+
     case RequestWorkingTaskManager(jobID) =>
       currentJobs.get(jobID) match {
         case Some((eg, _)) =>
@@ -147,15 +191,6 @@ trait TestingJobManager extends ActorLogMessages with WrapAsScala {
         case None => sender ! WorkingTaskManager(None)
       }
 
-    case RequestAccumulatorValues(jobID) =>
-
-      val (flinkAccumulators, userAccumulators) = currentJobs.get(jobID) match {
-        case Some((graph, jobInfo)) =>
-          (graph.getFlinkAccumulators, graph.aggregateUserAccumulators)
-        case None => null
-      }
-
-      sender ! RequestAccumulatorValuesResponse(jobID, flinkAccumulators, userAccumulators)
 
     case NotifyWhenJobStatus(jobID, state) =>
       val jobStatusListener = waitForJobStatus.getOrElseUpdate(jobID,

http://git-wip-us.apache.org/repos/asf/flink/blob/0f589aad/flink-runtime/src/test/scala/org/apache/flink/runtime/testingUtils/TestingJobManagerMessages.scala
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/scala/org/apache/flink/runtime/testingUtils/TestingJobManagerMessages.scala b/flink-runtime/src/test/scala/org/apache/flink/runtime/testingUtils/TestingJobManagerMessages.scala
index 46e8486..bbf7b2d 100644
--- a/flink-runtime/src/test/scala/org/apache/flink/runtime/testingUtils/TestingJobManagerMessages.scala
+++ b/flink-runtime/src/test/scala/org/apache/flink/runtime/testingUtils/TestingJobManagerMessages.scala
@@ -20,12 +20,12 @@ package org.apache.flink.runtime.testingUtils
 
 import akka.actor.ActorRef
 import org.apache.flink.api.common.JobID
+import org.apache.flink.api.common.accumulators.Accumulator
 import org.apache.flink.runtime.accumulators.AccumulatorRegistry
 import org.apache.flink.runtime.executiongraph.{ExecutionAttemptID, ExecutionGraph}
 import org.apache.flink.runtime.instance.InstanceGateway
 import org.apache.flink.runtime.jobgraph.JobStatus
 import java.util.Map
-import org.apache.flink.api.common.accumulators.Accumulator
 
 object TestingJobManagerMessages {
 
@@ -57,8 +57,18 @@ object TestingJobManagerMessages {
   case class NotifyWhenTaskManagerTerminated(taskManager: ActorRef)
   case class TaskManagerTerminated(taskManager: ActorRef)
 
-  case class RequestAccumulatorValues(jobID: JobID)
-  case class RequestAccumulatorValuesResponse(jobID: JobID,
+  /* Registers a listener to receive a message when accumulators changed.
+   * The change must be explicitly triggered by the TestingTaskManager which can receive an
+   * [[AccumulatorChanged]] message by a task that changed the accumulators. This message is then
+   * forwarded to the JobManager which will send the accumulators in the [[UpdatedAccumulators]]
+   * message when the next Heartbeat occurs.
+   * */
+  case class NotifyWhenAccumulatorChange(jobID: JobID)
+
+  /**
+   * Reports updated accumulators back to the listener.
+   */
+  case class UpdatedAccumulators(jobID: JobID,
     flinkAccumulators: Map[ExecutionAttemptID, Map[AccumulatorRegistry.Metric, Accumulator[_,_]]],
     userAccumulators: Map[String, Accumulator[_,_]])
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/0f589aad/flink-runtime/src/test/scala/org/apache/flink/runtime/testingUtils/TestingTaskManager.scala
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/scala/org/apache/flink/runtime/testingUtils/TestingTaskManager.scala b/flink-runtime/src/test/scala/org/apache/flink/runtime/testingUtils/TestingTaskManager.scala
index 220e6ca..6f8ddd3 100644
--- a/flink-runtime/src/test/scala/org/apache/flink/runtime/testingUtils/TestingTaskManager.scala
+++ b/flink-runtime/src/test/scala/org/apache/flink/runtime/testingUtils/TestingTaskManager.scala
@@ -94,7 +94,7 @@ class TestingTaskManager(config: TaskManagerConfiguration,
               waitForRemoval += (executionID -> (set + sender))
           }
       }
-      
+
     case TaskInFinalState(executionID) =>
       super.receiveWithLogMessages(TaskInFinalState(executionID))
       waitForRemoval.remove(executionID) match {
@@ -144,6 +144,21 @@ class TestingTaskManager(config: TaskManagerConfiguration,
       val waiting = waitForJobManagerToBeTerminated.getOrElse(jobManager.path.name, Set())
       waitForJobManagerToBeTerminated += jobManager.path.name -> (waiting + sender)
 
+    /**
+     * Message from task manager that accumulator values changed and need to be reported immediately
+     * instead of lazily through the
+     * [[org.apache.flink.runtime.messages.TaskManagerMessages.Heartbeat]] message. We forward this
+     * message to the job manager that it knows it should report to the listeners.
+     */
+    case msg: AccumulatorsChanged =>
+      currentJobManager match {
+        case Some(jobManager) =>
+          jobManager.forward(msg)
+          sendHeartbeatToJobManager()
+          sender ! true
+        case None =>
+      }
+
     case msg@Terminated(jobManager) =>
       super.receiveWithLogMessages(msg)
 

http://git-wip-us.apache.org/repos/asf/flink/blob/0f589aad/flink-runtime/src/test/scala/org/apache/flink/runtime/testingUtils/TestingTaskManagerMessages.scala
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/scala/org/apache/flink/runtime/testingUtils/TestingTaskManagerMessages.scala b/flink-runtime/src/test/scala/org/apache/flink/runtime/testingUtils/TestingTaskManagerMessages.scala
index c9a2f73..1c428cc 100644
--- a/flink-runtime/src/test/scala/org/apache/flink/runtime/testingUtils/TestingTaskManagerMessages.scala
+++ b/flink-runtime/src/test/scala/org/apache/flink/runtime/testingUtils/TestingTaskManagerMessages.scala
@@ -51,7 +51,14 @@ object TestingTaskManagerMessages {
   case class NotifyWhenJobManagerTerminated(jobManager: ActorRef)
 
   case class JobManagerTerminated(jobManager: ActorRef)
-  
+
+  /**
+   * Message to give a hint to the task manager that accumulator values were updated in the task.
+   * This message is forwarded to the job manager which knows that it needs to notify listeners
+   * of accumulator updates.
+   */
+  case class AccumulatorsChanged(jobID: JobID)
+
   // --------------------------------------------------------------------------
   // Utility methods to allow simpler case object access from Java
   // --------------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/flink/blob/0f589aad/flink-tests/src/test/java/org/apache/flink/test/accumulators/AccumulatorLiveITCase.java
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/java/org/apache/flink/test/accumulators/AccumulatorLiveITCase.java b/flink-tests/src/test/java/org/apache/flink/test/accumulators/AccumulatorLiveITCase.java
index 3d80157..f979aea 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/accumulators/AccumulatorLiveITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/accumulators/AccumulatorLiveITCase.java
@@ -21,19 +21,21 @@ package org.apache.flink.test.accumulators;
 import akka.actor.ActorRef;
 import akka.actor.ActorSystem;
 import akka.actor.Status;
+import akka.pattern.Patterns;
 import akka.testkit.JavaTestKit;
+import akka.util.Timeout;
 import org.apache.flink.api.common.JobExecutionResult;
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.api.common.Plan;
 import org.apache.flink.api.common.accumulators.Accumulator;
 import org.apache.flink.api.common.accumulators.IntCounter;
 import org.apache.flink.api.common.accumulators.LongCounter;
-import org.apache.flink.api.common.functions.FlatMapFunction;
 import org.apache.flink.api.common.functions.RichFlatMapFunction;
 import org.apache.flink.api.common.io.OutputFormat;
 import org.apache.flink.api.java.DataSet;
 import org.apache.flink.api.java.ExecutionEnvironment;
 import org.apache.flink.api.java.LocalEnvironment;
+import org.apache.flink.configuration.ConfigConstants;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.optimizer.DataStatistics;
 import org.apache.flink.optimizer.Optimizer;
@@ -44,50 +46,66 @@ import org.apache.flink.runtime.akka.AkkaUtils;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.jobgraph.JobGraph;
 import org.apache.flink.runtime.messages.JobManagerMessages;
-import org.apache.flink.runtime.taskmanager.TaskManager;
 import org.apache.flink.runtime.testingUtils.TestingCluster;
-import org.apache.flink.runtime.testingUtils.TestingJobManagerMessages.*;
+import org.apache.flink.runtime.testingUtils.TestingJobManagerMessages;
+import org.apache.flink.runtime.testingUtils.TestingTaskManagerMessages;
 import org.apache.flink.runtime.testingUtils.TestingUtils;
 import org.apache.flink.util.Collector;
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import scala.concurrent.Await;
+import scala.concurrent.Future;
+import scala.concurrent.duration.Duration;
+import scala.concurrent.duration.FiniteDuration;
 
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.Map;
+import java.util.concurrent.TimeUnit;
 
 import static org.junit.Assert.*;
 
 
 /**
- * Test the availability of accumulator results during runtime.
+ * Tests the availability of accumulator results during runtime. The test case tests a user-defined
+ * accumulator and Flink's internal accumulators for two consecutive tasks.
  */
-@SuppressWarnings("serial")
 public class AccumulatorLiveITCase {
 
+	private static final Logger LOG = LoggerFactory.getLogger(AccumulatorLiveITCase.class);
+
 	private static ActorSystem system;
 	private static ActorRef jobManager;
+	private static ActorRef taskManager;
+	private static JobID jobID;
 
-	// name of accumulator
+	// name of user accumulator
 	private static String NAME = "test";
-	// time to wait between changing the accumulator value
-	private static long WAIT_TIME = TaskManager.HEARTBEAT_INTERVAL().toMillis() + 500;
 
 	// number of heartbeat intervals to check
-	private static int NUM_ITERATIONS = 3;
-	// numer of retries in case the expected value is not seen
-	private static int NUM_RETRIES = 10;
+	private static final int NUM_ITERATIONS = 5;
 
 	private static List<String> inputData = new ArrayList<String>(NUM_ITERATIONS);
 
+	private static final FiniteDuration TIMEOUT = new FiniteDuration(10, TimeUnit.SECONDS);
+
 
 	@Before
 	public void before() throws Exception {
 		system = AkkaUtils.createLocalActorSystem(new Configuration());
-		TestingCluster testingCluster = TestingUtils.startTestingCluster(1, 1, TestingUtils.DEFAULT_AKKA_ASK_TIMEOUT());
+
+		Configuration config = new Configuration();
+		config.setInteger(ConfigConstants.TASK_MANAGER_NUM_TASK_SLOTS, 1);
+		config.setInteger(ConfigConstants.LOCAL_INSTANCE_MANAGER_NUMBER_TASK_MANAGER, 1);
+		config.setString(ConfigConstants.AKKA_ASK_TIMEOUT, TestingUtils.DEFAULT_AKKA_ASK_TIMEOUT());
+		TestingCluster testingCluster = new TestingCluster(config, false, true);
+
 		jobManager = testingCluster.getJobManager();
+		taskManager = testingCluster.getTaskManagersAsJava().get(0);
 
 		// generate test data
 		for (int i=0; i < NUM_ITERATIONS; i++) {
@@ -109,125 +127,157 @@ public class AccumulatorLiveITCase {
 			ExecutionEnvironment env = new PlanExtractor();
 			DataSet<String> input = env.fromCollection(inputData);
 			input
-					.flatMap(new Tokenizer())
 					.flatMap(new WaitingUDF())
 					.output(new WaitingOutputFormat());
 			env.execute();
 
-			/** Extract job graph **/
+			// Extract job graph and set job id for the task to notify of accumulator changes.
 			JobGraph jobGraph = getOptimizedPlan(((PlanExtractor) env).plan);
+			jobID = jobGraph.getJobID();
 
+			// register for accumulator changes
+			jobManager.tell(new TestingJobManagerMessages.NotifyWhenAccumulatorChange(jobID), getRef());
+			expectMsgEquals(TIMEOUT, true);
+
+			// submit job
 			jobManager.tell(new JobManagerMessages.SubmitJob(jobGraph, false), getRef());
-			expectMsgClass(Status.Success.class);
+			expectMsgClass(TIMEOUT, Status.Success.class);
+
+
+			ExecutionAttemptID mapperTaskID = null;
+
+			TestingJobManagerMessages.UpdatedAccumulators msg = (TestingJobManagerMessages.UpdatedAccumulators) receiveOne(TIMEOUT);
+			Map<ExecutionAttemptID, Map<AccumulatorRegistry.Metric, Accumulator<?, ?>>> flinkAccumulators = msg.flinkAccumulators();
+			Map<String, Accumulator<?, ?>> userAccumulators = msg.userAccumulators();
+
+			// find out the first task's execution attempt id
+			for (Map.Entry<ExecutionAttemptID, ?> entry : flinkAccumulators.entrySet()) {
+				if (entry.getValue() != null) {
+					mapperTaskID = entry.getKey();
+					break;
+				}
+			}
 
 			/* Check for accumulator values */
-			int i = 0, retries = 0;
+			if(checkUserAccumulators(0, userAccumulators) && checkFlinkAccumulators(mapperTaskID, 0, 0, 0, 0, flinkAccumulators)) {
+				LOG.info("Passed initial check for map task.");
+			} else {
+				fail("Wrong accumulator results when map task begins execution.");
+			}
+
+
 			int expectedAccVal = 0;
-			while(i <= NUM_ITERATIONS) {
-				if (retries > 0) {
-					// retry fast
-					Thread.sleep(WAIT_TIME / NUM_RETRIES);
+			ExecutionAttemptID sinkTaskID = null;
+
+			/* for mapper task */
+			for (int i = 1; i <= NUM_ITERATIONS; i++) {
+				expectedAccVal += i;
+
+				// receive message
+				msg = (TestingJobManagerMessages.UpdatedAccumulators) receiveOne(TIMEOUT);
+				flinkAccumulators = msg.flinkAccumulators();
+				userAccumulators = msg.userAccumulators();
+
+				LOG.info("{}", flinkAccumulators);
+				LOG.info("{}", userAccumulators);
+
+				if (checkUserAccumulators(expectedAccVal, userAccumulators) && checkFlinkAccumulators(mapperTaskID, 0, i, 0, i * 4, flinkAccumulators)) {
+					LOG.info("Passed round " + i);
 				} else {
-					// wait for heartbeat interval
-					Thread.sleep(WAIT_TIME);
+					fail("Failed in round #" + i);
+				}
+			}
+
+			msg = (TestingJobManagerMessages.UpdatedAccumulators) receiveOne(TIMEOUT);
+			flinkAccumulators = msg.flinkAccumulators();
+			userAccumulators = msg.userAccumulators();
+
+			// find the second's task id
+			for (ExecutionAttemptID key : flinkAccumulators.keySet()) {
+				if (key != mapperTaskID) {
+					sinkTaskID = key;
+					break;
 				}
+			}
+
+			if(checkUserAccumulators(expectedAccVal, userAccumulators) && checkFlinkAccumulators(sinkTaskID, 0, 0, 0, 0, flinkAccumulators)) {
+				LOG.info("Passed initial check for sink task.");
+			} else {
+				fail("Wrong accumulator results when sink task begins execution.");
+			}
 
-				jobManager.tell(new RequestAccumulatorValues(jobGraph.getJobID()), getRef());
-				RequestAccumulatorValuesResponse response =
-						expectMsgClass(RequestAccumulatorValuesResponse.class);
+			/* for sink task */
+			for (int i = 1; i <= NUM_ITERATIONS; i++) {
 
-				Map<String, Accumulator<?, ?>> userAccumulators = response.userAccumulators();
-				Map<ExecutionAttemptID, Map<AccumulatorRegistry.Metric, Accumulator<?,?>>> flinkAccumulators =
-						response.flinkAccumulators();
+				// receive message
+				msg = (TestingJobManagerMessages.UpdatedAccumulators) receiveOne(TIMEOUT);
 
-				if (checkUserAccumulators(expectedAccVal, userAccumulators) && checkFlinkAccumulators(i == NUM_ITERATIONS, i, i * 4, flinkAccumulators)) {
-//					System.out.println("Passed round " + i);
-					// We passed this round
-					i += 1;
-					expectedAccVal += i;
-					retries = 0;
+				flinkAccumulators = msg.flinkAccumulators();
+				userAccumulators = msg.userAccumulators();
+
+				LOG.info("{}", flinkAccumulators);
+				LOG.info("{}", userAccumulators);
+
+				if (checkUserAccumulators(expectedAccVal, userAccumulators) && checkFlinkAccumulators(sinkTaskID, i, 0, i*4, 0, flinkAccumulators)) {
+					LOG.info("Passed round " + i);
 				} else {
-					if (retries < NUM_RETRIES) {
-//						System.out.println("retrying for the " + retries + " time.");
-						// try again
-						retries += 1;
-					} else {
-						fail("Failed in round #" + i + " after " + retries + " retries.");
-					}
+					fail("Failed in round #" + i);
 				}
 			}
 
+			expectMsgClass(TIMEOUT, JobManagerMessages.JobResultSuccess.class);
+
 		}};
 	}
 
 	private static boolean checkUserAccumulators(int expected, Map<String, Accumulator<?,?>> accumulatorMap) {
-//		System.out.println("checking user accumulators");
+		LOG.info("checking user accumulators");
 		return accumulatorMap.containsKey(NAME) && expected == ((IntCounter)accumulatorMap.get(NAME)).getLocalValue();
 	}
 
-	private static boolean checkFlinkAccumulators(boolean lastRound, int expectedRecords, int expectedBytes,
+	private static boolean checkFlinkAccumulators(ExecutionAttemptID taskKey, int expectedRecordsIn, int expectedRecordsOut, int expectedBytesIn, int expectedBytesOut,
 												  Map<ExecutionAttemptID, Map<AccumulatorRegistry.Metric, Accumulator<?,?>>> accumulatorMap) {
-//		System.out.println("checking flink accumulators");
-
-		for(Map<AccumulatorRegistry.Metric, Accumulator<?,?>> taskMap : accumulatorMap.values()) {
-			if (taskMap != null) {
-				for (Map.Entry<AccumulatorRegistry.Metric, Accumulator<?, ?>> entry : taskMap.entrySet()) {
-					switch (entry.getKey()) {
-						/**
-						 * The following two cases are for the DataSource and Map task
-						 */
-						case NUM_RECORDS_OUT:
-							if (!lastRound) {
-								if(((LongCounter) entry.getValue()).getLocalValue() != expectedRecords) {
-									return false;
-								}
-							}
-							break;
-						case NUM_BYTES_OUT:
-							if (!lastRound) {
-								if (((LongCounter) entry.getValue()).getLocalValue() != expectedBytes) {
-									return false;
-								}
-							}
-							break;
-						/**
-						 * The following two cases are for the DataSink task
-						 */
-						case NUM_RECORDS_IN:
-							// check if we are in last round and in current task accumulator map
-							if (lastRound && ((LongCounter)taskMap.get(AccumulatorRegistry.Metric.NUM_RECORDS_OUT)).getLocalValue() == 0) {
-								if (((LongCounter) entry.getValue()).getLocalValue() != expectedRecords) {
-									return false;
-								}
-							}
-							break;
-						case NUM_BYTES_IN:
-							if (lastRound && ((LongCounter)taskMap.get(AccumulatorRegistry.Metric.NUM_RECORDS_OUT)).getLocalValue() == 0) {
-								if (((LongCounter) entry.getValue()).getLocalValue() != expectedBytes) {
-									return false;
-								}
-							}
-							break;
-						default:
-							fail("Unknown accumulator found.");
+		LOG.info("checking flink accumulators");
+
+		Map<AccumulatorRegistry.Metric, Accumulator<?, ?>> taskMap = accumulatorMap.get(taskKey);
+		assertTrue(accumulatorMap.size() > 0);
+
+		for (Map.Entry<AccumulatorRegistry.Metric, Accumulator<?, ?>> entry : taskMap.entrySet()) {
+			switch (entry.getKey()) {
+				/**
+				 * The following two cases are for the DataSource and Map task
+				 */
+				case NUM_RECORDS_OUT:
+					if(((LongCounter) entry.getValue()).getLocalValue() != expectedRecordsOut) {
+						return false;
 					}
-				}
+					break;
+				case NUM_BYTES_OUT:
+					if (((LongCounter) entry.getValue()).getLocalValue() != expectedBytesOut) {
+						return false;
+					}
+					break;
+				/**
+				 * The following two cases are for the DataSink task
+				 */
+				case NUM_RECORDS_IN:
+					if (((LongCounter) entry.getValue()).getLocalValue() != expectedRecordsIn) {
+						return false;
+					}
+					break;
+				case NUM_BYTES_IN:
+					if (((LongCounter) entry.getValue()).getLocalValue() != expectedBytesIn) {
+						return false;
+					}
+					break;
+				default:
+					fail("Unknown accumulator found.");
 			}
 		}
 		return true;
 	}
 
 
-	public static class Tokenizer implements FlatMapFunction<String, String> {
-
-		@Override
-		public void flatMap(String value, Collector<String> out) throws Exception {
-			for (String str : value.split("\n")) {
-				out.collect(str);
-			}
-		}
-	}
-
 	/**
 	 * UDF that waits for at least the heartbeat interval's duration.
 	 */
@@ -238,43 +288,55 @@ public class AccumulatorLiveITCase {
 		@Override
 		public void open(Configuration parameters) throws Exception {
 			getRuntimeContext().addAccumulator(NAME, counter);
+			notifyTaskManagerOfAccumulatorUpdate();
 		}
 
 		@Override
 		public void flatMap(String value, Collector<Integer> out) throws Exception {
-			/* Wait here to check the accumulator value in the meantime */
-			Thread.sleep(WAIT_TIME);
 			int val = Integer.valueOf(value);
 			counter.add(val);
 			out.collect(val);
+			LOG.debug("Emitting value {}.", value);
+			notifyTaskManagerOfAccumulatorUpdate();
 		}
+
 	}
 
 	private static class WaitingOutputFormat implements OutputFormat<Integer> {
 
 		@Override
 		public void configure(Configuration parameters) {
-
 		}
 
 		@Override
 		public void open(int taskNumber, int numTasks) throws IOException {
-
+			notifyTaskManagerOfAccumulatorUpdate();
 		}
 
 		@Override
 		public void writeRecord(Integer record) throws IOException {
+			notifyTaskManagerOfAccumulatorUpdate();
 		}
 
 		@Override
 		public void close() throws IOException {
+		}
+	}
+
+	/**
+	 * Notify task manager of accumulator update and wait until the Heartbeat containing the message
+	 * has been reported.
+	 */
+	public static void notifyTaskManagerOfAccumulatorUpdate() {
+		new JavaTestKit(system) {{
+			Timeout timeout = new Timeout(Duration.create(5, "seconds"));
+			Future<Object> ask = Patterns.ask(taskManager, new TestingTaskManagerMessages.AccumulatorsChanged(jobID), timeout);
 			try {
-//				System.out.println("starting output task");
-				Thread.sleep(WAIT_TIME);
-			} catch (InterruptedException e) {
-				fail("Interrupted test.");
+				Await.result(ask, timeout.duration());
+			} catch (Exception e) {
+				fail("Failed to notify task manager of accumulator update.");
 			}
-		}
+		}};
 	}
 
 	/**