You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2018/04/16 16:27:35 UTC
spark git commit: [SPARK-9312][ML] Add RawPrediction, numClasses,
and numFeatures for OneVsRestModel
Repository: spark
Updated Branches:
refs/heads/master 083cf2235 -> 5003736ad
[SPARK-9312][ML] Add RawPrediction, numClasses, and numFeatures for OneVsRestModel
add RawPrediction as output column
add numClasses and numFeatures to OneVsRestModel
## What changes were proposed in this pull request?
- Add two val numClasses and numFeatures in OneVsRestModel so that we can inherit from Classifier in the future
- Add rawPrediction output column in transform, the prediction label in calculated by the rawPrediciton like raw2prediction
## How was this patch tested?
(Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests)
(If this patch involves UI changes, please attach a screenshot; otherwise, remove this)
Please review http://spark.apache.org/contributing.html before opening a pull request.
Author: Lu WANG <lu...@databricks.com>
Closes #21044 from ludatabricks/SPARK-9312.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5003736a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5003736a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5003736a
Branch: refs/heads/master
Commit: 5003736ad60c3231bb18264c9561646c08379170
Parents: 083cf22
Author: Lu WANG <lu...@databricks.com>
Authored: Mon Apr 16 11:27:30 2018 -0500
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Mon Apr 16 11:27:30 2018 -0500
----------------------------------------------------------------------
.../spark/ml/classification/OneVsRest.scala | 56 ++++++++++++++++----
.../ml/classification/OneVsRestSuite.scala | 7 ++-
2 files changed, 51 insertions(+), 12 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/5003736a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
index f04fde2..5348d88 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
@@ -32,7 +32,7 @@ import org.apache.spark.SparkContext
import org.apache.spark.annotation.Since
import org.apache.spark.ml._
import org.apache.spark.ml.attribute._
-import org.apache.spark.ml.linalg.Vector
+import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params}
import org.apache.spark.ml.param.shared.{HasParallelism, HasWeightCol}
import org.apache.spark.ml.util._
@@ -55,7 +55,7 @@ private[ml] trait ClassifierTypeTrait {
/**
* Params for [[OneVsRest]].
*/
-private[ml] trait OneVsRestParams extends PredictorParams
+private[ml] trait OneVsRestParams extends ClassifierParams
with ClassifierTypeTrait with HasWeightCol {
/**
@@ -138,6 +138,14 @@ final class OneVsRestModel private[ml] (
@Since("1.4.0") val models: Array[_ <: ClassificationModel[_, _]])
extends Model[OneVsRestModel] with OneVsRestParams with MLWritable {
+ require(models.nonEmpty, "OneVsRestModel requires at least one model for one class")
+
+ @Since("2.4.0")
+ val numClasses: Int = models.length
+
+ @Since("2.4.0")
+ val numFeatures: Int = models.head.numFeatures
+
/** @group setParam */
@Since("2.1.0")
def setFeaturesCol(value: String): this.type = set(featuresCol, value)
@@ -146,6 +154,10 @@ final class OneVsRestModel private[ml] (
@Since("2.1.0")
def setPredictionCol(value: String): this.type = set(predictionCol, value)
+ /** @group setParam */
+ @Since("2.4.0")
+ def setRawPredictionCol(value: String): this.type = set(rawPredictionCol, value)
+
@Since("1.4.0")
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema, fitting = false, getClassifier.featuresDataType)
@@ -181,6 +193,7 @@ final class OneVsRestModel private[ml] (
val updateUDF = udf { (predictions: Map[Int, Double], prediction: Vector) =>
predictions + ((index, prediction(1)))
}
+
model.setFeaturesCol($(featuresCol))
val transformedDataset = model.transform(df).select(columns: _*)
val updatedDataset = transformedDataset
@@ -195,15 +208,34 @@ final class OneVsRestModel private[ml] (
newDataset.unpersist()
}
- // output the index of the classifier with highest confidence as prediction
- val labelUDF = udf { (predictions: Map[Int, Double]) =>
- predictions.maxBy(_._2)._1.toDouble
- }
+ if (getRawPredictionCol != "") {
+ val numClass = models.length
- // output label and label metadata as prediction
- aggregatedDataset
- .withColumn($(predictionCol), labelUDF(col(accColName)), labelMetadata)
- .drop(accColName)
+ // output the RawPrediction as vector
+ val rawPredictionUDF = udf { (predictions: Map[Int, Double]) =>
+ val predArray = Array.fill[Double](numClass)(0.0)
+ predictions.foreach { case (idx, value) => predArray(idx) = value }
+ Vectors.dense(predArray)
+ }
+
+ // output the index of the classifier with highest confidence as prediction
+ val labelUDF = udf { (rawPredictions: Vector) => rawPredictions.argmax.toDouble }
+
+ // output confidence as raw prediction, label and label metadata as prediction
+ aggregatedDataset
+ .withColumn(getRawPredictionCol, rawPredictionUDF(col(accColName)))
+ .withColumn(getPredictionCol, labelUDF(col(getRawPredictionCol)), labelMetadata)
+ .drop(accColName)
+ } else {
+ // output the index of the classifier with highest confidence as prediction
+ val labelUDF = udf { (predictions: Map[Int, Double]) =>
+ predictions.maxBy(_._2)._1.toDouble
+ }
+ // output label and label metadata as prediction
+ aggregatedDataset
+ .withColumn(getPredictionCol, labelUDF(col(accColName)), labelMetadata)
+ .drop(accColName)
+ }
}
@Since("1.4.1")
@@ -297,6 +329,10 @@ final class OneVsRest @Since("1.4.0") (
@Since("1.5.0")
def setPredictionCol(value: String): this.type = set(predictionCol, value)
+ /** @group setParam */
+ @Since("2.4.0")
+ def setRawPredictionCol(value: String): this.type = set(rawPredictionCol, value)
+
/**
* The implementation of parallel one vs. rest runs the classification for
* each class in a separate threads.
http://git-wip-us.apache.org/repos/asf/spark/blob/5003736a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
index 11e8836..2c3417c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
@@ -72,11 +72,12 @@ class OneVsRestSuite extends MLTest with DefaultReadWriteTest {
.setClassifier(new LogisticRegression)
assert(ova.getLabelCol === "label")
assert(ova.getPredictionCol === "prediction")
+ assert(ova.getRawPredictionCol === "rawPrediction")
val ovaModel = ova.fit(dataset)
MLTestingUtils.checkCopyAndUids(ova, ovaModel)
- assert(ovaModel.models.length === numClasses)
+ assert(ovaModel.numClasses === numClasses)
val transformedDataset = ovaModel.transform(dataset)
// check for label metadata in prediction col
@@ -179,6 +180,7 @@ class OneVsRestSuite extends MLTest with DefaultReadWriteTest {
val dataset2 = dataset.select(col("label").as("y"), col("features").as("fea"))
ovaModel.setFeaturesCol("fea")
ovaModel.setPredictionCol("pred")
+ ovaModel.setRawPredictionCol("")
val transformedDataset = ovaModel.transform(dataset2)
val outputFields = transformedDataset.schema.fieldNames.toSet
assert(outputFields === Set("y", "fea", "pred"))
@@ -190,7 +192,8 @@ class OneVsRestSuite extends MLTest with DefaultReadWriteTest {
val ovr = new OneVsRest()
.setClassifier(logReg)
val output = ovr.fit(dataset).transform(dataset)
- assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction"))
+ assert(output.schema.fieldNames.toSet
+ === Set("label", "features", "prediction", "rawPrediction"))
}
test("SPARK-21306: OneVsRest should support setWeightCol") {
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org