You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by pw...@apache.org on 2014/04/09 05:37:30 UTC

git commit: [SPARK-1434] [MLLIB] change labelParser from anonymous function to trait

Repository: spark
Updated Branches:
  refs/heads/master ce8ec5456 -> b9e0c937d


[SPARK-1434] [MLLIB] change labelParser from anonymous function to trait

This is a patch to address @mateiz 's comment in https://github.com/apache/spark/pull/245

MLUtils#loadLibSVMData uses an anonymous function for the label parser. Java users won't like it. So I make a trait for LabelParser and provide two implementations: binary and multiclass.

Author: Xiangrui Meng <me...@databricks.com>

Closes #345 from mengxr/label-parser and squashes the following commits:

ac44409 [Xiangrui Meng] use singleton objects for label parsers
3b1a7c6 [Xiangrui Meng] add tests for label parsers
c2e571c [Xiangrui Meng] rename LabelParser.apply to LabelParser.parse use extends for singleton
11c94e0 [Xiangrui Meng] add return types
7f8eb36 [Xiangrui Meng] change labelParser from annoymous function to trait


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b9e0c937
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b9e0c937
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b9e0c937

Branch: refs/heads/master
Commit: b9e0c937dfa1ca93b63d0b39d5f156b16c2fdc0a
Parents: ce8ec54
Author: Xiangrui Meng <me...@databricks.com>
Authored: Tue Apr 8 20:37:01 2014 -0700
Committer: Patrick Wendell <pw...@gmail.com>
Committed: Tue Apr 8 20:37:01 2014 -0700

----------------------------------------------------------------------
 .../apache/spark/mllib/util/LabelParsers.scala  | 49 ++++++++++++++++++++
 .../org/apache/spark/mllib/util/MLUtils.scala   | 28 ++---------
 .../spark/mllib/util/LabelParsersSuite.scala    | 41 ++++++++++++++++
 .../apache/spark/mllib/util/MLUtilsSuite.scala  |  4 +-
 4 files changed, 97 insertions(+), 25 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b9e0c937/mllib/src/main/scala/org/apache/spark/mllib/util/LabelParsers.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LabelParsers.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LabelParsers.scala
