You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by yl...@apache.org on 2016/12/05 08:33:06 UTC

spark git commit: [SPARK-18625][ML] OneVsRestModel should support setFeaturesCol and setPredictionCol

Repository: spark
Updated Branches:
  refs/heads/master e9730b707 -> bdfe7f674


[SPARK-18625][ML] OneVsRestModel should support setFeaturesCol and setPredictionCol

## What changes were proposed in this pull request?
add `setFeaturesCol` and `setPredictionCol` for `OneVsRestModel`

## How was this patch tested?
added tests

Author: Zheng RuiFeng <ru...@foxmail.com>

Closes #16059 from zhengruifeng/ovrm_setCol.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/bdfe7f67
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/bdfe7f67
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/bdfe7f67

Branch: refs/heads/master
Commit: bdfe7f67468ecfd9927a1fec60d6605dd05ebe3f
Parents: e9730b7
Author: Zheng RuiFeng <ru...@foxmail.com>
Authored: Mon Dec 5 00:32:58 2016 -0800
Committer: Yanbo Liang <yb...@gmail.com>
Committed: Mon Dec 5 00:32:58 2016 -0800

----------------------------------------------------------------------
 .../apache/spark/ml/classification/OneVsRest.scala    |  9 +++++++++
 .../spark/ml/classification/OneVsRestSuite.scala      | 14 +++++++++++++-
 2 files changed, 22 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/bdfe7f67/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
index f4ab0a0..e58b30d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
@@ -140,6 +140,14 @@ final class OneVsRestModel private[ml] (
     this(uid, Metadata.empty, models.asScala.toArray)
   }
 
+  /** @group setParam */
+  @Since("2.1.0")
+  def setFeaturesCol(value: String): this.type = set(featuresCol, value)
+
+  /** @group setParam */
+  @Since("2.1.0")
+  def setPredictionCol(value: String): this.type = set(predictionCol, value)
+
   @Since("1.4.0")
   override def transformSchema(schema: StructType): StructType = {
     validateAndTransformSchema(schema, fitting = false, getClassifier.featuresDataType)
@@ -175,6 +183,7 @@ final class OneVsRestModel private[ml] (
         val updateUDF = udf { (predictions: Map[Int, Double], prediction: Vector) =>
           predictions + ((index, prediction(1)))
         }
+        model.setFeaturesCol($(featuresCol))
         val transformedDataset = model.transform(df).select(columns: _*)
         val updatedDataset = transformedDataset
           .withColumn(tmpColName, updateUDF(col(accColName), col(rawPredictionCol)))

http://git-wip-us.apache.org/repos/asf/spark/blob/bdfe7f67/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
index 3f9bcec..aacb792 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
@@ -22,7 +22,7 @@ import org.apache.spark.ml.attribute.NominalAttribute
 import org.apache.spark.ml.classification.LogisticRegressionSuite._
 import org.apache.spark.ml.feature.LabeledPoint
 import org.apache.spark.ml.feature.StringIndexer
-import org.apache.spark.ml.linalg.{DenseMatrix, Vectors}
+import org.apache.spark.ml.linalg.Vectors
 import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
 import org.apache.spark.ml.util.{DefaultReadWriteTest, MetadataUtils, MLTestingUtils}
 import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
@@ -33,6 +33,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.mllib.util.TestingUtils._
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.Dataset
+import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types.Metadata
 
 class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@@ -136,6 +137,17 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
     assert(outputFields.contains("p"))
   }
 
+  test("SPARK-18625 : OneVsRestModel should support setFeaturesCol and setPredictionCol") {
+    val ova = new OneVsRest().setClassifier(new LogisticRegression)
+    val ovaModel = ova.fit(dataset)
+    val dataset2 = dataset.select(col("label").as("y"), col("features").as("fea"))
+    ovaModel.setFeaturesCol("fea")
+    ovaModel.setPredictionCol("pred")
+    val transformedDataset = ovaModel.transform(dataset2)
+    val outputFields = transformedDataset.schema.fieldNames.toSet
+    assert(outputFields === Set("y", "fea", "pred"))
+  }
+
   test("SPARK-8049: OneVsRest shouldn't output temp columns") {
     val logReg = new LogisticRegression()
       .setMaxIter(1)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org