You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by dl...@apache.org on 2015/10/20 07:36:45 UTC

[02/32] mahout git commit: MAHOUT-1712: Flink: Ax, At, Atx operators

MAHOUT-1712: Flink: Ax, At, Atx operators


Project: http://git-wip-us.apache.org/repos/asf/mahout/repo
Commit: http://git-wip-us.apache.org/repos/asf/mahout/commit/98d4ff03
Tree: http://git-wip-us.apache.org/repos/asf/mahout/tree/98d4ff03
Diff: http://git-wip-us.apache.org/repos/asf/mahout/diff/98d4ff03

Branch: refs/heads/flink-binding
Commit: 98d4ff0312e6228c7f34d6030a493d2aec794e03
Parents: bb4c4bc
Author: Alexey Grigorev <al...@gmail.com>
Authored: Tue Jun 16 16:46:15 2015 +0200
Committer: Alexey Grigorev <al...@gmail.com>
Committed: Fri Sep 25 17:41:37 2015 +0200

----------------------------------------------------------------------
 flink/pom.xml                                   |  12 +-
 .../flinkbindings/FlinkDistributedContext.scala |   5 +-
 .../mahout/flinkbindings/FlinkEngine.scala      | 110 ++++++++++++-------
 .../mahout/flinkbindings/blas/FlinkOpAt.scala   |  69 ++++++++++++
 .../mahout/flinkbindings/blas/FlinkOpAx.scala   |  42 +++++++
 .../drm/CheckpointedFlinkDrm.scala              | 103 +++++++++++++++++
 .../mahout/flinkbindings/drm/FlinkDrm.scala     |  91 +++++++++++++++
 .../apache/mahout/flinkbindings/package.scala   |  44 +++++++-
 .../flinkbindings/DistributedFlinkSuit.scala    |  27 +++++
 .../mahout/flinkbindings/RLikeOpsSuite.scala    |  96 ++++++++++++++++
 .../mahout/flinkbindings/blas/LATestSuit.scala  |  45 ++++++++
 pom.xml                                         |   1 +
 12 files changed, 599 insertions(+), 46 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/mahout/blob/98d4ff03/flink/pom.xml
----------------------------------------------------------------------
diff --git a/flink/pom.xml b/flink/pom.xml
index 9f62b5f..80ff05d 100644
--- a/flink/pom.xml
+++ b/flink/pom.xml
@@ -24,7 +24,7 @@
   <parent>
     <groupId>org.apache.mahout</groupId>
     <artifactId>mahout</artifactId>
-    <version>0.10.0-SNAPSHOT</version>
+    <version>0.11.0-SNAPSHOT</version>
     <relativePath>../pom.xml</relativePath>
   </parent>
 
@@ -136,6 +136,16 @@
       <artifactId>flink-scala</artifactId>
       <version>${flink.version}</version>
     </dependency>
+    <dependency>
+      <groupId>org.apache.flink</groupId>
+      <artifactId>flink-java</artifactId>
+      <version>${flink.version}</version>
+    </dependency>
+    <dependency>
+      <groupId>org.apache.flink</groupId>
+      <artifactId>flink-clients</artifactId>
+      <version>${flink.version}</version>
+    </dependency>
 
     <dependency>
       <groupId>org.apache.mahout</groupId>

http://git-wip-us.apache.org/repos/asf/mahout/blob/98d4ff03/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkDistributedContext.scala
----------------------------------------------------------------------
diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkDistributedContext.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkDistributedContext.scala
index 1124126..e9130dd 100644
--- a/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkDistributedContext.scala
+++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkDistributedContext.scala
@@ -17,8 +17,9 @@
 
 package org.apache.mahout.flinkbindings
 
