You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by dl...@apache.org on 2015/10/20 07:37:12 UTC

[29/32] mahout git commit: MAHOUT-1747: Flink: support for different key types

MAHOUT-1747: Flink: support for different key types


Project: http://git-wip-us.apache.org/repos/asf/mahout/repo
Commit: http://git-wip-us.apache.org/repos/asf/mahout/commit/19708f46
Tree: http://git-wip-us.apache.org/repos/asf/mahout/tree/19708f46
Diff: http://git-wip-us.apache.org/repos/asf/mahout/diff/19708f46

Branch: refs/heads/flink-binding
Commit: 19708f46a41d0bc437f904a383876858a3fef35d
Parents: 3112f3c
Author: Alexey Grigorev <al...@gmail.com>
Authored: Thu Aug 27 13:58:12 2015 +0200
Committer: Alexey Grigorev <al...@gmail.com>
Committed: Fri Sep 25 17:47:05 2015 +0200

----------------------------------------------------------------------
 .../mahout/flinkbindings/FlinkEngine.scala      | 15 +++-----
 .../mahout/flinkbindings/blas/FlinkOpAewB.scala | 25 ++++++++------
 .../mahout/flinkbindings/blas/FlinkOpAt.scala   |  5 +--
 .../mahout/flinkbindings/blas/FlinkOpAtB.scala  | 24 ++++++-------
 .../flinkbindings/blas/FlinkOpCBind.scala       | 24 ++++++-------
 .../mahout/flinkbindings/blas/package.scala     | 36 +++++++++++++++++---
 .../mahout/flinkbindings/drm/FlinkDrm.scala     |  6 ++++
 .../apache/mahout/flinkbindings/package.scala   |  8 ++++-
 .../mahout/flinkbindings/RLikeOpsSuite.scala    | 20 +++++++++++
 .../math/drm/logical/CheckpointAction.scala     |  1 +
 10 files changed, 112 insertions(+), 52 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/mahout/blob/19708f46/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala
