You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2015/07/31 08:02:19 UTC

spark git commit: [SPARK-7690] [ML] Multiclass classification Evaluator

Repository: spark
Updated Branches:
  refs/heads/master 83670fc9e -> 4e5919bfb


[SPARK-7690] [ML] Multiclass classification Evaluator

Multiclass Classification Evaluator for ML Pipelines. F1 score, precision, recall, weighted precision and weighted recall are supported as available metrics.

Author: Ram Sriharsha <rs...@hw11853.local>

Closes #7475 from harsha2010/SPARK-7690 and squashes the following commits:

9bf4ec7 [Ram Sriharsha] fix indentation
3f09a85 [Ram Sriharsha] cleanup doc
16115ae [Ram Sriharsha] code review fixes
032d2a3 [Ram Sriharsha] fix test
eec9865 [Ram Sriharsha] Fix Python Indentation
1dbeffd [Ram Sriharsha] Merge branch 'master' into SPARK-7690
68cea85 [Ram Sriharsha] Merge branch 'master' into SPARK-7690
54c03de [Ram Sriharsha] [SPARK-7690][ml][WIP] Multiclass Evaluator for ML Pipeline


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4e5919bf
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4e5919bf
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4e5919bf

Branch: refs/heads/master
Commit: 4e5919bfb47a58bcbda90ae01c1bed2128ded983
Parents: 83670fc
Author: Ram Sriharsha <rs...@hw11853.local>
Authored: Thu Jul 30 23:02:11 2015 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Thu Jul 30 23:02:11 2015 -0700

----------------------------------------------------------------------
 .../MulticlassClassificationEvaluator.scala     | 85 ++++++++++++++++++++
 ...MulticlassClassificationEvaluatorSuite.scala | 28 +++++++
 python/pyspark/ml/evaluation.py                 | 66 +++++++++++++++
 3 files changed, 179 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/4e5919bf/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
