You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by dl...@apache.org on 2015/10/20 07:37:05 UTC

[22/32] mahout git commit: MAHOUT-1751: Flink: AtA slim

MAHOUT-1751: Flink: AtA slim


Project: http://git-wip-us.apache.org/repos/asf/mahout/repo
Commit: http://git-wip-us.apache.org/repos/asf/mahout/commit/9d48487c
Tree: http://git-wip-us.apache.org/repos/asf/mahout/tree/9d48487c
Diff: http://git-wip-us.apache.org/repos/asf/mahout/diff/9d48487c

Branch: refs/heads/flink-binding
Commit: 9d48487cf2a7060193b305c152b8ce191f7d15be
Parents: ceb1f05
Author: Alexey Grigorev <al...@gmail.com>
Authored: Fri Aug 21 15:33:32 2015 +0200
Committer: Alexey Grigorev <al...@gmail.com>
Committed: Fri Sep 25 17:42:50 2015 +0200

----------------------------------------------------------------------
 .../mahout/flinkbindings/FlinkEngine.scala      | 12 +---
 .../mahout/flinkbindings/blas/FlinkOpAtA.scala  | 74 ++++++++++++++++++++
 .../mahout/flinkbindings/blas/FlinkOpAtB.scala  |  1 +
 .../drm/CheckpointedFlinkDrm.scala              | 27 ++++---
 4 files changed, 95 insertions(+), 19 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/mahout/blob/9d48487c/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala
----------------------------------------------------------------------
diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala
index 3076933..8e47629 100644
--- a/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala
+++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala
@@ -19,10 +19,8 @@
 package org.apache.mahout.flinkbindings
 
 import java.util.Collection
-
 import scala.collection.JavaConverters._
 import scala.reflect.ClassTag
-
 import org.apache.flink.api.common.functions.MapFunction
 import org.apache.flink.api.common.functions.ReduceFunction
 import org.apache.flink.api.java.tuple.Tuple2
@@ -79,6 +77,7 @@ import org.apache.mahout.math.indexeddataset.IndexedDataset
 import org.apache.mahout.math.indexeddataset.Schema
 import org.apache.mahout.math.scalabindings._
 import org.apache.mahout.math.scalabindings.RLikeOps._
