You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2018/04/16 16:27:35 UTC

spark git commit: [SPARK-9312][ML] Add RawPrediction, numClasses, and numFeatures for OneVsRestModel

Repository: spark
Updated Branches:
  refs/heads/master 083cf2235 -> 5003736ad


[SPARK-9312][ML] Add RawPrediction, numClasses, and numFeatures for OneVsRestModel

add RawPrediction as output column
add numClasses and numFeatures to OneVsRestModel

## What changes were proposed in this pull request?

- Add two val numClasses and numFeatures in OneVsRestModel so that we can inherit from Classifier in the future

- Add rawPrediction output column in transform, the prediction label in calculated by the rawPrediciton like raw2prediction

## How was this patch tested?

(Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests)
(If this patch involves UI changes, please attach a screenshot; otherwise, remove this)
Please review http://spark.apache.org/contributing.html before opening a pull request.

Author: Lu WANG <lu...@databricks.com>

Closes #21044 from ludatabricks/SPARK-9312.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5003736a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5003736a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5003736a

Branch: refs/heads/master
Commit: 5003736ad60c3231bb18264c9561646c08379170
Parents: 083cf22
Author: Lu WANG <lu...@databricks.com>
Authored: Mon Apr 16 11:27:30 2018 -0500
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Mon Apr 16 11:27:30 2018 -0500

----------------------------------------------------------------------
 .../spark/ml/classification/OneVsRest.scala     | 56 ++++++++++++++++----
 .../ml/classification/OneVsRestSuite.scala      |  7 ++-
 2 files changed, 51 insertions(+), 12 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/5003736a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
index f04fde2..5348d88 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
@@ -32,7 +32,7 @@ import org.apache.spark.SparkContext
 import org.apache.spark.annotation.Since
 import org.apache.spark.ml._
 import org.apache.spark.ml.attribute._
-import org.apache.spark.ml.linalg.Vector
+import org.apache.spark.ml.linalg.{Vector, Vectors}
 import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params}
 import org.apache.spark.ml.param.shared.{HasParallelism, HasWeightCol}
 import org.apache.spark.ml.util._
