You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by dl...@apache.org on 2015/10/20 07:36:58 UTC
[15/32] mahout git commit: MAHOUT-1711: Flink: drmBroadcast
implemented
MAHOUT-1711: Flink: drmBroadcast implemented
Project: http://git-wip-us.apache.org/repos/asf/mahout/repo
Commit: http://git-wip-us.apache.org/repos/asf/mahout/commit/08ad113f
Tree: http://git-wip-us.apache.org/repos/asf/mahout/tree/08ad113f
Diff: http://git-wip-us.apache.org/repos/asf/mahout/diff/08ad113f
Branch: refs/heads/flink-binding
Commit: 08ad113f732adfe844cb5fb68ab5897f75aa2456
Parents: 58f7948
Author: Alexey Grigorev <al...@gmail.com>
Authored: Thu Jun 18 18:01:57 2015 +0200
Committer: Alexey Grigorev <al...@gmail.com>
Committed: Fri Sep 25 17:41:51 2015 +0200
----------------------------------------------------------------------
.../mahout/flinkbindings/DataSetOps.scala | 154 +++---
.../mahout/flinkbindings/FlinkByteBCast.scala | 83 ++++
.../mahout/flinkbindings/FlinkEngine.scala | 470 ++++++++++---------
.../mahout/flinkbindings/blas/FlinkOpAtB.scala | 202 ++++----
.../flinkbindings/blas/FlinkOpMapBlock.scala | 3 +
.../drm/CheckpointedFlinkDrm.scala | 348 +++++++-------
.../apache/mahout/flinkbindings/package.scala | 210 ++++-----
.../flinkbindings/FlinkByteBCastSuite.scala | 27 ++
.../mahout/flinkbindings/RLikeOpsSuite.scala | 35 +-
.../mahout/flinkbindings/UseCasesSuite.scala | 22 +-
.../flinkbindings/examples/ReadCsvExample.scala | 78 +--
11 files changed, 878 insertions(+), 754 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/mahout/blob/08ad113f/flink/src/main/scala/org/apache/mahout/flinkbindings/DataSetOps.scala
----------------------------------------------------------------------
diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/DataSetOps.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/DataSetOps.scala
index 840b4e6..4f437ae 100644
--- a/flink/src/main/scala/org/apache/mahout/flinkbindings/DataSetOps.scala
+++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/DataSetOps.scala
@@ -1,78 +1,78 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.mahout.flinkbindings
-
-import java.lang.Iterable
-import java.util.Collections
-import java.util.Comparator
-import scala.collection.JavaConverters._
-import org.apache.flink.util.Collector
-import org.apache.flink.api.java.DataSet
-import org.apache.flink.api.java.tuple.Tuple2
-import org.apache.flink.api.common.functions.RichMapPartitionFunction
-import org.apache.flink.configuration.Configuration
-import scala.reflect.ClassTag
-
-
-class DataSetOps[K: ClassTag](val ds: DataSet[K]) {
-
- /**
- * Implementation taken from http://stackoverflow.com/questions/30596556/zipwithindex-on-apache-flink
- *
- * TODO: remove when FLINK-2152 is committed and released
- */
- def zipWithIndex(): DataSet[(Int, K)] = {
-
- // first for each partition count the number of elements - to calculate the offsets
- val counts = ds.mapPartition(new RichMapPartitionFunction[K, (Int, Int)] {
- override def mapPartition(values: Iterable[K], out: Collector[(Int, Int)]): Unit = {
- val cnt: Int = values.asScala.count(_ => true)
- val subtaskIdx = getRuntimeContext.getIndexOfThisSubtask
- out.collect((subtaskIdx, cnt))
- }
- })
-
- // then use the offsets to index items of each partition
- val zipped = ds.mapPartition(new RichMapPartitionFunction[K, (Int, K)] {
- var offset: Int = 0
-
- override def open(parameters: Configuration): Unit = {
- val offsetsJava: java.util.List[(Int, Int)] =
- getRuntimeContext.getBroadcastVariable("counts")
- val offsets = offsetsJava.asScala
-
- val sortedOffsets =
- offsets sortBy { case (id, _) => id } map { case (_, cnt) => cnt }
-
- val subtaskId = getRuntimeContext.getIndexOfThisSubtask
- offset = sortedOffsets.take(subtaskId).sum.toInt
- }
-
- override def mapPartition(values: Iterable[K], out: Collector[(Int, K)]): Unit = {
- val it = values.asScala
- it.zipWithIndex.foreach { case (value, idx) =>
- out.collect((idx + offset, value))
- }
- }
- }).withBroadcastSet(counts, "counts");
-
- zipped
- }
-
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.mahout.flinkbindings
+
+import java.lang.Iterable
+import java.util.Collections
+import java.util.Comparator
+import scala.collection.JavaConverters._
+import org.apache.flink.util.Collector
+import org.apache.flink.api.java.DataSet
+import org.apache.flink.api.java.tuple.Tuple2
+import org.apache.flink.api.common.functions.RichMapPartitionFunction
+import org.apache.flink.configuration.Configuration
+import scala.reflect.ClassTag
+
+
+class DataSetOps[K: ClassTag](val ds: DataSet[K]) {
+
+ /**
+ * Implementation taken from http://stackoverflow.com/questions/30596556/zipwithindex-on-apache-flink
+ *
+ * TODO: remove when FLINK-2152 is committed and released
+ */
+ def zipWithIndex(): DataSet[(Int, K)] = {
+
+ // first for each partition count the number of elements - to calculate the offsets
+ val counts = ds.mapPartition(new RichMapPartitionFunction[K, (Int, Int)] {
+ override def mapPartition(values: Iterable[K], out: Collector[(Int, Int)]): Unit = {
+ val cnt: Int = values.asScala.count(_ => true)
+ val subtaskIdx = getRuntimeContext.getIndexOfThisSubtask
+ out.collect((subtaskIdx, cnt))
+ }
+ })
+
+ // then use the offsets to index items of each partition
+ val zipped = ds.mapPartition(new RichMapPartitionFunction[K, (Int, K)] {
+ var offset: Int = 0
+
+ override def open(parameters: Configuration): Unit = {
+ val offsetsJava: java.util.List[(Int, Int)] =
+ getRuntimeContext.getBroadcastVariable("counts")
+ val offsets = offsetsJava.asScala
+
+ val sortedOffsets =
+ offsets sortBy { case (id, _) => id } map { case (_, cnt) => cnt }
+
+ val subtaskId = getRuntimeContext.getIndexOfThisSubtask
+ offset = sortedOffsets.take(subtaskId).sum.toInt
+ }
+
+ override def mapPartition(values: Iterable[K], out: Collector[(Int, K)]): Unit = {
+ val it = values.asScala
+ it.zipWithIndex.foreach { case (value, idx) =>
+ out.collect((idx + offset, value))
+ }
+ }
+ }).withBroadcastSet(counts, "counts");
+
+ zipped
+ }
+
}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/mahout/blob/08ad113f/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkByteBCast.scala
----------------------------------------------------------------------
diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkByteBCast.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkByteBCast.scala
new file mode 100644
index 0000000..70d0545
--- /dev/null
+++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkByteBCast.scala
@@ -0,0 +1,83 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.mahout.flinkbindings
+
+import org.apache.mahout.math.Matrix
+import org.apache.mahout.math.MatrixWritable
+import org.apache.mahout.math.Vector
+import org.apache.mahout.math.VectorWritable
+import org.apache.mahout.math.drm.BCast
+
+import com.google.common.io.ByteStreams
+
+/**
+ * FlinkByteBCast wraps vector/matrix objects, represents them as byte arrays, and when
+ * it's used in UDFs, they are serialized using standard Java serialization along with
+ * UDFs (as a part of closure) and broadcasted to worker nodes.
+ *
+ * There should be a smarter way of doing it with some macro and then rewriting the UDF and
+ * appending `withBroadcastSet` to flink dataSet pipeline, but it's not implemented at the moment.
+ */
+class FlinkByteBCast[T](private val arr: Array[Byte]) extends BCast[T] with Serializable {
+
+ private lazy val _value = {
+ val stream = ByteStreams.newDataInput(arr)
+ val streamType = stream.readInt()
+
+ if (streamType == FlinkByteBCast.StreamTypeVector) {
+ val writeable = new VectorWritable()
+ writeable.readFields(stream)
+ writeable.get.asInstanceOf[T]
+ } else if (streamType == FlinkByteBCast.StreamTypeMatrix) {
+ val writeable = new MatrixWritable()
+ writeable.readFields(stream)
+ writeable.get.asInstanceOf[T]
+ } else {
+ throw new IllegalArgumentException(s"unexpected type tag $streamType")
+ }
+ }
+
+ override def value: T = _value
+
+}
+
+object FlinkByteBCast {
+
+ val StreamTypeVector = 0x0000
+ val StreamTypeMatrix = 0xFFFF
+
+ def wrap(v: Vector): FlinkByteBCast[Vector] = {
+ val writeable = new VectorWritable(v)
+ val dataOutput = ByteStreams.newDataOutput()
+ dataOutput.writeInt(StreamTypeVector)
+ writeable.write(dataOutput)
+ val array = dataOutput.toByteArray()
+ return new FlinkByteBCast[Vector](array)
+ }
+
+ def wrap(m: Matrix): FlinkByteBCast[Matrix] = {
+ val writeable = new MatrixWritable(m)
+ val dataOutput = ByteStreams.newDataOutput()
+ dataOutput.writeInt(StreamTypeMatrix)
+ writeable.write(dataOutput)
+ val array = dataOutput.toByteArray()
+ return new FlinkByteBCast[Matrix](array)
+ }
+
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/mahout/blob/08ad113f/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala
----------------------------------------------------------------------
diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala
index 074676c..18e17db 100644
--- a/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala
+++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/FlinkEngine.scala
@@ -1,234 +1,238 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.mahout.flinkbindings
-
-import java.util.Collection
-import scala.reflect.ClassTag
-import scala.collection.JavaConverters._
-import com.google.common.collect._
-import org.apache.mahout.math._
-import org.apache.mahout.math.drm._
-import org.apache.mahout.math.indexeddataset._
-import org.apache.mahout.math.scalabindings._
-import org.apache.mahout.math.scalabindings.RLikeOps._
-import org.apache.mahout.math.drm.DrmTuple
-import org.apache.mahout.math.drm.logical._
-import org.apache.mahout.math.indexeddataset.BiDictionary
-import org.apache.mahout.flinkbindings._
-import org.apache.mahout.flinkbindings.drm._
-import org.apache.mahout.flinkbindings.blas._
-import org.apache.flink.api.common.typeinfo.TypeInformation
-import org.apache.flink.api.common.functions._
-import org.apache.flink.api.common.functions.MapFunction
-import org.apache.flink.api.java.typeutils.TypeExtractor
-import org.apache.flink.api.scala.DataSet
-import org.apache.flink.api.java.io.TypeSerializerInputFormat
-import org.apache.flink.api.common.io.SerializedInputFormat
-import org.apache.hadoop.mapred.JobConf
-import org.apache.hadoop.mapred.SequenceFileInputFormat
-import org.apache.hadoop.mapred.FileInputFormat
-import org.apache.mahout.flinkbindings.io._
-import org.apache.hadoop.io.Writable
-import org.apache.flink.api.java.tuple.Tuple2
-
-object FlinkEngine extends DistributedEngine {
-
- // By default, use Hadoop 1 utils
- var hdfsUtils: HDFSUtil = Hadoop1HDFSUtil
-
- /**
- * Load DRM from hdfs (as in Mahout DRM format).
- *
- * @param path The DFS path to load from
- * @param parMin Minimum parallelism after load (equivalent to #par(min=...)).
- */
- override def drmDfsRead(path: String, parMin: Int = 0)
- (implicit dc: DistributedContext): CheckpointedDrm[_] = {
- val metadata = hdfsUtils.readDrmHeader(path)
- val unwrapKey = metadata.unwrapKeyFunction
-
- val job = new JobConf
- val hadoopInput = new SequenceFileInputFormat[Writable, VectorWritable]
- FileInputFormat.addInputPath(job, new org.apache.hadoop.fs.Path(path))
-
- val writables = dc.env.createHadoopInput(hadoopInput, classOf[Writable], classOf[VectorWritable], job)
-
- val res = writables.map(new MapFunction[Tuple2[Writable, VectorWritable], (Any, Vector)] {
- def map(tuple: Tuple2[Writable, VectorWritable]): (Any, Vector) = {
- (unwrapKey(tuple.f0), tuple.f1)
- }
- })
-
- datasetWrap(res)(metadata.keyClassTag.asInstanceOf[ClassTag[Any]])
- }
-
- override def indexedDatasetDFSRead(src: String, schema: Schema, existingRowIDs: Option[BiDictionary])
- (implicit sc: DistributedContext): IndexedDataset = ???
-
- override def indexedDatasetDFSReadElements(src: String,schema: Schema, existingRowIDs: Option[BiDictionary])
- (implicit sc: DistributedContext): IndexedDataset = ???
-
-
- /**
- * Translates logical plan into Flink execution plan.
- **/
- override def toPhysical[K: ClassTag](plan: DrmLike[K], ch: CacheHint.CacheHint): CheckpointedDrm[K] = {
- // Flink-specific Physical Plan translation.
- val drm = flinkTranslate(plan)
-
- // to Help Flink's type inference had to use just one specific type - Int
- // see org.apache.mahout.flinkbindings.blas classes with TODO: casting inside
- val cls = implicitly[ClassTag[K]]
- if (!cls.runtimeClass.equals(classOf[Int])) {
- throw new IllegalArgumentException(s"At the moment only Int indexes are supported. Got $cls")
- }
-
- val newcp = new CheckpointedFlinkDrm(ds = drm.deblockify.ds, _nrow = plan.nrow, _ncol = plan.ncol)
- newcp.cache()
- }
-
- private def flinkTranslate[K: ClassTag](oper: DrmLike[K]): FlinkDrm[K] = oper match {
- case op @ OpAx(a, x) => FlinkOpAx.blockifiedBroadcastAx(op, flinkTranslate(a)(op.classTagA))
- case op @ OpAt(a) => FlinkOpAt.sparseTrick(op, flinkTranslate(a)(op.classTagA))
- case op @ OpAtx(a, x) => {
- // express Atx as (A.t) %*% x
- // TODO: create specific implementation of Atx
- val opAt = OpAt(a)
- val at = FlinkOpAt.sparseTrick(opAt, flinkTranslate(a)(op.classTagA))
- val atCast = new CheckpointedFlinkDrm(at.deblockify.ds, _nrow=opAt.nrow, _ncol=opAt.ncol)
- val opAx = OpAx(atCast, x)
- FlinkOpAx.blockifiedBroadcastAx(opAx, flinkTranslate(atCast)(op.classTagA))
- }
- case op @ OpAtB(a, b) => FlinkOpAtB.notZippable(op, flinkTranslate(a)(op.classTagA),
- flinkTranslate(b)(op.classTagA))
- case op @ OpABt(a, b) => {
- // express ABt via AtB: let C=At and D=Bt, and calculate CtD
- // TODO: create specific implementation of ABt
- val opAt = OpAt(a.asInstanceOf[DrmLike[Int]]) // TODO: casts!
- val at = FlinkOpAt.sparseTrick(opAt, flinkTranslate(a.asInstanceOf[DrmLike[Int]]))
- val c = new CheckpointedFlinkDrm(at.deblockify.ds, _nrow=opAt.nrow, _ncol=opAt.ncol)
-
- val opBt = OpAt(b.asInstanceOf[DrmLike[Int]]) // TODO: casts!
- val bt = FlinkOpAt.sparseTrick(opBt, flinkTranslate(b.asInstanceOf[DrmLike[Int]]))
- val d = new CheckpointedFlinkDrm(bt.deblockify.ds, _nrow=opBt.nrow, _ncol=opBt.ncol)
-
- FlinkOpAtB.notZippable(OpAtB(c, d), flinkTranslate(c), flinkTranslate(d))
- .asInstanceOf[FlinkDrm[K]]
- }
- case op @ OpAtA(a) => {
- // express AtA via AtB
- // TODO: create specific implementation of AtA
- val aInt = a.asInstanceOf[DrmLike[Int]] // TODO: casts!
- val opAtB = OpAtB(aInt, aInt)
- val aTranslated = flinkTranslate(aInt)
- FlinkOpAtB.notZippable(opAtB, aTranslated, aTranslated)
- }
- case op @ OpTimesRightMatrix(a, b) =>
- FlinkOpTimesRightMatrix.drmTimesInCore(op, flinkTranslate(a)(op.classTagA), b)
- case op @ OpAewScalar(a, scalar, _) =>
- FlinkOpAewScalar.opScalarNoSideEffect(op, flinkTranslate(a)(op.classTagA), scalar)
- case op @ OpAewB(a, b, _) =>
- FlinkOpAewB.rowWiseJoinNoSideEffect(op, flinkTranslate(a)(op.classTagA), flinkTranslate(b)(op.classTagA))
- case op @ OpCbind(a, b) =>
- FlinkOpCBind.cbind(op, flinkTranslate(a)(op.classTagA), flinkTranslate(b)(op.classTagA))
- case op @ OpRbind(a, b) =>
- FlinkOpRBind.rbind(op, flinkTranslate(a)(op.classTagA), flinkTranslate(b)(op.classTagA))
- case op @ OpRowRange(a, _) =>
- FlinkOpRowRange.slice(op, flinkTranslate(a)(op.classTagA))
- case op: OpMapBlock[K, _] =>
- FlinkOpMapBlock.apply(flinkTranslate(op.A)(op.classTagA), op.ncol, op.bmf)
- case cp: CheckpointedFlinkDrm[K] => new RowsFlinkDrm(cp.ds, cp.ncol)
- case _ => throw new NotImplementedError(s"operator $oper is not implemented yet")
- }
-
- /**
- * returns a vector that contains a column-wise sum from DRM
- */
- override def colSums[K: ClassTag](drm: CheckpointedDrm[K]): Vector = {
- val sum = drm.ds.map(new MapFunction[(K, Vector), Vector] {
- def map(tuple: (K, Vector)): Vector = tuple._2
- }).reduce(new ReduceFunction[Vector] {
- def reduce(v1: Vector, v2: Vector) = v1 + v2
- })
-
- val list = sum.collect.asScala.toList
- list.head
- }
-
- /** Engine-specific numNonZeroElementsPerColumn implementation based on a checkpoint. */
- override def numNonZeroElementsPerColumn[K: ClassTag](drm: CheckpointedDrm[K]): Vector = ???
-
- /**
- * returns a vector that contains a column-wise mean from DRM
- */
- override def colMeans[K: ClassTag](drm: CheckpointedDrm[K]): Vector = {
- drm.colSums() / drm.nrow
- }
-
- /**
- * Calculates the element-wise squared norm of a matrix
- */
- override def norm[K: ClassTag](drm: CheckpointedDrm[K]): Double = {
- val sumOfSquares = drm.ds.map(new MapFunction[(K, Vector), Double] {
- def map(tuple: (K, Vector)): Double = tuple match {
- case (idx, vec) => vec dot vec
- }
- }).reduce(new ReduceFunction[Double] {
- def reduce(v1: Double, v2: Double) = v1 + v2
- })
-
- val list = sumOfSquares.collect.asScala.toList
- list.head
- }
-
- /** Broadcast support */
- override def drmBroadcast(v: Vector)(implicit dc: DistributedContext): BCast[Vector] = ???
-
- /** Broadcast support */
- override def drmBroadcast(m: Matrix)(implicit dc: DistributedContext): BCast[Matrix] = ???
-
-
- /** Parallelize in-core matrix as spark distributed matrix, using row ordinal indices as data set keys. */
- override def drmParallelizeWithRowIndices(m: Matrix, numPartitions: Int = 1)
- (implicit dc: DistributedContext): CheckpointedDrm[Int] = {
- val parallelDrm = parallelize(m, numPartitions)
- new CheckpointedFlinkDrm(ds=parallelDrm, _nrow=m.numRows(), _ncol=m.numCols())
- }
-
- private[flinkbindings] def parallelize(m: Matrix, parallelismDegree: Int)
- (implicit dc: DistributedContext): DrmDataSet[Int] = {
- val rows = (0 until m.nrow).map(i => (i, m(i, ::)))
- val rowsJava: Collection[DrmTuple[Int]] = rows.asJava
-
- val dataSetType = TypeExtractor.getForObject(rows.head)
- dc.env.fromCollection(rowsJava, dataSetType).setParallelism(parallelismDegree)
- }
-
- /** Parallelize in-core matrix as spark distributed matrix, using row labels as a data set keys. */
- override def drmParallelizeWithRowLabels(m: Matrix, numPartitions: Int = 1)
- (implicit sc: DistributedContext): CheckpointedDrm[String] = ???
-
- /** This creates an empty DRM with specified number of partitions and cardinality. */
- override def drmParallelizeEmpty(nrow: Int, ncol: Int, numPartitions: Int = 10)
- (implicit sc: DistributedContext): CheckpointedDrm[Int] = ???
-
- /** Creates empty DRM with non-trivial height */
- override def drmParallelizeEmptyLong(nrow: Long, ncol: Int, numPartitions: Int = 10)
- (implicit sc: DistributedContext): CheckpointedDrm[Long] = ???
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.mahout.flinkbindings
+
+import java.util.Collection
+import scala.reflect.ClassTag
+import scala.collection.JavaConverters._
+import com.google.common.collect._
+import org.apache.mahout.math._
+import org.apache.mahout.math.drm._
+import org.apache.mahout.math.indexeddataset._
+import org.apache.mahout.math.scalabindings._
+import org.apache.mahout.math.scalabindings.RLikeOps._
+import org.apache.mahout.math.drm.DrmTuple
+import org.apache.mahout.math.drm.logical._
+import org.apache.mahout.math.indexeddataset.BiDictionary
+import org.apache.mahout.flinkbindings._
+import org.apache.mahout.flinkbindings.drm._
+import org.apache.mahout.flinkbindings.blas._
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.common.functions._
+import org.apache.flink.api.common.functions.MapFunction
+import org.apache.flink.api.java.typeutils.TypeExtractor
+import org.apache.flink.api.scala.DataSet
+import org.apache.flink.api.java.io.TypeSerializerInputFormat
+import org.apache.flink.api.common.io.SerializedInputFormat
+import org.apache.hadoop.mapred.JobConf
+import org.apache.hadoop.mapred.SequenceFileInputFormat
+import org.apache.hadoop.mapred.FileInputFormat
+import org.apache.mahout.flinkbindings.io._
+import org.apache.hadoop.io.Writable
+import org.apache.flink.api.java.tuple.Tuple2
+
+object FlinkEngine extends DistributedEngine {
+
+ // By default, use Hadoop 1 utils
+ var hdfsUtils: HDFSUtil = Hadoop1HDFSUtil
+
+ /**
+ * Load DRM from hdfs (as in Mahout DRM format).
+ *
+ * @param path The DFS path to load from
+ * @param parMin Minimum parallelism after load (equivalent to #par(min=...)).
+ */
+ override def drmDfsRead(path: String, parMin: Int = 0)
+ (implicit dc: DistributedContext): CheckpointedDrm[_] = {
+ val metadata = hdfsUtils.readDrmHeader(path)
+ val unwrapKey = metadata.unwrapKeyFunction
+
+ val job = new JobConf
+ val hadoopInput = new SequenceFileInputFormat[Writable, VectorWritable]
+ FileInputFormat.addInputPath(job, new org.apache.hadoop.fs.Path(path))
+
+ val writables = dc.env.createHadoopInput(hadoopInput, classOf[Writable], classOf[VectorWritable], job)
+
+ val res = writables.map(new MapFunction[Tuple2[Writable, VectorWritable], (Any, Vector)] {
+ def map(tuple: Tuple2[Writable, VectorWritable]): (Any, Vector) = {
+ (unwrapKey(tuple.f0), tuple.f1)
+ }
+ })
+
+ datasetWrap(res)(metadata.keyClassTag.asInstanceOf[ClassTag[Any]])
+ }
+
+ override def indexedDatasetDFSRead(src: String, schema: Schema, existingRowIDs: Option[BiDictionary])
+ (implicit sc: DistributedContext): IndexedDataset = ???
+
+ override def indexedDatasetDFSReadElements(src: String,schema: Schema, existingRowIDs: Option[BiDictionary])
+ (implicit sc: DistributedContext): IndexedDataset = ???
+
+
+ /**
+ * Translates logical plan into Flink execution plan.
+ **/
+ override def toPhysical[K: ClassTag](plan: DrmLike[K], ch: CacheHint.CacheHint): CheckpointedDrm[K] = {
+ // Flink-specific Physical Plan translation.
+ val drm = flinkTranslate(plan)
+
+ // to Help Flink's type inference had to use just one specific type - Int
+ // see org.apache.mahout.flinkbindings.blas classes with TODO: casting inside
+ // see MAHOUT-1747 and MAHOUT-1748
+ val cls = implicitly[ClassTag[K]]
+ if (!cls.runtimeClass.equals(classOf[Int])) {
+ throw new IllegalArgumentException(s"At the moment only Int indexes are supported. Got $cls")
+ }
+
+ val newcp = new CheckpointedFlinkDrm(ds = drm.deblockify.ds, _nrow = plan.nrow, _ncol = plan.ncol)
+ newcp.cache()
+ }
+
+ private def flinkTranslate[K: ClassTag](oper: DrmLike[K]): FlinkDrm[K] = oper match {
+ case op @ OpAx(a, x) => FlinkOpAx.blockifiedBroadcastAx(op, flinkTranslate(a)(op.classTagA))
+ case op @ OpAt(a) => FlinkOpAt.sparseTrick(op, flinkTranslate(a)(op.classTagA))
+ case op @ OpAtx(a, x) => {
+ // express Atx as (A.t) %*% x
+ // TODO: create specific implementation of Atx, see MAHOUT-1749
+ val opAt = OpAt(a)
+ val at = FlinkOpAt.sparseTrick(opAt, flinkTranslate(a)(op.classTagA))
+ val atCast = new CheckpointedFlinkDrm(at.deblockify.ds, _nrow=opAt.nrow, _ncol=opAt.ncol)
+ val opAx = OpAx(atCast, x)
+ FlinkOpAx.blockifiedBroadcastAx(opAx, flinkTranslate(atCast)(op.classTagA))
+ }
+ case op @ OpAtB(a, b) => FlinkOpAtB.notZippable(op, flinkTranslate(a)(op.classTagA),
+ flinkTranslate(b)(op.classTagA))
+ case op @ OpABt(a, b) => {
+ // express ABt via AtB: let C=At and D=Bt, and calculate CtD
+ // TODO: create specific implementation of ABt, see MAHOUT-1750
+ val opAt = OpAt(a.asInstanceOf[DrmLike[Int]]) // TODO: casts!
+ val at = FlinkOpAt.sparseTrick(opAt, flinkTranslate(a.asInstanceOf[DrmLike[Int]]))
+ val c = new CheckpointedFlinkDrm(at.deblockify.ds, _nrow=opAt.nrow, _ncol=opAt.ncol)
+
+ val opBt = OpAt(b.asInstanceOf[DrmLike[Int]]) // TODO: casts!
+ val bt = FlinkOpAt.sparseTrick(opBt, flinkTranslate(b.asInstanceOf[DrmLike[Int]]))
+ val d = new CheckpointedFlinkDrm(bt.deblockify.ds, _nrow=opBt.nrow, _ncol=opBt.ncol)
+
+ FlinkOpAtB.notZippable(OpAtB(c, d), flinkTranslate(c), flinkTranslate(d))
+ .asInstanceOf[FlinkDrm[K]]
+ }
+ case op @ OpAtA(a) => {
+ // express AtA via AtB
+ // TODO: create specific implementation of AtA, see MAHOUT-1751
+ val aInt = a.asInstanceOf[DrmLike[Int]] // TODO: casts!
+ val opAtB = OpAtB(aInt, aInt)
+ val aTranslated = flinkTranslate(aInt)
+ FlinkOpAtB.notZippable(opAtB, aTranslated, aTranslated)
+ }
+ case op @ OpTimesRightMatrix(a, b) =>
+ FlinkOpTimesRightMatrix.drmTimesInCore(op, flinkTranslate(a)(op.classTagA), b)
+ case op @ OpAewScalar(a, scalar, _) =>
+ FlinkOpAewScalar.opScalarNoSideEffect(op, flinkTranslate(a)(op.classTagA), scalar)
+ case op @ OpAewB(a, b, _) =>
+ FlinkOpAewB.rowWiseJoinNoSideEffect(op, flinkTranslate(a)(op.classTagA), flinkTranslate(b)(op.classTagA))
+ case op @ OpCbind(a, b) =>
+ FlinkOpCBind.cbind(op, flinkTranslate(a)(op.classTagA), flinkTranslate(b)(op.classTagA))
+ case op @ OpRbind(a, b) =>
+ FlinkOpRBind.rbind(op, flinkTranslate(a)(op.classTagA), flinkTranslate(b)(op.classTagA))
+ case op @ OpRowRange(a, _) =>
+ FlinkOpRowRange.slice(op, flinkTranslate(a)(op.classTagA))
+ case op: OpMapBlock[K, _] =>
+ FlinkOpMapBlock.apply(flinkTranslate(op.A)(op.classTagA), op.ncol, op.bmf)
+ case cp: CheckpointedFlinkDrm[K] => new RowsFlinkDrm(cp.ds, cp.ncol)
+ case _ => throw new NotImplementedError(s"operator $oper is not implemented yet")
+ }
+
+ /**
+ * returns a vector that contains a column-wise sum from DRM
+ */
+ override def colSums[K: ClassTag](drm: CheckpointedDrm[K]): Vector = {
+ val sum = drm.ds.map(new MapFunction[(K, Vector), Vector] {
+ def map(tuple: (K, Vector)): Vector = tuple._2
+ }).reduce(new ReduceFunction[Vector] {
+ def reduce(v1: Vector, v2: Vector) = v1 + v2
+ })
+
+ val list = sum.collect.asScala.toList
+ list.head
+ }
+
+ /** Engine-specific numNonZeroElementsPerColumn implementation based on a checkpoint. */
+ override def numNonZeroElementsPerColumn[K: ClassTag](drm: CheckpointedDrm[K]): Vector = ???
+
+ /**
+ * returns a vector that contains a column-wise mean from DRM
+ */
+ override def colMeans[K: ClassTag](drm: CheckpointedDrm[K]): Vector = {
+ drm.colSums() / drm.nrow
+ }
+
+ /**
+ * Calculates the element-wise squared norm of a matrix
+ */
+ override def norm[K: ClassTag](drm: CheckpointedDrm[K]): Double = {
+ val sumOfSquares = drm.ds.map(new MapFunction[(K, Vector), Double] {
+ def map(tuple: (K, Vector)): Double = tuple match {
+ case (idx, vec) => vec dot vec
+ }
+ }).reduce(new ReduceFunction[Double] {
+ def reduce(v1: Double, v2: Double) = v1 + v2
+ })
+
+ val list = sumOfSquares.collect.asScala.toList
+ list.head
+ }
+
+ /** Broadcast support */
+ override def drmBroadcast(v: Vector)(implicit dc: DistributedContext): BCast[Vector] =
+ FlinkByteBCast.wrap(v)
+
+
+ /** Broadcast support */
+ override def drmBroadcast(m: Matrix)(implicit dc: DistributedContext): BCast[Matrix] =
+ FlinkByteBCast.wrap(m)
+
+
+ /** Parallelize in-core matrix as spark distributed matrix, using row ordinal indices as data set keys. */
+ override def drmParallelizeWithRowIndices(m: Matrix, numPartitions: Int = 1)
+ (implicit dc: DistributedContext): CheckpointedDrm[Int] = {
+ val parallelDrm = parallelize(m, numPartitions)
+ new CheckpointedFlinkDrm(ds=parallelDrm, _nrow=m.numRows(), _ncol=m.numCols())
+ }
+
+ private[flinkbindings] def parallelize(m: Matrix, parallelismDegree: Int)
+ (implicit dc: DistributedContext): DrmDataSet[Int] = {
+ val rows = (0 until m.nrow).map(i => (i, m(i, ::)))
+ val rowsJava: Collection[DrmTuple[Int]] = rows.asJava
+
+ val dataSetType = TypeExtractor.getForObject(rows.head)
+ dc.env.fromCollection(rowsJava, dataSetType).setParallelism(parallelismDegree)
+ }
+
+ /** Parallelize in-core matrix as spark distributed matrix, using row labels as a data set keys. */
+ override def drmParallelizeWithRowLabels(m: Matrix, numPartitions: Int = 1)
+ (implicit sc: DistributedContext): CheckpointedDrm[String] = ???
+
+ /** This creates an empty DRM with specified number of partitions and cardinality. */
+ override def drmParallelizeEmpty(nrow: Int, ncol: Int, numPartitions: Int = 10)
+ (implicit sc: DistributedContext): CheckpointedDrm[Int] = ???
+
+ /** Creates empty DRM with non-trivial height */
+ override def drmParallelizeEmptyLong(nrow: Long, ncol: Int, numPartitions: Int = 10)
+ (implicit sc: DistributedContext): CheckpointedDrm[Long] = ???
}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/mahout/blob/08ad113f/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtB.scala
----------------------------------------------------------------------
diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtB.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtB.scala
index b5eb17c..462dc4a 100644
--- a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtB.scala
+++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpAtB.scala
@@ -1,102 +1,102 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.mahout.flinkbindings.blas
-
-import scala.reflect.ClassTag
-import org.apache.mahout.flinkbindings.drm.FlinkDrm
-import org.apache.mahout.math.drm.logical.OpAtB
-import org.apache.flink.api.common.functions.MapFunction
-import org.apache.flink.api.java.tuple.Tuple2
-import org.apache.mahout.math.Vector
-import org.apache.mahout.math.Matrix
-import org.apache.flink.api.common.functions.FlatMapFunction
-import org.apache.flink.util.Collector
-import org.apache.mahout.math.drm._
-import org.apache.mahout.math.scalabindings._
-import RLikeOps._
-import org.apache.flink.api.common.functions.GroupReduceFunction
-import java.lang.Iterable
-import scala.collection.JavaConverters._
-import com.google.common.collect.Lists
-import org.apache.mahout.flinkbindings.drm.BlockifiedFlinkDrm
-import org.apache.mahout.flinkbindings.BlockifiedDrmDataSet
-import org.apache.flink.api.scala._
-import org.apache.flink.api.common.typeinfo.TypeInformation
-import org.apache.mahout.flinkbindings.DrmDataSet
-
-
-/**
- * Implementation is taken from Spark's AtB
- * https://github.com/apache/mahout/blob/master/spark/src/main/scala/org/apache/mahout/sparkbindings/blas/AtB.scala
- */
-object FlinkOpAtB {
-
- def notZippable[K: ClassTag](op: OpAtB[K], At: FlinkDrm[K], B: FlinkDrm[K]): FlinkDrm[Int] = {
- // TODO: to help Flink's type inference
- // only Int is supported now
- val rowsAt = At.deblockify.ds.asInstanceOf[DrmDataSet[Int]]
- val rowsB = B.deblockify.ds.asInstanceOf[DrmDataSet[Int]]
- val joined = rowsAt.join(rowsB).where(tuple_1[Vector]).equalTo(tuple_1[Vector])
-
- val ncol = op.ncol
- val nrow = op.nrow
- val blockHeight = 10
- val blockCount = safeToNonNegInt((ncol - 1) / blockHeight + 1)
-
- val preProduct = joined.flatMap(new FlatMapFunction[Tuple2[(Int, Vector), (Int, Vector)],
- (Int, Matrix)] {
- def flatMap(in: Tuple2[(Int, Vector), (Int, Vector)],
- out: Collector[(Int, Matrix)]): Unit = {
- val avec = in.f0._2
- val bvec = in.f1._2
-
- 0.until(blockCount) map { blockKey =>
- val blockStart = blockKey * blockHeight
- val blockEnd = Math.min(ncol, blockStart + blockHeight)
-
- // Create block by cross product of proper slice of aRow and qRow
- val outer = avec(blockStart until blockEnd) cross bvec
- out.collect((blockKey, outer))
- }
- }
- })
-
- val res: BlockifiedDrmDataSet[Int] = preProduct.groupBy(tuple_1[Matrix]).reduceGroup(
- new GroupReduceFunction[(Int, Matrix), BlockifiedDrmTuple[Int]] {
- def reduce(values: Iterable[(Int, Matrix)], out: Collector[BlockifiedDrmTuple[Int]]): Unit = {
- val it = Lists.newArrayList(values).asScala
- val (idx, _) = it.head
-
- val block = it.map(t => t._2).reduce((m1, m2) => m1 + m2)
-
- val keys = idx.until(block.nrow).toArray[Int]
- out.collect((keys, block))
- }
- })
-
- new BlockifiedFlinkDrm(res, ncol)
- }
-
-}
-
-class DrmTupleToFlinkTupleMapper[K: ClassTag] extends MapFunction[(K, Vector), Tuple2[Int, Vector]] {
- def map(tuple: (K, Vector)): Tuple2[Int, Vector] = tuple match {
- case (key, vec) => new Tuple2[Int, Vector](key.asInstanceOf[Int], vec)
- }
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.mahout.flinkbindings.blas
+
+import scala.reflect.ClassTag
+import org.apache.mahout.flinkbindings.drm.FlinkDrm
+import org.apache.mahout.math.drm.logical.OpAtB
+import org.apache.flink.api.common.functions.MapFunction
+import org.apache.flink.api.java.tuple.Tuple2
+import org.apache.mahout.math.Vector
+import org.apache.mahout.math.Matrix
+import org.apache.flink.api.common.functions.FlatMapFunction
+import org.apache.flink.util.Collector
+import org.apache.mahout.math.drm._
+import org.apache.mahout.math.scalabindings._
+import RLikeOps._
+import org.apache.flink.api.common.functions.GroupReduceFunction
+import java.lang.Iterable
+import scala.collection.JavaConverters._
+import com.google.common.collect.Lists
+import org.apache.mahout.flinkbindings.drm.BlockifiedFlinkDrm
+import org.apache.mahout.flinkbindings.BlockifiedDrmDataSet
+import org.apache.flink.api.scala._
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.mahout.flinkbindings.DrmDataSet
+
+
+/**
+ * Implementation is taken from Spark's AtB
+ * https://github.com/apache/mahout/blob/master/spark/src/main/scala/org/apache/mahout/sparkbindings/blas/AtB.scala
+ */
+object FlinkOpAtB {
+
+ def notZippable[K: ClassTag](op: OpAtB[K], At: FlinkDrm[K], B: FlinkDrm[K]): FlinkDrm[Int] = {
+ // TODO: to help Flink's type inference
+ // only Int is supported now
+ val rowsAt = At.deblockify.ds.asInstanceOf[DrmDataSet[Int]]
+ val rowsB = B.deblockify.ds.asInstanceOf[DrmDataSet[Int]]
+ val joined = rowsAt.join(rowsB).where(tuple_1[Vector]).equalTo(tuple_1[Vector])
+
+ val ncol = op.ncol
+ val nrow = op.nrow
+ val blockHeight = 10
+ val blockCount = safeToNonNegInt((ncol - 1) / blockHeight + 1)
+
+ val preProduct = joined.flatMap(new FlatMapFunction[Tuple2[(Int, Vector), (Int, Vector)],
+ (Int, Matrix)] {
+ def flatMap(in: Tuple2[(Int, Vector), (Int, Vector)],
+ out: Collector[(Int, Matrix)]): Unit = {
+ val avec = in.f0._2
+ val bvec = in.f1._2
+
+ 0.until(blockCount) map { blockKey =>
+ val blockStart = blockKey * blockHeight
+ val blockEnd = Math.min(ncol, blockStart + blockHeight)
+
+ // Create block by cross product of proper slice of aRow and qRow
+ val outer = avec(blockStart until blockEnd) cross bvec
+ out.collect((blockKey, outer))
+ }
+ }
+ })
+
+ val res: BlockifiedDrmDataSet[Int] = preProduct.groupBy(tuple_1[Matrix]).reduceGroup(
+ new GroupReduceFunction[(Int, Matrix), BlockifiedDrmTuple[Int]] {
+ def reduce(values: Iterable[(Int, Matrix)], out: Collector[BlockifiedDrmTuple[Int]]): Unit = {
+ val it = Lists.newArrayList(values).asScala
+ val (idx, _) = it.head
+
+ val block = it.map(t => t._2).reduce((m1, m2) => m1 + m2)
+
+ val keys = idx.until(block.nrow).toArray[Int]
+ out.collect((keys, block))
+ }
+ })
+
+ new BlockifiedFlinkDrm(res, ncol)
+ }
+
+}
+
+class DrmTupleToFlinkTupleMapper[K: ClassTag] extends MapFunction[(K, Vector), Tuple2[Int, Vector]] {
+ def map(tuple: (K, Vector)): Tuple2[Int, Vector] = tuple match {
+ case (key, vec) => new Tuple2[Int, Vector](key.asInstanceOf[Int], vec)
+ }
}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/mahout/blob/08ad113f/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpMapBlock.scala
----------------------------------------------------------------------
diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpMapBlock.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpMapBlock.scala
index 5d73f59..c8c1fa4 100644
--- a/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpMapBlock.scala
+++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/blas/FlinkOpMapBlock.scala
@@ -36,6 +36,9 @@ import RLikeOps._
object FlinkOpMapBlock {
def apply[S, R: ClassTag](src: FlinkDrm[S], ncol: Int, function: BlockMapFunc[S, R]): FlinkDrm[R] = {
+
+
+
val res = src.blockify.ds.map(new MapFunction[(Array[S], Matrix), (Array[R], Matrix)] {
def map(block: (Array[S], Matrix)): (Array[R], Matrix) = {
val out = function(block)
http://git-wip-us.apache.org/repos/asf/mahout/blob/08ad113f/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/CheckpointedFlinkDrm.scala
----------------------------------------------------------------------
diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/CheckpointedFlinkDrm.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/CheckpointedFlinkDrm.scala
index 1a42f84..45c944a 100644
--- a/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/CheckpointedFlinkDrm.scala
+++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/drm/CheckpointedFlinkDrm.scala
@@ -1,175 +1,175 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.mahout.flinkbindings.drm
-
-import scala.reflect.ClassTag
-import org.apache.mahout.math.drm._
-import org.apache.mahout.math.scalabindings._
-import RLikeOps._
-import org.apache.mahout.flinkbindings._
-import org.apache.mahout.math.drm.CheckpointedDrm
-import org.apache.mahout.math.Matrix
-import org.apache.mahout.flinkbindings.FlinkDistributedContext
-import org.apache.flink.api.scala.ExecutionEnvironment
-import org.apache.mahout.math.drm.CacheHint
-import scala.util.Random
-import org.apache.mahout.math.drm.DistributedContext
-import org.apache.mahout.math.DenseMatrix
-import org.apache.mahout.math.SparseMatrix
-import org.apache.flink.api.java.io.LocalCollectionOutputFormat
-import java.util.ArrayList
-import scala.collection.JavaConverters._
-import org.apache.flink.api.common.functions.MapFunction
-import org.apache.flink.api.common.functions.ReduceFunction
-import org.apache.flink.api.java.DataSet
-import org.apache.hadoop.io.Writable
-import org.apache.hadoop.io.IntWritable
-import org.apache.hadoop.io.Text
-import org.apache.hadoop.io.LongWritable
-import org.apache.mahout.math.VectorWritable
-import org.apache.mahout.math.Vector
-import org.apache.hadoop.mapred.SequenceFileOutputFormat
-import org.apache.hadoop.mapred.JobConf
-import org.apache.hadoop.mapred.FileOutputFormat
-import org.apache.flink.api.java.tuple.Tuple2
-import org.apache.flink.api.java.hadoop.mapred.HadoopOutputFormat
-
-class CheckpointedFlinkDrm[K: ClassTag](val ds: DrmDataSet[K],
- private var _nrow: Long = CheckpointedFlinkDrm.UNKNOWN,
- private var _ncol: Int = CheckpointedFlinkDrm.UNKNOWN,
- override protected[mahout] val partitioningTag: Long = Random.nextLong(),
- private var _canHaveMissingRows: Boolean = false
- ) extends CheckpointedDrm[K] {
-
- lazy val nrow: Long = if (_nrow >= 0) _nrow else computeNRow
- lazy val ncol: Int = if (_ncol >= 0) _ncol else computeNCol
-
- protected def computeNRow: Long = {
- val count = ds.map(new MapFunction[DrmTuple[K], Long] {
- def map(value: DrmTuple[K]): Long = 1L
- }).reduce(new ReduceFunction[Long] {
- def reduce(a1: Long, a2: Long) = a1 + a2
- })
-
- val list = count.collect().asScala.toList
- list.head
- }
-
- protected def computeNCol: Int = {
- val max = ds.map(new MapFunction[DrmTuple[K], Int] {
- def map(value: DrmTuple[K]): Int = value._2.length
- }).reduce(new ReduceFunction[Int] {
- def reduce(a1: Int, a2: Int) = Math.max(a1, a2)
- })
-
- val list = max.collect().asScala.toList
- list.head
- }
-
- def keyClassTag: ClassTag[K] = implicitly[ClassTag[K]]
-
- def cache() = {
- // TODO
- this
- }
-
- def uncache = ???
-
- // Members declared in org.apache.mahout.math.drm.DrmLike
-
- protected[mahout] def canHaveMissingRows: Boolean = _canHaveMissingRows
-
- def checkpoint(cacheHint: CacheHint.CacheHint): CheckpointedDrm[K] = this
-
- def collect: Matrix = {
- val data = ds.collect().asScala.toList
- val isDense = data.forall(_._2.isDense)
-
- val m = if (isDense) {
- val cols = data.head._2.size()
- val rows = data.length
- new DenseMatrix(rows, cols)
- } else {
- val cols = ncol
- val rows = safeToNonNegInt(nrow)
- new SparseMatrix(rows, cols)
- }
-
- val intRowIndices = keyClassTag == implicitly[ClassTag[Int]]
-
- if (intRowIndices)
- data.foreach(t => m(t._1.asInstanceOf[Int], ::) := t._2)
- else {
- // assign all rows sequentially
- val d = data.zipWithIndex
- d.foreach(t => m(t._2, ::) := t._1._2)
-
- val rowBindings = d.map(t => (t._1._1.toString, t._2: java.lang.Integer)).toMap.asJava
- m.setRowLabelBindings(rowBindings)
- }
-
- m
- }
-
- def dfsWrite(path: String): Unit = {
- val env = ds.getExecutionEnvironment
-
- val keyTag = implicitly[ClassTag[K]]
- val convertKey = keyToWritableFunc(keyTag)
-
- val writableDataset = ds.map(new MapFunction[(K, Vector), Tuple2[Writable, VectorWritable]] {
- def map(tuple: (K, Vector)): Tuple2[Writable, VectorWritable] = tuple match {
- case (idx, vec) => new Tuple2(convertKey(idx), new VectorWritable(vec))
- }
- })
-
- val job = new JobConf
- val sequenceFormat = new SequenceFileOutputFormat[Writable, VectorWritable]
- FileOutputFormat.setOutputPath(job, new org.apache.hadoop.fs.Path(path))
-
- val hadoopOutput = new HadoopOutputFormat(sequenceFormat, job)
- writableDataset.output(hadoopOutput)
-
- env.execute(s"dfsWrite($path)")
- }
-
- private def keyToWritableFunc[K: ClassTag](keyTag: ClassTag[K]): (K) => Writable = {
- if (keyTag.runtimeClass == classOf[Int]) {
- (x: K) => new IntWritable(x.asInstanceOf[Int])
- } else if (keyTag.runtimeClass == classOf[String]) {
- (x: K) => new Text(x.asInstanceOf[String])
- } else if (keyTag.runtimeClass == classOf[Long]) {
- (x: K) => new LongWritable(x.asInstanceOf[Long])
- } else if (classOf[Writable].isAssignableFrom(keyTag.runtimeClass)) {
- (x: K) => x.asInstanceOf[Writable]
- } else {
- throw new IllegalArgumentException("Do not know how to convert class tag %s to Writable.".format(keyTag))
- }
- }
-
- def newRowCardinality(n: Int): CheckpointedDrm[K] = ???
-
- override val context: DistributedContext = ds.getExecutionEnvironment
-
-}
-
-object CheckpointedFlinkDrm {
- val UNKNOWN = -1
-
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.mahout.flinkbindings.drm
+
+import scala.reflect.ClassTag
+import org.apache.mahout.math.drm._
+import org.apache.mahout.math.scalabindings._
+import RLikeOps._
+import org.apache.mahout.flinkbindings._
+import org.apache.mahout.math.drm.CheckpointedDrm
+import org.apache.mahout.math.Matrix
+import org.apache.mahout.flinkbindings.FlinkDistributedContext
+import org.apache.flink.api.scala.ExecutionEnvironment
+import org.apache.mahout.math.drm.CacheHint
+import scala.util.Random
+import org.apache.mahout.math.drm.DistributedContext
+import org.apache.mahout.math.DenseMatrix
+import org.apache.mahout.math.SparseMatrix
+import org.apache.flink.api.java.io.LocalCollectionOutputFormat
+import java.util.ArrayList
+import scala.collection.JavaConverters._
+import org.apache.flink.api.common.functions.MapFunction
+import org.apache.flink.api.common.functions.ReduceFunction
+import org.apache.flink.api.java.DataSet
+import org.apache.hadoop.io.Writable
+import org.apache.hadoop.io.IntWritable
+import org.apache.hadoop.io.Text
+import org.apache.hadoop.io.LongWritable
+import org.apache.mahout.math.VectorWritable
+import org.apache.mahout.math.Vector
+import org.apache.hadoop.mapred.SequenceFileOutputFormat
+import org.apache.hadoop.mapred.JobConf
+import org.apache.hadoop.mapred.FileOutputFormat
+import org.apache.flink.api.java.tuple.Tuple2
+import org.apache.flink.api.java.hadoop.mapred.HadoopOutputFormat
+
+class CheckpointedFlinkDrm[K: ClassTag](val ds: DrmDataSet[K],
+ private var _nrow: Long = CheckpointedFlinkDrm.UNKNOWN,
+ private var _ncol: Int = CheckpointedFlinkDrm.UNKNOWN,
+ override protected[mahout] val partitioningTag: Long = Random.nextLong(),
+ private var _canHaveMissingRows: Boolean = false
+ ) extends CheckpointedDrm[K] {
+
+ lazy val nrow: Long = if (_nrow >= 0) _nrow else computeNRow
+ lazy val ncol: Int = if (_ncol >= 0) _ncol else computeNCol
+
+ protected def computeNRow: Long = {
+ val count = ds.map(new MapFunction[DrmTuple[K], Long] {
+ def map(value: DrmTuple[K]): Long = 1L
+ }).reduce(new ReduceFunction[Long] {
+ def reduce(a1: Long, a2: Long) = a1 + a2
+ })
+
+ val list = count.collect().asScala.toList
+ list.head
+ }
+
+ protected def computeNCol: Int = {
+ val max = ds.map(new MapFunction[DrmTuple[K], Int] {
+ def map(value: DrmTuple[K]): Int = value._2.length
+ }).reduce(new ReduceFunction[Int] {
+ def reduce(a1: Int, a2: Int) = Math.max(a1, a2)
+ })
+
+ val list = max.collect().asScala.toList
+ list.head
+ }
+
+ def keyClassTag: ClassTag[K] = implicitly[ClassTag[K]]
+
+ def cache() = {
+ // TODO
+ this
+ }
+
+ def uncache = ???
+
+ // Members declared in org.apache.mahout.math.drm.DrmLike
+
+ protected[mahout] def canHaveMissingRows: Boolean = _canHaveMissingRows
+
+ def checkpoint(cacheHint: CacheHint.CacheHint): CheckpointedDrm[K] = this
+
+ def collect: Matrix = {
+ val data = ds.collect().asScala.toList
+ val isDense = data.forall(_._2.isDense)
+
+ val m = if (isDense) {
+ val cols = data.head._2.size()
+ val rows = data.length
+ new DenseMatrix(rows, cols)
+ } else {
+ val cols = ncol
+ val rows = safeToNonNegInt(nrow)
+ new SparseMatrix(rows, cols)
+ }
+
+ val intRowIndices = keyClassTag == implicitly[ClassTag[Int]]
+
+ if (intRowIndices)
+ data.foreach(t => m(t._1.asInstanceOf[Int], ::) := t._2)
+ else {
+ // assign all rows sequentially
+ val d = data.zipWithIndex
+ d.foreach(t => m(t._2, ::) := t._1._2)
+
+ val rowBindings = d.map(t => (t._1._1.toString, t._2: java.lang.Integer)).toMap.asJava
+ m.setRowLabelBindings(rowBindings)
+ }
+
+ m
+ }
+
+ def dfsWrite(path: String): Unit = {
+ val env = ds.getExecutionEnvironment
+
+ val keyTag = implicitly[ClassTag[K]]
+ val convertKey = keyToWritableFunc(keyTag)
+
+ val writableDataset = ds.map(new MapFunction[(K, Vector), Tuple2[Writable, VectorWritable]] {
+ def map(tuple: (K, Vector)): Tuple2[Writable, VectorWritable] = tuple match {
+ case (idx, vec) => new Tuple2(convertKey(idx), new VectorWritable(vec))
+ }
+ })
+
+ val job = new JobConf
+ val sequenceFormat = new SequenceFileOutputFormat[Writable, VectorWritable]
+ FileOutputFormat.setOutputPath(job, new org.apache.hadoop.fs.Path(path))
+
+ val hadoopOutput = new HadoopOutputFormat(sequenceFormat, job)
+ writableDataset.output(hadoopOutput)
+
+ env.execute(s"dfsWrite($path)")
+ }
+
+ private def keyToWritableFunc[K: ClassTag](keyTag: ClassTag[K]): (K) => Writable = {
+ if (keyTag.runtimeClass == classOf[Int]) {
+ (x: K) => new IntWritable(x.asInstanceOf[Int])
+ } else if (keyTag.runtimeClass == classOf[String]) {
+ (x: K) => new Text(x.asInstanceOf[String])
+ } else if (keyTag.runtimeClass == classOf[Long]) {
+ (x: K) => new LongWritable(x.asInstanceOf[Long])
+ } else if (classOf[Writable].isAssignableFrom(keyTag.runtimeClass)) {
+ (x: K) => x.asInstanceOf[Writable]
+ } else {
+ throw new IllegalArgumentException("Do not know how to convert class tag %s to Writable.".format(keyTag))
+ }
+ }
+
+ def newRowCardinality(n: Int): CheckpointedDrm[K] = ???
+
+ override val context: DistributedContext = ds.getExecutionEnvironment
+
+}
+
+object CheckpointedFlinkDrm {
+ val UNKNOWN = -1
+
}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/mahout/blob/08ad113f/flink/src/main/scala/org/apache/mahout/flinkbindings/package.scala
----------------------------------------------------------------------
diff --git a/flink/src/main/scala/org/apache/mahout/flinkbindings/package.scala b/flink/src/main/scala/org/apache/mahout/flinkbindings/package.scala
index 56c737a..e46e605 100644
--- a/flink/src/main/scala/org/apache/mahout/flinkbindings/package.scala
+++ b/flink/src/main/scala/org/apache/mahout/flinkbindings/package.scala
@@ -1,106 +1,106 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.mahout
-
-import scala.reflect.ClassTag
-import org.slf4j.LoggerFactory
-import org.apache.flink.api.java.DataSet
-import org.apache.flink.api.java.ExecutionEnvironment
-import org.apache.flink.api.common.functions.MapFunction
-import org.apache.mahout.math.Vector
-import org.apache.mahout.math.DenseVector
-import org.apache.mahout.math.Matrix
-import org.apache.mahout.math.MatrixWritable
-import org.apache.mahout.math.VectorWritable
-import org.apache.mahout.math.drm._
-import org.apache.mahout.math.scalabindings._
-import org.apache.mahout.flinkbindings.FlinkDistributedContext
-import org.apache.mahout.flinkbindings.drm.FlinkDrm
-import org.apache.mahout.flinkbindings.drm.BlockifiedFlinkDrm
-import org.apache.mahout.flinkbindings.drm.RowsFlinkDrm
-import org.apache.mahout.flinkbindings.drm.CheckpointedFlinkDrm
-import org.apache.flink.api.common.functions.FilterFunction
-
-package object flinkbindings {
-
- private[flinkbindings] val log = LoggerFactory.getLogger("apache.org.mahout.flinkbingings")
-
- /** Row-wise organized DRM dataset type */
- type DrmDataSet[K] = DataSet[DrmTuple[K]]
-
- /**
- * Blockifed DRM dataset (keys of original DRM are grouped into array corresponding to rows of Matrix
- * object value
- */
- type BlockifiedDrmDataSet[K] = DataSet[BlockifiedDrmTuple[K]]
-
-
- implicit def wrapMahoutContext(context: DistributedContext): FlinkDistributedContext = {
- assert(context.isInstanceOf[FlinkDistributedContext], "it must be FlinkDistributedContext")
- context.asInstanceOf[FlinkDistributedContext]
- }
-
- implicit def wrapContext(env: ExecutionEnvironment): FlinkDistributedContext =
- new FlinkDistributedContext(env)
- implicit def unwrapContext(ctx: FlinkDistributedContext): ExecutionEnvironment = ctx.env
-
- private[flinkbindings] implicit def castCheckpointedDrm[K: ClassTag](drm: CheckpointedDrm[K]): CheckpointedFlinkDrm[K] = {
- assert(drm.isInstanceOf[CheckpointedFlinkDrm[K]], "it must be a Flink-backed matrix")
- drm.asInstanceOf[CheckpointedFlinkDrm[K]]
- }
-
- implicit def checkpointeDrmToFlinkDrm[K: ClassTag](cp: CheckpointedDrm[K]): FlinkDrm[K] = {
- val flinkDrm = castCheckpointedDrm(cp)
- new RowsFlinkDrm[K](flinkDrm.ds, flinkDrm.ncol)
- }
-
- private[flinkbindings] implicit def wrapAsWritable(m: Matrix): MatrixWritable = new MatrixWritable(m)
- private[flinkbindings] implicit def wrapAsWritable(v: Vector): VectorWritable = new VectorWritable(v)
- private[flinkbindings] implicit def unwrapFromWritable(w: MatrixWritable): Matrix = w.get()
- private[flinkbindings] implicit def unwrapFromWritable(w: VectorWritable): Vector = w.get()
-
-
- def readCsv(file: String, delim: String = ",", comment: String = "#")
- (implicit dc: DistributedContext): CheckpointedDrm[Int] = {
- val vectors = dc.env.readTextFile(file)
- .filter(new FilterFunction[String] {
- def filter(in: String): Boolean = {
- !in.startsWith(comment)
- }
- })
- .map(new MapFunction[String, Vector] {
- def map(in: String): Vector = {
- val array = in.split(delim).map(_.toDouble)
- new DenseVector(array)
- }
- })
- datasetToDrm(vectors)
- }
-
- def datasetToDrm(ds: DataSet[Vector]): CheckpointedDrm[Int] = {
- val zipped = new DataSetOps(ds).zipWithIndex
- datasetWrap(zipped)
- }
-
- def datasetWrap[K: ClassTag](dataset: DataSet[(K, Vector)]): CheckpointedDrm[K] = {
- new CheckpointedFlinkDrm[K](dataset)
- }
-
-
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.mahout
+
+import scala.reflect.ClassTag
+import org.slf4j.LoggerFactory
+import org.apache.flink.api.java.DataSet
+import org.apache.flink.api.java.ExecutionEnvironment
+import org.apache.flink.api.common.functions.MapFunction
+import org.apache.mahout.math.Vector
+import org.apache.mahout.math.DenseVector
+import org.apache.mahout.math.Matrix
+import org.apache.mahout.math.MatrixWritable
+import org.apache.mahout.math.VectorWritable
+import org.apache.mahout.math.drm._
+import org.apache.mahout.math.scalabindings._
+import org.apache.mahout.flinkbindings.FlinkDistributedContext
+import org.apache.mahout.flinkbindings.drm.FlinkDrm
+import org.apache.mahout.flinkbindings.drm.BlockifiedFlinkDrm
+import org.apache.mahout.flinkbindings.drm.RowsFlinkDrm
+import org.apache.mahout.flinkbindings.drm.CheckpointedFlinkDrm
+import org.apache.flink.api.common.functions.FilterFunction
+
+package object flinkbindings {
+
+ private[flinkbindings] val log = LoggerFactory.getLogger("apache.org.mahout.flinkbingings")
+
+ /** Row-wise organized DRM dataset type */
+ type DrmDataSet[K] = DataSet[DrmTuple[K]]
+
+ /**
+ * Blockifed DRM dataset (keys of original DRM are grouped into array corresponding to rows of Matrix
+ * object value
+ */
+ type BlockifiedDrmDataSet[K] = DataSet[BlockifiedDrmTuple[K]]
+
+
+ implicit def wrapMahoutContext(context: DistributedContext): FlinkDistributedContext = {
+ assert(context.isInstanceOf[FlinkDistributedContext], "it must be FlinkDistributedContext")
+ context.asInstanceOf[FlinkDistributedContext]
+ }
+
+ implicit def wrapContext(env: ExecutionEnvironment): FlinkDistributedContext =
+ new FlinkDistributedContext(env)
+ implicit def unwrapContext(ctx: FlinkDistributedContext): ExecutionEnvironment = ctx.env
+
+ private[flinkbindings] implicit def castCheckpointedDrm[K: ClassTag](drm: CheckpointedDrm[K]): CheckpointedFlinkDrm[K] = {
+ assert(drm.isInstanceOf[CheckpointedFlinkDrm[K]], "it must be a Flink-backed matrix")
+ drm.asInstanceOf[CheckpointedFlinkDrm[K]]
+ }
+
+ implicit def checkpointeDrmToFlinkDrm[K: ClassTag](cp: CheckpointedDrm[K]): FlinkDrm[K] = {
+ val flinkDrm = castCheckpointedDrm(cp)
+ new RowsFlinkDrm[K](flinkDrm.ds, flinkDrm.ncol)
+ }
+
+ private[flinkbindings] implicit def wrapAsWritable(m: Matrix): MatrixWritable = new MatrixWritable(m)
+ private[flinkbindings] implicit def wrapAsWritable(v: Vector): VectorWritable = new VectorWritable(v)
+ private[flinkbindings] implicit def unwrapFromWritable(w: MatrixWritable): Matrix = w.get()
+ private[flinkbindings] implicit def unwrapFromWritable(w: VectorWritable): Vector = w.get()
+
+
+ def readCsv(file: String, delim: String = ",", comment: String = "#")
+ (implicit dc: DistributedContext): CheckpointedDrm[Int] = {
+ val vectors = dc.env.readTextFile(file)
+ .filter(new FilterFunction[String] {
+ def filter(in: String): Boolean = {
+ !in.startsWith(comment)
+ }
+ })
+ .map(new MapFunction[String, Vector] {
+ def map(in: String): Vector = {
+ val array = in.split(delim).map(_.toDouble)
+ new DenseVector(array)
+ }
+ })
+ datasetToDrm(vectors)
+ }
+
+ def datasetToDrm(ds: DataSet[Vector]): CheckpointedDrm[Int] = {
+ val zipped = new DataSetOps(ds).zipWithIndex
+ datasetWrap(zipped)
+ }
+
+ def datasetWrap[K: ClassTag](dataset: DataSet[(K, Vector)]): CheckpointedDrm[K] = {
+ new CheckpointedFlinkDrm[K](dataset)
+ }
+
+
}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/mahout/blob/08ad113f/flink/src/test/scala/org/apache/mahout/flinkbindings/FlinkByteBCastSuite.scala
----------------------------------------------------------------------
diff --git a/flink/src/test/scala/org/apache/mahout/flinkbindings/FlinkByteBCastSuite.scala b/flink/src/test/scala/org/apache/mahout/flinkbindings/FlinkByteBCastSuite.scala
new file mode 100644
index 0000000..6dcedd9
--- /dev/null
+++ b/flink/src/test/scala/org/apache/mahout/flinkbindings/FlinkByteBCastSuite.scala
@@ -0,0 +1,27 @@
+package org.apache.mahout.flinkbindings
+
+import org.apache.mahout.flinkbindings._
+import org.apache.mahout.math._
+import org.apache.mahout.math.drm._
+import org.apache.mahout.math.drm.RLikeDrmOps._
+import org.apache.mahout.math.scalabindings._
+import org.apache.mahout.math.scalabindings.RLikeOps._
+import org.junit.runner.RunWith
+import org.scalatest.FunSuite
+import org.scalatest.junit.JUnitRunner
+
+@RunWith(classOf[JUnitRunner])
+class FlinkByteBCastSuite extends FunSuite {
+
+ test("BCast vector") {
+ val v = dvec(1, 2, 3)
+ val vBc = FlinkByteBCast.wrap(v)
+ assert((v - vBc.value).norm(2) <= 1e-6)
+ }
+
+ test("BCast matrix") {
+ val m = dense((1, 2), (3, 4))
+ val mBc = FlinkByteBCast.wrap(m)
+ assert((m - mBc.value).norm <= 1e-6)
+ }
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/mahout/blob/08ad113f/flink/src/test/scala/org/apache/mahout/flinkbindings/RLikeOpsSuite.scala
----------------------------------------------------------------------
diff --git a/flink/src/test/scala/org/apache/mahout/flinkbindings/RLikeOpsSuite.scala b/flink/src/test/scala/org/apache/mahout/flinkbindings/RLikeOpsSuite.scala
index 707bfc9..fa924a9 100644
--- a/flink/src/test/scala/org/apache/mahout/flinkbindings/RLikeOpsSuite.scala
+++ b/flink/src/test/scala/org/apache/mahout/flinkbindings/RLikeOpsSuite.scala
@@ -18,22 +18,18 @@
*/
package org.apache.mahout.flinkbindings
-import org.junit.runner.RunWith
-import org.scalatest.junit.JUnitRunner
-import org.scalatest.FunSuite
+import org.apache.mahout.flinkbindings._
import org.apache.mahout.math._
-import scalabindings._
-import RLikeOps._
import org.apache.mahout.math.drm._
-import RLikeDrmOps._
-import org.apache.mahout.flinkbindings._
-import org.apache.mahout.math.function.IntIntFunction
-import scala.util.Random
-import scala.util.MurmurHash
-import scala.util.hashing.MurmurHash3
+import org.apache.mahout.math.drm.RLikeDrmOps._
+import org.apache.mahout.math.scalabindings._
+import org.apache.mahout.math.scalabindings.RLikeOps._
+import org.junit.runner.RunWith
+import org.scalatest.FunSuite
+import org.scalatest.junit.JUnitRunner
import org.slf4j.Logger
import org.slf4j.LoggerFactory
-import org.scalatest.Ignore
+
@RunWith(classOf[JUnitRunner])
class RLikeOpsSuite extends FunSuite with DistributedFlinkSuit {
@@ -251,4 +247,19 @@ class RLikeOpsSuite extends FunSuite with DistributedFlinkSuit {
assert((res.collect - expected).norm < 1e-6)
}
+ test("drmBroadcast") {
+ val inCoreA = dense((1, 2), (3, 4), (11, 4))
+ val x = dvec(1, 2)
+ val A = drmParallelize(m = inCoreA, numPartitions = 2)
+
+ val b = drmBroadcast(x)
+
+ val res = A.mapBlock(1) { case (idx, block) =>
+ (idx, (block %*% b).toColMatrix)
+ }
+
+ val expected = inCoreA %*% x
+ assert((res.collect(::, 0) - expected).norm(2) < 1e-6)
+ }
+
}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/mahout/blob/08ad113f/flink/src/test/scala/org/apache/mahout/flinkbindings/UseCasesSuite.scala
----------------------------------------------------------------------
diff --git a/flink/src/test/scala/org/apache/mahout/flinkbindings/UseCasesSuite.scala b/flink/src/test/scala/org/apache/mahout/flinkbindings/UseCasesSuite.scala
index 5b0bc46..a144f6d 100644
--- a/flink/src/test/scala/org/apache/mahout/flinkbindings/UseCasesSuite.scala
+++ b/flink/src/test/scala/org/apache/mahout/flinkbindings/UseCasesSuite.scala
@@ -18,22 +18,18 @@
*/
package org.apache.mahout.flinkbindings
-import org.junit.runner.RunWith
-import org.scalatest.junit.JUnitRunner
-import org.scalatest.FunSuite
-import org.apache.mahout.math._
-import scalabindings._
-import RLikeOps._
+import scala.util.hashing.MurmurHash3
+import org.apache.mahout.math.Matrices
+import org.apache.mahout.math.Vector
import org.apache.mahout.math.drm._
-import RLikeDrmOps._
-import org.apache.mahout.flinkbindings._
+import org.apache.mahout.math.drm.RLikeDrmOps._
+import org.apache.mahout.math.scalabindings._
+import org.apache.mahout.math.scalabindings.RLikeOps._
import org.apache.mahout.math.function.IntIntFunction
-import scala.util.Random
-import scala.util.MurmurHash
-import scala.util.hashing.MurmurHash3
-import org.slf4j.Logger
+import org.junit.runner.RunWith
import org.slf4j.LoggerFactory
-import org.scalatest.Ignore
+import org.scalatest.FunSuite
+import org.scalatest.junit.JUnitRunner
@RunWith(classOf[JUnitRunner])
class UseCasesSuite extends FunSuite with DistributedFlinkSuit {
http://git-wip-us.apache.org/repos/asf/mahout/blob/08ad113f/flink/src/test/scala/org/apache/mahout/flinkbindings/examples/ReadCsvExample.scala
----------------------------------------------------------------------
diff --git a/flink/src/test/scala/org/apache/mahout/flinkbindings/examples/ReadCsvExample.scala b/flink/src/test/scala/org/apache/mahout/flinkbindings/examples/ReadCsvExample.scala
index 074f9a2..a9e8436 100644
--- a/flink/src/test/scala/org/apache/mahout/flinkbindings/examples/ReadCsvExample.scala
+++ b/flink/src/test/scala/org/apache/mahout/flinkbindings/examples/ReadCsvExample.scala
@@ -1,39 +1,39 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.mahout.flinkbindings.examples
-
-import org.apache.flink.api.java.ExecutionEnvironment
-import org.apache.mahout.math.drm._
-import org.apache.mahout.math.drm.RLikeDrmOps._
-import org.apache.mahout.flinkbindings._
-
-object ReadCsvExample {
-
- def main(args: Array[String]): Unit = {
- val filePath = "file:///c:/tmp/data/slashdot0902/Slashdot0902.txt"
-
- val env = ExecutionEnvironment.getExecutionEnvironment
- implicit val ctx = new FlinkDistributedContext(env)
-
- val drm = readCsv(filePath, delim = "\t", comment = "#")
- val C = drm.t %*% drm
- println(C.collect)
- }
-
-}
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.mahout.flinkbindings.examples
+
+import org.apache.flink.api.java.ExecutionEnvironment
+import org.apache.mahout.math.drm._
+import org.apache.mahout.math.drm.RLikeDrmOps._
+import org.apache.mahout.flinkbindings._
+
+object ReadCsvExample {
+
+ def main(args: Array[String]): Unit = {
+ val filePath = "file:///c:/tmp/data/slashdot0902/Slashdot0902.txt"
+
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ implicit val ctx = new FlinkDistributedContext(env)
+
+ val drm = readCsv(filePath, delim = "\t", comment = "#")
+ val C = drm.t %*% drm
+ println(C.collect)
+ }
+
+}