You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2015/07/24 07:35:51 UTC
spark git commit: [SPARK-8092] [ML] Allow OneVsRest Classifier
feature and label column names to be configurable.
Repository: spark
Updated Branches:
refs/heads/master d249636e5 -> d4d762f27
[SPARK-8092] [ML] Allow OneVsRest Classifier feature and label column names to be configurable.
The base classifier input and output columns are ignored in favor of the ones specified in OneVsRest.
Author: Ram Sriharsha <rs...@hw11853.local>
Closes #6631 from harsha2010/SPARK-8092 and squashes the following commits:
6591dc6 [Ram Sriharsha] add documentation for params
b7024b1 [Ram Sriharsha] cleanup
f0e2bfb [Ram Sriharsha] merge with master
108d3d7 [Ram Sriharsha] merge with master
4f74126 [Ram Sriharsha] Allow label/ features columns to be configurable
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d4d762f2
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d4d762f2
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d4d762f2
Branch: refs/heads/master
Commit: d4d762f275749a923356cd84de549b14c22cc3eb
Parents: d249636
Author: Ram Sriharsha <rs...@hw11853.local>
Authored: Thu Jul 23 22:35:41 2015 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Thu Jul 23 22:35:41 2015 -0700
----------------------------------------------------------------------
.../spark/ml/classification/OneVsRest.scala | 17 +++++++++++++-
.../ml/classification/OneVsRestSuite.scala | 24 ++++++++++++++++++++
2 files changed, 40 insertions(+), 1 deletion(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/d4d762f2/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
index ea757c5..1741f19 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
@@ -47,6 +47,8 @@ private[ml] trait OneVsRestParams extends PredictorParams {
/**
* param for the base binary classifier that we reduce multiclass classification into.
+ * The base classifier input and output columns are ignored in favor of
+ * the ones specified in [[OneVsRest]].
* @group param
*/
val classifier: Param[ClassifierType] = new Param(this, "classifier", "base binary classifier")
@@ -160,6 +162,15 @@ final class OneVsRest(override val uid: String)
set(classifier, value.asInstanceOf[ClassifierType])
}
+ /** @group setParam */
+ def setLabelCol(value: String): this.type = set(labelCol, value)
+
+ /** @group setParam */
+ def setFeaturesCol(value: String): this.type = set(featuresCol, value)
+
+ /** @group setParam */
+ def setPredictionCol(value: String): this.type = set(predictionCol, value)
+
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema, fitting = true, getClassifier.featuresDataType)
}
@@ -195,7 +206,11 @@ final class OneVsRest(override val uid: String)
val labelUDFWithNewMeta = labelUDF(col($(labelCol))).as(labelColName, newLabelMeta)
val trainingDataset = multiclassLabeled.withColumn(labelColName, labelUDFWithNewMeta)
val classifier = getClassifier
- classifier.fit(trainingDataset, classifier.labelCol -> labelColName)
+ val paramMap = new ParamMap()
+ paramMap.put(classifier.labelCol -> labelColName)
+ paramMap.put(classifier.featuresCol -> getFeaturesCol)
+ paramMap.put(classifier.predictionCol -> getPredictionCol)
+ classifier.fit(trainingDataset, paramMap)
}.toArray[ClassificationModel[_, _]]
if (handlePersistence) {
http://git-wip-us.apache.org/repos/asf/spark/blob/d4d762f2/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
index 75cf5bd..3775292 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.ml.classification
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute.NominalAttribute
+import org.apache.spark.ml.feature.StringIndexer
import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
import org.apache.spark.ml.util.MetadataUtils
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
@@ -104,6 +105,29 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
ova.fit(datasetWithLabelMetadata)
}
+ test("SPARK-8092: ensure label features and prediction cols are configurable") {
+ val labelIndexer = new StringIndexer()
+ .setInputCol("label")
+ .setOutputCol("indexed")
+
+ val indexedDataset = labelIndexer
+ .fit(dataset)
+ .transform(dataset)
+ .drop("label")
+ .withColumnRenamed("features", "f")
+
+ val ova = new OneVsRest()
+ ova.setClassifier(new LogisticRegression())
+ .setLabelCol(labelIndexer.getOutputCol)
+ .setFeaturesCol("f")
+ .setPredictionCol("p")
+
+ val ovaModel = ova.fit(indexedDataset)
+ val transformedDataset = ovaModel.transform(indexedDataset)
+ val outputFields = transformedDataset.schema.fieldNames.toSet
+ assert(outputFields.contains("p"))
+ }
+
test("SPARK-8049: OneVsRest shouldn't output temp columns") {
val logReg = new LogisticRegression()
.setMaxIter(1)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org