You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by dl...@apache.org on 2015/10/20 07:37:12 UTC
[29/32] mahout git commit: MAHOUT-1747: Flink: support for different
key types
MAHOUT-1747: Flink: support for different key types
Project: http://git-wip-us.apache.org/repos/asf/mahout/repo
Commit: http://git-wip-us.apache.org/repos/asf/mahout/commit/19708f46
Tree: http://git-wip-us.apache.org/repos/asf/mahout/tree/19708f46
Diff: http://git-wip-us.apache.org/repos/asf/mahout/diff/19708f46
Branch: refs/heads/flink-binding
Commit: 19708f46a41d0bc437f904a383876858a3fef35d
Parents: 3112f3c
Author: Alexey Grigorev <al...@gmail.com>
Authored: Thu Aug 27 13:58:12 2015 +0200
Committer: Alexey Grigorev <al...@gmail.com>
Committed: Fri Sep 25 17:47:05 2015 +0200
----------------------------------------------------------------------
.../mahout/flinkbindings/FlinkEngine.scala | 15 +++-----
.../mahout/flinkbindings/blas/FlinkOpAewB.scala | 25 ++++++++------
.../mahout/flinkbindings/blas/FlinkOpAt.scala | 5 +--
.../mahout/flinkbindings/blas/FlinkOpAtB.scala | 24 ++++++-------
.../flinkbindings/blas/FlinkOpCBind.scala | 24 ++++++-------
.../mahout/flinkbindings/blas/package.scala | 36 +++++++++++++++++---
.../mahout/flinkbindings/drm/FlinkDrm.scala | 6 ++++
.../apache/mahout/flinkbindings/package.scala | 8 ++++-
.../mahout/flinkbindings/RLikeOpsSuite.scala | 20 +++++++++++
.../math/drm/logical/CheckpointAction.scala | 1 +
10 files changed, 112 insertions(+), 52 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/mahout/blob/19708f46/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala
----------------------------------------------------------------------
diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala
index ab35e78..2c07681 100644
--- a/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala
+++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala
@@ -89,15 +89,6 @@ object FlinkEngine extends DistributedEngine {
override def toPhysical[K: ClassTag](plan: DrmLike[K], ch: CacheHint.CacheHint): CheckpointedDrm[K] = {
// Flink-specific Physical Plan translation.
val drm = flinkTranslate(plan)
-
- // to Help Flink's type inference had to use just one specific type - Int
- // see org.apache.mahout.flinkbindings.blas classes with TODO: casting inside
- // see MAHOUT-1747 and MAHOUT-1748
- val cls = implicitly[ClassTag[K]]
- if (!cls.runtimeClass.equals(classOf[Int])) {
- throw new IllegalArgumentException(s"At the moment only Int indexes are supported. Got $cls")
- }
-
val newcp = new CheckpointedFlinkDrm(ds = drm.deblockify.ds, _nrow = plan.nrow, _ncol = plan.ncol)
newcp.cache()
}
@@ -149,6 +140,8 @@ object FlinkEngine extends DistributedEngine {
FlinkOpCBind.cbindScalar(op, flinkTranslate(a)(op.classTagA), x)
case op @ OpRowRange(a, _) =>
FlinkOpRowRange.slice(op, flinkTranslate(a)(op.classTagA))
+ case op @ OpABAnyKey(a, b) if extractRealClassTag(a) != extractRealClassTag(b) =>
+ throw new IllegalArgumentException("DRMs A and B have different indices, cannot multiply them")
case op: OpMapBlock[K, _] =>
FlinkOpMapBlock.apply(flinkTranslate(op.A)(op.classTagA), op.ncol, op.bmf)
case cp: CheckpointedFlinkDrm[K] => new RowsFlinkDrm(cp.ds, cp.ncol)
@@ -243,7 +236,9 @@ object FlinkEngine extends DistributedEngine {
/** Parallelize in-core matrix as spark distributed matrix, using row labels as a data set keys. */
override def drmParallelizeWithRowLabels(m: Matrix, numPartitions: Int = 1)
- (implicit dc: DistributedContext): CheckpointedDrm[String] = ???
+ (implicit dc: DistributedContext): CheckpointedDrm[String] = {
+ ???
+ }
/** This creates an empty DRM with specified number of partitions and cardinality. */
override def drmParallelizeEmpty(nrow: Int, ncol: Int, numPartitions: Int = 10)
http://git-wip-us.apache.org/repos/asf/mahout/blob/19708f46/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAewB.scala
----------------------------------------------------------------------
diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAewB.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAewB.scala
index 2b35685..460199e 100644
--- a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAewB.scala
+++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAewB.scala
@@ -2,16 +2,17 @@ package org.apache.mahout.flinkbindings.blas
import java.lang.Iterable
-import scala.collection.JavaConverters.asScalaBufferConverter
+import scala.collection.JavaConverters._
import scala.reflect.ClassTag
import org.apache.flink.api.common.functions.CoGroupFunction
import org.apache.flink.api.java.DataSet
import org.apache.flink.util.Collector
+import org.apache.mahout.flinkbindings._
import org.apache.mahout.flinkbindings.drm.FlinkDrm
import org.apache.mahout.flinkbindings.drm.RowsFlinkDrm
import org.apache.mahout.math.Vector
-import org.apache.mahout.math.drm.logical._
+import org.apache.mahout.math.drm.logical.OpAewB
import org.apache.mahout.math.scalabindings.RLikeOps._
import com.google.common.collect.Lists
@@ -25,22 +26,24 @@ object FlinkOpAewB {
def rowWiseJoinNoSideEffect[K: ClassTag](op: OpAewB[K], A: FlinkDrm[K], B: FlinkDrm[K]): FlinkDrm[K] = {
val function = AewBOpsCloning.strToFunction(op.op)
- // TODO: get rid of casts!
- val rowsA = A.deblockify.ds.asInstanceOf[DataSet[(Int, Vector)]]
- val rowsB = B.deblockify.ds.asInstanceOf[DataSet[(Int, Vector)]]
+ val classTag = extractRealClassTag(op.A)
+ val joiner = selector[Vector, Any](classTag.asInstanceOf[ClassTag[Any]])
- val res: DataSet[(Int, Vector)] =
- rowsA.coGroup(rowsB).where(tuple_1[Vector]).equalTo(tuple_1[Vector])
- .`with`(new CoGroupFunction[(Int, Vector), (Int, Vector), (Int, Vector)] {
- def coGroup(it1java: Iterable[(Int, Vector)], it2java: Iterable[(Int, Vector)],
- out: Collector[(Int, Vector)]): Unit = {
+ val rowsA = A.deblockify.ds.asInstanceOf[DrmDataSet[Any]]
+ val rowsB = B.deblockify.ds.asInstanceOf[DrmDataSet[Any]]
+
+ val res: DataSet[(Any, Vector)] =
+ rowsA.coGroup(rowsB).where(joiner).equalTo(joiner)
+ .`with`(new CoGroupFunction[(_, Vector), (_, Vector), (_, Vector)] {
+ def coGroup(it1java: Iterable[(_, Vector)], it2java: Iterable[(_, Vector)],
+ out: Collector[(_, Vector)]): Unit = {
val it1 = Lists.newArrayList(it1java).asScala
val it2 = Lists.newArrayList(it2java).asScala
if (!it1.isEmpty && !it2.isEmpty) {
val (idx, a) = it1.head
val (_, b) = it2.head
- out.collect(idx -> function(a, b))
+ out.collect((idx, function(a, b)))
} else if (it1.isEmpty && !it2.isEmpty) {
out.collect(it2.head)
} else if (!it1.isEmpty && it2.isEmpty) {
http://git-wip-us.apache.org/repos/asf/mahout/blob/19708f46/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAt.scala
----------------------------------------------------------------------
diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAt.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAt.scala
index ac6837c..b859e1f 100644
--- a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAt.scala
+++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAt.scala
@@ -27,6 +27,7 @@ import org.apache.flink.api.common.functions.FlatMapFunction
import org.apache.flink.api.common.functions.GroupReduceFunction
import org.apache.flink.shaded.com.google.common.collect.Lists
import org.apache.flink.util.Collector
+import org.apache.mahout.flinkbindings._
import org.apache.mahout.flinkbindings.drm.FlinkDrm
import org.apache.mahout.flinkbindings.drm.RowsFlinkDrm
import org.apache.mahout.math.Matrix
@@ -66,14 +67,14 @@ object FlinkOpAt {
}
})
- val regrouped = sparseParts.groupBy(tuple_1[Vector])
+ val regrouped = sparseParts.groupBy(selector[Vector, Int])
val sparseTotal = regrouped.reduceGroup(new GroupReduceFunction[(Int, Vector), DrmTuple[Int]] {
def reduce(values: Iterable[(Int, Vector)], out: Collector[DrmTuple[Int]]): Unit = {
val it = Lists.newArrayList(values).asScala
val (idx, _) = it.head
val vector = it map { case (idx, vec) => vec } reduce (_ + _)
- out.collect(idx -> vector)
+ out.collect((idx, vector))
}
})
http://git-wip-us.apache.org/repos/asf/mahout/blob/19708f46/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtB.scala
----------------------------------------------------------------------
diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtB.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtB.scala
index c54e6de..ebb1064 100644
--- a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtB.scala
+++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtB.scala
@@ -26,17 +26,16 @@ import scala.reflect.ClassTag
import org.apache.flink.api.common.functions.FlatMapFunction
import org.apache.flink.api.common.functions.GroupReduceFunction
import org.apache.flink.api.common.functions.MapFunction
+import org.apache.flink.api.java.DataSet
import org.apache.flink.api.java.tuple.Tuple2
import org.apache.flink.util.Collector
-import org.apache.mahout.flinkbindings.BlockifiedDrmDataSet
-import org.apache.mahout.flinkbindings.DrmDataSet
+import org.apache.mahout.flinkbindings._
import org.apache.mahout.flinkbindings.drm.BlockifiedFlinkDrm
import org.apache.mahout.flinkbindings.drm.FlinkDrm
import org.apache.mahout.math.Matrix
import org.apache.mahout.math.Vector
-import org.apache.mahout.math.drm.BlockifiedDrmTuple
+import org.apache.mahout.math.drm._
import org.apache.mahout.math.drm.logical.OpAtB
-import org.apache.mahout.math.drm.safeToNonNegInt
import org.apache.mahout.math.scalabindings.RLikeOps._
import com.google.common.collect.Lists
@@ -50,20 +49,21 @@ import com.google.common.collect.Lists
object FlinkOpAtB {
def notZippable[K: ClassTag](op: OpAtB[K], At: FlinkDrm[K], B: FlinkDrm[K]): FlinkDrm[Int] = {
- // TODO: to help Flink's type inference
- // only Int is supported now
- val rowsAt = At.deblockify.ds.asInstanceOf[DrmDataSet[Int]]
- val rowsB = B.deblockify.ds.asInstanceOf[DrmDataSet[Int]]
- val joined = rowsAt.join(rowsB).where(tuple_1[Vector]).equalTo(tuple_1[Vector])
+ val classTag = extractRealClassTag(op.A)
+ val joiner = selector[Vector, Any](classTag.asInstanceOf[ClassTag[Any]])
+
+ val rowsAt = At.deblockify.ds.asInstanceOf[DrmDataSet[Any]]
+ val rowsB = B.deblockify.ds.asInstanceOf[DrmDataSet[Any]]
+ val joined = rowsAt.join(rowsB).where(joiner).equalTo(joiner)
val ncol = op.ncol
val nrow = op.nrow
val blockHeight = 10
val blockCount = safeToNonNegInt((ncol - 1) / blockHeight + 1)
- val preProduct = joined.flatMap(new FlatMapFunction[Tuple2[(Int, Vector), (Int, Vector)],
+ val preProduct: DataSet[(Int, Matrix)] = joined.flatMap(new FlatMapFunction[Tuple2[(_, Vector), (_, Vector)],
(Int, Matrix)] {
- def flatMap(in: Tuple2[(Int, Vector), (Int, Vector)],
+ def flatMap(in: Tuple2[(_, Vector), (_, Vector)],
out: Collector[(Int, Matrix)]): Unit = {
val avec = in.f0._2
val bvec = in.f1._2
@@ -79,7 +79,7 @@ object FlinkOpAtB {
}
})
- val res: BlockifiedDrmDataSet[Int] = preProduct.groupBy(tuple_1[Matrix]).reduceGroup(
+ val res: BlockifiedDrmDataSet[Int] = preProduct.groupBy(selector[Matrix, Int]).reduceGroup(
new GroupReduceFunction[(Int, Matrix), BlockifiedDrmTuple[Int]] {
def reduce(values: Iterable[(Int, Matrix)], out: Collector[BlockifiedDrmTuple[Int]]): Unit = {
val it = Lists.newArrayList(values).asScala
http://git-wip-us.apache.org/repos/asf/mahout/blob/19708f46/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpCBind.scala
----------------------------------------------------------------------
diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpCBind.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpCBind.scala
index 27237d6..234937b 100644
--- a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpCBind.scala
+++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpCBind.scala
@@ -19,23 +19,21 @@
package org.apache.mahout.flinkbindings.blas
import java.lang.Iterable
-
-
import scala.collection.JavaConverters._
import scala.collection.JavaConversions._
import scala.reflect.ClassTag
-
import org.apache.flink.api.common.functions.CoGroupFunction
import org.apache.flink.api.common.functions.MapFunction
import org.apache.flink.api.java.DataSet
import org.apache.flink.util.Collector
+import org.apache.mahout.flinkbindings._
import org.apache.mahout.flinkbindings.drm._
import org.apache.mahout.math._
import org.apache.mahout.math.drm.logical.OpCbind
import org.apache.mahout.math.drm.logical.OpCbindScalar
import org.apache.mahout.math.scalabindings.RLikeOps._
-
import com.google.common.collect.Lists
+import org.apache.mahout.flinkbindings.DrmDataSet
/**
@@ -49,15 +47,17 @@ object FlinkOpCBind {
val n1 = op.A.ncol
val n2 = op.B.ncol
- // TODO: cast!
- val rowsA = A.deblockify.ds.asInstanceOf[DataSet[(Int, Vector)]]
- val rowsB = B.deblockify.ds.asInstanceOf[DataSet[(Int, Vector)]]
+ val classTag = extractRealClassTag(op.A)
+ val joiner = selector[Vector, Any](classTag.asInstanceOf[ClassTag[Any]])
+
+ val rowsA = A.deblockify.ds.asInstanceOf[DrmDataSet[Any]]
+ val rowsB = B.deblockify.ds.asInstanceOf[DrmDataSet[Any]]
- val res: DataSet[(Int, Vector)] =
- rowsA.coGroup(rowsB).where(tuple_1[Vector]).equalTo(tuple_1[Vector])
- .`with`(new CoGroupFunction[(Int, Vector), (Int, Vector), (Int, Vector)] {
- def coGroup(it1java: Iterable[(Int, Vector)], it2java: Iterable[(Int, Vector)],
- out: Collector[(Int, Vector)]): Unit = {
+ val res: DataSet[(Any, Vector)] =
+ rowsA.coGroup(rowsB).where(joiner).equalTo(joiner)
+ .`with`(new CoGroupFunction[(_, Vector), (_, Vector), (_, Vector)] {
+ def coGroup(it1java: Iterable[(_, Vector)], it2java: Iterable[(_, Vector)],
+ out: Collector[(_, Vector)]): Unit = {
val it1 = Lists.newArrayList(it1java).asScala
val it2 = Lists.newArrayList(it2java).asScala
http://git-wip-us.apache.org/repos/asf/mahout/blob/19708f46/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/package.scala
----------------------------------------------------------------------
diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/package.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/package.scala
index 6868a83..27f552c 100644
--- a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/package.scala
+++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/package.scala
@@ -18,15 +18,43 @@
*/
package org.apache.mahout.flinkbindings
-import org.apache.flink.api.java.functions.KeySelector
-import org.apache.mahout.math.Vector
import scala.reflect.ClassTag
+import org.apache.flink.api.java.functions.KeySelector
+import org.apache.flink.api.java.typeutils.ResultTypeQueryable
+import org.apache.flink.api.scala.createTypeInformation
+import org.apache.flink.api.common.typeinfo.TypeInformation
package object blas {
// TODO: remove it once figure out how to make Flink accept interfaces (Vector here)
- def tuple_1[K: ClassTag] = new KeySelector[(Int, K), Integer] {
- def getKey(tuple: Tuple2[Int, K]): Integer = tuple._1
+ def selector[V, K: ClassTag]: KeySelector[(K, V), K] = {
+ val tag = implicitly[ClassTag[K]]
+ if (tag.runtimeClass.equals(classOf[Int])) {
+ tuple_1_int.asInstanceOf[KeySelector[(K, V), K]]
+ } else if (tag.runtimeClass.equals(classOf[Long])) {
+ tuple_1_long.asInstanceOf[KeySelector[(K, V), K]]
+ } else if (tag.runtimeClass.equals(classOf[String])) {
+ tuple_1_string.asInstanceOf[KeySelector[(K, V), K]]
+ } else {
+ throw new IllegalArgumentException(s"index type $tag is not supported")
+ }
+ }
+
+ private def tuple_1_int[K: ClassTag] = new KeySelector[(Int, _), Int]
+ with ResultTypeQueryable[Int] {
+ def getKey(tuple: Tuple2[Int, _]): Int = tuple._1
+ def getProducedType: TypeInformation[Int] = createTypeInformation[Int]
}
+ private def tuple_1_long[K: ClassTag] = new KeySelector[(Long, _), Long]
+ with ResultTypeQueryable[Long] {
+ def getKey(tuple: Tuple2[Long, _]): Long = tuple._1
+ def getProducedType: TypeInformation[Long] = createTypeInformation[Long]
+ }
+
+ private def tuple_1_string[K: ClassTag] = new KeySelector[(String, _), String]
+ with ResultTypeQueryable[String] {
+ def getKey(tuple: Tuple2[String, _]): String = tuple._1
+ def getProducedType: TypeInformation[String] = createTypeInformation[String]
+ }
}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/mahout/blob/19708f46/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/FlinkDrm.scala
----------------------------------------------------------------------
diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/FlinkDrm.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/FlinkDrm.scala
index 82c0d29..27eac4e 100644
--- a/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/FlinkDrm.scala
+++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/FlinkDrm.scala
@@ -45,6 +45,8 @@ trait FlinkDrm[K] {
def blockify: BlockifiedFlinkDrm[K]
def deblockify: RowsFlinkDrm[K]
+
+ def classTag: ClassTag[K]
}
class RowsFlinkDrm[K: ClassTag](val ds: DrmDataSet[K], val ncol: Int) extends FlinkDrm[K] {
@@ -81,6 +83,8 @@ class RowsFlinkDrm[K: ClassTag](val ds: DrmDataSet[K], val ncol: Int) extends Fl
def deblockify = this
+ def classTag = implicitly[ClassTag[K]]
+
}
class BlockifiedFlinkDrm[K: ClassTag](val ds: BlockifiedDrmDataSet[K], val ncol: Int) extends FlinkDrm[K] {
@@ -104,4 +108,6 @@ class BlockifiedFlinkDrm[K: ClassTag](val ds: BlockifiedDrmDataSet[K], val ncol:
})
new RowsFlinkDrm(out, ncol)
}
+
+ def classTag = implicitly[ClassTag[K]]
}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/mahout/blob/19708f46/flink/src/main/scala/org/apache/mahout/flinkbindings/package.scala
----------------------------------------------------------------------
diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/package.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/package.scala
index aa253ab..57d2f48 100644
--- a/flink/src/main/scala/org/apache/mahout/flinkbindings/package.scala
+++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/package.scala
@@ -20,7 +20,6 @@ package org.apache.mahout
import scala.Array._
import scala.reflect.ClassTag
-
import org.apache.flink.api.common.functions.FilterFunction
import org.apache.flink.api.common.functions.MapFunction
import org.apache.flink.api.java.DataSet
@@ -42,6 +41,7 @@ import org.apache.mahout.math.drm.CheckpointedDrm
import org.apache.mahout.math.drm.DistributedContext
import org.apache.mahout.math.drm.DrmTuple
import org.slf4j.LoggerFactory
+import org.apache.mahout.math.drm.logical.CheckpointAction
package object flinkbindings {
@@ -108,5 +108,11 @@ package object flinkbindings {
new CheckpointedFlinkDrm[K](dataset)
}
+ private[flinkbindings] def extractRealClassTag[K: ClassTag](drm: DrmLike[K]): ClassTag[_] = drm match {
+ case d: CheckpointAction[K] => d.classTag
+ case d: CheckpointedFlinkDrm[K] => d.keyClassTag
+ // will not always return correct result, often result in Any
+ case _ => implicitly[ClassTag[K]]
+ }
}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/mahout/blob/19708f46/flink/src/test/scala/org/apache/mahout/flinkbindings/RLikeOpsSuite.scala
----------------------------------------------------------------------
diff --git a/flink/src/test/scala/org/apache/mahout/flinkbindings/RLikeOpsSuite.scala b/flink/src/test/scala/org/apache/mahout/flinkbindings/RLikeOpsSuite.scala
index 800218b..225a956 100644
--- a/flink/src/test/scala/org/apache/mahout/flinkbindings/RLikeOpsSuite.scala
+++ b/flink/src/test/scala/org/apache/mahout/flinkbindings/RLikeOpsSuite.scala
@@ -294,4 +294,24 @@ class RLikeOpsSuite extends FunSuite with DistributedFlinkSuite {
assert((res.collect(::, 0) - expected).norm(2) < 1e-6)
}
+ test("A.t %*% B with Long keys") {
+ val inCoreA = dense((1, 2), (3, 4), (3, 5))
+ val inCoreB = dense((3, 5), (4, 6), (0, 1))
+
+ val A = drmParallelize(inCoreA, numPartitions = 2).mapBlock()({
+ case (keys, block) => (keys.map(_.toLong), block)
+ })
+
+ val B = drmParallelize(inCoreB, numPartitions = 2).mapBlock()({
+ case (keys, block) => (keys.map(_.toLong), block)
+ })
+
+ val C = A.t %*% B
+ val inCoreC = C.collect
+ val expected = inCoreA.t %*% inCoreB
+
+ (inCoreC - expected).norm should be < 1E-10
+ }
+
+
}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/mahout/blob/19708f46/math-scala/src/main/scala/org/apache/mahout/math/drm/logical/CheckpointAction.scala
----------------------------------------------------------------------
diff --git a/math-scala/src/main/scala/org/apache/mahout/math/drm/logical/CheckpointAction.scala b/math-scala/src/main/scala/org/apache/mahout/math/drm/logical/CheckpointAction.scala
index a7934a3..2324ca2 100644
--- a/math-scala/src/main/scala/org/apache/mahout/math/drm/logical/CheckpointAction.scala
+++ b/math-scala/src/main/scala/org/apache/mahout/math/drm/logical/CheckpointAction.scala
@@ -44,5 +44,6 @@ abstract class CheckpointAction[K: ClassTag] extends DrmLike[K] {
case Some(cp) => cp
}
+ val classTag = implicitly[ClassTag[K]]
}