You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/06/03 01:51:23 UTC

spark git commit: [SPARK-8049] [MLLIB] drop tmp col from OneVsRest output

Repository: spark
Updated Branches:
  refs/heads/master 605ddbb27 -> 89f21f66b


[SPARK-8049] [MLLIB] drop tmp col from OneVsRest output

The temporary column should be dropped after we get the prediction column. harsha2010

Author: Xiangrui Meng <me...@databricks.com>

Closes #6592 from mengxr/SPARK-8049 and squashes the following commits:

1d89107 [Xiangrui Meng] use SparkFunSuite
6ee70de [Xiangrui Meng] drop tmp col from OneVsRest output


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/89f21f66
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/89f21f66
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/89f21f66

Branch: refs/heads/master
Commit: 89f21f66b5549524d1a6e4fb576a4f80d9fef903
Parents: 605ddbb
Author: Xiangrui Meng <me...@databricks.com>
Authored: Tue Jun 2 16:51:17 2015 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Tue Jun 2 16:51:17 2015 -0700

----------------------------------------------------------------------
 .../org/apache/spark/ml/classification/OneVsRest.scala      | 1 +
 .../org/apache/spark/ml/classification/OneVsRestSuite.scala | 9 +++++++++
 2 files changed, 10 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/89f21f66/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
index 7b726da..825f9ed 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
@@ -131,6 +131,7 @@ final class OneVsRestModel private[ml] (
     // output label and label metadata as prediction
     val labelUdf = callUDF(label, DoubleType, col(accColName))
     aggregatedDataset.withColumn($(predictionCol), labelUdf.as($(predictionCol), labelMetadata))
+      .drop(accColName)
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/89f21f66/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
index f439f32..1d04ccb 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
@@ -93,6 +93,15 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext {
     val datasetWithLabelMetadata = dataset.select(labelWithMetadata, features)
     ova.fit(datasetWithLabelMetadata)
   }
+
+  test("SPARK-8049: OneVsRest shouldn't output temp columns") {
+    val logReg = new LogisticRegression()
+      .setMaxIter(1)
+    val ovr = new OneVsRest()
+      .setClassifier(logReg)
+    val output = ovr.fit(dataset).transform(dataset)
+    assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction"))
+  }
 }
 
 private class MockLogisticRegression(uid: String) extends LogisticRegression(uid) {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org