You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2015/07/24 07:35:51 UTC

spark git commit: [SPARK-8092] [ML] Allow OneVsRest Classifier feature and label column names to be configurable.

Repository: spark
Updated Branches:
  refs/heads/master d249636e5 -> d4d762f27


[SPARK-8092] [ML] Allow OneVsRest Classifier feature and label column names to be configurable.

The base classifier input and output columns are ignored in favor of  the ones specified in OneVsRest.

Author: Ram Sriharsha <rs...@hw11853.local>

Closes #6631 from harsha2010/SPARK-8092 and squashes the following commits:

6591dc6 [Ram Sriharsha] add documentation for params
b7024b1 [Ram Sriharsha] cleanup
f0e2bfb [Ram Sriharsha] merge with master
108d3d7 [Ram Sriharsha] merge with master
4f74126 [Ram Sriharsha] Allow label/ features columns to be configurable


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d4d762f2
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d4d762f2
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d4d762f2

Branch: refs/heads/master
Commit: d4d762f275749a923356cd84de549b14c22cc3eb
Parents: d249636
Author: Ram Sriharsha <rs...@hw11853.local>
Authored: Thu Jul 23 22:35:41 2015 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Thu Jul 23 22:35:41 2015 -0700

----------------------------------------------------------------------
 .../spark/ml/classification/OneVsRest.scala     | 17 +++++++++++++-
 .../ml/classification/OneVsRestSuite.scala      | 24 ++++++++++++++++++++
 2 files changed, 40 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d4d762f2/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
index ea757c5..1741f19 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
@@ -47,6 +47,8 @@ private[ml] trait OneVsRestParams extends PredictorParams {
 
   /**
    * param for the base binary classifier that we reduce multiclass classification into.
+   * The base classifier input and output columns are ignored in favor of
+   * the ones specified in [[OneVsRest]].
    * @group param
    */
   val classifier: Param[ClassifierType] = new Param(this, "classifier", "base binary classifier")
@@ -160,6 +162,15 @@ final class OneVsRest(override val uid: String)
     set(classifier, value.asInstanceOf[ClassifierType])
   }
 
+  /** @group setParam */
+  def setLabelCol(value: String): this.type = set(labelCol, value)
+
+  /** @group setParam */
+  def setFeaturesCol(value: String): this.type = set(featuresCol, value)
+
+  /** @group setParam */
+  def setPredictionCol(value: String): this.type = set(predictionCol, value)
+
   override def transformSchema(schema: StructType): StructType = {
     validateAndTransformSchema(schema, fitting = true, getClassifier.featuresDataType)
   }
@@ -195,7 +206,11 @@ final class OneVsRest(override val uid: String)
       val labelUDFWithNewMeta = labelUDF(col($(labelCol))).as(labelColName, newLabelMeta)
       val trainingDataset = multiclassLabeled.withColumn(labelColName, labelUDFWithNewMeta)
       val classifier = getClassifier
-      classifier.fit(trainingDataset, classifier.labelCol -> labelColName)
+      val paramMap = new ParamMap()
+      paramMap.put(classifier.labelCol -> labelColName)
+      paramMap.put(classifier.featuresCol -> getFeaturesCol)
+      paramMap.put(classifier.predictionCol -> getPredictionCol)
+      classifier.fit(trainingDataset, paramMap)
     }.toArray[ClassificationModel[_, _]]
 
     if (handlePersistence) {

http://git-wip-us.apache.org/repos/asf/spark/blob/d4d762f2/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
index 75cf5bd..3775292 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.ml.classification
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.attribute.NominalAttribute
+import org.apache.spark.ml.feature.StringIndexer
 import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
 import org.apache.spark.ml.util.MetadataUtils
 import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
@@ -104,6 +105,29 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
     ova.fit(datasetWithLabelMetadata)
   }
 
+  test("SPARK-8092: ensure label features and prediction cols are configurable") {
+    val labelIndexer = new StringIndexer()
+      .setInputCol("label")
+      .setOutputCol("indexed")
+
+    val indexedDataset = labelIndexer
+      .fit(dataset)
+      .transform(dataset)
+      .drop("label")
+      .withColumnRenamed("features", "f")
+
+    val ova = new OneVsRest()
+    ova.setClassifier(new LogisticRegression())
+      .setLabelCol(labelIndexer.getOutputCol)
+      .setFeaturesCol("f")
+      .setPredictionCol("p")
+
+    val ovaModel = ova.fit(indexedDataset)
+    val transformedDataset = ovaModel.transform(indexedDataset)
+    val outputFields = transformedDataset.schema.fieldNames.toSet
+    assert(outputFields.contains("p"))
+  }
+
   test("SPARK-8049: OneVsRest shouldn't output temp columns") {
     val logReg = new LogisticRegression()
       .setMaxIter(1)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org