+import org.apache.mahout.flinkbindings.blas.FlinkOpAtA
 
 object FlinkEngine extends DistributedEngine {
 
@@ -165,14 +164,7 @@ object FlinkEngine extends DistributedEngine {
       FlinkOpAtB.notZippable(OpAtB(c, d), flinkTranslate(c), flinkTranslate(d))
                 .asInstanceOf[FlinkDrm[K]]
     }
-    case op @ OpAtA(a) => {
-      // express AtA via AtB
-      // TODO: create specific implementation of AtA, see MAHOUT-1751 
-      val aInt = a.asInstanceOf[DrmLike[Int]] // TODO: casts!
-      val opAtB = OpAtB(aInt, aInt)
-      val aTranslated = flinkTranslate(aInt)
-      FlinkOpAtB.notZippable(opAtB, aTranslated, aTranslated)
-    }
+    case op @ OpAtA(a) => FlinkOpAtA.at_a(op, flinkTranslate(a)(op.classTagA))
     case op @ OpTimesRightMatrix(a, b) => 
       FlinkOpTimesRightMatrix.drmTimesInCore(op, flinkTranslate(a)(op.classTagA), b)
     case op @ OpAewUnaryFunc(a, f, _) =>

http://git-wip-us.apache.org/repos/asf/mahout/blob/9d48487c/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtA.scala
----------------------------------------------------------------------
diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtA.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtA.scala
new file mode 100644
index 0000000..63d1845
--- /dev/null
+++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtA.scala
@@ -0,0 +1,74 @@
+package org.apache.mahout.flinkbindings.blas
+
+import java.lang.Iterable
+import scala.collection.JavaConverters.asScalaBufferConverter
+import scala.reflect.ClassTag
+import org.apache.mahout.math.drm._
+import org.apache.flink.api.common.functions.CoGroupFunction
+import org.apache.flink.api.java.DataSet
+import org.apache.flink.util.Collector
+import org.apache.mahout.flinkbindings._
+import org.apache.mahout.flinkbindings.drm.FlinkDrm
+import org.apache.mahout.flinkbindings.drm.RowsFlinkDrm
+import org.apache.mahout.math._
+import org.apache.mahout.math.Vector
+import org.apache.mahout.math.drm._
+import org.apache.mahout.math.drm.logical._
+import org.apache.mahout.math.scalabindings.RLikeOps._
+import com.google.common.collect.Lists
+import org.apache.flink.shaded.com.google.common.collect.Lists
+import org.apache.flink.util.Collector
+import org.apache.mahout.flinkbindings.drm.FlinkDrm
+import org.apache.mahout.flinkbindings.drm.RowsFlinkDrm
+import org.apache.mahout.math.Matrix
+import org.apache.mahout.math.SequentialAccessSparseVector
+import org.apache.mahout.math.Vector
+import org.apache.mahout.math.drm.DrmTuple
+import org.apache.mahout.math.drm.logical.OpAt
+import org.apache.mahout.math.scalabindings.RLikeOps._
+import org.apache.flink.api.common.functions.MapFunction
+import org.apache.flink.api.common.functions.ReduceFunction
+
+
+/**
+ */
+object FlinkOpAtA {
+
+  final val PROPERTY_ATA_MAXINMEMNCOL = "mahout.math.AtA.maxInMemNCol"
+  final val PROPERTY_ATA_MAXINMEMNCOL_DEFAULT = "200"
+
+
+  def at_a(op: OpAtA[_], A: FlinkDrm[_]): FlinkDrm[Int] = {
+    val maxInMemStr = System.getProperty(PROPERTY_ATA_MAXINMEMNCOL, PROPERTY_ATA_MAXINMEMNCOL_DEFAULT)
+    val maxInMemNCol = maxInMemStr.toInt
+    maxInMemNCol.ensuring(_ > 0, "Invalid A'A in-memory setting for optimizer")
+
+    if (op.ncol <= maxInMemNCol) {
+      implicit val ctx = A.context
+      val inCoreAtA = slim(op, A)
+      val result = drmParallelize(inCoreAtA, numPartitions = 1)
+      result
+    } else {
+      fat(op, A)
+    }
+  }
+
+  def slim(op: OpAtA[_], A: FlinkDrm[_]): Matrix = {
+    val ds = A.blockify.ds.asInstanceOf[DataSet[(Array[Any], Matrix)]]
+
+    val res = ds.map(new MapFunction[(Array[Any], Matrix), Matrix] {
+      // TODO: optimize it: use upper-triangle matrices like in Spark
+      def map(block: (Array[Any], Matrix)): Matrix =  block match {
+        case (idx, m) => m.t %*% m
+      }
+    }).reduce(new ReduceFunction[Matrix] {
+      def reduce(m1: Matrix, m2: Matrix) = m1 + m2
+    }).collect()
+
+    res.asScala.head
+  }
+
+  def fat(op: OpAtA[_], A: FlinkDrm[_]): FlinkDrm[Int] = {
+    throw new NotImplementedError("fat matrices are not yet supported")
+  }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/9d48487c/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtB.scala
----------------------------------------------------------------------
diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtB.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtB.scala
index 297f676..f02cd84 100644
--- a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtB.scala
+++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtB.scala
@@ -42,6 +42,7 @@ import org.apache.mahout.math.scalabindings.RLikeOps._
 import com.google.common.collect.Lists
 
 
+
 /**
  * Implementation is taken from Spark's AtB
  * https://github.com/apache/mahout/blob/master/spark/src/main/scala/org/apache/mahout/sparkbindings/blas/AtB.scala

http://git-wip-us.apache.org/repos/asf/mahout/blob/9d48487c/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/CheckpointedFlinkDrm.scala
----------------------------------------------------------------------
diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/CheckpointedFlinkDrm.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/CheckpointedFlinkDrm.scala
index e29b80c..f58e05b 100644
--- a/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/CheckpointedFlinkDrm.scala
+++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/CheckpointedFlinkDrm.scala
@@ -94,26 +94,35 @@ class CheckpointedFlinkDrm[K: ClassTag](val ds: DrmDataSet[K],
     val data = ds.collect().asScala.toList
     val isDense = data.forall(_._2.isDense)
 
+    val cols = ncol
+    val rows = safeToNonNegInt(nrow)
+
     val m = if (isDense) {
-      val cols = data.head._2.size()
-      val rows = data.length
       new DenseMatrix(rows, cols)
     } else {
-      val cols = ncol
-      val rows = safeToNonNegInt(nrow)
       new SparseMatrix(rows, cols)
     }
 
     val intRowIndices = keyClassTag == implicitly[ClassTag[Int]]
 
-    if (intRowIndices)
-      data.foreach(t => m(t._1.asInstanceOf[Int], ::) := t._2)
-    else {
+    if (intRowIndices) {
+      data.foreach { case (t, vec) =>
+        val idx = t.asInstanceOf[Int]
+        m(idx, ::) := vec
+      }
+
+      println(m.ncol, m.nrow)
+    } else {
       // assign all rows sequentially
       val d = data.zipWithIndex
-      d.foreach(t => m(t._2, ::) := t._1._2)
+      d.foreach {
+        case ((_, vec), idx) => m(idx, ::) := vec
+      }
+
+      val rowBindings = d.map {
+        case ((t, _), idx) => (t.toString, idx: java.lang.Integer) 
+      }.toMap.asJava
 
-      val rowBindings = d.map(t => (t._1._1.toString, t._2: java.lang.Integer)).toMap.asJava
       m.setRowLabelBindings(rowBindings)
     }