-import org.apache.mahout.math.drm.{ DistributedEngine, BCast, DistributedContext }
-import org.apache.flink.api.scala.ExecutionEnvironment
+import org.apache.flink.api.java.ExecutionEnvironment
+import org.apache.mahout.math.drm.DistributedContext
+import org.apache.mahout.math.drm.DistributedEngine
 
 class FlinkDistributedContext(val env: ExecutionEnvironment) extends DistributedContext {
 

http://git-wip-us.apache.org/repos/asf/mahout/blob/98d4ff03/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala
----------------------------------------------------------------------
diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala
index 66c1089..17bf0b6 100644
--- a/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala
+++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala
@@ -1,7 +1,10 @@
 package org.apache.mahout.flinkbindings
 
 import scala.reflect.ClassTag
-
+import org.apache.flink.api.scala.DataSet
+import org.apache.mahout.flinkbindings._
+import org.apache.mahout.flinkbindings.drm.CheckpointedFlinkDrm
+import org.apache.mahout.math._
 import org.apache.mahout.math.Matrix
 import org.apache.mahout.math.Vector
 import org.apache.mahout.math.drm.BCast
@@ -14,45 +17,74 @@ import org.apache.mahout.math.indexeddataset.DefaultIndexedDatasetElementReadSch
 import org.apache.mahout.math.indexeddataset.DefaultIndexedDatasetReadSchema
 import org.apache.mahout.math.indexeddataset.IndexedDataset
 import org.apache.mahout.math.indexeddataset.Schema
-
+import org.apache.mahout.math.scalabindings._
+import org.apache.mahout.math.scalabindings.RLikeOps._
 import com.google.common.collect.BiMap
 import com.google.common.collect.HashBiMap
+import scala.collection.JavaConverters._
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.java.typeutils.TypeExtractor
+import org.apache.mahout.math.drm.DrmTuple
+import java.util.Collection
+import org.apache.mahout.flinkbindings.drm.FlinkDrm
+import org.apache.mahout.flinkbindings.blas._
+import org.apache.mahout.math.drm.logical.OpAx
+import org.apache.mahout.flinkbindings.drm.RowsFlinkDrm
+import org.apache.mahout.math.drm.logical.OpAt
+import org.apache.mahout.math.drm.logical.OpAtx
+import org.apache.mahout.math.drm.logical.OpAtx
 
 object FlinkEngine extends DistributedEngine {
 
   /** Second optimizer pass. Translate previously rewritten logical pipeline into physical engine plan. */
   override def toPhysical[K: ClassTag](plan: DrmLike[K], ch: CacheHint.CacheHint): CheckpointedDrm[K] = {
-    null
+    // Flink-specific Physical Plan translation.
+    val drm = flinkTranslate(plan)
+
+    val newcp = new CheckpointedFlinkDrm(
+      ds = drm.deblockify.ds, // TODO: make it lazy!
+      _nrow = plan.nrow,
+      _ncol = plan.ncol
+//      _cacheStorageLevel = cacheHint2Spark(ch),
+//      partitioningTag = plan.partitioningTag
+    )
+
+    newcp.cache()
   }
 
-  /** Engine-specific colSums implementation based on a checkpoint. */
-  override def colSums[K: ClassTag](drm: CheckpointedDrm[K]): Vector = {
-    null
+  private def flinkTranslate[K: ClassTag](oper: DrmLike[K]): FlinkDrm[K] = oper match {
+    case op @ OpAx(a, x) => FlinkOpAx.blockifiedBroadcastAx(op, flinkTranslate(a)(op.classTagA))
+    case op @ OpAt(a) => FlinkOpAt.sparseTrick(op, flinkTranslate(a)(op.classTagA))
+    case op @ OpAtx(a, x) => {
+      val opAt = OpAt(a)
+      val at = FlinkOpAt.sparseTrick(opAt, flinkTranslate(a)(op.classTagA))
+      val atCast = new CheckpointedFlinkDrm(at.deblockify.ds, _nrow=opAt.nrow, _ncol=opAt.ncol)
+      val opAx = OpAx(atCast, x)
+      FlinkOpAx.blockifiedBroadcastAx(opAx, flinkTranslate(atCast)(op.classTagA))
+    }
+    case cp: CheckpointedFlinkDrm[K] => new RowsFlinkDrm(cp.ds, cp.ncol)
+    case _ => ???
   }
+  
+
+  def translate[K: ClassTag](oper: DrmLike[K]): DataSet[K] = ???
+
+  /** Engine-specific colSums implementation based on a checkpoint. */
+  override def colSums[K: ClassTag](drm: CheckpointedDrm[K]): Vector = ???
 
   /** Engine-specific numNonZeroElementsPerColumn implementation based on a checkpoint. */
-  override def numNonZeroElementsPerColumn[K: ClassTag](drm: CheckpointedDrm[K]): Vector = {
-    null
-  }
+  override def numNonZeroElementsPerColumn[K: ClassTag](drm: CheckpointedDrm[K]): Vector = ???
 
   /** Engine-specific colMeans implementation based on a checkpoint. */
-  override def colMeans[K: ClassTag](drm: CheckpointedDrm[K]): Vector = {
-    null
-  }
+  override def colMeans[K: ClassTag](drm: CheckpointedDrm[K]): Vector = ???
 
-  override def norm[K: ClassTag](drm: CheckpointedDrm[K]): Double = {
-    0.0d
-  }
+  override def norm[K: ClassTag](drm: CheckpointedDrm[K]): Double = ???
 
   /** Broadcast support */
-  override def drmBroadcast(v: Vector)(implicit dc: DistributedContext): BCast[Vector] = {
-    null
-  }
+  override def drmBroadcast(v: Vector)(implicit dc: DistributedContext): BCast[Vector] = ???
 
   /** Broadcast support */
-  override def drmBroadcast(m: Matrix)(implicit dc: DistributedContext): BCast[Matrix] = {
-    null
-  }
+  override def drmBroadcast(m: Matrix)(implicit dc: DistributedContext): BCast[Matrix] = ???
 
   /**
    * Load DRM from hdfs (as in Mahout DRM format).
@@ -61,33 +93,35 @@ object FlinkEngine extends DistributedEngine {
    * @param parMin Minimum parallelism after load (equivalent to #par(min=...)).
    */
   override def drmDfsRead(path: String, parMin: Int = 0)
-                         (implicit sc: DistributedContext): CheckpointedDrm[_] = {
-    null
-  }
+                         (implicit sc: DistributedContext): CheckpointedDrm[_] = ???
 
   /** Parallelize in-core matrix as spark distributed matrix, using row ordinal indices as data set keys. */
   override def drmParallelizeWithRowIndices(m: Matrix, numPartitions: Int = 1)
                                            (implicit sc: DistributedContext): CheckpointedDrm[Int] = {
-    null
+    val parallelDrm = parallelize(m, numPartitions)
+    new CheckpointedFlinkDrm(ds=parallelDrm, _nrow=m.numRows(), _ncol=m.numCols())
+  }
+
+  private[flinkbindings] def parallelize(m: Matrix, parallelismDegree: Int)
+      (implicit sc: DistributedContext): DrmDataSet[Int] = {
+    val rows = (0 until m.nrow).map(i => (i, m(i, ::)))
+    val rowsJava: Collection[DrmTuple[Int]]  = rows.asJava
+
+    val dataSetType = TypeExtractor.getForObject(rows.head)
+    sc.env.fromCollection(rowsJava, dataSetType).setParallelism(parallelismDegree)
   }
 
   /** Parallelize in-core matrix as spark distributed matrix, using row labels as a data set keys. */
   override def drmParallelizeWithRowLabels(m: Matrix, numPartitions: Int = 1)
-                                          (implicit sc: DistributedContext): CheckpointedDrm[String] = {
-    null
-  }
+                                          (implicit sc: DistributedContext): CheckpointedDrm[String] = ???
 
   /** This creates an empty DRM with specified number of partitions and cardinality. */
   override def drmParallelizeEmpty(nrow: Int, ncol: Int, numPartitions: Int = 10)
-                                  (implicit sc: DistributedContext): CheckpointedDrm[Int] = {
-    null
-  }
+                                  (implicit sc: DistributedContext): CheckpointedDrm[Int] = ???
 
   /** Creates empty DRM with non-trivial height */
   override def drmParallelizeEmptyLong(nrow: Long, ncol: Int, numPartitions: Int = 10)
-                                      (implicit sc: DistributedContext): CheckpointedDrm[Long] = {
-    null
-  }
+                                      (implicit sc: DistributedContext): CheckpointedDrm[Long] = ???
 
   /**
    * Load IndexedDataset from text delimited format.
@@ -97,9 +131,7 @@ object FlinkEngine extends DistributedEngine {
   override def indexedDatasetDFSRead(src: String,
                    schema: Schema = DefaultIndexedDatasetReadSchema, 
                    existingRowIDs: BiMap[String, Int] = HashBiMap.create())
-            (implicit sc: DistributedContext): IndexedDataset = {
-    null
-  }
+            (implicit sc: DistributedContext): IndexedDataset = ???
 
   /**
    * Load IndexedDataset from text delimited format, one element per line
@@ -109,7 +141,5 @@ object FlinkEngine extends DistributedEngine {
   override def indexedDatasetDFSReadElements(src: String,
                     schema: Schema = DefaultIndexedDatasetElementReadSchema,
                     existingRowIDs: BiMap[String, Int] = HashBiMap.create())
-             (implicit sc: DistributedContext): IndexedDataset = {
-    null
-  }
+             (implicit sc: DistributedContext): IndexedDataset = ???
 }
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/mahout/blob/98d4ff03/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAt.scala
----------------------------------------------------------------------
diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAt.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAt.scala
new file mode 100644
index 0000000..be7fc8f
--- /dev/null
+++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAt.scala
@@ -0,0 +1,69 @@
+package org.apache.mahout.flinkbindings.blas
+
+import org.apache.mahout.math.drm.logical.OpAt
+import org.apache.mahout.flinkbindings.DrmDataSet
+import org.apache.mahout.flinkbindings.drm.FlinkDrm
+import org.apache.flink.api.common.functions.FlatMapFunction
+import org.apache.mahout.math.Matrix
+import scala.reflect.ClassTag
+import org.apache.flink.util.Collector
+import org.apache.mahout.flinkbindings._
+import org.apache.mahout.math._
+import scalabindings._
+import RLikeOps._
+import org.apache.flink.api.common.functions.GroupReduceFunction
+import org.apache.mahout.math.drm.DrmTuple
+import java.lang.Iterable
+import scala.collection.JavaConverters._
+import org.apache.mahout.flinkbindings.drm.RowsFlinkDrm
+import org.apache.flink.api.java.functions.KeySelector
+import java.util.ArrayList
+import org.apache.flink.shaded.com.google.common.collect.Lists
+
+/**
+ * Taken from
+ */
+object FlinkOpAt {
+
+  /**
+   * The idea here is simple: compile vertical column vectors of every partition block as sparse
+   * vectors of the <code>A.nrow</code> length; then group them by their column index and sum the
+   * groups into final rows of the transposed matrix.
+   */
+  def sparseTrick(op: OpAt, A: FlinkDrm[Int]): FlinkDrm[Int] = {
+    val ncol = op.ncol // # of rows of A, i.e. # of columns of A^T
+
+    val sparseParts = A.blockify.ds.flatMap(new FlatMapFunction[(Array[Int], Matrix), DrmTuple[Int]] {
+      def flatMap(typle: (Array[Int], Matrix), out: Collector[DrmTuple[Int]]): Unit = typle match {
+        case (keys, block) => {
+          (0 until block.ncol).map(columnIdx => {
+            val columnVector: Vector = new SequentialAccessSparseVector(ncol)
+
+            keys.zipWithIndex.foreach { case (key, idx) =>
+                columnVector(key) = block(idx, columnIdx)
+            }
+
+            out.collect(new Tuple2(columnIdx, columnVector))
+          })
+        }
+      }
+    })
+
+    val regrouped = sparseParts.groupBy(new KeySelector[Tuple2[Int, Vector], Integer] {
+      def getKey(tuple: Tuple2[Int, Vector]): Integer = tuple._1
+    })
+
+    val sparseTotal = regrouped.reduceGroup(new GroupReduceFunction[Tuple2[Int, Vector], DrmTuple[Int]] {
+      def reduce(values: Iterable[DrmTuple[Int]], out: Collector[DrmTuple[Int]]): Unit = {
+        val it = Lists.newArrayList(values).asScala
+        val (idx, _) = it.head
+        val vector = it map { case (idx, vec) => vec } reduce (_ + _)
+        out.collect(idx -> vector)
+      }
+    })
+
+    // TODO: densify or not?
+    new RowsFlinkDrm(sparseTotal, ncol)
+  }
+
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/mahout/blob/98d4ff03/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAx.scala
----------------------------------------------------------------------
diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAx.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAx.scala
new file mode 100644
index 0000000..d401abf
--- /dev/null
+++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAx.scala
@@ -0,0 +1,42 @@
+package org.apache.mahout.flinkbindings.blas
+
+import scala.reflect.ClassTag
+import org.apache.mahout.flinkbindings.drm.FlinkDrm
+import org.apache.mahout.flinkbindings._
+import org.apache.mahout.math.drm.drmBroadcast
+import org.apache.mahout.math.drm.logical.OpAx
+import org.apache.mahout.math.Matrix
+import org.apache.flink.api.common.functions.MapFunction
+import org.apache.mahout.flinkbindings.drm.BlockifiedFlinkDrm
+import org.apache.mahout.math._
+import scalabindings._
+import RLikeOps._
+import org.apache.flink.api.common.functions.RichMapFunction
+import org.apache.flink.configuration.Configuration
+import java.util.List
+
+object FlinkOpAx {
+
+  def blockifiedBroadcastAx[K: ClassTag](op: OpAx[K], A: FlinkDrm[K]): FlinkDrm[K] = {
+    implicit val ctx = A.context
+    //    val x = drmBroadcast(op.x)
+
+    val singletonDataSetX = ctx.env.fromElements(op.x)
+
+    val out = A.blockify.ds.map(new RichMapFunction[(Array[K], Matrix), (Array[K], Matrix)] {
+      var x: Vector = null
+
+      override def open(params: Configuration): Unit = {
+        val runtime = this.getRuntimeContext()
+        val dsX: List[Vector] = runtime.getBroadcastVariable("vector")
+        x = dsX.get(0)
+      }
+
+      override def map(tuple: (Array[K], Matrix)): (Array[K], Matrix) = tuple match {
+        case (keys, mat) => (keys, (mat %*% x).toColMatrix)
+      }
+    }).withBroadcastSet(singletonDataSetX, "vector")
+
+    new BlockifiedFlinkDrm(out, op.nrow.toInt)
+  }
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/mahout/blob/98d4ff03/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/CheckpointedFlinkDrm.scala
----------------------------------------------------------------------
diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/CheckpointedFlinkDrm.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/CheckpointedFlinkDrm.scala
new file mode 100644
index 0000000..c19920f
--- /dev/null
+++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/CheckpointedFlinkDrm.scala
@@ -0,0 +1,103 @@
+package org.apache.mahout.flinkbindings.drm
+
+import scala.reflect.ClassTag
+import org.apache.mahout.math.drm._
+import org.apache.mahout.math.scalabindings._
+import RLikeOps._
+import org.apache.mahout.flinkbindings._
+
+import org.apache.mahout.math.drm.CheckpointedDrm
+import org.apache.mahout.math.Matrix
+import org.apache.mahout.flinkbindings.FlinkDistributedContext
+import org.apache.flink.api.scala.ExecutionEnvironment
+import org.apache.mahout.math.drm.CacheHint
+import scala.util.Random
+import org.apache.mahout.math.drm.DistributedContext
+import org.apache.mahout.math.DenseMatrix
+import org.apache.mahout.math.SparseMatrix
+import org.apache.flink.api.java.io.LocalCollectionOutputFormat
+import java.util.ArrayList
+
+import scala.collection.JavaConverters._
+
+class CheckpointedFlinkDrm[K: ClassTag](val ds: DrmDataSet[K],
+  private var _nrow: Long = CheckpointedFlinkDrm.UNKNOWN,
+  private var _ncol: Int = CheckpointedFlinkDrm.UNKNOWN,
+  // private val _cacheStorageLevel: StorageLevel = StorageLevel.MEMORY_ONLY,
+  override protected[mahout] val partitioningTag: Long = Random.nextLong(),
+  private var _canHaveMissingRows: Boolean = false) extends CheckpointedDrm[K] {
+
+  lazy val nrow = if (_nrow >= 0) _nrow else computeNRow
+  lazy val ncol = if (_ncol >= 0) _ncol else computeNCol
+
+  protected def computeNRow = ???
+  protected def computeNCol = ??? /*{
+  TODO: find out how to get one value
+    val max = ds.map(new MapFunction[DrmTuple[K], Int] {
+      def map(value: DrmTuple[K]): Int = value._2.length
+    }).reduce(new ReduceFunction[Int] {
+      def reduce(a1: Int, a2: Int) = Math.max(a1, a2)
+    })
+    
+    max
+  }*/
+  def keyClassTag: ClassTag[K] = implicitly[ClassTag[K]]
+
+  def cache() = {
+    // TODO
+    this
+  }
+
+  def uncache = ???
+
+  // Members declared in org.apache.mahout.math.drm.DrmLike   
+
+  protected[mahout] def canHaveMissingRows: Boolean = _canHaveMissingRows
+
+  def checkpoint(cacheHint: CacheHint.CacheHint): CheckpointedDrm[K] = this
+
+  def collect: Matrix = {
+    val dataJavaList = new ArrayList[DrmTuple[K]]
+    val outputFormat = new LocalCollectionOutputFormat[DrmTuple[K]](dataJavaList)
+    ds.output(outputFormat)
+    val data = dataJavaList.asScala
+    ds.getExecutionEnvironment.execute("Checkpointed Flink Drm collect()")
+
+    val isDense = data.forall(_._2.isDense)
+
+    val m = if (isDense) {
+      val cols = data.head._2.size()
+      val rows = data.length
+      new DenseMatrix(rows, cols)
+    } else {
+      val cols = ncol
+      val rows = safeToNonNegInt(nrow)
+      new SparseMatrix(rows, cols)
+    }
+
+    val intRowIndices = keyClassTag == implicitly[ClassTag[Int]]
+
+    if (intRowIndices)
+      data.foreach(t => m(t._1.asInstanceOf[Int], ::) := t._2)
+    else {
+      // assign all rows sequentially
+      val d = data.zipWithIndex
+      d.foreach(t => m(t._2, ::) := t._1._2)
+
+      val rowBindings = d.map(t => (t._1._1.toString, t._2: java.lang.Integer)).toMap.asJava
+      m.setRowLabelBindings(rowBindings)
+    }
+
+    m
+  }
+
+  def dfsWrite(path: String) = ???
+  def newRowCardinality(n: Int): CheckpointedDrm[K] = ???
+
+  override val context: DistributedContext = ds.getExecutionEnvironment
+
+}
+
+object CheckpointedFlinkDrm {
+  val UNKNOWN = -1;
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/mahout/blob/98d4ff03/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/FlinkDrm.scala
----------------------------------------------------------------------
diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/FlinkDrm.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/FlinkDrm.scala
new file mode 100644
index 0000000..3dc5684
--- /dev/null
+++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/FlinkDrm.scala
@@ -0,0 +1,91 @@
+package org.apache.mahout.flinkbindings.drm
+
+import org.apache.flink.api.common.functions.FlatMapFunction
+import org.apache.flink.api.java.DataSet
+import org.apache.flink.api.java.ExecutionEnvironment
+import org.apache.flink.util.Collector
+import org.apache.mahout.flinkbindings.FlinkDistributedContext
+import org.apache.mahout.math.Matrix
+import org.apache.mahout.math.drm._
+import org.apache.mahout.math.scalabindings._
+import RLikeOps._
+import org.apache.mahout.flinkbindings._
+import org.apache.flink.api.common.functions.MapPartitionFunction
+import org.apache.mahout.math.Vector
+import java.lang.Iterable
+import scala.collection.JavaConverters._
+import org.apache.mahout.math.DenseMatrix
+import scala.reflect.ClassTag
+import org.apache.mahout.math.SparseRowMatrix
+import scala.reflect.ClassTag
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.scala.codegen.TypeInformationGen
+import org.apache.flink.api.java.typeutils.TypeExtractor
+
+trait FlinkDrm[K] {
+  def executionEnvironment: ExecutionEnvironment
+  def context: FlinkDistributedContext
+  def isBlockified: Boolean
+
+  def blockify: BlockifiedFlinkDrm[K]
+  def deblockify: RowsFlinkDrm[K]
+}
+
+class RowsFlinkDrm[K: ClassTag](val ds: DrmDataSet[K], val ncol: Int) extends FlinkDrm[K] {
+
+  def executionEnvironment = ds.getExecutionEnvironment
+  def context: FlinkDistributedContext = ds.getExecutionEnvironment
+
+  def isBlockified = false
+
+  def blockify(): BlockifiedFlinkDrm[K] = {
+    val ncolLocal = ncol
+    val classTag = implicitly[ClassTag[K]]
+
+    val parts = ds.mapPartition(new MapPartitionFunction[DrmTuple[K], (Array[K], Matrix)] {
+      def mapPartition(values: Iterable[DrmTuple[K]], out: Collector[(Array[K], Matrix)]): Unit = {
+        val it = values.asScala.seq
+
+        val (keys, vectors) = it.unzip
+        val isDense = vectors.head.isDense
+
+        if (isDense) {
+          val matrix = new DenseMatrix(vectors.size, ncolLocal)
+          vectors.zipWithIndex.foreach { case (vec, idx) => matrix(idx, ::) := vec }
+          out.collect((keys.toArray(classTag), matrix))
+        } else {
+          val matrix = new SparseRowMatrix(vectors.size, ncolLocal, vectors.toArray)
+          out.collect((keys.toArray(classTag), matrix))
+        }
+      }
+    })
+
+    new BlockifiedFlinkDrm(parts, ncol)
+  }
+
+  def deblockify = this
+
+}
+
+class BlockifiedFlinkDrm[K: ClassTag](val ds: BlockifiedDrmDataSet[K], val ncol: Int) extends FlinkDrm[K] {
+
+  def executionEnvironment = ds.getExecutionEnvironment
+  def context: FlinkDistributedContext = ds.getExecutionEnvironment
+
+  def isBlockified = true
+
+  def blockify = this
+
+  def deblockify = {
+    val out = ds.flatMap(new FlatMapFunction[(Array[K], Matrix), DrmTuple[K]] {
+      def flatMap(typle: (Array[K], Matrix), out: Collector[DrmTuple[K]]): Unit = typle match {
+        case (keys, block) => keys.view.zipWithIndex.foreach {
+          case (key, idx) => {
+            out.collect((key, block(idx, ::)))
+          }
+        }
+      }
+    })
+    new RowsFlinkDrm(out, ncol)
+  }
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/mahout/blob/98d4ff03/flink/src/main/scala/org/apache/mahout/flinkbindings/package.scala
----------------------------------------------------------------------
diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/package.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/package.scala
index fb0780e..0b26781 100644
--- a/flink/src/main/scala/org/apache/mahout/flinkbindings/package.scala
+++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/package.scala
@@ -1,11 +1,49 @@
 package org.apache.mahout
 
+import org.apache.flink.api.java.DataSet
+import org.apache.flink.api.java.ExecutionEnvironment
+import org.apache.mahout.flinkbindings.FlinkDistributedContext
+import org.apache.mahout.flinkbindings.drm.BlockifiedFlinkDrm
+import org.apache.mahout.flinkbindings.drm.RowsFlinkDrm
+import org.apache.mahout.math.drm._
 import org.slf4j.LoggerFactory
+import scala.reflect.ClassTag
+import org.apache.mahout.flinkbindings.drm.FlinkDrm
+import org.apache.mahout.flinkbindings.drm.CheckpointedFlinkDrm
+import org.apache.mahout.flinkbindings.drm.FlinkDrm
+import org.apache.mahout.flinkbindings.drm.RowsFlinkDrm
 
 package object flinkbindings {
-  
+
   private[flinkbindings] val log = LoggerFactory.getLogger("apache.org.mahout.flinkbingings")
+
+  /** Row-wise organized DRM dataset type */
+  type DrmDataSet[K] = DataSet[DrmTuple[K]]
+
+  /**
+   * Blockifed DRM dataset (keys of original DRM are grouped into array corresponding to rows of Matrix
+   * object value
+   */
+  type BlockifiedDrmDataSet[K] = DataSet[BlockifiedDrmTuple[K]]
+
   
-  
-  
+  implicit def wrapMahoutContext(context: DistributedContext): FlinkDistributedContext = {
+    assert(context.isInstanceOf[FlinkDistributedContext], "it must be FlinkDistributedContext")
+    context.asInstanceOf[FlinkDistributedContext]
+  }
+
+  implicit def wrapContext(env: ExecutionEnvironment): FlinkDistributedContext =
+    new FlinkDistributedContext(env)
+  implicit def unwrapContext(ctx: FlinkDistributedContext): ExecutionEnvironment = ctx.env
+
+  private[flinkbindings] implicit def castCheckpointedDrm[K: ClassTag](drm: CheckpointedDrm[K]): CheckpointedFlinkDrm[K] = {
+    assert(drm.isInstanceOf[CheckpointedFlinkDrm[K]], "it must be a Flink-backed matrix")
+    drm.asInstanceOf[CheckpointedFlinkDrm[K]]
+  }
+
+  implicit def checkpointeDrmToFlinkDrm[K: ClassTag](cp: CheckpointedDrm[K]): FlinkDrm[K] = {
+    val flinkDrm = castCheckpointedDrm(cp)
+    new RowsFlinkDrm[K](flinkDrm.ds, flinkDrm.ncol)
+  }
+
 }
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/mahout/blob/98d4ff03/flink/src/test/scala/org/apache/mahout/flinkbindings/DistributedFlinkSuit.scala
----------------------------------------------------------------------
diff --git a/flink/src/test/scala/org/apache/mahout/flinkbindings/DistributedFlinkSuit.scala b/flink/src/test/scala/org/apache/mahout/flinkbindings/DistributedFlinkSuit.scala
new file mode 100644
index 0000000..2412c2c
--- /dev/null
+++ b/flink/src/test/scala/org/apache/mahout/flinkbindings/DistributedFlinkSuit.scala
@@ -0,0 +1,27 @@
+package org.apache.mahout.flinkbindings
+
+import org.apache.mahout.test.DistributedMahoutSuite
+import org.scalatest.Suite
+import org.apache.mahout.math.drm.DistributedContext
+import org.apache.flink.api.java.ExecutionEnvironment
+
+trait DistributedFlinkSuit extends DistributedMahoutSuite { this: Suite =>
+
+  protected implicit var mahoutCtx: DistributedContext = _
+  protected var env: ExecutionEnvironment = null
+  
+  def initContext() {
+    env = ExecutionEnvironment.getExecutionEnvironment
+    mahoutCtx = env
+  }
+
+  override def beforeEach() {
+    initContext()
+  }
+
+  override def afterEach() {
+    super.afterEach()
+//    env.execute("Mahout Flink Binding Test Suite")
+  }
+
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/mahout/blob/98d4ff03/flink/src/test/scala/org/apache/mahout/flinkbindings/RLikeOpsSuite.scala
----------------------------------------------------------------------
diff --git a/flink/src/test/scala/org/apache/mahout/flinkbindings/RLikeOpsSuite.scala b/flink/src/test/scala/org/apache/mahout/flinkbindings/RLikeOpsSuite.scala
new file mode 100644
index 0000000..07d6a84
--- /dev/null
+++ b/flink/src/test/scala/org/apache/mahout/flinkbindings/RLikeOpsSuite.scala
@@ -0,0 +1,96 @@
+package org.apache.mahout.flinkbindings
+
+import org.junit.runner.RunWith
+import org.scalatest.junit.JUnitRunner
+import org.scalatest.FunSuite
+import org.apache.mahout.math._
+import scalabindings._
+import RLikeOps._
+import org.apache.mahout.math.drm._
+import RLikeDrmOps._
+import org.apache.mahout.flinkbindings._
+import org.apache.mahout.math.function.IntIntFunction
+import scala.util.Random
+import scala.util.MurmurHash
+import scala.util.hashing.MurmurHash3
+import org.slf4j.Logger
+import org.slf4j.LoggerFactory
+import org.scalatest.Ignore
+
+@RunWith(classOf[JUnitRunner])
+class RLikeOpsSuite extends FunSuite with DistributedFlinkSuit {
+
+  val LOGGER = LoggerFactory.getLogger(getClass())
+
+  test("A %*% x") {
+    val inCoreA = dense((1, 2, 3), (2, 3, 4), (3, 4, 5))
+    val A = drmParallelize(m = inCoreA, numPartitions = 2)
+    val x: Vector = (0, 1, 2)
+
+    val res = A %*% x
+
+    val b = res.collect(::, 0)
+    assert(b == dvec(8, 11, 14))
+  }
+
+  test("Power interation 1000 x 1000 matrix") {
+    val dim = 1000
+
+    // we want a symmetric matrix so we can have real eigenvalues
+    val inCoreA = symmtericMatrix(dim, max = 2000)
+
+    val A = drmParallelize(m = inCoreA, numPartitions = 2)
+
+    var x: Vector = 1 to dim map (_ => 1.0 / Math.sqrt(dim))
+    var converged = false
+
+    var iteration = 1
+
+    while (!converged) {
+      LOGGER.info(s"iteration #$iteration...")
+
+      val Ax = A %*% x
+      var x_new = Ax.collect(::, 0)
+      x_new = x_new / x_new.norm(2)
+
+      val diff = (x_new - x).norm(2)
+      LOGGER.info(s"difference norm is $diff")
+
+      converged = diff < 1e-6
+      iteration = iteration + 1
+      x = x_new
+    }
+
+    LOGGER.info("converged")
+    // TODO: add test that it's the 1st PC
+  }
+
+  def symmtericMatrix(dim: Int, max: Int, seed: Int = 0x31337) = {
+    Matrices.functionalMatrixView(dim, dim, new IntIntFunction {
+      def apply(i: Int, j: Int): Double = {
+        val arr = Array(i + j, i * j, i + j + 31, i / (j + 1) + j / (i + 1))
+        Math.abs(MurmurHash3.arrayHash(arr, seed) % max)
+      }
+    })
+  }
+
+  test("A.t") {
+    val inCoreA = dense((1, 2, 3), (2, 3, 4))
+    val A = drmParallelize(m = inCoreA, numPartitions = 2)
+    val res = A.t.collect
+
+    val expected = inCoreA.t
+    assert((res - expected).norm < 1e-6)
+  }
+
+  test("A.t %*% x") {
+    val inCoreA = dense((1, 2, 3), (2, 3, 4))
+    val A = drmParallelize(m = inCoreA, numPartitions = 2)
+    val x = dvec(3, 11)
+    val res = (A.t %*% x).collect(::, 0)
+
+    val expected = inCoreA.t %*% x 
+    assert((res - expected).norm(2) < 1e-6)
+  }
+  
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/mahout/blob/98d4ff03/flink/src/test/scala/org/apache/mahout/flinkbindings/blas/LATestSuit.scala
----------------------------------------------------------------------
diff --git a/flink/src/test/scala/org/apache/mahout/flinkbindings/blas/LATestSuit.scala b/flink/src/test/scala/org/apache/mahout/flinkbindings/blas/LATestSuit.scala
new file mode 100644
index 0000000..3ce8895
--- /dev/null
+++ b/flink/src/test/scala/org/apache/mahout/flinkbindings/blas/LATestSuit.scala
@@ -0,0 +1,45 @@
+package org.apache.mahout.flinkbindings.blas
+
+import org.scalatest.FunSuite
+import org.apache.mahout.math._
+import scalabindings._
+import RLikeOps._
+import drm._
+import org.apache.mahout.flinkbindings._
+import org.junit.runner.RunWith
+import org.scalatest.junit.JUnitRunner
+import org.apache.mahout.math.drm.logical.OpAx
+import org.apache.mahout.flinkbindings.drm.CheckpointedFlinkDrm
+import org.apache.mahout.flinkbindings.drm.RowsFlinkDrm
+import org.apache.mahout.math.drm.logical.OpAt
+
+@RunWith(classOf[JUnitRunner])
+class LATestSuit extends FunSuite with DistributedFlinkSuit {
+
+  test("Ax") {
+    val inCoreA = dense((1, 2, 3), (2, 3, 4), (3, 4, 5))
+    val A = drmParallelize(m = inCoreA, numPartitions = 2)
+    val x: Vector = (0, 1, 2)
+
+    val opAx = new OpAx(A, x)
+    val res = FlinkOpAx.blockifiedBroadcastAx(opAx, A)
+    val drm = new CheckpointedFlinkDrm(res.deblockify.ds)
+    val output = drm.collect
+
+    val b = output(::, 0)
+    assert(b == dvec(8, 11, 14))
+  }
+
+  test("At") {
+    val inCoreA = dense((1, 2, 3), (2, 3, 4))
+    val A = drmParallelize(m = inCoreA, numPartitions = 2)
+
+    val opAt = new OpAt(A)
+    val res = FlinkOpAt.sparseTrick(opAt, A)
+    val drm = new CheckpointedFlinkDrm(res.deblockify.ds, _nrow=inCoreA.ncol, _ncol=inCoreA.nrow)
+    val output = drm.collect
+
+    assert((output - inCoreA.t).norm < 1e-6)
+  }
+
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/mahout/blob/98d4ff03/pom.xml
----------------------------------------------------------------------
diff --git a/pom.xml b/pom.xml
index f90b41b..35495fb 100644
--- a/pom.xml
+++ b/pom.xml
@@ -811,6 +811,7 @@
     <module>math-scala</module>
     <module>spark</module>
     <module>spark-shell</module>
+    <module>flink</module>
     <module>h2o</module>
   </modules>