You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/05/22 18:59:47 UTC

spark git commit: [SPARK-7404] [ML] Add RegressionEvaluator to spark.ml

Repository: spark
Updated Branches:
  refs/heads/master 3b68cb043 -> f490b3b4c


[SPARK-7404] [ML] Add RegressionEvaluator to spark.ml

Author: Ram Sriharsha <rs...@hw11853.local>

Closes #6344 from harsha2010/SPARK-7404 and squashes the following commits:

16b9d77 [Ram Sriharsha] consistent naming
7f100b6 [Ram Sriharsha] cleanup
c46044d [Ram Sriharsha] Merge with Master + Code Review Fixes
188fa0a [Ram Sriharsha] Merge branch 'master' into SPARK-7404
f5b6a4c [Ram Sriharsha] cleanup doc
97beca5 [Ram Sriharsha] update test to use R packages
32dd310 [Ram Sriharsha] fix indentation
f93b812 [Ram Sriharsha] fix test
1b6ebb3 [Ram Sriharsha] [SPARK-7404][ml] Add RegressionEvaluator to spark.ml


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f490b3b4
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f490b3b4
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f490b3b4

Branch: refs/heads/master
Commit: f490b3b4c706c92aa65d000b9d885f4d160a5f39
Parents: 3b68cb0
Author: Ram Sriharsha <rs...@hw11853.local>
Authored: Fri May 22 09:59:44 2015 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Fri May 22 09:59:44 2015 -0700

----------------------------------------------------------------------
 .../ml/evaluation/RegressionEvaluator.scala     | 84 ++++++++++++++++++++
 .../evaluation/RegressionEvaluatorSuite.scala   | 71 +++++++++++++++++
 2 files changed, 155 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f490b3b4/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
new file mode 100644
index 0000000..ec493f8
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.evaluation
+
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.param.{Param, ParamValidators}
+import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
+import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
+import org.apache.spark.mllib.evaluation.RegressionMetrics
+import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.types.DoubleType
+
+/**
+ * :: AlphaComponent ::
+ *
+ * Evaluator for regression, which expects two input columns: prediction and label.
+ */
+@AlphaComponent
+class RegressionEvaluator(override val uid: String)
+  extends Evaluator with HasPredictionCol with HasLabelCol {
+
+  def this() = this(Identifiable.randomUID("regEval"))
+
+  /**
+   * param for metric name in evaluation
+   * @group param
+   */
+  val metricName: Param[String] = {
+    val allowedParams = ParamValidators.inArray(Array("mse", "rmse", "r2", "mae"))
+    new Param(this, "metricName", "metric name in evaluation (mse|rmse|r2|mae)", allowedParams)
+  }
+
+  /** @group getParam */
+  def getMetricName: String = $(metricName)
+
+  /** @group setParam */
+  def setMetricName(value: String): this.type = set(metricName, value)
+
+  /** @group setParam */
+  def setPredictionCol(value: String): this.type = set(predictionCol, value)
+
+  /** @group setParam */
+  def setLabelCol(value: String): this.type = set(labelCol, value)
+
+  setDefault(metricName -> "rmse")
+
+  override def evaluate(dataset: DataFrame): Double = {
+    val schema = dataset.schema
+    SchemaUtils.checkColumnType(schema, $(predictionCol), DoubleType)
+    SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
+
+    val predictionAndLabels = dataset.select($(predictionCol), $(labelCol))
+      .map { case Row(prediction: Double, label: Double) =>
+        (prediction, label)
+      }
+    val metrics = new RegressionMetrics(predictionAndLabels)
+    val metric = $(metricName) match {
+      case "rmse" =>
+        metrics.rootMeanSquaredError
+      case "mse" =>
+        metrics.meanSquaredError
+      case "r2" =>
+        metrics.r2
+      case "mae" =>
+        metrics.meanAbsoluteError
+    }
+    metric
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/f490b3b4/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
new file mode 100644
index 0000000..983f8b4
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.evaluation
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.ml.regression.LinearRegression
+import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
+import org.apache.spark.mllib.util.TestingUtils._
+
+class RegressionEvaluatorSuite extends FunSuite with MLlibTestSparkContext {
+
+  test("Regression Evaluator: default params") {
+    /**
+     * Here is the instruction describing how to export the test data into CSV format
+     * so we can validate the metrics compared with R's mmetric package.
+     *
+     * import org.apache.spark.mllib.util.LinearDataGenerator
+     * val data = sc.parallelize(LinearDataGenerator.generateLinearInput(6.3,
+     *   Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1))
+     * data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1))
+     *   .saveAsTextFile("path")
+     */
+    val dataset = sqlContext.createDataFrame(
+      sc.parallelize(LinearDataGenerator.generateLinearInput(
+        6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2))
+    /**
+     * Using the following R code to load the data, train the model and evaluate metrics.
+     *
+     * > library("glmnet")
+     * > library("rminer")
+     * > data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE)
+     * > features <- as.matrix(data.frame(as.numeric(data$V2), as.numeric(data$V3)))
+     * > label <- as.numeric(data$V1)
+     * > model <- glmnet(features, label, family="gaussian", alpha = 0, lambda = 0)
+     * > rmse <- mmetric(label, predict(model, features), metric='RMSE')
+     * > mae <- mmetric(label, predict(model, features), metric='MAE')
+     * > r2 <- mmetric(label, predict(model, features), metric='R2')
+     */
+    val trainer = new LinearRegression
+    val model = trainer.fit(dataset)
+    val predictions = model.transform(dataset)
+
+    // default = rmse
+    val evaluator = new RegressionEvaluator()
+    assert(evaluator.evaluate(predictions) ~== 0.1019382 absTol 0.001)
+
+    // r2 score
+    evaluator.setMetricName("r2")
+    assert(evaluator.evaluate(predictions) ~== 0.9998196 absTol 0.001)
+
+    // mae
+    evaluator.setMetricName("mae")
+    assert(evaluator.evaluate(predictions) ~== 0.08036075 absTol 0.001)
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org