You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by dl...@apache.org on 2015/10/20 07:37:14 UTC
[31/32] mahout git commit: MAHOUT-1751: Flink: fat AtA
MAHOUT-1751: Flink: fat AtA
Project: http://git-wip-us.apache.org/repos/asf/mahout/repo
Commit: http://git-wip-us.apache.org/repos/asf/mahout/commit/78c9ac2e
Tree: http://git-wip-us.apache.org/repos/asf/mahout/tree/78c9ac2e
Diff: http://git-wip-us.apache.org/repos/asf/mahout/diff/78c9ac2e
Branch: refs/heads/flink-binding
Commit: 78c9ac2ea3b14cdf20d38fbc3f7bbee1b059d610
Parents: 137b3b8
Author: Alexey Grigorev <al...@gmail.com>
Authored: Fri Sep 25 17:27:44 2015 +0200
Committer: Alexey Grigorev <al...@gmail.com>
Committed: Fri Sep 25 17:48:16 2015 +0200
----------------------------------------------------------------------
.../mahout/flinkbindings/FlinkEngine.scala | 1 +
.../mahout/flinkbindings/blas/FlinkOpAtA.scala | 141 +++++++++++++++----
.../drm/CheckpointedFlinkDrm.scala | 7 +-
.../mahout/flinkbindings/blas/LATestSuite.scala | 27 ++++
4 files changed, 147 insertions(+), 29 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/mahout/blob/78c9ac2e/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala
----------------------------------------------------------------------
diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala
index 2c07681..fee3d73 100644
--- a/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala
+++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala
@@ -128,6 +128,7 @@ object FlinkEngine extends DistributedEngine {
FlinkOpAewScalar.opUnaryFunction(op, flinkTranslate(a)(op.classTagA))
case op @ OpAewUnaryFuncFusion(a, _) =>
FlinkOpAewScalar.opUnaryFunction(op, flinkTranslate(a)(op.classTagA))
+ // deprecated
case op @ OpAewScalar(a, scalar, _) =>
FlinkOpAewScalar.opScalarNoSideEffect(op, flinkTranslate(a)(op.classTagA), scalar)
case op @ OpAewB(a, b, _) =>
http://git-wip-us.apache.org/repos/asf/mahout/blob/78c9ac2e/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtA.scala
----------------------------------------------------------------------
diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtA.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtA.scala
index 63d1845..0bda805 100644
--- a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtA.scala
+++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtA.scala
@@ -1,43 +1,34 @@
package org.apache.mahout.flinkbindings.blas
import java.lang.Iterable
-import scala.collection.JavaConverters.asScalaBufferConverter
-import scala.reflect.ClassTag
-import org.apache.mahout.math.drm._
-import org.apache.flink.api.common.functions.CoGroupFunction
+
+import scala.collection.JavaConverters._
+
+import org.apache.flink.api.common.functions._
import org.apache.flink.api.java.DataSet
-import org.apache.flink.util.Collector
-import org.apache.mahout.flinkbindings._
-import org.apache.mahout.flinkbindings.drm.FlinkDrm
-import org.apache.mahout.flinkbindings.drm.RowsFlinkDrm
-import org.apache.mahout.math._
-import org.apache.mahout.math.Vector
-import org.apache.mahout.math.drm._
-import org.apache.mahout.math.drm.logical._
-import org.apache.mahout.math.scalabindings.RLikeOps._
-import com.google.common.collect.Lists
+import org.apache.flink.configuration.Configuration
import org.apache.flink.shaded.com.google.common.collect.Lists
import org.apache.flink.util.Collector
-import org.apache.mahout.flinkbindings.drm.FlinkDrm
-import org.apache.mahout.flinkbindings.drm.RowsFlinkDrm
+import org.apache.mahout.flinkbindings._
+import org.apache.mahout.flinkbindings.drm._
import org.apache.mahout.math.Matrix
-import org.apache.mahout.math.SequentialAccessSparseVector
-import org.apache.mahout.math.Vector
-import org.apache.mahout.math.drm.DrmTuple
-import org.apache.mahout.math.drm.logical.OpAt
+import org.apache.mahout.math.drm._
+import org.apache.mahout.math.drm.BlockifiedDrmTuple
+import org.apache.mahout.math.drm.logical.OpAtA
+import org.apache.mahout.math.scalabindings._
import org.apache.mahout.math.scalabindings.RLikeOps._
-import org.apache.flink.api.common.functions.MapFunction
-import org.apache.flink.api.common.functions.ReduceFunction
/**
+ * Inspired by Spark's implementation from
+ * https://github.com/apache/mahout/blob/master/spark/src/main/scala/org/apache/mahout/sparkbindings/blas/AtA.scala
+ *
*/
object FlinkOpAtA {
final val PROPERTY_ATA_MAXINMEMNCOL = "mahout.math.AtA.maxInMemNCol"
final val PROPERTY_ATA_MAXINMEMNCOL_DEFAULT = "200"
-
def at_a(op: OpAtA[_], A: FlinkDrm[_]): FlinkDrm[Int] = {
val maxInMemStr = System.getProperty(PROPERTY_ATA_MAXINMEMNCOL, PROPERTY_ATA_MAXINMEMNCOL_DEFAULT)
val maxInMemNCol = maxInMemStr.toInt
@@ -49,7 +40,7 @@ object FlinkOpAtA {
val result = drmParallelize(inCoreAtA, numPartitions = 1)
result
} else {
- fat(op, A)
+ fat(op.asInstanceOf[OpAtA[Any]], A.asInstanceOf[FlinkDrm[Any]])
}
}
@@ -68,7 +59,103 @@ object FlinkOpAtA {
res.asScala.head
}
- def fat(op: OpAtA[_], A: FlinkDrm[_]): FlinkDrm[Int] = {
- throw new NotImplementedError("fat matrices are not yet supported")
+ def fat(op: OpAtA[Any], A: FlinkDrm[Any]): FlinkDrm[Int] = {
+ val nrow = op.A.nrow
+ val ncol = op.A.ncol
+ val ds = A.blockify.ds
+
+ val numberOfPartitions: DataSet[Int] = ds.map(new MapFunction[(Array[Any], Matrix), Int] {
+ def map(a: (Array[Any], Matrix)): Int = 1
+ }).reduce(new ReduceFunction[Int] {
+ def reduce(a: Int, b: Int): Int = a + b
+ })
+
+ val subresults: DataSet[(Int, Matrix)] =
+ ds.flatMap(new RichFlatMapFunction[(Array[Any], Matrix), (Int, Matrix)] {
+
+ var ranges: Array[Range] = null
+
+ override def open(params: Configuration): Unit = {
+ val runtime = this.getRuntimeContext()
+ val dsX: java.util.List[Int] = runtime.getBroadcastVariable("numberOfPartitions")
+ val parts = dsX.get(0)
+ val numParts = estimatePartitions(nrow, ncol, parts)
+ ranges = computeEvenSplits(ncol, numParts)
+ }
+
+ def flatMap(tuple: (Array[Any], Matrix), out: Collector[(Int, Matrix)]): Unit = {
+ val block = tuple._2
+
+ ranges.zipWithIndex.foreach { case (range, idx) =>
+ out.collect(idx -> block(::, range).t %*% block)
+ }
+ }
+
+ }).withBroadcastSet(numberOfPartitions, "numberOfPartitions")
+
+ val res = subresults.groupBy(selector[Matrix, Int])
+ .reduceGroup(new RichGroupReduceFunction[(Int, Matrix), BlockifiedDrmTuple[Int]] {
+
+ var ranges: Array[Range] = null
+
+ override def open(params: Configuration): Unit = {
+ val runtime = this.getRuntimeContext()
+ val dsX: java.util.List[Int] = runtime.getBroadcastVariable("numberOfPartitions")
+ val parts = dsX.get(0)
+ val numParts = estimatePartitions(nrow, ncol, parts)
+ ranges = computeEvenSplits(ncol, numParts)
+ }
+
+ def reduce(values: Iterable[(Int, Matrix)], out: Collector[BlockifiedDrmTuple[Int]]): Unit = {
+ val it = Lists.newArrayList(values).asScala
+ val (blockKey, _) = it.head
+
+ val block = it.map { _._2 }.reduce { (m1, m2) => m1 + m2 }
+ val blockStart = ranges(blockKey).start
+ val rowKeys = Array.tabulate(block.nrow)(blockStart + _)
+
+ out.collect(rowKeys -> block)
+ }
+ }).withBroadcastSet(numberOfPartitions, "numberOfPartitions")
+
+ new BlockifiedFlinkDrm(res, ncol)
+ }
+
+ def estimatePartitions(nrow: Long, ncol: Int, parts:Int): Int = {
+ // per-partition element density
+ val epp = nrow.toDouble * ncol / parts
+
+ // product partitions
+ val prodParts = nrow * ncol / epp
+
+ val nparts = math.round(prodParts).toInt max 1
+
+ // Constrain nparts to maximum of anrow to prevent guaranteed empty partitions.
+ if (nparts > nrow) {
+ nrow.toInt
+ } else {
+ nparts
+ }
+ }
+
+ def computeEvenSplits(nrow: Long, numSplits: Int): Array[Range] = {
+ require(numSplits <= nrow, "Requested amount of splits greater than number of data points.")
+ require(nrow >= 1)
+ require(numSplits >= 1)
+
+ // Base split -- what is our base split size?
+ val baseSplit = safeToNonNegInt(nrow / numSplits)
+
+ // Slack -- how many splits will have to be incremented by 1 though?
+ val slack = safeToNonNegInt(nrow % numSplits)
+
+ // Compute ranges. We need to set ranges so that numSplits - slack splits have size of baseSplit;
+ // and `slack` splits have size baseSplit + 1. Here is how we do it: First, we compute the range
+ // offsets:
+ val offsets = (0 to numSplits).map(i => i * (baseSplit + 1) - (0 max i - slack))
+ // And then we connect the ranges using gaps between offsets:
+
+ val ranges = offsets.sliding(2).map { offs => (offs(0) until offs(1)) }
+ ranges.toArray
}
-}
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/mahout/blob/78c9ac2e/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/CheckpointedFlinkDrm.scala
----------------------------------------------------------------------
diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/CheckpointedFlinkDrm.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/CheckpointedFlinkDrm.scala
index f58e05b..ee392b0 100644
--- a/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/CheckpointedFlinkDrm.scala
+++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/CheckpointedFlinkDrm.scala
@@ -82,7 +82,10 @@ class CheckpointedFlinkDrm[K: ClassTag](val ds: DrmDataSet[K],
this
}
- def uncache = ???
+ def uncache() = {
+ // TODO
+ this
+ }
// Members declared in org.apache.mahout.math.drm.DrmLike
@@ -160,7 +163,7 @@ class CheckpointedFlinkDrm[K: ClassTag](val ds: DrmDataSet[K],
(x: K) => new LongWritable(x.asInstanceOf[Long])
} else if (classOf[Writable].isAssignableFrom(keyTag.runtimeClass)) {
(x: K) => x.asInstanceOf[Writable]
- } else {
+ } else {
throw new IllegalArgumentException("Do not know how to convert class tag %s to Writable.".format(keyTag))
}
}
http://git-wip-us.apache.org/repos/asf/mahout/blob/78c9ac2e/flink/src/test/scala/org/apache/mahout/flinkbindings/blas/LATestSuite.scala
----------------------------------------------------------------------
diff --git a/flink/src/test/scala/org/apache/mahout/flinkbindings/blas/LATestSuite.scala b/flink/src/test/scala/org/apache/mahout/flinkbindings/blas/LATestSuite.scala
index 42c1f63..786ab5f 100644
--- a/flink/src/test/scala/org/apache/mahout/flinkbindings/blas/LATestSuite.scala
+++ b/flink/src/test/scala/org/apache/mahout/flinkbindings/blas/LATestSuite.scala
@@ -185,4 +185,31 @@ class LATestSuite extends FunSuite with DistributedFlinkSuite {
assert((output - expected).norm < 1e-6)
}
+ test("At A slim") {
+ val inCoreA = dense((1, 2, 3), (2, 3, 1), (3, 4, 4), (4, 4, 5), (5, 5, 7), (6, 7, 11))
+ val A = drmParallelize(m = inCoreA, numPartitions = 2)
+
+ val op = new OpAtA(A)
+ val output = FlinkOpAtA.slim(op, A)
+
+ val expected = inCoreA.t %*% inCoreA
+ assert((output - expected).norm < 1e-6)
+ }
+
+ test("At A fat") {
+ val inCoreA = dense((1, 2, 3, 2, 3, 1), (3, 4, 4, 4, 4, 5), (5, 5, 7, 6, 7, 11))
+ val A = drmParallelize(m = inCoreA, numPartitions = 2)
+ val Aany = A.asInstanceOf[CheckpointedDrm[Any]]
+
+ val op = new OpAtA(Aany)
+
+ val res = FlinkOpAtA.fat(op, Aany)
+ val drm = new CheckpointedFlinkDrm(res.deblockify.ds, _nrow=op.nrow, _ncol=op.ncol)
+ val output = drm.collect
+ println(output)
+
+ val expected = inCoreA.t %*% inCoreA
+ assert((output - expected).norm < 1e-6)
+ }
+
}
\ No newline at end of file