You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by dl...@apache.org on 2015/10/20 07:37:05 UTC
[22/32] mahout git commit: MAHOUT-1751: Flink: AtA slim
MAHOUT-1751: Flink: AtA slim
Project: http://git-wip-us.apache.org/repos/asf/mahout/repo
Commit: http://git-wip-us.apache.org/repos/asf/mahout/commit/9d48487c
Tree: http://git-wip-us.apache.org/repos/asf/mahout/tree/9d48487c
Diff: http://git-wip-us.apache.org/repos/asf/mahout/diff/9d48487c
Branch: refs/heads/flink-binding
Commit: 9d48487cf2a7060193b305c152b8ce191f7d15be
Parents: ceb1f05
Author: Alexey Grigorev <al...@gmail.com>
Authored: Fri Aug 21 15:33:32 2015 +0200
Committer: Alexey Grigorev <al...@gmail.com>
Committed: Fri Sep 25 17:42:50 2015 +0200
----------------------------------------------------------------------
.../mahout/flinkbindings/FlinkEngine.scala | 12 +---
.../mahout/flinkbindings/blas/FlinkOpAtA.scala | 74 ++++++++++++++++++++
.../mahout/flinkbindings/blas/FlinkOpAtB.scala | 1 +
.../drm/CheckpointedFlinkDrm.scala | 27 ++++---
4 files changed, 95 insertions(+), 19 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/mahout/blob/9d48487c/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala
----------------------------------------------------------------------
diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala
index 3076933..8e47629 100644
--- a/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala
+++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala
@@ -19,10 +19,8 @@
package org.apache.mahout.flinkbindings
import java.util.Collection
-
import scala.collection.JavaConverters._
import scala.reflect.ClassTag
-
import org.apache.flink.api.common.functions.MapFunction
import org.apache.flink.api.common.functions.ReduceFunction
import org.apache.flink.api.java.tuple.Tuple2
@@ -79,6 +77,7 @@ import org.apache.mahout.math.indexeddataset.IndexedDataset
import org.apache.mahout.math.indexeddataset.Schema
import org.apache.mahout.math.scalabindings._
import org.apache.mahout.math.scalabindings.RLikeOps._
+import org.apache.mahout.flinkbindings.blas.FlinkOpAtA
object FlinkEngine extends DistributedEngine {
@@ -165,14 +164,7 @@ object FlinkEngine extends DistributedEngine {
FlinkOpAtB.notZippable(OpAtB(c, d), flinkTranslate(c), flinkTranslate(d))
.asInstanceOf[FlinkDrm[K]]
}
- case op @ OpAtA(a) => {
- // express AtA via AtB
- // TODO: create specific implementation of AtA, see MAHOUT-1751
- val aInt = a.asInstanceOf[DrmLike[Int]] // TODO: casts!
- val opAtB = OpAtB(aInt, aInt)
- val aTranslated = flinkTranslate(aInt)
- FlinkOpAtB.notZippable(opAtB, aTranslated, aTranslated)
- }
+ case op @ OpAtA(a) => FlinkOpAtA.at_a(op, flinkTranslate(a)(op.classTagA))
case op @ OpTimesRightMatrix(a, b) =>
FlinkOpTimesRightMatrix.drmTimesInCore(op, flinkTranslate(a)(op.classTagA), b)
case op @ OpAewUnaryFunc(a, f, _) =>
http://git-wip-us.apache.org/repos/asf/mahout/blob/9d48487c/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtA.scala
----------------------------------------------------------------------
diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtA.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtA.scala
new file mode 100644
index 0000000..63d1845
--- /dev/null
+++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtA.scala
@@ -0,0 +1,74 @@
+package org.apache.mahout.flinkbindings.blas
+
+import java.lang.Iterable
+import scala.collection.JavaConverters.asScalaBufferConverter
+import scala.reflect.ClassTag
+import org.apache.mahout.math.drm._
+import org.apache.flink.api.common.functions.CoGroupFunction
+import org.apache.flink.api.java.DataSet
+import org.apache.flink.util.Collector
+import org.apache.mahout.flinkbindings._
+import org.apache.mahout.flinkbindings.drm.FlinkDrm
+import org.apache.mahout.flinkbindings.drm.RowsFlinkDrm
+import org.apache.mahout.math._
+import org.apache.mahout.math.Vector
+import org.apache.mahout.math.drm._
+import org.apache.mahout.math.drm.logical._
+import org.apache.mahout.math.scalabindings.RLikeOps._
+import com.google.common.collect.Lists
+import org.apache.flink.shaded.com.google.common.collect.Lists
+import org.apache.flink.util.Collector
+import org.apache.mahout.flinkbindings.drm.FlinkDrm
+import org.apache.mahout.flinkbindings.drm.RowsFlinkDrm
+import org.apache.mahout.math.Matrix
+import org.apache.mahout.math.SequentialAccessSparseVector
+import org.apache.mahout.math.Vector
+import org.apache.mahout.math.drm.DrmTuple
+import org.apache.mahout.math.drm.logical.OpAt
+import org.apache.mahout.math.scalabindings.RLikeOps._
+import org.apache.flink.api.common.functions.MapFunction
+import org.apache.flink.api.common.functions.ReduceFunction
+
+
+/**
+ */
+object FlinkOpAtA {
+
+ final val PROPERTY_ATA_MAXINMEMNCOL = "mahout.math.AtA.maxInMemNCol"
+ final val PROPERTY_ATA_MAXINMEMNCOL_DEFAULT = "200"
+
+
+ def at_a(op: OpAtA[_], A: FlinkDrm[_]): FlinkDrm[Int] = {
+ val maxInMemStr = System.getProperty(PROPERTY_ATA_MAXINMEMNCOL, PROPERTY_ATA_MAXINMEMNCOL_DEFAULT)
+ val maxInMemNCol = maxInMemStr.toInt
+ maxInMemNCol.ensuring(_ > 0, "Invalid A'A in-memory setting for optimizer")
+
+ if (op.ncol <= maxInMemNCol) {
+ implicit val ctx = A.context
+ val inCoreAtA = slim(op, A)
+ val result = drmParallelize(inCoreAtA, numPartitions = 1)
+ result
+ } else {
+ fat(op, A)
+ }
+ }
+
+ def slim(op: OpAtA[_], A: FlinkDrm[_]): Matrix = {
+ val ds = A.blockify.ds.asInstanceOf[DataSet[(Array[Any], Matrix)]]
+
+ val res = ds.map(new MapFunction[(Array[Any], Matrix), Matrix] {
+ // TODO: optimize it: use upper-triangle matrices like in Spark
+ def map(block: (Array[Any], Matrix)): Matrix = block match {
+ case (idx, m) => m.t %*% m
+ }
+ }).reduce(new ReduceFunction[Matrix] {
+ def reduce(m1: Matrix, m2: Matrix) = m1 + m2
+ }).collect()
+
+ res.asScala.head
+ }
+
+ def fat(op: OpAtA[_], A: FlinkDrm[_]): FlinkDrm[Int] = {
+ throw new NotImplementedError("fat matrices are not yet supported")
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/9d48487c/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtB.scala
----------------------------------------------------------------------
diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtB.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtB.scala
index 297f676..f02cd84 100644
--- a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtB.scala
+++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtB.scala
@@ -42,6 +42,7 @@ import org.apache.mahout.math.scalabindings.RLikeOps._
import com.google.common.collect.Lists
+
/**
* Implementation is taken from Spark's AtB
* https://github.com/apache/mahout/blob/master/spark/src/main/scala/org/apache/mahout/sparkbindings/blas/AtB.scala
http://git-wip-us.apache.org/repos/asf/mahout/blob/9d48487c/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/CheckpointedFlinkDrm.scala
----------------------------------------------------------------------
diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/CheckpointedFlinkDrm.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/CheckpointedFlinkDrm.scala
index e29b80c..f58e05b 100644
--- a/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/CheckpointedFlinkDrm.scala
+++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/CheckpointedFlinkDrm.scala
@@ -94,26 +94,35 @@ class CheckpointedFlinkDrm[K: ClassTag](val ds: DrmDataSet[K],
val data = ds.collect().asScala.toList
val isDense = data.forall(_._2.isDense)
+ val cols = ncol
+ val rows = safeToNonNegInt(nrow)
+
val m = if (isDense) {
- val cols = data.head._2.size()
- val rows = data.length
new DenseMatrix(rows, cols)
} else {
- val cols = ncol
- val rows = safeToNonNegInt(nrow)
new SparseMatrix(rows, cols)
}
val intRowIndices = keyClassTag == implicitly[ClassTag[Int]]
- if (intRowIndices)
- data.foreach(t => m(t._1.asInstanceOf[Int], ::) := t._2)
- else {
+ if (intRowIndices) {
+ data.foreach { case (t, vec) =>
+ val idx = t.asInstanceOf[Int]
+ m(idx, ::) := vec
+ }
+
+ println(m.ncol, m.nrow)
+ } else {
// assign all rows sequentially
val d = data.zipWithIndex
- d.foreach(t => m(t._2, ::) := t._1._2)
+ d.foreach {
+ case ((_, vec), idx) => m(idx, ::) := vec
+ }
+
+ val rowBindings = d.map {
+ case ((t, _), idx) => (t.toString, idx: java.lang.Integer)
+ }.toMap.asJava
- val rowBindings = d.map(t => (t._1._1.toString, t._2: java.lang.Integer)).toMap.asJava
m.setRowLabelBindings(rowBindings)
}