new file mode 100644
index 0000000..44f779c
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.evaluation
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.ml.param.{ParamMap, ParamValidators, Param}
+import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
+import org.apache.spark.ml.util.{SchemaUtils, Identifiable}
+import org.apache.spark.mllib.evaluation.MulticlassMetrics
+import org.apache.spark.sql.{Row, DataFrame}
+import org.apache.spark.sql.types.DoubleType
+
+/**
+ * :: Experimental ::
+ * Evaluator for multiclass classification, which expects two input columns: score and label.
+ */
+@Experimental
+class MulticlassClassificationEvaluator (override val uid: String)
+  extends Evaluator with HasPredictionCol with HasLabelCol {
+
+  def this() = this(Identifiable.randomUID("mcEval"))
+
+  /**
+   * param for metric name in evaluation (supports `"f1"` (default), `"precision"`, `"recall"`,
+   * `"weightedPrecision"`, `"weightedRecall"`)
+   * @group param
+   */
+  val metricName: Param[String] = {
+    val allowedParams = ParamValidators.inArray(Array("f1", "precision",
+      "recall", "weightedPrecision", "weightedRecall"))
+    new Param(this, "metricName", "metric name in evaluation " +
+      "(f1|precision|recall|weightedPrecision|weightedRecall)", allowedParams)
+  }
+
+  /** @group getParam */
+  def getMetricName: String = $(metricName)
+
+  /** @group setParam */
+  def setMetricName(value: String): this.type = set(metricName, value)
+
+  /** @group setParam */
+  def setPredictionCol(value: String): this.type = set(predictionCol, value)
+
+  /** @group setParam */
+  def setLabelCol(value: String): this.type = set(labelCol, value)
+
+  setDefault(metricName -> "f1")
+
+  override def evaluate(dataset: DataFrame): Double = {
+    val schema = dataset.schema
+    SchemaUtils.checkColumnType(schema, $(predictionCol), DoubleType)
+    SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
+
+    val predictionAndLabels = dataset.select($(predictionCol), $(labelCol))
+      .map { case Row(prediction: Double, label: Double) =>
+      (prediction, label)
+    }
+    val metrics = new MulticlassMetrics(predictionAndLabels)
+    val metric = $(metricName) match {
+      case "f1" => metrics.weightedFMeasure
+      case "precision" => metrics.precision
+      case "recall" => metrics.recall
+      case "weightedPrecision" => metrics.weightedPrecision
+      case "weightedRecall" => metrics.weightedRecall
+    }
+    metric
+  }
+
+  override def copy(extra: ParamMap): MulticlassClassificationEvaluator = defaultCopy(extra)
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/4e5919bf/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala
new file mode 100644
index 0000000..6d8412b
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluatorSuite.scala
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.evaluation
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.param.ParamsSuite
+
+class MulticlassClassificationEvaluatorSuite extends SparkFunSuite {
+
+  test("params") {
+    ParamsSuite.checkParams(new MulticlassClassificationEvaluator)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/4e5919bf/python/pyspark/ml/evaluation.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py
index 595593a..06e8093 100644
--- a/python/pyspark/ml/evaluation.py
+++ b/python/pyspark/ml/evaluation.py
@@ -214,6 +214,72 @@ class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol):
         kwargs = self.setParams._input_kwargs
         return self._set(**kwargs)
 
+
+@inherit_doc
+class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol):
+    """
+    Evaluator for Multiclass Classification, which expects two input
+    columns: prediction and label.
+    >>> scoreAndLabels = [(0.0, 0.0), (0.0, 1.0), (0.0, 0.0),
+    ...     (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)]
+    >>> dataset = sqlContext.createDataFrame(scoreAndLabels, ["prediction", "label"])
+    ...
+    >>> evaluator = MulticlassClassificationEvaluator(predictionCol="prediction")
+    >>> evaluator.evaluate(dataset)
+    0.66...
+    >>> evaluator.evaluate(dataset, {evaluator.metricName: "precision"})
+    0.66...
+    >>> evaluator.evaluate(dataset, {evaluator.metricName: "recall"})
+    0.66...
+    """
+    # a placeholder to make it appear in the generated doc
+    metricName = Param(Params._dummy(), "metricName",
+                       "metric name in evaluation "
+                       "(f1|precision|recall|weightedPrecision|weightedRecall)")
+
+    @keyword_only
+    def __init__(self, predictionCol="prediction", labelCol="label",
+                 metricName="f1"):
+        """
+        __init__(self, predictionCol="prediction", labelCol="label", \
+                 metricName="f1")
+        """
+        super(MulticlassClassificationEvaluator, self).__init__()
+        self._java_obj = self._new_java_obj(
+            "org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator", self.uid)
+        # param for metric name in evaluation (f1|precision|recall|weightedPrecision|weightedRecall)
+        self.metricName = Param(self, "metricName",
+                                "metric name in evaluation"
+                                " (f1|precision|recall|weightedPrecision|weightedRecall)")
+        self._setDefault(predictionCol="prediction", labelCol="label",
+                         metricName="f1")
+        kwargs = self.__init__._input_kwargs
+        self._set(**kwargs)
+
+    def setMetricName(self, value):
+        """
+        Sets the value of :py:attr:`metricName`.
+        """
+        self._paramMap[self.metricName] = value
+        return self
+
+    def getMetricName(self):
+        """
+        Gets the value of metricName or its default value.
+        """
+        return self.getOrDefault(self.metricName)
+
+    @keyword_only
+    def setParams(self, predictionCol="prediction", labelCol="label",
+                  metricName="f1"):
+        """
+        setParams(self, predictionCol="prediction", labelCol="label", \
+                  metricName="f1")
+        Sets params for multiclass classification evaluator.
+        """
+        kwargs = self.setParams._input_kwargs
+        return self._set(**kwargs)
+
 if __name__ == "__main__":
     import doctest
     from pyspark.context import SparkContext


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org