You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2017/01/24 01:25:02 UTC

spark git commit: [SPARK-17747][ML] WeightCol support non-double numeric datatypes

Repository: spark
Updated Branches:
  refs/heads/master e4974721f -> 49f5b0ae4


[SPARK-17747][ML] WeightCol support non-double numeric datatypes

## What changes were proposed in this pull request?

1, add test for `WeightCol` in `MLTestingUtils.checkNumericTypes`
2, move datatype cast to `Predict.fit`, and supply algos' `train()` with casted dataframe
## How was this patch tested?

local tests in spark-shell and unit tests

Author: Zheng RuiFeng <ru...@foxmail.com>

Closes #15314 from zhengruifeng/weightCol_support_int.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/49f5b0ae
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/49f5b0ae
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/49f5b0ae

Branch: refs/heads/master
Commit: 49f5b0ae4c31e4b7369104a14e562e1546aa7736
Parents: e497472
Author: Zheng RuiFeng <ru...@foxmail.com>
Authored: Mon Jan 23 17:24:53 2017 -0800
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Mon Jan 23 17:24:53 2017 -0800

----------------------------------------------------------------------
 .../scala/org/apache/spark/ml/Predictor.scala   | 32 ++++++++++--
 .../ml/regression/IsotonicRegression.scala      |  9 ++--
 .../org/apache/spark/ml/PredictorSuite.scala    | 26 ++++++----
 .../LogisticRegressionSuite.scala               |  2 +-
 .../ml/classification/NaiveBayesSuite.scala     |  6 +--
 .../GeneralizedLinearRegressionSuite.scala      |  2 +-
 .../ml/regression/IsotonicRegressionSuite.scala |  2 +-
 .../ml/regression/LinearRegressionSuite.scala   |  2 +-
 .../apache/spark/ml/util/MLTestingUtils.scala   | 52 +++++++++++++++-----
 9 files changed, 95 insertions(+), 38 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/49f5b0ae/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
index 4b43a3a..215f9d8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
@@ -40,7 +40,7 @@ private[ml] trait PredictorParams extends Params
    * @param schema input schema
    * @param fitting whether this is in fitting
    * @param featuresDataType  SQL DataType for FeaturesType.
-   *                          E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features.
+   *                          E.g., [[VectorUDT]] for vector features.
    * @return output schema
    */
   protected def validateAndTransformSchema(
@@ -51,6 +51,14 @@ private[ml] trait PredictorParams extends Params
     SchemaUtils.checkColumnType(schema, $(featuresCol), featuresDataType)
     if (fitting) {
       SchemaUtils.checkNumericType(schema, $(labelCol))
+
+      this match {
+        case p: HasWeightCol =>
+          if (isDefined(p.weightCol) && $(p.weightCol).nonEmpty) {
+            SchemaUtils.checkNumericType(schema, $(p.weightCol))
+          }
+        case _ =>
+      }
     }
     SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType)
   }
@@ -59,10 +67,12 @@ private[ml] trait PredictorParams extends Params
 /**
  * :: DeveloperApi ::
  * Abstraction for prediction problems (regression and classification). It accepts all NumericType
- * labels and will automatically cast it to DoubleType in `fit()`.
+ * labels and will automatically cast it to DoubleType in `fit()`. If this predictor supports
+ * weights, it accepts all NumericType weights, which will be automatically casted to DoubleType
+ * in `fit()`.
  *
  * @tparam FeaturesType  Type of features.
- *                       E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features.
+ *                       E.g., [[VectorUDT]] for vector features.
  * @tparam Learner  Specialization of this class.  If you subclass this type, use this type
  *                  parameter to specify the concrete type.
  * @tparam M  Specialization of [[PredictionModel]].  If you subclass this type, use this type
@@ -91,7 +101,19 @@ abstract class Predictor[
 
     // Cast LabelCol to DoubleType and keep the metadata.
     val labelMeta = dataset.schema($(labelCol)).metadata
-    val casted = dataset.withColumn($(labelCol), col($(labelCol)).cast(DoubleType), labelMeta)
+    val labelCasted = dataset.withColumn($(labelCol), col($(labelCol)).cast(DoubleType), labelMeta)
+
+    // Cast WeightCol to DoubleType and keep the metadata.
+    val casted = this match {
+      case p: HasWeightCol =>
+        if (isDefined(p.weightCol) && $(p.weightCol).nonEmpty) {
+          val weightMeta = dataset.schema($(p.weightCol)).metadata
+          labelCasted.withColumn($(p.weightCol), col($(p.weightCol)).cast(DoubleType), weightMeta)
+        } else {
+          labelCasted
+        }
+      case _ => labelCasted
+    }
 
     copyValues(train(casted).setParent(this))
   }
@@ -138,7 +160,7 @@ abstract class Predictor[
  * Abstraction for a model for prediction tasks (regression and classification).
  *
  * @tparam FeaturesType  Type of features.
- *                       E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features.
+ *                       E.g., [[VectorUDT]] for vector features.
  * @tparam M  Specialization of [[PredictionModel]].  If you subclass this type, use this type
  *            parameter to specify the concrete type for the corresponding model.
  */

