You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/09/01 01:06:39 UTC
spark git commit: [SPARK-10349] [ML] OneVsRest use 'when ...
otherwise' not UDF to generate new label at binary reduction
Repository: spark
Updated Branches:
refs/heads/master 540bdee93 -> fe16fd0b8
[SPARK-10349] [ML] OneVsRest use 'when ... otherwise' not UDF to generate new label at binary reduction
Currently OneVsRest use UDF to generate new binary label during training.
Considering that [SPARK-7321](https://issues.apache.org/jira/browse/SPARK-7321) has been merged, we can use ```when ... otherwise``` which will be more efficiency.
Author: Yanbo Liang <yb...@gmail.com>
Closes #8519 from yanboliang/spark-10349.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/fe16fd0b
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/fe16fd0b
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/fe16fd0b
Branch: refs/heads/master
Commit: fe16fd0b8b717f01151bc659ec3299dab091c97a
Parents: 540bdee
Author: Yanbo Liang <yb...@gmail.com>
Authored: Mon Aug 31 16:06:38 2015 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Mon Aug 31 16:06:38 2015 -0700
----------------------------------------------------------------------
.../org/apache/spark/ml/classification/OneVsRest.scala | 10 ++--------
1 file changed, 2 insertions(+), 8 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/fe16fd0b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
index c62e132..debc164 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
@@ -91,7 +91,6 @@ final class OneVsRestModel private[ml] (
// add an accumulator column to store predictions of all the models
val accColName = "mbc$acc" + UUID.randomUUID().toString
val initUDF = udf { () => Map[Int, Double]() }
- val mapType = MapType(IntegerType, DoubleType, valueContainsNull = false)
val newDataset = dataset.withColumn(accColName, initUDF())
// persist if underlying dataset is not persistent.
@@ -195,16 +194,11 @@ final class OneVsRest(override val uid: String)
// create k columns, one for each binary classifier.
val models = Range(0, numClasses).par.map { index =>
- val labelUDF = udf { (label: Double) =>
- if (label.toInt == index) 1.0 else 0.0
- }
-
// generate new label metadata for the binary problem.
- // TODO: use when ... otherwise after SPARK-7321 is merged
val newLabelMeta = BinaryAttribute.defaultAttr.withName("label").toMetadata()
val labelColName = "mc2b$" + index
- val trainingDataset =
- multiclassLabeled.withColumn(labelColName, labelUDF(col($(labelCol))), newLabelMeta)
+ val trainingDataset = multiclassLabeled.withColumn(
+ labelColName, when(col($(labelCol)) === index.toDouble, 1.0).otherwise(0.0), newLabelMeta)
val classifier = getClassifier
val paramMap = new ParamMap()
paramMap.put(classifier.labelCol -> labelColName)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org