You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by db...@apache.org on 2015/10/28 09:02:13 UTC

spark git commit: [SPARK-11332] [ML] Refactored to use ml.feature.Instance instead of WeightedLeastSquare.Instance

Repository: spark
Updated Branches:
  refs/heads/master 82c1c5772 -> 5f1cee6f1


[SPARK-11332] [ML] Refactored to use ml.feature.Instance instead of WeightedLeastSquare.Instance

WeightedLeastSquares now uses the common Instance class in ml.feature instead of a private one.

Author: Nakul Jindal <nj...@us.ibm.com>

Closes #9325 from nakul02/SPARK-11332_refactor_WeightedLeastSquares_dot_Instance.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5f1cee6f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5f1cee6f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5f1cee6f

Branch: refs/heads/master
Commit: 5f1cee6f158adb1f9f485ed1d529c56bace68adc
Parents: 82c1c57
Author: Nakul Jindal <nj...@us.ibm.com>
Authored: Wed Oct 28 01:02:03 2015 -0700
Committer: DB Tsai <db...@netflix.com>
Committed: Wed Oct 28 01:02:03 2015 -0700

----------------------------------------------------------------------
 .../spark/ml/optim/WeightedLeastSquares.scala   | 25 +++++++-------------
 .../spark/ml/regression/LinearRegression.scala  |  4 ++--
 .../ml/optim/WeightedLeastSquaresSuite.scala    | 10 ++++----
 3 files changed, 15 insertions(+), 24 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/5f1cee6f/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
index d7eaa5a..3d64f7f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.ml.optim
 
 import org.apache.spark.Logging
+import org.apache.spark.ml.feature.Instance
 import org.apache.spark.mllib.linalg._
 import org.apache.spark.rdd.RDD
 
@@ -122,16 +123,6 @@ private[ml] class WeightedLeastSquares(
 private[ml] object WeightedLeastSquares {
 
   /**
-   * Case class for weighted observations.
-   * @param w weight, must be positive
-   * @param a features
-   * @param b label
-   */
-  case class Instance(w: Double, a: Vector, b: Double) {
-    require(w >= 0.0, s"Weight cannot be negative: $w.")
-  }
-
-  /**
    * Aggregator to provide necessary summary statistics for solving [[WeightedLeastSquares]].
    */
   // TODO: consolidate aggregates for summary statistics
@@ -168,8 +159,8 @@ private[ml] object WeightedLeastSquares {
      * Adds an instance.
      */
     def add(instance: Instance): this.type = {
-      val Instance(w, a, b) = instance
-      val ak = a.size
+      val Instance(l, w, f) = instance
+      val ak = f.size
       if (!initialized) {
         init(ak)
       }
@@ -177,11 +168,11 @@ private[ml] object WeightedLeastSquares {
       count += 1L
       wSum += w
       wwSum += w * w
-      bSum += w * b
-      bbSum += w * b * b
-      BLAS.axpy(w, a, aSum)
-      BLAS.axpy(w * b, a, abSum)
-      BLAS.spr(w, a, aaSum)
+      bSum += w * l
+      bbSum += w * l * l
+      BLAS.axpy(w, f, aSum)
+      BLAS.axpy(w * l, f, abSum)
+      BLAS.spr(w, f, aaSum)
       this
     }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/5f1cee6f/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index c3ee8b3..f663b9b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -154,10 +154,10 @@ class LinearRegression(override val uid: String)
         "solver is used.'")
       // For low dimensional data, WeightedLeastSquares is more efficiently since the
       // training algorithm only requires one pass through the data. (SPARK-10668)
-      val instances: RDD[WeightedLeastSquares.Instance] = dataset.select(
+      val instances: RDD[Instance] = dataset.select(
         col($(labelCol)), w, col($(featuresCol))).map {
           case Row(label: Double, weight: Double, features: Vector) =>
-            WeightedLeastSquares.Instance(weight, features, label)
+            Instance(label, weight, features)
       }
 
       val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam),

http://git-wip-us.apache.org/repos/asf/spark/blob/5f1cee6f/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala
index 652f3ad..b542ba3 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/optim/WeightedLeastSquaresSuite.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.ml.optim
 
 import org.apache.spark.SparkFunSuite
-import org.apache.spark.ml.optim.WeightedLeastSquares.Instance
+import org.apache.spark.ml.feature.Instance
 import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.mllib.util.TestingUtils._
@@ -38,10 +38,10 @@ class WeightedLeastSquaresSuite extends SparkFunSuite with MLlibTestSparkContext
        w <- c(1, 2, 3, 4)
      */
     instances = sc.parallelize(Seq(
-      Instance(1.0, Vectors.dense(0.0, 5.0).toSparse, 17.0),
-      Instance(2.0, Vectors.dense(1.0, 7.0), 19.0),
-      Instance(3.0, Vectors.dense(2.0, 11.0), 23.0),
-      Instance(4.0, Vectors.dense(3.0, 13.0), 29.0)
+      Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse),
+      Instance(19.0, 2.0, Vectors.dense(1.0, 7.0)),
+      Instance(23.0, 3.0, Vectors.dense(2.0, 11.0)),
+      Instance(29.0, 4.0, Vectors.dense(3.0, 13.0))
     ), 2)
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org