http://git-wip-us.apache.org/repos/asf/spark/blob/49f5b0ae/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
index 1ed9d3c..90e77bc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
@@ -86,11 +86,8 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures
     } else {
       col($(featuresCol))
     }
-    val w = if (hasWeightCol) {
-      col($(weightCol))
-    } else {
-      lit(1.0)
-    }
+    val w = if (hasWeightCol) col($(weightCol)).cast(DoubleType) else lit(1.0)
+
     dataset.select(col($(labelCol)).cast(DoubleType), f, w).rdd.map {
       case Row(label: Double, feature: Double, weight: Double) =>
         (label, feature, weight)
@@ -109,7 +106,7 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures
     if (fitting) {
       SchemaUtils.checkNumericType(schema, $(labelCol))
       if (hasWeightCol) {
-        SchemaUtils.checkColumnType(schema, $(weightCol), DoubleType)
+        SchemaUtils.checkNumericType(schema, $(weightCol))
       } else {
         logInfo("The weight column is not defined. Treat all instance weights as 1.0.")
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/49f5b0ae/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala
index 03e0c53..ec45e32 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/PredictorSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.ml
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.linalg._
 import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.param.shared.HasWeightCol
 import org.apache.spark.ml.util._
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.sql.Dataset
@@ -30,24 +31,28 @@ class PredictorSuite extends SparkFunSuite with MLlibTestSparkContext {
 
   import PredictorSuite._
 
-  test("should support all NumericType labels and not support other types") {
+  test("should support all NumericType labels and weights, and not support other types") {
     val df = spark.createDataFrame(Seq(
-      (0, Vectors.dense(0, 2, 3)),
-      (1, Vectors.dense(0, 3, 9)),
-      (0, Vectors.dense(0, 2, 6))
-    )).toDF("label", "features")
+      (0, 1, Vectors.dense(0, 2, 3)),
+      (1, 2, Vectors.dense(0, 3, 9)),
+      (0, 3, Vectors.dense(0, 2, 6))
+    )).toDF("label", "weight", "features")
 
     val types =
       Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0))
 
-    val predictor = new MockPredictor()
+    val predictor = new MockPredictor().setWeightCol("weight")
 
     types.foreach { t =>
-      predictor.fit(df.select(col("label").cast(t), col("features")))
+      predictor.fit(df.select(col("label").cast(t), col("weight").cast(t), col("features")))
     }
 
     intercept[IllegalArgumentException] {
-      predictor.fit(df.select(col("label").cast(StringType), col("features")))
+      predictor.fit(df.select(col("label").cast(StringType), col("weight"), col("features")))
+    }
+
+    intercept[IllegalArgumentException] {
+      predictor.fit(df.select(col("label"), col("weight").cast(StringType), col("features")))
     }
   }
 }