@@ -55,7 +55,7 @@ private[ml] trait ClassifierTypeTrait {
 /**
  * Params for [[OneVsRest]].
  */
-private[ml] trait OneVsRestParams extends PredictorParams
+private[ml] trait OneVsRestParams extends ClassifierParams
   with ClassifierTypeTrait with HasWeightCol {
 
   /**
@@ -138,6 +138,14 @@ final class OneVsRestModel private[ml] (
     @Since("1.4.0") val models: Array[_ <: ClassificationModel[_, _]])
   extends Model[OneVsRestModel] with OneVsRestParams with MLWritable {
 
+  require(models.nonEmpty, "OneVsRestModel requires at least one model for one class")
+
+  @Since("2.4.0")
+  val numClasses: Int = models.length
+
+  @Since("2.4.0")
+  val numFeatures: Int = models.head.numFeatures
+
   /** @group setParam */
   @Since("2.1.0")
   def setFeaturesCol(value: String): this.type = set(featuresCol, value)
@@ -146,6 +154,10 @@ final class OneVsRestModel private[ml] (
   @Since("2.1.0")
   def setPredictionCol(value: String): this.type = set(predictionCol, value)
 
+  /** @group setParam */
+  @Since("2.4.0")
+  def setRawPredictionCol(value: String): this.type = set(rawPredictionCol, value)
+
   @Since("1.4.0")
   override def transformSchema(schema: StructType): StructType = {
     validateAndTransformSchema(schema, fitting = false, getClassifier.featuresDataType)
@@ -181,6 +193,7 @@ final class OneVsRestModel private[ml] (
         val updateUDF = udf { (predictions: Map[Int, Double], prediction: Vector) =>
           predictions + ((index, prediction(1)))
         }
+
         model.setFeaturesCol($(featuresCol))
         val transformedDataset = model.transform(df).select(columns: _*)
         val updatedDataset = transformedDataset
@@ -195,15 +208,34 @@ final class OneVsRestModel private[ml] (
       newDataset.unpersist()
     }
 
-    // output the index of the classifier with highest confidence as prediction
-    val labelUDF = udf { (predictions: Map[Int, Double]) =>
-      predictions.maxBy(_._2)._1.toDouble
-    }
+    if (getRawPredictionCol != "") {
+      val numClass = models.length
 
-    // output label and label metadata as prediction
-    aggregatedDataset
-      .withColumn($(predictionCol), labelUDF(col(accColName)), labelMetadata)
-      .drop(accColName)
+      // output the RawPrediction as vector
+      val rawPredictionUDF = udf { (predictions: Map[Int, Double]) =>
+        val predArray = Array.fill[Double](numClass)(0.0)
+        predictions.foreach { case (idx, value) => predArray(idx) = value }
+        Vectors.dense(predArray)
+      }
+
+      // output the index of the classifier with highest confidence as prediction
+      val labelUDF = udf { (rawPredictions: Vector) => rawPredictions.argmax.toDouble }
+
+      // output confidence as raw prediction, label and label metadata as prediction
+      aggregatedDataset
+        .withColumn(getRawPredictionCol, rawPredictionUDF(col(accColName)))
+        .withColumn(getPredictionCol, labelUDF(col(getRawPredictionCol)), labelMetadata)
+        .drop(accColName)
+    } else {
+      // output the index of the classifier with highest confidence as prediction
+      val labelUDF = udf { (predictions: Map[Int, Double]) =>
+        predictions.maxBy(_._2)._1.toDouble
+      }
+      // output label and label metadata as prediction
+      aggregatedDataset
+        .withColumn(getPredictionCol, labelUDF(col(accColName)), labelMetadata)
+        .drop(accColName)
+    }
   }
 
   @Since("1.4.1")
@@ -297,6 +329,10 @@ final class OneVsRest @Since("1.4.0") (
   @Since("1.5.0")
   def setPredictionCol(value: String): this.type = set(predictionCol, value)
 
+  /** @group setParam */
+  @Since("2.4.0")
+  def setRawPredictionCol(value: String): this.type = set(rawPredictionCol, value)
+
   /**
    * The implementation of parallel one vs. rest runs the classification for
    * each class in a separate threads.

http://git-wip-us.apache.org/repos/asf/spark/blob/5003736a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
index 11e8836..2c3417c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
@@ -72,11 +72,12 @@ class OneVsRestSuite extends MLTest with DefaultReadWriteTest {
       .setClassifier(new LogisticRegression)
     assert(ova.getLabelCol === "label")
     assert(ova.getPredictionCol === "prediction")
+    assert(ova.getRawPredictionCol === "rawPrediction")
     val ovaModel = ova.fit(dataset)
 
     MLTestingUtils.checkCopyAndUids(ova, ovaModel)
 
-    assert(ovaModel.models.length === numClasses)
+    assert(ovaModel.numClasses === numClasses)
     val transformedDataset = ovaModel.transform(dataset)
 
     // check for label metadata in prediction col
@@ -179,6 +180,7 @@ class OneVsRestSuite extends MLTest with DefaultReadWriteTest {
     val dataset2 = dataset.select(col("label").as("y"), col("features").as("fea"))
     ovaModel.setFeaturesCol("fea")
     ovaModel.setPredictionCol("pred")
+    ovaModel.setRawPredictionCol("")
     val transformedDataset = ovaModel.transform(dataset2)
     val outputFields = transformedDataset.schema.fieldNames.toSet
     assert(outputFields === Set("y", "fea", "pred"))
@@ -190,7 +192,8 @@ class OneVsRestSuite extends MLTest with DefaultReadWriteTest {
     val ovr = new OneVsRest()
       .setClassifier(logReg)
     val output = ovr.fit(dataset).transform(dataset)
-    assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction"))
+    assert(output.schema.fieldNames.toSet
+      === Set("label", "features", "prediction", "rawPrediction"))
   }
 
   test("SPARK-21306: OneVsRest should support setWeightCol") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org