You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/05/29 07:38:41 UTC

spark git commit: [SPARK-7922] [MLLIB] use DataFrames for user/item factors in ALSModel

Repository: spark
Updated Branches:
  refs/heads/master cd3d9a5c0 -> db9513789


[SPARK-7922] [MLLIB] use DataFrames for user/item factors in ALSModel

Expose user/item factors in DataFrames. This is to be more consistent with the pipeline API. It also helps maintain consistent APIs across languages. This PR also removed fitting params from `ALSModel`.

coderxiang

Author: Xiangrui Meng <me...@databricks.com>

Closes #6468 from mengxr/SPARK-7922 and squashes the following commits:

7bfb1d5 [Xiangrui Meng] update ALSModel in PySpark
1ba5607 [Xiangrui Meng] use DataFrames for user/item factors in ALS


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/db951378
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/db951378
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/db951378

Branch: refs/heads/master
Commit: db9513789756da4f16bb1fe8cf1d19500f231f54
Parents: cd3d9a5
Author: Xiangrui Meng <me...@databricks.com>
Authored: Thu May 28 22:38:38 2015 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Thu May 28 22:38:38 2015 -0700

----------------------------------------------------------------------
 .../apache/spark/ml/recommendation/ALS.scala    | 101 +++++++++++--------
 python/pyspark/ml/recommendation.py             |  30 +++++-
 python/pyspark/mllib/common.py                  |   5 +-
 3 files changed, 89 insertions(+), 47 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/db951378/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 900b637..df009d8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -35,21 +35,46 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental}
 import org.apache.spark.ml.{Estimator, Model}
 import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared._
-import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
 import org.apache.spark.mllib.optimization.NNLS
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructField, StructType}
+import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructType}
 import org.apache.spark.storage.StorageLevel
 import org.apache.spark.util.Utils
 import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet, SortDataFormat, Sorter}
 import org.apache.spark.util.random.XORShiftRandom
 
 /**
+ * Common params for ALS and ALSModel.
+ */
+private[recommendation] trait ALSModelParams extends Params with HasPredictionCol {
+  /**
+   * Param for the column name for user ids.
+   * Default: "user"
+   * @group param
+   */
+  val userCol = new Param[String](this, "userCol", "column name for user ids")
+
+  /** @group getParam */
+  def getUserCol: String = $(userCol)
+
+  /**
+   * Param for the column name for item ids.
+   * Default: "item"
+   * @group param
+   */
+  val itemCol = new Param[String](this, "itemCol", "column name for item ids")
+
+  /** @group getParam */
+  def getItemCol: String = $(itemCol)
+}
+
+/**
  * Common params for ALS.
  */
-private[recommendation] trait ALSParams extends Params with HasMaxIter with HasRegParam
+private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter with HasRegParam
   with HasPredictionCol with HasCheckpointInterval with HasSeed {
 
   /**
@@ -106,26 +131,6 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR
   def getAlpha: Double = $(alpha)
 
   /**
-   * Param for the column name for user ids.
-   * Default: "user"
-   * @group param
-   */
-  val userCol = new Param[String](this, "userCol", "column name for user ids")
-
-  /** @group getParam */
-  def getUserCol: String = $(userCol)
-
-  /**
-   * Param for the column name for item ids.
-   * Default: "item"
-   * @group param
-   */
-  val itemCol = new Param[String](this, "itemCol", "column name for item ids")
-
-  /** @group getParam */
-  def getItemCol: String = $(itemCol)
-
-  /**
    * Param for the column name for ratings.
    * Default: "rating"
    * @group param
@@ -156,55 +161,60 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR
    * @return output schema
    */
   protected def validateAndTransformSchema(schema: StructType): StructType = {
-    require(schema($(userCol)).dataType == IntegerType)
-    require(schema($(itemCol)).dataType== IntegerType)
+    SchemaUtils.checkColumnType(schema, $(userCol), IntegerType)
+    SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType)
     val ratingType = schema($(ratingCol)).dataType
     require(ratingType == FloatType || ratingType == DoubleType)
-    val predictionColName = $(predictionCol)
-    require(!schema.fieldNames.contains(predictionColName),
-      s"Prediction column $predictionColName already exists.")
-    val newFields = schema.fields :+ StructField($(predictionCol), FloatType, nullable = false)
-    StructType(newFields)
+    SchemaUtils.appendColumn(schema, $(predictionCol), FloatType)
   }
 }
 
 /**
  * :: Experimental ::
  * Model fitted by ALS.
+ *
+ * @param rank rank of the matrix factorization model
+ * @param userFactors a DataFrame that stores user factors in two columns: `id` and `features`
+ * @param itemFactors a DataFrame that stores item factors in two columns: `id` and `features`
  */
 @Experimental
 class ALSModel private[ml] (
     override val uid: String,
-    k: Int,
-    userFactors: RDD[(Int, Array[Float])],
-    itemFactors: RDD[(Int, Array[Float])])
-  extends Model[ALSModel] with ALSParams {
+    val rank: Int,
+    @transient val userFactors: DataFrame,
+    @transient val itemFactors: DataFrame)
+  extends Model[ALSModel] with ALSModelParams {
+
+  /** @group setParam */
+  def setUserCol(value: String): this.type = set(userCol, value)
+
+  /** @group setParam */
+  def setItemCol(value: String): this.type = set(itemCol, value)
 
   /** @group setParam */
   def setPredictionCol(value: String): this.type = set(predictionCol, value)
 
   override def transform(dataset: DataFrame): DataFrame = {
-    import dataset.sqlContext.implicits._
-    val users = userFactors.toDF("id", "features")
-    val items = itemFactors.toDF("id", "features")
-
     // Register a UDF for DataFrame, and then
     // create a new column named map(predictionCol) by running the predict UDF.
     val predict = udf { (userFeatures: Seq[Float], itemFeatures: Seq[Float]) =>
       if (userFeatures != null && itemFeatures != null) {
-        blas.sdot(k, userFeatures.toArray, 1, itemFeatures.toArray, 1)
+        blas.sdot(rank, userFeatures.toArray, 1, itemFeatures.toArray, 1)
       } else {
         Float.NaN
       }
     }
     dataset
-      .join(users, dataset($(userCol)) === users("id"), "left")
-      .join(items, dataset($(itemCol)) === items("id"), "left")
-      .select(dataset("*"), predict(users("features"), items("features")).as($(predictionCol)))
+      .join(userFactors, dataset($(userCol)) === userFactors("id"), "left")
+      .join(itemFactors, dataset($(itemCol)) === itemFactors("id"), "left")
+      .select(dataset("*"),
+        predict(userFactors("features"), itemFactors("features")).as($(predictionCol)))
   }
 
   override def transformSchema(schema: StructType): StructType = {
-    validateAndTransformSchema(schema)
+    SchemaUtils.checkColumnType(schema, $(userCol), IntegerType)
+    SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType)
+    SchemaUtils.appendColumn(schema, $(predictionCol), FloatType)
   }
 }
 
