You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ml...@apache.org on 2017/02/28 14:17:37 UTC
spark git commit: [SPARK-14489][ML][PYSPARK] ALS unknown user/item
prediction strategy
Repository: spark
Updated Branches:
refs/heads/master 9b8eca65d -> b40546651
[SPARK-14489][ML][PYSPARK] ALS unknown user/item prediction strategy
This PR adds a param to `ALS`/`ALSModel` to set the strategy used when encountering unknown users or items at prediction time in `transform`. This can occur in 2 scenarios: (a) production scoring, and (b) cross-validation & evaluation.
The current behavior returns `NaN` if a user/item is unknown. In scenario (b), this can easily occur when using `CrossValidator` or `TrainValidationSplit` since some users/items may only occur in the test set and not in the training set. In this case, the evaluator returns `NaN` for all metrics, making model selection impossible.
The new param, `coldStartStrategy`, defaults to `nan` (the current behavior). The other option supported initially is `drop`, which drops all rows with `NaN` predictions. This flag allows users to use `ALS` in cross-validation settings. It is made an `expertParam`. The param is made a string so that the set of strategies can be extended in future (some options are discussed in [SPARK-14489](https://issues.apache.org/jira/browse/SPARK-14489)).
## How was this patch tested?
New unit tests, and manual "before and after" tests for Scala & Python using MovieLens `ml-latest-small` as example data. Here, using `CrossValidator` or `TrainValidationSplit` with the default param setting results in metrics that are all `NaN`, while setting `coldStartStrategy` to `drop` results in valid metrics.
Author: Nick Pentreath <ni...@za.ibm.com>
Closes #12896 from MLnick/SPARK-14489-als-nan.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b4054665
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b4054665
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b4054665
Branch: refs/heads/master
Commit: b405466513bcc02cadf1477b6b682ace95d81658
Parents: 9b8eca6
Author: Nick Pentreath <ni...@za.ibm.com>
Authored: Tue Feb 28 16:17:35 2017 +0200
Committer: Nick Pentreath <ni...@za.ibm.com>
Committed: Tue Feb 28 16:17:35 2017 +0200
----------------------------------------------------------------------
.../apache/spark/ml/recommendation/ALS.scala | 44 ++++++++++++++++-
.../spark/ml/recommendation/ALSSuite.scala | 51 +++++++++++++++++++-
python/pyspark/ml/recommendation.py | 30 ++++++++++--
3 files changed, 116 insertions(+), 9 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/b4054665/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 97c8655..af00762 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -90,6 +90,27 @@ private[recommendation] trait ALSModelParams extends Params with HasPredictionCo
n.toInt
}
}
+
+ /**
+ * Param for strategy for dealing with unknown or new users/items at prediction time.
+ * This may be useful in cross-validation or production scenarios, for handling user/item ids
+ * the model has not seen in the training data.
+ * Supported values:
+ * - "nan": predicted value for unknown ids will be NaN.
+ * - "drop": rows in the input DataFrame containing unknown ids will be dropped from
+ * the output DataFrame containing predictions.
+ * Default: "nan".
+ * @group expertParam
+ */
+ val coldStartStrategy = new Param[String](this, "coldStartStrategy",
+ "strategy for dealing with unknown or new users/items at prediction time. This may be " +
+ "useful in cross-validation or production scenarios, for handling user/item ids the model " +
+ "has not seen in the training data. Supported values: " +
+ s"${ALSModel.supportedColdStartStrategies.mkString(",")}.",
+ (s: String) => ALSModel.supportedColdStartStrategies.contains(s.toLowerCase))
+
+ /** @group expertGetParam */
+ def getColdStartStrategy: String = $(coldStartStrategy).toLowerCase
}
/**
@@ -203,7 +224,8 @@ private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter w
setDefault(rank -> 10, maxIter -> 10, regParam -> 0.1, numUserBlocks -> 10, numItemBlocks -> 10,
implicitPrefs -> false, alpha -> 1.0, userCol -> "user", itemCol -> "item",
ratingCol -> "rating", nonnegative -> false, checkpointInterval -> 10,
- intermediateStorageLevel -> "MEMORY_AND_DISK", finalStorageLevel -> "MEMORY_AND_DISK")
+ intermediateStorageLevel -> "MEMORY_AND_DISK", finalStorageLevel -> "MEMORY_AND_DISK",
+ coldStartStrategy -> "nan")
/**
* Validates and transforms the input schema.
@@ -248,6 +270,10 @@ class ALSModel private[ml] (
@Since("1.3.0")
def setPredictionCol(value: String): this.type = set(predictionCol, value)
+ /** @group expertSetParam */
+ @Since("2.2.0")
+ def setColdStartStrategy(value: String): this.type = set(coldStartStrategy, value)
+
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema)
@@ -260,13 +286,19 @@ class ALSModel private[ml] (
Float.NaN
}
}
- dataset
+ val predictions = dataset
.join(userFactors,
checkedCast(dataset($(userCol)).cast(DoubleType)) === userFactors("id"), "left")
.join(itemFactors,
checkedCast(dataset($(itemCol)).cast(DoubleType)) === itemFactors("id"), "left")
.select(dataset("*"),
predict(userFactors("features"), itemFactors("features")).as($(predictionCol)))
+ getColdStartStrategy match {
+ case ALSModel.Drop =>
+ predictions.na.drop("all", Seq($(predictionCol)))
+ case ALSModel.NaN =>
+ predictions
+ }
}
@Since("1.3.0")
@@ -290,6 +322,10 @@ class ALSModel private[ml] (
@Since("1.6.0")
object ALSModel extends MLReadable[ALSModel] {
+ private val NaN = "nan"
+ private val Drop = "drop"
+ private[recommendation] final val supportedColdStartStrategies = Array(NaN, Drop)
+
@Since("1.6.0")
override def read: MLReader[ALSModel] = new ALSModelReader
@@ -432,6 +468,10 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel]
@Since("2.0.0")
def setFinalStorageLevel(value: String): this.type = set(finalStorageLevel, value)
+ /** @group expertSetParam */
+ @Since("2.2.0")
+ def setColdStartStrategy(value: String): this.type = set(coldStartStrategy, value)
+
/**
* Sets both numUserBlocks and numItemBlocks to the specific value.
*
http://git-wip-us.apache.org/repos/asf/spark/blob/b4054665/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
index b923bac..c9e7b50 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
@@ -498,8 +498,8 @@ class ALSSuite
(ex, act) =>
ex.userFactors.first().getSeq[Float](1) === act.userFactors.first.getSeq[Float](1)
} { (ex, act, _) =>
- ex.transform(_: DataFrame).select("prediction").first.getFloat(0) ~==
- act.transform(_: DataFrame).select("prediction").first.getFloat(0) absTol 1e-6
+ ex.transform(_: DataFrame).select("prediction").first.getDouble(0) ~==
+ act.transform(_: DataFrame).select("prediction").first.getDouble(0) absTol 1e-6
}
}
// check user/item ids falling outside of Int range
@@ -547,6 +547,53 @@ class ALSSuite
ALS.train(ratings)
}
}
+
+ test("ALS cold start user/item prediction strategy") {
+ val spark = this.spark
+ import spark.implicits._
+ import org.apache.spark.sql.functions._
+
+ val (ratings, _) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 1)
+ val data = ratings.toDF
+ val knownUser = data.select(max("user")).as[Int].first()
+ val unknownUser = knownUser + 10
+ val knownItem = data.select(max("item")).as[Int].first()
+ val unknownItem = knownItem + 20
+ val test = Seq(
+ (unknownUser, unknownItem),
+ (knownUser, unknownItem),
+ (unknownUser, knownItem),
+ (knownUser, knownItem)
+ ).toDF("user", "item")
+
+ val als = new ALS().setMaxIter(1).setRank(1)
+ // default is 'nan'
+ val defaultModel = als.fit(data)
+ val defaultPredictions = defaultModel.transform(test).select("prediction").as[Float].collect()
+ assert(defaultPredictions.length == 4)
+ assert(defaultPredictions.slice(0, 3).forall(_.isNaN))
+ assert(!defaultPredictions.last.isNaN)
+
+ // check 'drop' strategy should filter out rows with unknown users/items
+ val dropPredictions = defaultModel
+ .setColdStartStrategy("drop")
+ .transform(test)
+ .select("prediction").as[Float].collect()
+ assert(dropPredictions.length == 1)
+ assert(!dropPredictions.head.isNaN)
+ assert(dropPredictions.head ~== defaultPredictions.last relTol 1e-14)
+ }
+
+ test("case insensitive cold start param value") {
+ val spark = this.spark
+ import spark.implicits._
+ val (ratings, _) = genExplicitTestData(numUsers = 2, numItems = 2, rank = 1)
+ val data = ratings.toDF
+ val model = new ALS().fit(data)
+ Seq("nan", "NaN", "Nan", "drop", "DROP", "Drop").foreach { s =>
+ model.setColdStartStrategy(s).transform(data)
+ }
+ }
}
class ALSCleanerSuite extends SparkFunSuite {
http://git-wip-us.apache.org/repos/asf/spark/blob/b4054665/python/pyspark/ml/recommendation.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py
index e28d38b..43f82da 100644
--- a/python/pyspark/ml/recommendation.py
+++ b/python/pyspark/ml/recommendation.py
@@ -125,19 +125,25 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
finalStorageLevel = Param(Params._dummy(), "finalStorageLevel",
"StorageLevel for ALS model factors.",
typeConverter=TypeConverters.toString)
+ coldStartStrategy = Param(Params._dummy(), "coldStartStrategy", "strategy for dealing with " +
+ "unknown or new users/items at prediction time. This may be useful " +
+ "in cross-validation or production scenarios, for handling " +
+ "user/item ids the model has not seen in the training data. " +
+ "Supported values: 'nan', 'drop'.",
+ typeConverter=TypeConverters.toString)
@keyword_only
def __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10,
implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None,
ratingCol="rating", nonnegative=False, checkpointInterval=10,
intermediateStorageLevel="MEMORY_AND_DISK",
- finalStorageLevel="MEMORY_AND_DISK"):
+ finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan"):
"""
__init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, \
implicitPrefs=false, alpha=1.0, userCol="user", itemCol="item", seed=None, \
ratingCol="rating", nonnegative=false, checkpointInterval=10, \
intermediateStorageLevel="MEMORY_AND_DISK", \
- finalStorageLevel="MEMORY_AND_DISK")
+ finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan")
"""
super(ALS, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.recommendation.ALS", self.uid)
@@ -145,7 +151,7 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item",
ratingCol="rating", nonnegative=False, checkpointInterval=10,
intermediateStorageLevel="MEMORY_AND_DISK",
- finalStorageLevel="MEMORY_AND_DISK")
+ finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan")
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)
@@ -155,13 +161,13 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None,
ratingCol="rating", nonnegative=False, checkpointInterval=10,
intermediateStorageLevel="MEMORY_AND_DISK",
- finalStorageLevel="MEMORY_AND_DISK"):
+ finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan"):
"""
setParams(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, \
implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None, \
ratingCol="rating", nonnegative=False, checkpointInterval=10, \
intermediateStorageLevel="MEMORY_AND_DISK", \
- finalStorageLevel="MEMORY_AND_DISK")
+ finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan")
Sets params for ALS.
"""
kwargs = self.setParams._input_kwargs
@@ -332,6 +338,20 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
"""
return self.getOrDefault(self.finalStorageLevel)
+ @since("2.2.0")
+ def setColdStartStrategy(self, value):
+ """
+ Sets the value of :py:attr:`coldStartStrategy`.
+ """
+ return self._set(coldStartStrategy=value)
+
+ @since("2.2.0")
+ def getColdStartStrategy(self):
+ """
+ Gets the value of coldStartStrategy or its default value.
+ """
+ return self.getOrDefault(self.coldStartStrategy)
+
class ALSModel(JavaModel, JavaMLWritable, JavaMLReadable):
"""
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org