You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2018/02/25 15:30:05 UTC
spark git commit: [SPARK-22886][ML][TESTS] ML test for structured streaming: ml.recomme…
Repository: spark
Updated Branches:
refs/heads/master 1a198ce8f -> 3ca9a2c56
[SPARK-22886][ML][TESTS] ML test for structured streaming: ml.recomme…
## What changes were proposed in this pull request?
Converting spark.ml.recommendation tests to also check code with structured streaming, using the ML testing infrastructure implemented in SPARK-22882.
## How was this patch tested?
Automated: Pass the Jenkins.
Author: Gabor Somogyi <ga...@gmail.com>
Closes #20362 from gaborgsomogyi/SPARK-22886.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3ca9a2c5
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3ca9a2c5
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3ca9a2c5
Branch: refs/heads/master
Commit: 3ca9a2c56513444d7b233088b020d2d43fa6b77a
Parents: 1a198ce
Author: Gabor Somogyi <ga...@gmail.com>
Authored: Sun Feb 25 09:29:59 2018 -0600
Committer: Sean Owen <so...@cloudera.com>
Committed: Sun Feb 25 09:29:59 2018 -0600
----------------------------------------------------------------------
.../spark/ml/recommendation/ALSSuite.scala | 213 +++++++++++++------
.../apache/spark/ml/util/MLTestingUtils.scala | 44 ----
2 files changed, 143 insertions(+), 114 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/3ca9a2c5/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
index addcd21..e3dfe2f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
@@ -22,8 +22,7 @@ import java.util.Random
import scala.collection.JavaConverters._
import scala.collection.mutable
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.WrappedArray
+import scala.collection.mutable.{ArrayBuffer, WrappedArray}
import scala.language.existentials
import com.github.fommil.netlib.BLAS.{getInstance => blas}
@@ -35,21 +34,20 @@ import org.apache.spark._
import org.apache.spark.internal.Logging
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.recommendation.ALS._
-import org.apache.spark.ml.recommendation.ALS.Rating
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
-import org.apache.spark.mllib.recommendation.MatrixFactorizationModelSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted}
-import org.apache.spark.sql.{DataFrame, Row, SparkSession}
-import org.apache.spark.sql.functions.lit
+import org.apache.spark.sql.{DataFrame, Encoder, Row, SparkSession}
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.functions.{col, lit}
+import org.apache.spark.sql.streaming.StreamingQueryException
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
-class ALSSuite
- extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest with Logging {
+class ALSSuite extends MLTest with DefaultReadWriteTest with Logging {
override def beforeAll(): Unit = {
super.beforeAll()
@@ -413,34 +411,36 @@ class ALSSuite
.setSeed(0)
val alpha = als.getAlpha
val model = als.fit(training.toDF())
- val predictions = model.transform(test.toDF()).select("rating", "prediction").rdd.map {
- case Row(rating: Float, prediction: Float) =>
- (rating.toDouble, prediction.toDouble)
+ testTransformerByGlobalCheckFunc[Rating[Int]](test.toDF(), model, "rating", "prediction") {
+ case rows: Seq[Row] =>
+ val predictions = rows.map(row => (row.getFloat(0).toDouble, row.getFloat(1).toDouble))
+
+ val rmse =
+ if (implicitPrefs) {
+ // TODO: Use a better (rank-based?) evaluation metric for implicit feedback.
+ // We limit the ratings and the predictions to interval [0, 1] and compute the
+ // weighted RMSE with the confidence scores as weights.
+ val (totalWeight, weightedSumSq) = predictions.map { case (rating, prediction) =>
+ val confidence = 1.0 + alpha * math.abs(rating)
+ val rating01 = math.max(math.min(rating, 1.0), 0.0)
+ val prediction01 = math.max(math.min(prediction, 1.0), 0.0)
+ val err = prediction01 - rating01
+ (confidence, confidence * err * err)
+ }.reduce[(Double, Double)] { case ((c0, e0), (c1, e1)) =>
+ (c0 + c1, e0 + e1)
+ }
+ math.sqrt(weightedSumSq / totalWeight)
+ } else {
+ val errorSquares = predictions.map { case (rating, prediction) =>
+ val err = rating - prediction
+ err * err
+ }
+ val mse = errorSquares.sum / errorSquares.length
+ math.sqrt(mse)
+ }
+ logInfo(s"Test RMSE is $rmse.")
+ assert(rmse < targetRMSE)
}
- val rmse =
- if (implicitPrefs) {
- // TODO: Use a better (rank-based?) evaluation metric for implicit feedback.
- // We limit the ratings and the predictions to interval [0, 1] and compute the weighted RMSE
- // with the confidence scores as weights.
- val (totalWeight, weightedSumSq) = predictions.map { case (rating, prediction) =>
- val confidence = 1.0 + alpha * math.abs(rating)
- val rating01 = math.max(math.min(rating, 1.0), 0.0)
- val prediction01 = math.max(math.min(prediction, 1.0), 0.0)
- val err = prediction01 - rating01
- (confidence, confidence * err * err)
- }.reduce { case ((c0, e0), (c1, e1)) =>
- (c0 + c1, e0 + e1)
- }
- math.sqrt(weightedSumSq / totalWeight)
- } else {
- val mse = predictions.map { case (rating, prediction) =>
- val err = rating - prediction
- err * err
- }.mean()
- math.sqrt(mse)
- }
- logInfo(s"Test RMSE is $rmse.")
- assert(rmse < targetRMSE)
MLTestingUtils.checkCopyAndUids(als, model)
}
@@ -586,6 +586,68 @@ class ALSSuite
allModelParamSettings, checkModelData)
}
+ private def checkNumericTypesALS(
+ estimator: ALS,
+ spark: SparkSession,
+ column: String,
+ baseType: NumericType)
+ (check: (ALSModel, ALSModel) => Unit)
+ (check2: (ALSModel, ALSModel, DataFrame, Encoder[_]) => Unit): Unit = {
+ val dfs = genRatingsDFWithNumericCols(spark, column)
+ val df = dfs.find {
+ case (numericTypeWithEncoder, _) => numericTypeWithEncoder.numericType == baseType
+ } match {
+ case Some((_, df)) => df
+ }
+ val expected = estimator.fit(df)
+ val actuals = dfs.filter(_ != baseType).map(t => (t, estimator.fit(t._2)))
+ actuals.foreach { case (_, actual) => check(expected, actual) }
+ actuals.foreach { case (t, actual) => check2(expected, actual, t._2, t._1.encoder) }
+
+ val baseDF = dfs.find(_._1.numericType == baseType).get._2
+ val others = baseDF.columns.toSeq.diff(Seq(column)).map(col)
+ val cols = Seq(col(column).cast(StringType)) ++ others
+ val strDF = baseDF.select(cols: _*)
+ val thrown = intercept[IllegalArgumentException] {
+ estimator.fit(strDF)
+ }
+ assert(thrown.getMessage.contains(
+ s"$column must be of type NumericType but was actually of type StringType"))
+ }
+
+ private class NumericTypeWithEncoder[A](val numericType: NumericType)
+ (implicit val encoder: Encoder[(A, Int, Double)])
+
+ private def genRatingsDFWithNumericCols(
+ spark: SparkSession,
+ column: String) = {
+
+ import testImplicits._
+
+ val df = spark.createDataFrame(Seq(
+ (0, 10, 1.0),
+ (1, 20, 2.0),
+ (2, 30, 3.0),
+ (3, 40, 4.0),
+ (4, 50, 5.0)
+ )).toDF("user", "item", "rating")
+
+ val others = df.columns.toSeq.diff(Seq(column)).map(col)
+ val types =
+ Seq(new NumericTypeWithEncoder[Short](ShortType),
+ new NumericTypeWithEncoder[Long](LongType),
+ new NumericTypeWithEncoder[Int](IntegerType),
+ new NumericTypeWithEncoder[Float](FloatType),
+ new NumericTypeWithEncoder[Byte](ByteType),
+ new NumericTypeWithEncoder[Double](DoubleType),
+ new NumericTypeWithEncoder[Decimal](DecimalType(10, 0))(ExpressionEncoder())
+ )
+ types.map { t =>
+ val cols = Seq(col(column).cast(t.numericType)) ++ others
+ t -> df.select(cols: _*)
+ }
+ }
+
test("input type validation") {
val spark = this.spark
import spark.implicits._
@@ -595,12 +657,16 @@ class ALSSuite
val als = new ALS().setMaxIter(1).setRank(1)
Seq(("user", IntegerType), ("item", IntegerType), ("rating", FloatType)).foreach {
case (colName, sqlType) =>
- MLTestingUtils.checkNumericTypesALS(als, spark, colName, sqlType) {
+ checkNumericTypesALS(als, spark, colName, sqlType) {
(ex, act) =>
- ex.userFactors.first().getSeq[Float](1) === act.userFactors.first.getSeq[Float](1)
- } { (ex, act, _) =>
- ex.transform(_: DataFrame).select("prediction").first.getDouble(0) ~==
- act.transform(_: DataFrame).select("prediction").first.getDouble(0) absTol 1e-6
+ ex.userFactors.first().getSeq[Float](1) === act.userFactors.first().getSeq[Float](1)
+ } { (ex, act, df, enc) =>
+ val expected = ex.transform(df).selectExpr("prediction")
+ .first().getFloat(0)
+ testTransformerByGlobalCheckFunc(df, act, "prediction") {
+ case rows: Seq[Row] =>
+ expected ~== rows.head.getFloat(0) absTol 1e-6
+ }(enc)
}
}
// check user/item ids falling outside of Int range
@@ -628,18 +694,22 @@ class ALSSuite
}
withClue("transform should fail when ids exceed integer range. ") {
val model = als.fit(df)
- assert(intercept[SparkException] {
- model.transform(df.select(df("user_big").as("user"), df("item"))).first
- }.getMessage.contains(msg))
- assert(intercept[SparkException] {
- model.transform(df.select(df("user_small").as("user"), df("item"))).first
- }.getMessage.contains(msg))
- assert(intercept[SparkException] {
- model.transform(df.select(df("item_big").as("item"), df("user"))).first
- }.getMessage.contains(msg))
- assert(intercept[SparkException] {
- model.transform(df.select(df("item_small").as("item"), df("user"))).first
- }.getMessage.contains(msg))
+ def testTransformIdExceedsIntRange[A : Encoder](dataFrame: DataFrame): Unit = {
+ assert(intercept[SparkException] {
+ model.transform(dataFrame).first
+ }.getMessage.contains(msg))
+ assert(intercept[StreamingQueryException] {
+ testTransformer[A](dataFrame, model, "prediction") { _ => }
+ }.getMessage.contains(msg))
+ }
+ testTransformIdExceedsIntRange[(Long, Int)](df.select(df("user_big").as("user"),
+ df("item")))
+ testTransformIdExceedsIntRange[(Double, Int)](df.select(df("user_small").as("user"),
+ df("item")))
+ testTransformIdExceedsIntRange[(Long, Int)](df.select(df("item_big").as("item"),
+ df("user")))
+ testTransformIdExceedsIntRange[(Double, Int)](df.select(df("item_small").as("item"),
+ df("user")))
}
}
@@ -662,28 +732,31 @@ class ALSSuite
val knownItem = data.select(max("item")).as[Int].first()
val unknownItem = knownItem + 20
val test = Seq(
- (unknownUser, unknownItem),
- (knownUser, unknownItem),
- (unknownUser, knownItem),
- (knownUser, knownItem)
- ).toDF("user", "item")
+ (unknownUser, unknownItem, true),
+ (knownUser, unknownItem, true),
+ (unknownUser, knownItem, true),
+ (knownUser, knownItem, false)
+ ).toDF("user", "item", "expectedIsNaN")
val als = new ALS().setMaxIter(1).setRank(1)
// default is 'nan'
val defaultModel = als.fit(data)
- val defaultPredictions = defaultModel.transform(test).select("prediction").as[Float].collect()
- assert(defaultPredictions.length == 4)
- assert(defaultPredictions.slice(0, 3).forall(_.isNaN))
- assert(!defaultPredictions.last.isNaN)
+ testTransformer[(Int, Int, Boolean)](test, defaultModel, "expectedIsNaN", "prediction") {
+ case Row(expectedIsNaN: Boolean, prediction: Float) =>
+ assert(prediction.isNaN === expectedIsNaN)
+ }
// check 'drop' strategy should filter out rows with unknown users/items
- val dropPredictions = defaultModel
- .setColdStartStrategy("drop")
- .transform(test)
- .select("prediction").as[Float].collect()
- assert(dropPredictions.length == 1)
- assert(!dropPredictions.head.isNaN)
- assert(dropPredictions.head ~== defaultPredictions.last relTol 1e-14)
+ val defaultPrediction = defaultModel.transform(test).select("prediction")
+ .as[Float].filter(!_.isNaN).first()
+ testTransformerByGlobalCheckFunc[(Int, Int, Boolean)](test,
+ defaultModel.setColdStartStrategy("drop"), "prediction") {
+ case rows: Seq[Row] =>
+ val dropPredictions = rows.map(_.getFloat(0))
+ assert(dropPredictions.length == 1)
+ assert(!dropPredictions.head.isNaN)
+ assert(dropPredictions.head ~== defaultPrediction relTol 1e-14)
+ }
}
test("case insensitive cold start param value") {
@@ -693,7 +766,7 @@ class ALSSuite
val data = ratings.toDF
val model = new ALS().fit(data)
Seq("nan", "NaN", "Nan", "drop", "DROP", "Drop").foreach { s =>
- model.setColdStartStrategy(s).transform(data)
+ testTransformer[Rating[Int]](data, model.setColdStartStrategy(s), "prediction") { _ => }
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/3ca9a2c5/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
index aef81c8..c328d81 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
@@ -91,30 +91,6 @@ object MLTestingUtils extends SparkFunSuite {
}
}
- def checkNumericTypesALS(
- estimator: ALS,
- spark: SparkSession,
- column: String,
- baseType: NumericType)
- (check: (ALSModel, ALSModel) => Unit)
- (check2: (ALSModel, ALSModel, DataFrame) => Unit): Unit = {
- val dfs = genRatingsDFWithNumericCols(spark, column)
- val expected = estimator.fit(dfs(baseType))
- val actuals = dfs.keys.filter(_ != baseType).map(t => (t, estimator.fit(dfs(t))))
- actuals.foreach { case (_, actual) => check(expected, actual) }
- actuals.foreach { case (t, actual) => check2(expected, actual, dfs(t)) }
-
- val baseDF = dfs(baseType)
- val others = baseDF.columns.toSeq.diff(Seq(column)).map(col)
- val cols = Seq(col(column).cast(StringType)) ++ others
- val strDF = baseDF.select(cols: _*)
- val thrown = intercept[IllegalArgumentException] {
- estimator.fit(strDF)
- }
- assert(thrown.getMessage.contains(
- s"$column must be of type NumericType but was actually of type StringType"))
- }
-
def checkNumericTypes[T <: Evaluator](evaluator: T, spark: SparkSession): Unit = {
val dfs = genEvaluatorDFWithNumericLabelCol(spark, "label", "prediction")
val expected = evaluator.evaluate(dfs(DoubleType))
@@ -176,26 +152,6 @@ object MLTestingUtils extends SparkFunSuite {
}.toMap
}
- def genRatingsDFWithNumericCols(
- spark: SparkSession,
- column: String): Map[NumericType, DataFrame] = {
- val df = spark.createDataFrame(Seq(
- (0, 10, 1.0),
- (1, 20, 2.0),
- (2, 30, 3.0),
- (3, 40, 4.0),
- (4, 50, 5.0)
- )).toDF("user", "item", "rating")
-
- val others = df.columns.toSeq.diff(Seq(column)).map(col)
- val types: Seq[NumericType] =
- Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0))
- types.map { t =>
- val cols = Seq(col(column).cast(t)) ++ others
- t -> df.select(cols: _*)
- }.toMap
- }
-
def genEvaluatorDFWithNumericLabelCol(
spark: SparkSession,
labelColName: String = "label",
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org