----------------------------------------------------------------------
diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala
index ab35e78..2c07681 100644
--- a/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala
+++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala
@@ -89,15 +89,6 @@ object FlinkEngine extends DistributedEngine {
   override def toPhysical[K: ClassTag](plan: DrmLike[K], ch: CacheHint.CacheHint): CheckpointedDrm[K] = {
     // Flink-specific Physical Plan translation.
     val drm = flinkTranslate(plan)
-
-    // to Help Flink's type inference had to use just one specific type - Int 
-    // see org.apache.mahout.flinkbindings.blas classes with TODO: casting inside
-    // see MAHOUT-1747 and MAHOUT-1748
-    val cls = implicitly[ClassTag[K]]
-    if (!cls.runtimeClass.equals(classOf[Int])) {
-      throw new IllegalArgumentException(s"At the moment only Int indexes are supported. Got $cls")
-    }
-
     val newcp = new CheckpointedFlinkDrm(ds = drm.deblockify.ds, _nrow = plan.nrow, _ncol = plan.ncol)
     newcp.cache()
   }
@@ -149,6 +140,8 @@ object FlinkEngine extends DistributedEngine {
       FlinkOpCBind.cbindScalar(op, flinkTranslate(a)(op.classTagA), x)
     case op @ OpRowRange(a, _) => 
       FlinkOpRowRange.slice(op, flinkTranslate(a)(op.classTagA))
+    case op @ OpABAnyKey(a, b) if extractRealClassTag(a) != extractRealClassTag(b) =>
+      throw new IllegalArgumentException("DRMs A and B have different indices, cannot multiply them")
     case op: OpMapBlock[K, _] => 
       FlinkOpMapBlock.apply(flinkTranslate(op.A)(op.classTagA), op.ncol, op.bmf)
     case cp: CheckpointedFlinkDrm[K] => new RowsFlinkDrm(cp.ds, cp.ncol)
@@ -243,7 +236,9 @@ object FlinkEngine extends DistributedEngine {
 
   /** Parallelize in-core matrix as spark distributed matrix, using row labels as a data set keys. */
   override def drmParallelizeWithRowLabels(m: Matrix, numPartitions: Int = 1)
-                                          (implicit dc: DistributedContext): CheckpointedDrm[String] = ???
+                                          (implicit dc: DistributedContext): CheckpointedDrm[String] = {
+    ???
+  }
 
   /** This creates an empty DRM with specified number of partitions and cardinality. */
   override def drmParallelizeEmpty(nrow: Int, ncol: Int, numPartitions: Int = 10)

http://git-wip-us.apache.org/repos/asf/mahout/blob/19708f46/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAewB.scala
----------------------------------------------------------------------
diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAewB.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAewB.scala
index 2b35685..460199e 100644
--- a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAewB.scala
+++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAewB.scala
@@ -2,16 +2,17 @@ package org.apache.mahout.flinkbindings.blas
 
 import java.lang.Iterable
 
-import scala.collection.JavaConverters.asScalaBufferConverter
+import scala.collection.JavaConverters._
 import scala.reflect.ClassTag
 
 import org.apache.flink.api.common.functions.CoGroupFunction
 import org.apache.flink.api.java.DataSet
 import org.apache.flink.util.Collector
+import org.apache.mahout.flinkbindings._
 import org.apache.mahout.flinkbindings.drm.FlinkDrm
 import org.apache.mahout.flinkbindings.drm.RowsFlinkDrm
 import org.apache.mahout.math.Vector
-import org.apache.mahout.math.drm.logical._
+import org.apache.mahout.math.drm.logical.OpAewB
 import org.apache.mahout.math.scalabindings.RLikeOps._
 
 import com.google.common.collect.Lists
@@ -25,22 +26,24 @@ object FlinkOpAewB {
   def rowWiseJoinNoSideEffect[K: ClassTag](op: OpAewB[K], A: FlinkDrm[K], B: FlinkDrm[K]): FlinkDrm[K] = {
     val function = AewBOpsCloning.strToFunction(op.op)
 
-    // TODO: get rid of casts!
-    val rowsA = A.deblockify.ds.asInstanceOf[DataSet[(Int, Vector)]]
-    val rowsB = B.deblockify.ds.asInstanceOf[DataSet[(Int, Vector)]]
+    val classTag = extractRealClassTag(op.A)
+    val joiner = selector[Vector, Any](classTag.asInstanceOf[ClassTag[Any]]) 
 
-    val res: DataSet[(Int, Vector)] = 
-      rowsA.coGroup(rowsB).where(tuple_1[Vector]).equalTo(tuple_1[Vector])
-        .`with`(new CoGroupFunction[(Int, Vector), (Int, Vector), (Int, Vector)] {
-      def coGroup(it1java: Iterable[(Int, Vector)], it2java: Iterable[(Int, Vector)], 
-                  out: Collector[(Int, Vector)]): Unit = {
+    val rowsA = A.deblockify.ds.asInstanceOf[DrmDataSet[Any]]
+    val rowsB = B.deblockify.ds.asInstanceOf[DrmDataSet[Any]]
+
+    val res: DataSet[(Any, Vector)] = 
+      rowsA.coGroup(rowsB).where(joiner).equalTo(joiner)
+        .`with`(new CoGroupFunction[(_, Vector), (_, Vector), (_, Vector)] {
+      def coGroup(it1java: Iterable[(_, Vector)], it2java: Iterable[(_, Vector)], 
+                  out: Collector[(_, Vector)]): Unit = {
         val it1 = Lists.newArrayList(it1java).asScala
         val it2 = Lists.newArrayList(it2java).asScala
 
         if (!it1.isEmpty && !it2.isEmpty) {
           val (idx, a) = it1.head
           val (_, b) = it2.head
-          out.collect(idx -> function(a, b))
+          out.collect((idx, function(a, b)))
         } else if (it1.isEmpty && !it2.isEmpty) {
           out.collect(it2.head)
         } else if (!it1.isEmpty && it2.isEmpty) {

http://git-wip-us.apache.org/repos/asf/mahout/blob/19708f46/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAt.scala
----------------------------------------------------------------------
diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAt.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAt.scala
index ac6837c..b859e1f 100644
--- a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAt.scala
+++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAt.scala
@@ -27,6 +27,7 @@ import org.apache.flink.api.common.functions.FlatMapFunction
 import org.apache.flink.api.common.functions.GroupReduceFunction
 import org.apache.flink.shaded.com.google.common.collect.Lists
 import org.apache.flink.util.Collector
+import org.apache.mahout.flinkbindings._
 import org.apache.mahout.flinkbindings.drm.FlinkDrm
 import org.apache.mahout.flinkbindings.drm.RowsFlinkDrm
 import org.apache.mahout.math.Matrix
@@ -66,14 +67,14 @@ object FlinkOpAt {
       }
     })
 
-    val regrouped = sparseParts.groupBy(tuple_1[Vector])
+    val regrouped = sparseParts.groupBy(selector[Vector, Int])
 
     val sparseTotal = regrouped.reduceGroup(new GroupReduceFunction[(Int, Vector), DrmTuple[Int]] {
       def reduce(values: Iterable[(Int, Vector)], out: Collector[DrmTuple[Int]]): Unit = {
         val it = Lists.newArrayList(values).asScala
         val (idx, _) = it.head
         val vector = it map { case (idx, vec) => vec } reduce (_ + _)
-        out.collect(idx -> vector)
+        out.collect((idx, vector))
       }
     })
 

http://git-wip-us.apache.org/repos/asf/mahout/blob/19708f46/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtB.scala
----------------------------------------------------------------------
diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtB.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtB.scala
index c54e6de..ebb1064 100644
--- a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtB.scala
+++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtB.scala
@@ -26,17 +26,16 @@ import scala.reflect.ClassTag
 import org.apache.flink.api.common.functions.FlatMapFunction
 import org.apache.flink.api.common.functions.GroupReduceFunction
 import org.apache.flink.api.common.functions.MapFunction
+import org.apache.flink.api.java.DataSet
 import org.apache.flink.api.java.tuple.Tuple2
 import org.apache.flink.util.Collector
-import org.apache.mahout.flinkbindings.BlockifiedDrmDataSet
-import org.apache.mahout.flinkbindings.DrmDataSet
+import org.apache.mahout.flinkbindings._
 import org.apache.mahout.flinkbindings.drm.BlockifiedFlinkDrm
 import org.apache.mahout.flinkbindings.drm.FlinkDrm
 import org.apache.mahout.math.Matrix
 import org.apache.mahout.math.Vector
-import org.apache.mahout.math.drm.BlockifiedDrmTuple
+import org.apache.mahout.math.drm._
 import org.apache.mahout.math.drm.logical.OpAtB
-import org.apache.mahout.math.drm.safeToNonNegInt
 import org.apache.mahout.math.scalabindings.RLikeOps._
 
 import com.google.common.collect.Lists
@@ -50,20 +49,21 @@ import com.google.common.collect.Lists
 object FlinkOpAtB {
 
   def notZippable[K: ClassTag](op: OpAtB[K], At: FlinkDrm[K], B: FlinkDrm[K]): FlinkDrm[Int] = {
-    // TODO: to help Flink's type inference
-    // only Int is supported now 
-    val rowsAt = At.deblockify.ds.asInstanceOf[DrmDataSet[Int]]
-    val rowsB = B.deblockify.ds.asInstanceOf[DrmDataSet[Int]]
-    val joined = rowsAt.join(rowsB).where(tuple_1[Vector]).equalTo(tuple_1[Vector])
+    val classTag = extractRealClassTag(op.A)
+    val joiner = selector[Vector, Any](classTag.asInstanceOf[ClassTag[Any]]) 
+
+    val rowsAt = At.deblockify.ds.asInstanceOf[DrmDataSet[Any]]
+    val rowsB = B.deblockify.ds.asInstanceOf[DrmDataSet[Any]]
+    val joined = rowsAt.join(rowsB).where(joiner).equalTo(joiner)
 
     val ncol = op.ncol
     val nrow = op.nrow
     val blockHeight = 10
     val blockCount = safeToNonNegInt((ncol - 1) / blockHeight + 1)
 
-    val preProduct = joined.flatMap(new FlatMapFunction[Tuple2[(Int, Vector), (Int, Vector)], 
+    val preProduct: DataSet[(Int, Matrix)] = joined.flatMap(new FlatMapFunction[Tuple2[(_, Vector), (_, Vector)], 
                                                         (Int, Matrix)] {
-      def flatMap(in: Tuple2[(Int, Vector), (Int, Vector)],
+      def flatMap(in: Tuple2[(_, Vector), (_, Vector)],
                   out: Collector[(Int, Matrix)]): Unit = {
         val avec = in.f0._2
         val bvec = in.f1._2
@@ -79,7 +79,7 @@ object FlinkOpAtB {
       }
     })
 
-    val res: BlockifiedDrmDataSet[Int] = preProduct.groupBy(tuple_1[Matrix]).reduceGroup(
+    val res: BlockifiedDrmDataSet[Int] = preProduct.groupBy(selector[Matrix, Int]).reduceGroup(
             new GroupReduceFunction[(Int, Matrix), BlockifiedDrmTuple[Int]] {
       def reduce(values: Iterable[(Int, Matrix)], out: Collector[BlockifiedDrmTuple[Int]]): Unit = {
         val it = Lists.newArrayList(values).asScala

http://git-wip-us.apache.org/repos/asf/mahout/blob/19708f46/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpCBind.scala
----------------------------------------------------------------------
diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpCBind.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpCBind.scala
index 27237d6..234937b 100644
--- a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpCBind.scala
+++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpCBind.scala
@@ -19,23 +19,21 @@
 package org.apache.mahout.flinkbindings.blas
 
 import java.lang.Iterable
-
-
 import scala.collection.JavaConverters._
 import scala.collection.JavaConversions._
 import scala.reflect.ClassTag
-
 import org.apache.flink.api.common.functions.CoGroupFunction
 import org.apache.flink.api.common.functions.MapFunction
 import org.apache.flink.api.java.DataSet
 import org.apache.flink.util.Collector
+import org.apache.mahout.flinkbindings._
 import org.apache.mahout.flinkbindings.drm._
 import org.apache.mahout.math._
 import org.apache.mahout.math.drm.logical.OpCbind
 import org.apache.mahout.math.drm.logical.OpCbindScalar
 import org.apache.mahout.math.scalabindings.RLikeOps._
-
 import com.google.common.collect.Lists
+import org.apache.mahout.flinkbindings.DrmDataSet
 
 
 /**
@@ -49,15 +47,17 @@ object FlinkOpCBind {
     val n1 = op.A.ncol
     val n2 = op.B.ncol
 
-    // TODO: cast!
-    val rowsA = A.deblockify.ds.asInstanceOf[DataSet[(Int, Vector)]]
-    val rowsB = B.deblockify.ds.asInstanceOf[DataSet[(Int, Vector)]]
+    val classTag = extractRealClassTag(op.A)
+    val joiner = selector[Vector, Any](classTag.asInstanceOf[ClassTag[Any]]) 
+
+    val rowsA = A.deblockify.ds.asInstanceOf[DrmDataSet[Any]]
+    val rowsB = B.deblockify.ds.asInstanceOf[DrmDataSet[Any]]
 
-    val res: DataSet[(Int, Vector)] = 
-      rowsA.coGroup(rowsB).where(tuple_1[Vector]).equalTo(tuple_1[Vector])
-        .`with`(new CoGroupFunction[(Int, Vector), (Int, Vector), (Int, Vector)] {
-      def coGroup(it1java: Iterable[(Int, Vector)], it2java: Iterable[(Int, Vector)], 
-                  out: Collector[(Int, Vector)]): Unit = {
+    val res: DataSet[(Any, Vector)] = 
+      rowsA.coGroup(rowsB).where(joiner).equalTo(joiner)
+        .`with`(new CoGroupFunction[(_, Vector), (_, Vector), (_, Vector)] {
+      def coGroup(it1java: Iterable[(_, Vector)], it2java: Iterable[(_, Vector)], 
+                  out: Collector[(_, Vector)]): Unit = {
         val it1 = Lists.newArrayList(it1java).asScala
         val it2 = Lists.newArrayList(it2java).asScala
 

http://git-wip-us.apache.org/repos/asf/mahout/blob/19708f46/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/package.scala
----------------------------------------------------------------------
diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/package.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/package.scala
index 6868a83..27f552c 100644
--- a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/package.scala
+++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/package.scala
@@ -18,15 +18,43 @@
  */
 package org.apache.mahout.flinkbindings
 
-import org.apache.flink.api.java.functions.KeySelector
-import org.apache.mahout.math.Vector
 import scala.reflect.ClassTag
+import org.apache.flink.api.java.functions.KeySelector
+import org.apache.flink.api.java.typeutils.ResultTypeQueryable
+import org.apache.flink.api.scala.createTypeInformation
+import org.apache.flink.api.common.typeinfo.TypeInformation
 
 package object blas {
 
   // TODO: remove it once figure out how to make Flink accept interfaces (Vector here)
-  def tuple_1[K: ClassTag] = new KeySelector[(Int, K), Integer] {
-    def getKey(tuple: Tuple2[Int, K]): Integer = tuple._1
+  def selector[V, K: ClassTag]: KeySelector[(K, V), K] = {
+    val tag = implicitly[ClassTag[K]]
+    if (tag.runtimeClass.equals(classOf[Int])) {
+      tuple_1_int.asInstanceOf[KeySelector[(K, V), K]]
+    } else if (tag.runtimeClass.equals(classOf[Long])) {
+      tuple_1_long.asInstanceOf[KeySelector[(K, V), K]]
+    } else if (tag.runtimeClass.equals(classOf[String])) {
+      tuple_1_string.asInstanceOf[KeySelector[(K, V), K]]
+    } else {
+      throw new IllegalArgumentException(s"index type $tag is not supported")
+    }
+  }
+
+  private def tuple_1_int[K: ClassTag] = new KeySelector[(Int, _), Int] 
+                  with ResultTypeQueryable[Int] {
+    def getKey(tuple: Tuple2[Int, _]): Int = tuple._1
+    def getProducedType: TypeInformation[Int] = createTypeInformation[Int]
   }
 
+  private def tuple_1_long[K: ClassTag] = new KeySelector[(Long, _), Long] 
+                  with ResultTypeQueryable[Long] {
+    def getKey(tuple: Tuple2[Long, _]): Long = tuple._1
+    def getProducedType: TypeInformation[Long] = createTypeInformation[Long]
+  }
+
+  private def tuple_1_string[K: ClassTag] = new KeySelector[(String, _), String] 
+                  with ResultTypeQueryable[String] {
+    def getKey(tuple: Tuple2[String, _]): String = tuple._1
+    def getProducedType: TypeInformation[String] = createTypeInformation[String]
+  }
 }
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/mahout/blob/19708f46/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/FlinkDrm.scala
----------------------------------------------------------------------
diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/FlinkDrm.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/FlinkDrm.scala
index 82c0d29..27eac4e 100644
--- a/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/FlinkDrm.scala
+++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/FlinkDrm.scala
@@ -45,6 +45,8 @@ trait FlinkDrm[K] {
 
   def blockify: BlockifiedFlinkDrm[K]
   def deblockify: RowsFlinkDrm[K]
+
+  def classTag: ClassTag[K]
 }
 
 class RowsFlinkDrm[K: ClassTag](val ds: DrmDataSet[K], val ncol: Int) extends FlinkDrm[K] {
@@ -81,6 +83,8 @@ class RowsFlinkDrm[K: ClassTag](val ds: DrmDataSet[K], val ncol: Int) extends Fl
 
   def deblockify = this
 
+  def classTag = implicitly[ClassTag[K]]
+
 }
 
 class BlockifiedFlinkDrm[K: ClassTag](val ds: BlockifiedDrmDataSet[K], val ncol: Int) extends FlinkDrm[K] {
@@ -104,4 +108,6 @@ class BlockifiedFlinkDrm[K: ClassTag](val ds: BlockifiedDrmDataSet[K], val ncol:
     })
     new RowsFlinkDrm(out, ncol)
   }
+
+  def classTag = implicitly[ClassTag[K]]
 }
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/mahout/blob/19708f46/flink/src/main/scala/org/apache/mahout/flinkbindings/package.scala
----------------------------------------------------------------------
diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/package.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/package.scala
index aa253ab..57d2f48 100644
--- a/flink/src/main/scala/org/apache/mahout/flinkbindings/package.scala
+++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/package.scala
@@ -20,7 +20,6 @@ package org.apache.mahout
 
 import scala.Array._
 import scala.reflect.ClassTag
-
 import org.apache.flink.api.common.functions.FilterFunction
 import org.apache.flink.api.common.functions.MapFunction
 import org.apache.flink.api.java.DataSet
@@ -42,6 +41,7 @@ import org.apache.mahout.math.drm.CheckpointedDrm
 import org.apache.mahout.math.drm.DistributedContext
 import org.apache.mahout.math.drm.DrmTuple
 import org.slf4j.LoggerFactory
+import org.apache.mahout.math.drm.logical.CheckpointAction
 
 package object flinkbindings {
 
@@ -108,5 +108,11 @@ package object flinkbindings {
     new CheckpointedFlinkDrm[K](dataset)
   }
 
+  private[flinkbindings] def extractRealClassTag[K: ClassTag](drm: DrmLike[K]): ClassTag[_] = drm match {
+    case d: CheckpointAction[K] => d.classTag
+    case d: CheckpointedFlinkDrm[K] => d.keyClassTag
+    // will not always return correct result, often result in Any
+    case _ => implicitly[ClassTag[K]]
+  }
 
 }
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/mahout/blob/19708f46/flink/src/test/scala/org/apache/mahout/flinkbindings/RLikeOpsSuite.scala
----------------------------------------------------------------------
diff --git a/flink/src/test/scala/org/apache/mahout/flinkbindings/RLikeOpsSuite.scala b/flink/src/test/scala/org/apache/mahout/flinkbindings/RLikeOpsSuite.scala
index 800218b..225a956 100644
--- a/flink/src/test/scala/org/apache/mahout/flinkbindings/RLikeOpsSuite.scala
+++ b/flink/src/test/scala/org/apache/mahout/flinkbindings/RLikeOpsSuite.scala
@@ -294,4 +294,24 @@ class RLikeOpsSuite extends FunSuite with DistributedFlinkSuite {
     assert((res.collect(::, 0) - expected).norm(2) < 1e-6)
   }
 
+  test("A.t %*% B with Long keys") {
+    val inCoreA = dense((1, 2), (3, 4), (3, 5))
+    val inCoreB = dense((3, 5), (4, 6), (0, 1))
+
+    val A = drmParallelize(inCoreA, numPartitions = 2).mapBlock()({
+      case (keys, block) => (keys.map(_.toLong), block)
+    })
+
+    val B = drmParallelize(inCoreB, numPartitions = 2).mapBlock()({
+      case (keys, block) => (keys.map(_.toLong), block)
+    })
+
+    val C = A.t %*% B
+    val inCoreC = C.collect
+    val expected = inCoreA.t %*% inCoreB
+
+    (inCoreC - expected).norm should be < 1E-10
+  }
+
+
 }
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/mahout/blob/19708f46/math-scala/src/main/scala/org/apache/mahout/math/drm/logical/CheckpointAction.scala
----------------------------------------------------------------------
diff --git a/math-scala/src/main/scala/org/apache/mahout/math/drm/logical/CheckpointAction.scala b/math-scala/src/main/scala/org/apache/mahout/math/drm/logical/CheckpointAction.scala
index a7934a3..2324ca2 100644
--- a/math-scala/src/main/scala/org/apache/mahout/math/drm/logical/CheckpointAction.scala
+++ b/math-scala/src/main/scala/org/apache/mahout/math/drm/logical/CheckpointAction.scala
@@ -44,5 +44,6 @@ abstract class CheckpointAction[K: ClassTag] extends DrmLike[K] {
     case Some(cp) => cp
   }
 
+  val classTag = implicitly[ClassTag[K]]
 }