new file mode 100644
index 0000000..f7966d3
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LabelParsers.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.util
+
+/** Trait for label parsers. */
+trait LabelParser extends Serializable {
+  /** Parses a string label into a double label. */
+  def parse(labelString: String): Double
+}
+
+/**
+ * Label parser for binary labels, which outputs 1.0 (positive) if the value is greater than 0.5,
+ * or 0.0 (negative) otherwise. So it works with +1/-1 labeling and +1/0 labeling.
+ */
+object BinaryLabelParser extends LabelParser {
+  /** Gets the default instance of BinaryLabelParser. */
+  def getInstance(): LabelParser = this
+
+  /**
+   * Parses the input label into positive (1.0) if the value is greater than 0.5,
+   * or negative (0.0) otherwise.
+   */
+  override def parse(labelString: String): Double = if (labelString.toDouble > 0.5) 1.0 else 0.0
+}
+
+/**
+ * Label parser for multiclass labels, which converts the input label to double.
+ */
+object MulticlassLabelParser extends LabelParser {
+  /** Gets the default instance of MulticlassLabelParser. */
+  def getInstance(): LabelParser = this
+
+  override def parse(labelString: String): Double =  labelString.toDouble
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/b9e0c937/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
index cb85e43..83d1bd3 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
@@ -39,17 +39,6 @@ object MLUtils {
   }
 
   /**
-   * Multiclass label parser, which parses a string into double.
-   */
-  val multiclassLabelParser: String => Double = _.toDouble
-
-  /**
-   * Binary label parser, which outputs 1.0 (positive) if the value is greater than 0.5,
-   * or 0.0 (negative) otherwise.
-   */
-  val binaryLabelParser: String => Double = label => if (label.toDouble > 0.5) 1.0 else 0.0
-
-  /**
    * Loads labeled data in the LIBSVM format into an RDD[LabeledPoint].
    * The LIBSVM format is a text-based format used by LIBSVM and LIBLINEAR.
    * Each line represents a labeled sparse feature vector using the following format:
@@ -69,7 +58,7 @@ object MLUtils {
   def loadLibSVMData(
       sc: SparkContext,
       path: String,
-      labelParser: String => Double,
+      labelParser: LabelParser,
       numFeatures: Int,
       minSplits: Int): RDD[LabeledPoint] = {
     val parsed = sc.textFile(path, minSplits)
@@ -89,7 +78,7 @@ object MLUtils {
       }.reduce(math.max)
     }
     parsed.map { items =>
-      val label = labelParser(items.head)
+      val label = labelParser.parse(items.head)
       val (indices, values) = items.tail.map { item =>
         val indexAndValue = item.split(':')
         val index = indexAndValue(0).toInt - 1
@@ -107,14 +96,7 @@ object MLUtils {
    * with number of features determined automatically and the default number of partitions.
    */
   def loadLibSVMData(sc: SparkContext, path: String): RDD[LabeledPoint] =
-    loadLibSVMData(sc, path, binaryLabelParser, -1, sc.defaultMinSplits)
-
-  /**
-   * Loads binary labeled data in the LIBSVM format into an RDD[LabeledPoint],
-   * with number of features specified explicitly and the default number of partitions.
-   */
-  def loadLibSVMData(sc: SparkContext, path: String, numFeatures: Int): RDD[LabeledPoint] =
-    loadLibSVMData(sc, path, binaryLabelParser, numFeatures, sc.defaultMinSplits)
+    loadLibSVMData(sc, path, BinaryLabelParser, -1, sc.defaultMinSplits)
 
   /**
    * Loads labeled data in the LIBSVM format into an RDD[LabeledPoint],
@@ -124,7 +106,7 @@ object MLUtils {
   def loadLibSVMData(
       sc: SparkContext,
       path: String,
-      labelParser: String => Double): RDD[LabeledPoint] =
+      labelParser: LabelParser): RDD[LabeledPoint] =
     loadLibSVMData(sc, path, labelParser, -1, sc.defaultMinSplits)
 
   /**
@@ -135,7 +117,7 @@ object MLUtils {
   def loadLibSVMData(
       sc: SparkContext,
       path: String,
-      labelParser: String => Double,
+      labelParser: LabelParser,
       numFeatures: Int): RDD[LabeledPoint] =
     loadLibSVMData(sc, path, labelParser, numFeatures, sc.defaultMinSplits)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/b9e0c937/mllib/src/test/scala/org/apache/spark/mllib/util/LabelParsersSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/LabelParsersSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/LabelParsersSuite.scala
new file mode 100644
index 0000000..ac85677
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/LabelParsersSuite.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.util
+
+import org.scalatest.FunSuite
+
+class LabelParsersSuite extends FunSuite {
+  test("binary label parser") {
+    for (parser <- Seq(BinaryLabelParser, BinaryLabelParser.getInstance())) {
+      assert(parser.parse("+1") === 1.0)
+      assert(parser.parse("1") === 1.0)
+      assert(parser.parse("0") === 0.0)
+      assert(parser.parse("-1") === 0.0)
+    }
+  }
+
+  test("multiclass label parser") {
+    for (parser <- Seq(MulticlassLabelParser, MulticlassLabelParser.getInstance())) {
+      assert(parser.parse("0") == 0.0)
+      assert(parser.parse("+1") === 1.0)
+      assert(parser.parse("1") === 1.0)
+      assert(parser.parse("2") === 2.0)
+      assert(parser.parse("3") === 3.0)
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/b9e0c937/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
index 27d41c7..e451c35 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
@@ -80,7 +80,7 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext {
     Files.write(lines, file, Charsets.US_ASCII)
     val path = tempDir.toURI.toString
 
-    val pointsWithNumFeatures = MLUtils.loadLibSVMData(sc, path, 6).collect()
+    val pointsWithNumFeatures = MLUtils.loadLibSVMData(sc, path, BinaryLabelParser, 6).collect()
     val pointsWithoutNumFeatures = MLUtils.loadLibSVMData(sc, path).collect()
 
     for (points <- Seq(pointsWithNumFeatures, pointsWithoutNumFeatures)) {
@@ -93,7 +93,7 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext {
       assert(points(2).features === Vectors.sparse(6, Seq((1, 4.0), (3, 5.0), (5, 6.0))))
     }
 
-    val multiclassPoints = MLUtils.loadLibSVMData(sc, path, MLUtils.multiclassLabelParser).collect()
+    val multiclassPoints = MLUtils.loadLibSVMData(sc, path, MulticlassLabelParser).collect()
     assert(multiclassPoints.length === 3)
     assert(multiclassPoints(0).label === 1.0)
     assert(multiclassPoints(1).label === -1.0)