@@ -299,6 +309,7 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams {
   }
 
   override def fit(dataset: DataFrame): ALSModel = {
+    import dataset.sqlContext.implicits._
     val ratings = dataset
       .select(col($(userCol)).cast(IntegerType), col($(itemCol)).cast(IntegerType),
         col($(ratingCol)).cast(FloatType))
@@ -310,7 +321,9 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams {
       maxIter = $(maxIter), regParam = $(regParam), implicitPrefs = $(implicitPrefs),
       alpha = $(alpha), nonnegative = $(nonnegative),
       checkpointInterval = $(checkpointInterval), seed = $(seed))
-    val model = new ALSModel(uid, $(rank), userFactors, itemFactors).setParent(this)
+    val userDF = userFactors.toDF("id", "features")
+    val itemDF = itemFactors.toDF("id", "features")
+    val model = new ALSModel(uid, $(rank), userDF, itemDF).setParent(this)
     copyValues(model)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/db951378/python/pyspark/ml/recommendation.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py
index b3e0dd7..b06099a 100644
--- a/python/pyspark/ml/recommendation.py
+++ b/python/pyspark/ml/recommendation.py
@@ -63,8 +63,15 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
     indicated user preferences rather than explicit ratings given to
     items.
 
+    >>> df = sqlContext.createDataFrame(
+    ...     [(0, 0, 4.0), (0, 1, 2.0), (1, 1, 3.0), (1, 2, 4.0), (2, 1, 1.0), (2, 2, 5.0)],
+    ...     ["user", "item", "rating"])
     >>> als = ALS(rank=10, maxIter=5)
     >>> model = als.fit(df)
+    >>> model.rank
+    10
+    >>> model.userFactors.orderBy("id").collect()
+    [Row(id=0, features=[...]), Row(id=1, ...), Row(id=2, ...)]
     >>> test = sqlContext.createDataFrame([(0, 2), (1, 0), (2, 0)], ["user", "item"])
     >>> predictions = sorted(model.transform(test).collect(), key=lambda r: r[0])
     >>> predictions[0]
@@ -260,6 +267,27 @@ class ALSModel(JavaModel):
     Model fitted by ALS.
     """
 
+    @property
+    def rank(self):
+        """rank of the matrix factorization model"""
+        return self._call_java("rank")
+
+    @property
+    def userFactors(self):
+        """
+        a DataFrame that stores user factors in two columns: `id` and
+        `features`
+        """
+        return self._call_java("userFactors")
+
+    @property
+    def itemFactors(self):
+        """
+        a DataFrame that stores item factors in two columns: `id` and
+        `features`
+        """
+        return self._call_java("itemFactors")
+
 
 if __name__ == "__main__":
     import doctest
@@ -272,8 +300,6 @@ if __name__ == "__main__":
     sqlContext = SQLContext(sc)
     globs['sc'] = sc
     globs['sqlContext'] = sqlContext
-    globs['df'] = sqlContext.createDataFrame([(0, 0, 4.0), (0, 1, 2.0), (1, 1, 3.0), (1, 2, 4.0),
-                                              (2, 1, 1.0), (2, 2, 5.0)], ["user", "item", "rating"])
     (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
     sc.stop()
     if failure_count:

http://git-wip-us.apache.org/repos/asf/spark/blob/db951378/python/pyspark/mllib/common.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py
index ba60589..855e85f 100644
--- a/python/pyspark/mllib/common.py
+++ b/python/pyspark/mllib/common.py
@@ -27,7 +27,7 @@ from py4j.java_collections import ListConverter, JavaArray, JavaList
 
 from pyspark import RDD, SparkContext
 from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
-
+from pyspark.sql import DataFrame, SQLContext
 
 # Hack for support float('inf') in Py4j
 _old_smart_decode = py4j.protocol.smart_decode
@@ -99,6 +99,9 @@ def _java2py(sc, r, encoding="bytes"):
             jrdd = sc._jvm.SerDe.javaToPython(r)
             return RDD(jrdd, sc)
 
+        if clsName == 'DataFrame':
+            return DataFrame(r, SQLContext(sc))
+
         if clsName in _picklable_classes:
             r = sc._jvm.SerDe.dumps(r)
         elif isinstance(r, (JavaArray, JavaList)):


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org