@@ -55,12 +60,15 @@ class PredictorSuite extends SparkFunSuite with MLlibTestSparkContext {
 object PredictorSuite {
 
   class MockPredictor(override val uid: String)
-    extends Predictor[Vector, MockPredictor, MockPredictionModel] {
+    extends Predictor[Vector, MockPredictor, MockPredictionModel] with HasWeightCol {
 
     def this() = this(Identifiable.randomUID("mockpredictor"))
 
+    def setWeightCol(value: String): this.type = set(weightCol, value)
+
     override def train(dataset: Dataset[_]): MockPredictionModel = {
       require(dataset.schema("label").dataType == DoubleType)
+      require(dataset.schema("weight").dataType == DoubleType)
       new MockPredictionModel(uid)
     }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/49f5b0ae/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index c14dcbd..43547a4 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -2066,7 +2066,7 @@ class LogisticRegressionSuite
       checkModelData)
   }
 
-  test("should support all NumericType labels and not support other types") {
+  test("should support all NumericType labels and weights, and not support other types") {
     val lr = new LogisticRegression().setMaxIter(1)
     MLTestingUtils.checkNumericTypes[LogisticRegressionModel, LogisticRegression](
       lr, spark) { (expected, actual) =>

http://git-wip-us.apache.org/repos/asf/spark/blob/49f5b0ae/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
index 2a69ef1..37d7991 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
@@ -283,7 +283,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
     testEstimatorAndModelReadWrite(nb, dataset, NaiveBayesSuite.allParamSettings, checkModelData)
   }
 
-  test("should support all NumericType labels and not support other types") {
+  test("should support all NumericType labels and weights, and not support other types") {
     val nb = new NaiveBayes()
     MLTestingUtils.checkNumericTypes[NaiveBayesModel, NaiveBayes](
       nb, spark) { (expected, actual) =>
@@ -324,8 +324,8 @@ object NaiveBayesSuite {
     sample: Int = 10): Seq[LabeledPoint] = {
     val D = theta(0).length
     val rnd = new Random(seed)
-    val _pi = pi.map(math.pow(math.E, _))
-    val _theta = theta.map(row => row.map(math.pow(math.E, _)))
+    val _pi = pi.map(math.exp)
+    val _theta = theta.map(row => row.map(math.exp))
 
     for (i <- 0 until nPoints) yield {
       val y = calcLabel(rnd.nextDouble(), _pi)

http://git-wip-us.apache.org/repos/asf/spark/blob/49f5b0ae/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
index e3c2787..828b95e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
@@ -1086,7 +1086,7 @@ class GeneralizedLinearRegressionSuite
       GeneralizedLinearRegressionSuite.allParamSettings, checkModelData)
   }
 
-  test("should support all NumericType labels and not support other types") {
+  test("should support all NumericType labels and weights, and not support other types") {
     val glr = new GeneralizedLinearRegression().setMaxIter(1)
     MLTestingUtils.checkNumericTypes[
         GeneralizedLinearRegressionModel, GeneralizedLinearRegression](

http://git-wip-us.apache.org/repos/asf/spark/blob/49f5b0ae/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
index c2c7947..8cbb2ac 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
@@ -181,7 +181,7 @@ class IsotonicRegressionSuite
       checkModelData)
   }
 
-  test("should support all NumericType labels and not support other types") {
+  test("should support all NumericType labels and weights, and not support other types") {
     val ir = new IsotonicRegression()
     MLTestingUtils.checkNumericTypes[IsotonicRegressionModel, IsotonicRegression](
       ir, spark, isClassification = false) { (expected, actual) =>

http://git-wip-us.apache.org/repos/asf/spark/blob/49f5b0ae/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index e05d0c9..584a1b2 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -988,7 +988,7 @@ class LinearRegressionSuite
       checkModelData)
   }
 
-  test("should support all NumericType labels and not support other types") {
+  test("should support all NumericType labels and weights, and not support other types") {
     for (solver <- Seq("auto", "l-bfgs", "normal")) {
       val lr = new LinearRegression().setMaxIter(1).setSolver(solver)
       MLTestingUtils.checkNumericTypes[LinearRegressionModel, LinearRegression](

http://git-wip-us.apache.org/repos/asf/spark/blob/49f5b0ae/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
index d219c42..f1ed568 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
@@ -47,18 +47,44 @@ object MLTestingUtils extends SparkFunSuite {
     } else {
       genRegressionDFWithNumericLabelCol(spark)
     }
-    val expected = estimator.fit(dfs(DoubleType))
-    val actuals = dfs.keys.filter(_ != DoubleType).map(t => estimator.fit(dfs(t)))
+
+    val finalEstimator = estimator match {
+      case weighted: Estimator[M] with HasWeightCol =>
+        weighted.set(weighted.weightCol, "weight")
+        weighted
+      case _ => estimator
+    }
+
+    val expected = finalEstimator.fit(dfs(DoubleType))
+
+    val actuals = dfs.keys.filter(_ != DoubleType).map { t =>
+      finalEstimator.fit(dfs(t))
+    }
+
     actuals.foreach(actual => check(expected, actual))
 
     val dfWithStringLabels = spark.createDataFrame(Seq(
-      ("0", Vectors.dense(0, 2, 3), 0.0)
-    )).toDF("label", "features", "censor")
+      ("0", 1, Vectors.dense(0, 2, 3), 0.0)
+    )).toDF("label", "weight", "features", "censor")
     val thrown = intercept[IllegalArgumentException] {
       estimator.fit(dfWithStringLabels)
     }
     assert(thrown.getMessage.contains(
       "Column label must be of type NumericType but was actually of type StringType"))
+
+    estimator match {
+      case weighted: Estimator[M] with HasWeightCol =>
+        val dfWithStringWeights = spark.createDataFrame(Seq(
+          (0, "1", Vectors.dense(0, 2, 3), 0.0)
+        )).toDF("label", "weight", "features", "censor")
+        weighted.set(weighted.weightCol, "weight")
+        val thrown = intercept[IllegalArgumentException] {
+          weighted.fit(dfWithStringWeights)
+        }
+        assert(thrown.getMessage.contains(
+          "Column weight must be of type NumericType but was actually of type StringType"))
+      case _ =>
+    }
   }
 
   def checkNumericTypesALS(
@@ -75,7 +101,7 @@ object MLTestingUtils extends SparkFunSuite {
     actuals.foreach { case (t, actual) => check2(expected, actual, dfs(t)) }
 
     val baseDF = dfs(baseType)
-    val others = baseDF.columns.toSeq.diff(Seq(column)).map(col(_))
+    val others = baseDF.columns.toSeq.diff(Seq(column)).map(col)
     val cols = Seq(col(column).cast(StringType)) ++ others
     val strDF = baseDF.select(cols: _*)
     val thrown = intercept[IllegalArgumentException] {
@@ -104,7 +130,8 @@ object MLTestingUtils extends SparkFunSuite {
   def genClassifDFWithNumericLabelCol(
       spark: SparkSession,
       labelColName: String = "label",
-      featuresColName: String = "features"): Map[NumericType, DataFrame] = {
+      featuresColName: String = "features",
+      weightColName: String = "weight"): Map[NumericType, DataFrame] = {
     val df = spark.createDataFrame(Seq(
       (0, Vectors.dense(0, 2, 3)),
       (1, Vectors.dense(0, 3, 1)),
@@ -118,12 +145,14 @@ object MLTestingUtils extends SparkFunSuite {
     types.map { t =>
         val castDF = df.select(col(labelColName).cast(t), col(featuresColName))
         t -> TreeTests.setMetadata(castDF, 2, labelColName, featuresColName)
+          .withColumn(weightColName, round(rand(seed = 42)).cast(t))
       }.toMap
   }
 
   def genRegressionDFWithNumericLabelCol(
       spark: SparkSession,
       labelColName: String = "label",
+      weightColName: String = "weight",
       featuresColName: String = "features",
       censorColName: String = "censor"): Map[NumericType, DataFrame] = {
     val df = spark.createDataFrame(Seq(
@@ -137,10 +166,11 @@ object MLTestingUtils extends SparkFunSuite {
     val types =
       Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0))
     types.map { t =>
-        val castDF = df.select(col(labelColName).cast(t), col(featuresColName))
-        t -> TreeTests.setMetadata(castDF, 0, labelColName, featuresColName)
-          .withColumn(censorColName, lit(0.0))
-      }.toMap
+      val castDF = df.select(col(labelColName).cast(t), col(featuresColName))
+      t -> TreeTests.setMetadata(castDF, 0, labelColName, featuresColName)
+        .withColumn(censorColName, lit(0.0))
+        .withColumn(weightColName, round(rand(seed = 42)).cast(t))
+    }.toMap
   }
 
   def genRatingsDFWithNumericCols(
@@ -154,7 +184,7 @@ object MLTestingUtils extends SparkFunSuite {
       (4, 50, 5.0)
     )).toDF("user", "item", "rating")
 
-    val others = df.columns.toSeq.diff(Seq(column)).map(col(_))
+    val others = df.columns.toSeq.diff(Seq(column)).map(col)
     val types: Seq[NumericType] =
       Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0))
     types.map { t =>


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org