You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/06/03 01:51:23 UTC
spark git commit: [SPARK-8049] [MLLIB] drop tmp col from OneVsRest
output
Repository: spark
Updated Branches:
refs/heads/master 605ddbb27 -> 89f21f66b
[SPARK-8049] [MLLIB] drop tmp col from OneVsRest output
The temporary column should be dropped after we get the prediction column. harsha2010
Author: Xiangrui Meng <me...@databricks.com>
Closes #6592 from mengxr/SPARK-8049 and squashes the following commits:
1d89107 [Xiangrui Meng] use SparkFunSuite
6ee70de [Xiangrui Meng] drop tmp col from OneVsRest output
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/89f21f66
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/89f21f66
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/89f21f66
Branch: refs/heads/master
Commit: 89f21f66b5549524d1a6e4fb576a4f80d9fef903
Parents: 605ddbb
Author: Xiangrui Meng <me...@databricks.com>
Authored: Tue Jun 2 16:51:17 2015 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Tue Jun 2 16:51:17 2015 -0700
----------------------------------------------------------------------
.../org/apache/spark/ml/classification/OneVsRest.scala | 1 +
.../org/apache/spark/ml/classification/OneVsRestSuite.scala | 9 +++++++++
2 files changed, 10 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/89f21f66/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
index 7b726da..825f9ed 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
@@ -131,6 +131,7 @@ final class OneVsRestModel private[ml] (
// output label and label metadata as prediction
val labelUdf = callUDF(label, DoubleType, col(accColName))
aggregatedDataset.withColumn($(predictionCol), labelUdf.as($(predictionCol), labelMetadata))
+ .drop(accColName)
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/89f21f66/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
index f439f32..1d04ccb 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
@@ -93,6 +93,15 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
val datasetWithLabelMetadata = dataset.select(labelWithMetadata, features)
ova.fit(datasetWithLabelMetadata)
}
+
+ test("SPARK-8049: OneVsRest shouldn't output temp columns") {
+ val logReg = new LogisticRegression()
+ .setMaxIter(1)
+ val ovr = new OneVsRest()
+ .setClassifier(logReg)
+ val output = ovr.fit(dataset).transform(dataset)
+ assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction"))
+ }
}
private class MockLogisticRegression(uid: String) extends LogisticRegression(uid) {
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org