You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ra...@apache.org on 2017/01/13 19:52:57 UTC
mahout git commit: MAHOUT-1896: Add convenience methods for
interacting with SparkML closes apache/mahout-263
Repository: mahout
Updated Branches:
refs/heads/master 1e0812876 -> b3b72cb65
MAHOUT-1896: Add convenience methods for interacting with SparkML closes apache/mahout-263
Project: http://git-wip-us.apache.org/repos/asf/mahout/repo
Commit: http://git-wip-us.apache.org/repos/asf/mahout/commit/b3b72cb6
Tree: http://git-wip-us.apache.org/repos/asf/mahout/tree/b3b72cb6
Diff: http://git-wip-us.apache.org/repos/asf/mahout/diff/b3b72cb6
Branch: refs/heads/master
Commit: b3b72cb658ed6ad59a092eb554b06898163a1375
Parents: 1e08128
Author: rawkintrevo <tr...@gmail.com>
Authored: Fri Jan 13 13:52:35 2017 -0600
Committer: rawkintrevo <tr...@gmail.com>
Committed: Fri Jan 13 13:52:35 2017 -0600
----------------------------------------------------------------------
spark/pom.xml | 6 ++
.../apache/mahout/sparkbindings/package.scala | 51 +++++++++++++-
.../mahout/sparkbindings/drm/DrmLikeSuite.scala | 74 ++++++++++++++++++++
3 files changed, 130 insertions(+), 1 deletion(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/mahout/blob/b3b72cb6/spark/pom.xml
----------------------------------------------------------------------
diff --git a/spark/pom.xml b/spark/pom.xml
index 5fc9863..f965d38 100644
--- a/spark/pom.xml
+++ b/spark/pom.xml
@@ -149,6 +149,12 @@
</dependency>
<dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-mllib_${scala.compat.version}</artifactId>
+ <version>${spark.version}</version>
+ </dependency>
+
+ <dependency>
<groupId>org.apache.mahout</groupId>
<artifactId>mahout-math-scala_${scala.compat.version}</artifactId>
</dependency>
http://git-wip-us.apache.org/repos/asf/mahout/blob/b3b72cb6/spark/src/main/scala/org/apache/mahout/sparkbindings/package.scala
----------------------------------------------------------------------
diff --git a/spark/src/main/scala/org/apache/mahout/sparkbindings/package.scala b/spark/src/main/scala/org/apache/mahout/sparkbindings/package.scala
index acca75e..8064cf0 100644
--- a/spark/src/main/scala/org/apache/mahout/sparkbindings/package.scala
+++ b/spark/src/main/scala/org/apache/mahout/sparkbindings/package.scala
@@ -21,12 +21,15 @@ import java.io._
import org.apache.mahout.logging._
import org.apache.mahout.math.drm._
-import org.apache.mahout.math.{MatrixWritable, VectorWritable, Matrix, Vector}
+import org.apache.mahout.math.{Matrix, MatrixWritable, Vector, VectorWritable}
import org.apache.mahout.sparkbindings.drm.{CheckpointedDrmSpark, CheckpointedDrmSparkOps, SparkBCast}
import org.apache.mahout.util.IOUtilsScala
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.linalg.{Vector => SparkVector, SparseVector => SparseSparkVector, DenseVector => DenseSparkVector}
+import org.apache.spark.sql.DataFrame
import collection._
import collection.generic.Growable
@@ -141,6 +144,52 @@ package object sparkbindings {
new CheckpointedDrmSpark[K](rddInput = rdd, _nrow = nrow, _ncol = ncol, cacheHint = cacheHint,
_canHaveMissingRows = canHaveMissingRows)
+ /** A drmWrap version that takes an RDD[org.apache.spark.mllib.regression.LabeledPoint]
+ * returns a DRM where column the label is the last column */
+ def drmWrapMLLibLabeledPoint(rdd: RDD[LabeledPoint],
+ nrow: Long = -1,
+ ncol: Int = -1,
+ cacheHint: CacheHint.CacheHint = CacheHint.NONE,
+ canHaveMissingRows: Boolean = false): CheckpointedDrm[Int] = {
+ val drmRDD: DrmRdd[Int] = rdd.zipWithIndex.map(lv => {
+ lv._1.features match {
+ case _: DenseSparkVector => (lv._2.toInt, new org.apache.mahout.math.DenseVector( lv._1.features.toArray ++ Array(lv._1.label) ))
+ case _: SparseSparkVector => (lv._2.toInt,
+ new org.apache.mahout.math.RandomAccessSparseVector(new org.apache.mahout.math.DenseVector( lv._1.features.toArray ++ Array(lv._1.label) )) )
+ }
+ })
+
+ drmWrap(drmRDD, nrow, ncol, cacheHint, canHaveMissingRows)
+ }
+
+ /** A drmWrap version that takes a DataFrame of Row[Double] */
+ def drmWrapDataFrame(df: DataFrame,
+ nrow: Long = -1,
+ ncol: Int = -1,
+ cacheHint: CacheHint.CacheHint = CacheHint.NONE,
+ canHaveMissingRows: Boolean = false): CheckpointedDrm[Int] = {
+ val drmRDD: DrmRdd[Int] = df.rdd
+ .zipWithIndex
+ .map( o => (o._2.toInt, o._1.mkString(",").split(",").map(s => s.toDouble)) )
+ .map(o => (o._1, new org.apache.mahout.math.DenseVector( o._2 )))
+
+ drmWrap(drmRDD, nrow, ncol, cacheHint, canHaveMissingRows)
+ }
+
+ /** A drmWrap Version that takes an RDD[org.apache.spark.mllib.linalg.Vector] */
+ def drmWrapMLLibVector(rdd: RDD[SparkVector],
+ nrow: Long = -1,
+ ncol: Int = -1,
+ cacheHint: CacheHint.CacheHint = CacheHint.NONE,
+ canHaveMissingRows: Boolean = false): CheckpointedDrm[Int] = {
+ val drmRDD: DrmRdd[Int] = rdd.zipWithIndex.map( v => {
+ v._1 match {
+ case _: DenseSparkVector => (v._2.toInt, new org.apache.mahout.math.DenseVector(v._1.toArray))
+ case _: SparseSparkVector => (v._2.toInt, new org.apache.mahout.math.RandomAccessSparseVector(new org.apache.mahout.math.DenseVector(v._1.toArray)) )
+ }
+ })
+ drmWrap(drmRDD, nrow, ncol, cacheHint, canHaveMissingRows)
+ }
/** Another drmWrap version that takes in vertical block-partitioned input to form the matrix. */
def drmWrapBlockified[K: ClassTag](blockifiedDrmRdd: BlockifiedDrmRdd[K], nrow: Long = -1, ncol: Int = -1,
http://git-wip-us.apache.org/repos/asf/mahout/blob/b3b72cb6/spark/src/test/scala/org/apache/mahout/sparkbindings/drm/DrmLikeSuite.scala
----------------------------------------------------------------------
diff --git a/spark/src/test/scala/org/apache/mahout/sparkbindings/drm/DrmLikeSuite.scala b/spark/src/test/scala/org/apache/mahout/sparkbindings/drm/DrmLikeSuite.scala
index 8f9b00f..e88e7ef 100644
--- a/spark/src/test/scala/org/apache/mahout/sparkbindings/drm/DrmLikeSuite.scala
+++ b/spark/src/test/scala/org/apache/mahout/sparkbindings/drm/DrmLikeSuite.scala
@@ -23,8 +23,10 @@ import scalabindings._
import drm._
import RLikeOps._
import RLikeDrmOps._
+import org.apache.mahout.sparkbindings._
import org.apache.mahout.sparkbindings.test.DistributedSparkSuite
+case class Thingy(thing1: Double, thing2: Double, thing3: Double)
/** DRMLike tests -- just run common DRM tests in Spark. */
class DrmLikeSuite extends FunSuite with DistributedSparkSuite with DrmLikeSuiteBase {
@@ -63,6 +65,78 @@ class DrmLikeSuite extends FunSuite with DistributedSparkSuite with DrmLikeSuite
throw new AssertionError("Block must be dense.")
keys -> block
}).norm should be < 1e-4
+
+ }
+
+ test("DRM wrap labeled points") {
+
+ import org.apache.spark.mllib.linalg.{Vectors => SparkVector}
+ import org.apache.spark.mllib.regression.LabeledPoint
+
+ val sc = mahoutCtx.asInstanceOf[SparkDistributedContext].sc
+
+ val lpRDD = sc.parallelize(Seq(LabeledPoint(1.0, SparkVector.dense(2.0, 0.0, 4.0)),
+ LabeledPoint(2.0, SparkVector.dense(3.0, 0.0, 5.0)),
+ LabeledPoint(3.0, SparkVector.dense(4.0, 0.0, 6.0)) ))
+
+ val lpDRM = drmWrapMLLibLabeledPoint(rdd = lpRDD)
+ val lpM = lpDRM.collect(::,::)
+ val testM = dense((2,0,4,1), (3,0,5,2), (4,0,6,3))
+ assert(lpM === testM)
}
+ test("DRM wrap spark vectors") {
+
+ import org.apache.spark.mllib.linalg.{Vectors => SparkVector}
+
+ val sc = mahoutCtx.asInstanceOf[SparkDistributedContext].sc
+
+ val svRDD = sc.parallelize(Seq(SparkVector.dense(2.0, 0.0, 4.0),
+ SparkVector.dense(3.0, 0.0, 5.0),
+ SparkVector.dense(4.0, 0.0, 6.0) ))
+
+ val svDRM = drmWrapMLLibVector(rdd = svRDD)
+ val svM = svDRM.collect(::,::)
+ val testM = dense((2,0,4), (3,0,5), (4,0,6))
+
+ assert(svM === testM)
+
+ val ssvRDD = sc.parallelize(Seq(SparkVector.sparse(3, Array(1,2), Array(3,4)),
+ SparkVector.sparse(3, Array(0,2), Array(3,4)),
+ SparkVector.sparse(3, Array(0,1), Array(3,4))) )
+
+ val ssvDRM = drmWrapMLLibVector(rdd = ssvRDD)
+ val ssvM = ssvDRM.collect(::,::)
+
+ val testSM = sparse(
+ (1, 3) :: (2, 4) :: Nil,
+ (0, 3) :: (2, 4) :: Nil,
+ (0, 3) :: (1, 4) :: Nil)
+
+ assert(ssvM === testSM)
+ }
+
+
+
+ test("DRM wrap spark dataframe") {
+
+ import org.apache.spark.mllib.linalg.{Vectors => SparkVector}
+
+ val sc = mahoutCtx.asInstanceOf[SparkDistributedContext].sc
+
+ val sqlContext= new org.apache.spark.sql.SQLContext(sc)
+ import sqlContext.implicits._
+
+ val myDF = sc.parallelize(Seq((2.0, 0.0, 4.0),
+ (3.0, 0.0, 5.0),
+ (4.0, 0.0, 6.0) ))
+ .map(o => Thingy(o._1, o._2, o._3))
+ .toDF()
+
+ val dfDRM = drmWrapDataFrame(df = myDF)
+ val dfM = dfDRM.collect(::,::)
+ val testM = dense((2,0,4), (3,0,5), (4,0,6))
+
+ assert(dfM === testM)
+ }
}