You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tr...@apache.org on 2017/02/05 20:57:46 UTC

[1/2] flink git commit: [FLINK-2211] [ml] Generalize ALS API

Repository: flink
Updated Branches:
  refs/heads/master 31e157346 -> 215776b81


[FLINK-2211] [ml] Generalize ALS API

Allows the user and items to be of type Long

This closes #3265.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/43d2fd23
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/43d2fd23
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/43d2fd23

Branch: refs/heads/master
Commit: 43d2fd23a75a5ac7769d37cb5c2559803bd65800
Parents: 31e1573
Author: Till Rohrmann <tr...@apache.org>
Authored: Fri Feb 3 18:22:13 2017 +0100
Committer: Till Rohrmann <tr...@apache.org>
Committed: Sun Feb 5 21:56:25 2017 +0100

----------------------------------------------------------------------
 .../apache/flink/ml/recommendation/ALS.scala    | 125 +++++++++++++------
 .../flink/ml/recommendation/ALSITSuite.scala    |  51 +++++++-
 .../ml/recommendation/Recommendation.scala      |  63 ++++++++++
 3 files changed, 198 insertions(+), 41 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/43d2fd23/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/recommendation/ALS.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/recommendation/ALS.scala b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/recommendation/ALS.scala
index d8af42f..0454381 100644
--- a/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/recommendation/ALS.scala
+++ b/flink-libraries/flink-ml/src/main/scala/org/apache/flink/ml/recommendation/ALS.scala
@@ -194,7 +194,7 @@ class ALS extends Predictor[ALS] {
     * @return
     */
   def empiricalRisk(
-      labeledData: DataSet[(Int, Int, Double)],
+      labeledData: DataSet[(Long, Long, Double)],
       riskParameters: ParameterMap = ParameterMap.Empty)
     : DataSet[Double] = {
     val resultingParameters = parameters ++ riskParameters
@@ -293,20 +293,20 @@ object ALS {
     * @param item Item iD of the rated item
     * @param rating Rating value
     */
-  case class Rating(user: Int, item: Int, rating: Double)
+  case class Rating(user: Long, item: Long, rating: Double)
 
   /** Latent factor model vector
     *
     * @param id
     * @param factors
     */
-  case class Factors(id: Int, factors: Array[Double]) {
+  case class Factors(id: Long, factors: Array[Double]) {
     override def toString = s"($id, ${factors.mkString(",")})"
   }
 
   case class Factorization(userFactors: DataSet[Factors], itemFactors: DataSet[Factors])
 
-  case class OutBlockInformation(elementIDs: Array[Int], outLinks: OutLinks) {
+  case class OutBlockInformation(elementIDs: Array[Long], outLinks: OutLinks) {
     override def toString: String = {
       s"OutBlockInformation:((${elementIDs.mkString(",")}), ($outLinks))"
     }
@@ -349,7 +349,7 @@ object ALS {
     def apply(idx: Int) = links(idx)
   }
 
-  case class InBlockInformation(elementIDs: Array[Int], ratingsForBlock: Array[BlockRating]) {
+  case class InBlockInformation(elementIDs: Array[Long], ratingsForBlock: Array[BlockRating]) {
 
     override def toString: String = {
       s"InBlockInformation:((${elementIDs.mkString(",")}), (${ratingsForBlock.mkString("\n")}))"
@@ -376,8 +376,8 @@ object ALS {
   }
 
   class BlockIDGenerator(blocks: Int) extends Serializable {
-    def apply(id: Int): Int = {
-      id % blocks
+    def apply(id: Long): Int = {
+      (id % blocks).toInt
     }
   }
 
@@ -390,12 +390,15 @@ object ALS {
   // ===================================== Operations ==============================================
 
   /** Predict operation which calculates the matrix entry for the given indices  */
-  implicit val predictRating = new PredictDataSetOperation[ALS, (Int, Int), (Int ,Int, Double)] {
+  implicit val predictRating = new PredictDataSetOperation[
+      ALS,
+      (Long, Long),
+      (Long, Long, Double)] {
     override def predictDataSet(
         instance: ALS,
         predictParameters: ParameterMap,
-        input: DataSet[(Int, Int)])
-      : DataSet[(Int, Int, Double)] = {
+        input: DataSet[(Long, Long)])
+      : DataSet[(Long, Long, Double)] = {
 
       instance.factorsOption match {
         case Some((userFactors, itemFactors)) => {
@@ -425,16 +428,34 @@ object ALS {
     }
   }
 
+  implicit val predictRatingInt = new PredictDataSetOperation[ALS, (Int, Int), (Int, Int, Double)] {
+    override def predictDataSet(
+      instance: ALS,
+      predictParameters: ParameterMap,
+      input: DataSet[(Int, Int)])
+    : DataSet[(Int, Int, Double)] = {
+      val longInput = input.map { x => (x._1.toLong, x._2.toLong)}
+
+      val longResult = implicitly[PredictDataSetOperation[ALS, (Long, Long), (Long, Long, Double)]]
+        .predictDataSet(
+          instance,
+          predictParameters,
+          longInput)
+
+      longResult.map{ x => (x._1.toInt, x._2.toInt, x._3)}
+    }
+  }
+
   /** Calculates the matrix factorization for the given ratings. A rating is defined as
     * a tuple of user ID, item ID and the corresponding rating.
     *
     * @return Factorization containing the user and item matrix
     */
-  implicit val fitALS =  new FitOperation[ALS, (Int, Int, Double)] {
+  implicit val fitALS =  new FitOperation[ALS, (Long, Long, Double)] {
     override def fit(
         instance: ALS,
         fitParameters: ParameterMap,
-        input: DataSet[(Int, Int, Double)])
+        input: DataSet[(Long, Long, Double)])
       : Unit = {
       val resultParameters = instance.parameters ++ fitParameters
 
@@ -457,13 +478,13 @@ object ALS {
 
       val ratingsByUserBlock = ratings.map{
         rating =>
-          val blockID = rating.user % userBlocks
+          val blockID = (rating.user % userBlocks).toInt
           (blockID, rating)
       } partitionCustom(blockIDPartitioner, 0)
 
       val ratingsByItemBlock = ratings map {
         rating =>
-          val blockID = rating.item % itemBlocks
+          val blockID = (rating.item % itemBlocks).toInt
           (blockID, new Rating(rating.item, rating.user, rating.rating))
       } partitionCustom(blockIDPartitioner, 0)
 
@@ -518,6 +539,19 @@ object ALS {
     }
   }
 
+  implicit val fitALSInt =  new FitOperation[ALS, (Int, Int, Double)] {
+    override def fit(
+      instance: ALS,
+      fitParameters: ParameterMap,
+      input: DataSet[(Int, Int, Double)])
+    : Unit = {
+
+      val longInput = input.map { x => (x._1.toLong, x._2.toLong, x._3)}
+
+      implicitly[FitOperation[ALS, (Long, Long, Double)]].fit(instance, fitParameters, longInput)
+    }
+  }
+
   /** Calculates a single half step of the ALS optimization. The result is the new value for
     * either the user or item matrix, depending with which matrix the method was called.
     *
@@ -706,13 +740,13 @@ object ALS {
     * @param ratings
     * @return
     */
-  def createUsersPerBlock(ratings: DataSet[(Int, Rating)]): DataSet[(Int, Array[Int])] = {
+  def createUsersPerBlock(ratings: DataSet[(Int, Rating)]): DataSet[(Int, Array[Long])] = {
     ratings.map{ x => (x._1, x._2.user)}.withForwardedFields("0").groupBy(0).
       sortGroup(1, Order.ASCENDING).reduceGroup {
       users => {
-        val result = ArrayBuffer[Int]()
+        val result = ArrayBuffer[Long]()
         var id = -1
-        var oldUser = -1
+        var oldUser = -1L
 
         while(users.hasNext) {
           val user = users.next()
@@ -746,7 +780,7 @@ object ALS {
     * @return
     */
   def createOutBlockInformation(ratings: DataSet[(Int, Rating)],
-    usersPerBlock: DataSet[(Int, Array[Int])],
+    usersPerBlock: DataSet[(Int, Array[Long])],
     itemBlocks: Int, blockIDGenerator: BlockIDGenerator):
   DataSet[(Int, OutBlockInformation)] = {
     ratings.coGroup(usersPerBlock).where(0).equalTo(0).apply {
@@ -795,7 +829,7 @@ object ALS {
     * @return
     */
   def createInBlockInformation(ratings: DataSet[(Int, Rating)],
-    usersPerBlock: DataSet[(Int, Array[Int])],
+    usersPerBlock: DataSet[(Int, Array[Long])],
     blockIDGenerator: BlockIDGenerator):
   DataSet[(Int, InBlockInformation)] = {
     // Group for every user block the users which have rated the same item and collect their ratings
@@ -803,8 +837,8 @@ object ALS {
       .withForwardedFields("0").groupBy(0, 1).reduceGroup {
       x =>
         var userBlockID = -1
-        var itemID = -1
-        val userIDs = ArrayBuffer[Int]()
+        var itemID = -1L
+        val userIDs = ArrayBuffer[Long]()
         val ratings = ArrayBuffer[Double]()
 
         while (x.hasNext) {
@@ -824,12 +858,12 @@ object ALS {
     // accordingly.
     val collectedPartialInfos = partialInInfos.groupBy(0, 1).sortGroup(2, Order.ASCENDING).
       reduceGroup {
-      new GroupReduceFunction[(Int, Int, Int, (Array[Int], Array[Double])), (Int,
-        Int, Array[(Array[Int], Array[Double])])](){
-        val buffer = new ArrayBuffer[(Array[Int], Array[Double])]
+      new GroupReduceFunction[(Int, Int, Long, (Array[Long], Array[Double])), (Int,
+        Int, Array[(Array[Long], Array[Double])])](){
+        val buffer = new ArrayBuffer[(Array[Long], Array[Double])]
 
-        override def reduce(iterable: lang.Iterable[(Int, Int, Int, (Array[Int],
-          Array[Double]))], collector: Collector[(Int, Int, Array[(Array[Int],
+        override def reduce(iterable: lang.Iterable[(Int, Int, Long, (Array[Long],
+          Array[Double]))], collector: Collector[(Int, Int, Array[(Array[Long],
           Array[Double])])]): Unit = {
 
           val infos = iterable.iterator()
@@ -858,7 +892,7 @@ object ALS {
             counter += 1
           }
 
-          val array = new Array[(Array[Int], Array[Double])](counter)
+          val array = new Array[(Array[Long], Array[Double])](counter)
 
           buffer.copyToArray(array)
 
@@ -871,13 +905,13 @@ object ALS {
     // respect to their itemBlockID, because the block update messages are sorted the same way
     collectedPartialInfos.coGroup(usersPerBlock).where(0).equalTo(0).
       sortFirstGroup(1, Order.ASCENDING).apply{
-      new CoGroupFunction[(Int, Int, Array[(Array[Int], Array[Double])]),
-        (Int, Array[Int]), (Int, InBlockInformation)] {
+      new CoGroupFunction[(Int, Int, Array[(Array[Long], Array[Double])]),
+        (Int, Array[Long]), (Int, InBlockInformation)] {
         val buffer = ArrayBuffer[BlockRating]()
 
         override def coGroup(partialInfosIterable:
-        lang.Iterable[(Int, Int,  Array[(Array[Int], Array[Double])])],
-          userIterable: lang.Iterable[(Int, Array[Int])],
+        lang.Iterable[(Int, Int,  Array[(Array[Long], Array[Double])])],
+          userIterable: lang.Iterable[(Int, Array[Long])],
           collector: Collector[(Int, InBlockInformation)]): Unit = {
 
           val users = userIterable.iterator()
@@ -895,12 +929,21 @@ object ALS {
             // entry contains the ratings and userIDs of a complete item block
             val entry = partialInfo._3
 
+            val blockRelativeIndicesRatings = new Array[(Array[Int], Array[Double])](entry.size)
+
             // transform userIDs to positional indices
-            for (row <- 0 until entry.length; col <- 0 until entry(row)._1.length) {
-              entry(row)._1(col) = userIDToPos(entry(row)._1(col))
+            for (row <- 0 until entry.length) {
+              val rowEntries = entry(row)._1
+              val rowIndices = new Array[Int](rowEntries.length)
+
+              for (col <- 0 until rowEntries.length) {
+                rowIndices(col) = userIDToPos(rowEntries(col))
+              }
+
+              blockRelativeIndicesRatings(row) = (rowIndices, entry(row)._2)
             }
 
-            buffer(counter).ratings = entry
+            buffer(counter).ratings = blockRelativeIndicesRatings
 
             counter += 1
           }
@@ -909,13 +952,21 @@ object ALS {
             val partialInfo = partialInfos.next()
             // entry contains the ratings and userIDs of a complete item block
             val entry = partialInfo._3
+            val blockRelativeIndicesRatings = new Array[(Array[Int], Array[Double])](entry.size)
 
             // transform userIDs to positional indices
-            for (row <- 0 until entry.length; col <- 0 until entry(row)._1.length) {
-              entry(row)._1(col) = userIDToPos(entry(row)._1(col))
+            for (row <- 0 until entry.length) {
+              val rowEntries = entry(row)._1
+              val rowIndices = new Array[Int](rowEntries.length)
+
+              for (col <- 0 until rowEntries.length) {
+                rowIndices(col) = userIDToPos(rowEntries(col))
+              }
+
+              blockRelativeIndicesRatings(row) = (rowIndices, entry(row)._2)
             }
 
-            buffer += new BlockRating(entry)
+            buffer += new BlockRating(blockRelativeIndicesRatings)
 
             counter += 1
           }

http://git-wip-us.apache.org/repos/asf/flink/blob/43d2fd23/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/recommendation/ALSITSuite.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/recommendation/ALSITSuite.scala b/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/recommendation/ALSITSuite.scala
index 043d8cb..0c85c46 100644
--- a/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/recommendation/ALSITSuite.scala
+++ b/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/recommendation/ALSITSuite.scala
@@ -42,19 +42,19 @@ class ALSITSuite extends FlatSpec with Matchers with FlinkTestBase {
       .setBlocks(4)
       .setNumFactors(numFactors)
 
-    val inputDS = env.fromCollection(data)
+    val inputDS = env.fromCollection(dataLong)
 
     als.fit(inputDS)
 
-    val testData = env.fromCollection(expectedResult.map {
+    val testData = env.fromCollection(expectedResultLong.map {
       case (userID, itemID, rating) => (userID, itemID)
     })
 
     val predictions = als.predict(testData).collect()
 
-    predictions.length should equal(expectedResult.length)
+    predictions.length should equal(expectedResultLong.length)
 
-    val resultMap = expectedResult.map {
+    val resultMap = expectedResultLong.map {
       case (uID, iID, value) => (uID, iID) -> value
     }.toMap
 
@@ -70,4 +70,47 @@ class ALSITSuite extends FlatSpec with Matchers with FlinkTestBase {
 
     risk should be(expectedEmpiricalRisk +- 1)
   }
+
+  it should "properly factorize a matrix (integer indices)" in {
+    import Recommendation._
+
+    val env = ExecutionEnvironment.getExecutionEnvironment
+
+    val als = ALS()
+      .setIterations(iterations)
+      .setLambda(lambda)
+      .setBlocks(4)
+      .setNumFactors(numFactors)
+
+    val inputDS = env.fromCollection(data)
+
+    als.fit(inputDS)
+
+
+    val testData = env.fromCollection(expectedResult.map {
+     case (userID, itemID, rating) => (userID, itemID)
+   })
+
+    val predictions = als.predict(testData).collect()
+
+    predictions.length should equal(expectedResult.length)
+
+    val resultMap = expectedResultLong.map {
+      case (uID, iID, value) => (uID, iID) -> value
+    }.toMap
+
+    predictions foreach {
+      case (uID, iID, value) => {
+        resultMap.isDefinedAt((uID, iID)) should be(true)
+
+        value should be(resultMap((uID, iID)) +- 0.1)
+      }
+    }
+
+    val risk = als.empiricalRisk(
+        inputDS.map( x => (x._1.toLong, x._2.toLong, x._3)))
+      .collect().head
+
+    risk should be(expectedEmpiricalRisk +- 1)
+  }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/43d2fd23/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/recommendation/Recommendation.scala
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/recommendation/Recommendation.scala b/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/recommendation/Recommendation.scala
index 8d8e4b9..3b466fd 100644
--- a/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/recommendation/Recommendation.scala
+++ b/flink-libraries/flink-ml/src/test/scala/org/apache/flink/ml/recommendation/Recommendation.scala
@@ -86,5 +86,68 @@ object Recommendation {
     )
   }
 
+  val dataLong: Seq[(Long, Long, Double)] = {
+    Seq(
+      (2,13,534.3937734561154),
+      (6,14,509.63176469621936),
+      (4,14,515.8246770897443),
+      (7,3,495.05234565105),
+      (2,3,532.3281786219485),
+      (5,3,497.1906356844367),
+      (3,3,512.0640508585093),
+      (10,3,500.2906742233019),
+      (1,4,521.9189079662882),
+      (2,4,515.0734651491396),
+      (1,7,522.7532725967008),
+      (8,4,492.65683825096403),
+      (4,8,492.65683825096403),
+      (10,8,507.03319667905413),
+      (7,1,522.7532725967008),
+      (1,1,572.2230209271174),
+      (2,1,563.5849190220224),
+      (6,1,518.4844061038742),
+      (9,1,529.2443732217674),
+      (8,1,543.3202505434103),
+      (7,2,516.0188923307859),
+      (1,2,563.5849190220224),
+      (1,11,515.1023793011227),
+      (8,2,536.8571133978352),
+      (2,11,507.90776961762225),
+      (3,2,532.3281786219485),
+      (5,11,476.24185144363304),
+      (4,2,515.0734651491396),
+      (4,11,469.92049343738233),
+      (3,12,509.4713776280098),
+      (4,12,494.6533165132021),
+      (7,5,482.2907867916308),
+      (6,5,477.5940040923741),
+      (4,5,480.9040684364228),
+      (1,6,518.4844061038742),
+      (6,6,470.6605085832807),
+      (8,6,489.6360564705307),
+      (4,6,472.74052954447046),
+      (7,9,482.5837650471611),
+      (5,9,487.00175463269863),
+      (9,9,500.69514584780944),
+      (4,9,477.71644808419325),
+      (7,10,485.3852917539852),
+      (8,10,507.03319667905413),
+      (3,10,500.2906742233019),
+      (5,15,488.08215944254437),
+      (6,15,480.16929757607346)
+    )
+  }
+
+  val expectedResultLong: Seq[(Long, Long, Double)] = {
+    Seq(
+      (2, 2, 526.1037),
+      (5, 9, 468.5680),
+      (10, 3, 484.8975),
+      (5, 13, 451.6228),
+      (1, 15, 493.4956),
+      (4, 11, 456.3862)
+    )
+  }
+
   val expectedEmpiricalRisk = 505374.1877
 }


[2/2] flink git commit: [FLINK-5652] [asyncIO] Cancel timers when completing a StreamRecordQueueEntry

Posted by tr...@apache.org.
[FLINK-5652] [asyncIO] Cancel timers when completing a StreamRecordQueueEntry

Whenever a StreamRecordQueueEntry has been completed we no longer need the registered timeout.
Therefore, we have to cancel the corresponding ScheduledFuture so that the system knows that
it can remove the associated TriggerTask. This is important since the TriggerTask contains a
reference on the StreamRecordQueueEntry. Consequently, such a task will prevent the
StreamRecordQueueEntry from being garbage collected.

This closes #3264.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/215776b8
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/215776b8
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/215776b8

Branch: refs/heads/master
Commit: 215776b81a52cd380e8ccabd65da612f77da25e6
Parents: 43d2fd2
Author: Till Rohrmann <tr...@apache.org>
Authored: Fri Feb 3 16:02:55 2017 +0100
Committer: Till Rohrmann <tr...@apache.org>
Committed: Sun Feb 5 21:57:01 2017 +0100

----------------------------------------------------------------------
 .../api/operators/async/AsyncWaitOperator.java  | 13 ++-
 .../operators/async/AsyncWaitOperatorTest.java  | 86 ++++++++++++++++++++
 2 files changed, 98 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/215776b8/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/async/AsyncWaitOperator.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/async/AsyncWaitOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/async/AsyncWaitOperator.java
index 6793620..a70d825 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/async/AsyncWaitOperator.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/async/AsyncWaitOperator.java
@@ -22,6 +22,7 @@ import org.apache.flink.annotation.Internal;
 import org.apache.flink.api.common.state.ListState;
 import org.apache.flink.api.common.state.ListStateDescriptor;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.runtime.concurrent.AcceptFunction;
 import org.apache.flink.runtime.state.StateInitializationContext;
 import org.apache.flink.runtime.state.StateSnapshotContext;
 import org.apache.flink.streaming.api.datastream.AsyncDataStream;
@@ -50,6 +51,7 @@ import org.apache.flink.util.Preconditions;
 import java.util.Collection;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
+import java.util.concurrent.ScheduledFuture;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
 
@@ -203,7 +205,7 @@ public class AsyncWaitOperator<IN, OUT>
 			// register a timeout for this AsyncStreamRecordBufferEntry
 			long timeoutTimestamp = timeout + getProcessingTimeService().getCurrentProcessingTime();
 
-			getProcessingTimeService().registerTimer(
+			final ScheduledFuture<?> timerFuture = getProcessingTimeService().registerTimer(
 				timeoutTimestamp,
 				new ProcessingTimeCallback() {
 					@Override
@@ -212,6 +214,15 @@ public class AsyncWaitOperator<IN, OUT>
 							new TimeoutException("Async function call has timed out."));
 					}
 				});
+
+			// Cancel the timer once we've completed the stream record buffer entry. This will remove
+			// the register trigger task
+			streamRecordBufferEntry.onComplete(new AcceptFunction<StreamElementQueueEntry<Collection<OUT>>>() {
+				@Override
+				public void accept(StreamElementQueueEntry<Collection<OUT>> value) {
+					timerFuture.cancel(true);
+				}
+			}, executor);
 		}
 
 		addAsyncBufferEntry(streamRecordBufferEntry);

http://git-wip-us.apache.org/repos/asf/flink/blob/215776b8/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/async/AsyncWaitOperatorTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/async/AsyncWaitOperatorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/async/AsyncWaitOperatorTest.java
index 4558e06..c2b0803 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/async/AsyncWaitOperatorTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/async/AsyncWaitOperatorTest.java
@@ -49,10 +49,13 @@ import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.api.operators.Output;
 import org.apache.flink.streaming.api.operators.async.queue.StreamElementQueue;
 import org.apache.flink.streaming.api.operators.async.queue.StreamElementQueueEntry;
+import org.apache.flink.streaming.api.operators.async.queue.StreamRecordQueueEntry;
 import org.apache.flink.streaming.api.watermark.Watermark;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.streaming.runtime.tasks.OneInputStreamTask;
 import org.apache.flink.streaming.runtime.tasks.OneInputStreamTaskTestHarness;
+import org.apache.flink.streaming.runtime.tasks.ProcessingTimeCallback;
+import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService;
 import org.apache.flink.streaming.runtime.tasks.StreamMockEnvironment;
 import org.apache.flink.streaming.runtime.tasks.StreamTask;
 import org.apache.flink.streaming.runtime.tasks.TestProcessingTimeService;
@@ -76,12 +79,15 @@ import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
+import java.util.concurrent.ScheduledFuture;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertTrue;
 import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.anyLong;
+import static org.mockito.Matchers.eq;
 import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.doNothing;
 import static org.mockito.Mockito.doReturn;
@@ -801,4 +807,84 @@ public class AsyncWaitOperatorTest extends TestLogger {
 			super.close();
 		}
 	}
+
+	/**
+	 * FLINK-5652
+	 * Tests that registered timers are properly canceled upon completion of a
+	 * {@link StreamRecordQueueEntry} in order to avoid resource leaks because TriggerTasks hold
+	 * a reference on the StreamRecordQueueEntry.
+	 */
+	@Test
+	public void testTimeoutCleanup() throws Exception {
+		final Object lock = new Object();
+
+		final long timeout = 100000L;
+		final long timestamp = 1L;
+
+		Environment environment = mock(Environment.class);
+		when(environment.getMetricGroup()).thenReturn(new UnregisteredTaskMetricsGroup());
+		when(environment.getTaskManagerInfo()).thenReturn(new TestingTaskManagerRuntimeInfo());
+		when(environment.getUserClassLoader()).thenReturn(getClass().getClassLoader());
+		when(environment.getTaskInfo()).thenReturn(new TaskInfo(
+			"testTask",
+			1,
+			0,
+			1,
+			0));
+
+		ScheduledFuture<?> scheduledFuture = mock(ScheduledFuture.class);
+
+		ProcessingTimeService processingTimeService = mock(ProcessingTimeService.class);
+		when(processingTimeService.getCurrentProcessingTime()).thenReturn(timestamp);
+		doReturn(scheduledFuture).when(processingTimeService).registerTimer(anyLong(), any(ProcessingTimeCallback.class));
+
+		StreamTask<?, ?> containingTask = mock(StreamTask.class);
+		when(containingTask.getEnvironment()).thenReturn(environment);
+		when(containingTask.getCheckpointLock()).thenReturn(lock);
+		when(containingTask.getProcessingTimeService()).thenReturn(processingTimeService);
+
+		StreamConfig streamConfig = mock(StreamConfig.class);
+		doReturn(IntSerializer.INSTANCE).when(streamConfig).getTypeSerializerIn1(any(ClassLoader.class));
+
+		Output<StreamRecord<Integer>> output = mock(Output.class);
+
+		AsyncWaitOperator<Integer, Integer> operator = new AsyncWaitOperator<>(
+			new AsyncFunction<Integer, Integer>() {
+				private static final long serialVersionUID = -3718276118074877073L;
+
+				@Override
+				public void asyncInvoke(Integer input, AsyncCollector<Integer> collector) throws Exception {
+					collector.collect(Collections.singletonList(input));
+				}
+			},
+			timeout,
+			1,
+			AsyncDataStream.OutputMode.UNORDERED);
+
+		operator.setup(
+			containingTask,
+			streamConfig,
+			output);
+
+		operator.open();
+
+		final StreamRecord<Integer> streamRecord = new StreamRecord<>(42, timestamp);
+
+		synchronized (lock) {
+			// processing an element will register a timeout
+			operator.processElement(streamRecord);
+		}
+
+		synchronized (lock) {
+			// closing the operator waits until all inputs have been processed
+			operator.close();
+		}
+
+		// check that we actually outputted the result of the single input
+		verify(output).collect(eq(streamRecord));
+		verify(processingTimeService).registerTimer(eq(processingTimeService.getCurrentProcessingTime() + timeout), any(ProcessingTimeCallback.class));
+
+		// check that we have cancelled our registered timeout
+		verify(scheduledFuture).cancel(eq(true));
+	}
 }