You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2015/05/04 06:44:42 UTC

spark git commit: [SPARK-7241] Pearson correlation for DataFrames

Repository: spark
Updated Branches:
  refs/heads/master 1ffa8cb91 -> 9646018bb


[SPARK-7241] Pearson correlation for DataFrames

submitting this PR from a phone, excuse the brevity.
adds Pearson correlation to Dataframes, reusing the covariance calculation code

cc mengxr rxin

Author: Burak Yavuz <br...@gmail.com>

Closes #5858 from brkyvz/df-corr and squashes the following commits:

285b838 [Burak Yavuz] addressed comments v2.0
d10babb [Burak Yavuz] addressed comments v0.2
4b74b24 [Burak Yavuz] Merge branch 'master' of github.com:apache/spark into df-corr
4fe693b [Burak Yavuz] addressed comments v0.1
a682d06 [Burak Yavuz] ready for PR


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9646018b
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9646018b
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9646018b

Branch: refs/heads/master
Commit: 9646018bb4466433521b4e602b808f16e8d0ffdb
Parents: 1ffa8cb
Author: Burak Yavuz <br...@gmail.com>
Authored: Sun May 3 21:44:39 2015 -0700
Committer: Reynold Xin <rx...@databricks.com>
Committed: Sun May 3 21:44:39 2015 -0700

----------------------------------------------------------------------
 python/pyspark/sql/dataframe.py                 | 26 +++++++++
 python/pyspark/sql/tests.py                     |  6 ++
 .../spark/sql/DataFrameStatFunctions.scala      | 26 +++++++++
 .../sql/execution/stat/StatFunctions.scala      | 58 +++++++++++++-------
 .../apache/spark/sql/JavaDataFrameSuite.java    |  7 +++
 .../apache/spark/sql/DataFrameStatSuite.scala   | 33 +++++++++--
 6 files changed, 130 insertions(+), 26 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/9646018b/python/pyspark/sql/dataframe.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 8ddcff8..aac5b8c 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -875,6 +875,27 @@ class DataFrame(object):
 
             return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), self.sql_ctx)
 
+    def corr(self, col1, col2, method=None):
+        """
+        Calculates the correlation of two columns of a DataFrame as a double value. Currently only
+        supports the Pearson Correlation Coefficient.
+        :func:`DataFrame.corr` and :func:`DataFrameStatFunctions.corr` are aliases.
+
+        :param col1: The name of the first column
+        :param col2: The name of the second column
+        :param method: The correlation method. Currently only supports "pearson"
+        """
+        if not isinstance(col1, str):
+            raise ValueError("col1 should be a string.")
+        if not isinstance(col2, str):
+            raise ValueError("col2 should be a string.")
+        if not method:
+            method = "pearson"
+        if not method == "pearson":
+            raise ValueError("Currently only the calculation of the Pearson Correlation " +
+                             "coefficient is supported.")
+        return self._jdf.stat().corr(col1, col2, method)
+
     def cov(self, col1, col2):
         """
         Calculate the sample covariance for the given columns, specified by their names, as a
@@ -1359,6 +1380,11 @@ class DataFrameStatFunctions(object):
     def __init__(self, df):
         self.df = df
 
+    def corr(self, col1, col2, method=None):
+        return self.df.corr(col1, col2, method)
+
+    corr.__doc__ = DataFrame.corr.__doc__
+
     def cov(self, col1, col2):
         return self.df.cov(col1, col2)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/9646018b/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 613efc0..d652c30 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -394,6 +394,12 @@ class SQLTests(ReusedPySparkTestCase):
         self.assertTrue(95 < g.agg(functions.approxCountDistinct(df.key)).first()[0])
         self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0])
 
+    def test_corr(self):
+        import math
+        df = self.sc.parallelize([Row(a=i, b=math.sqrt(i)) for i in range(10)]).toDF()
+        corr = df.stat.corr("a", "b")
+        self.assertTrue(abs(corr - 0.95734012) < 1e-6)
+
     def test_cov(self):
         df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF()
         cov = df.stat.cov("a", "b")

http://git-wip-us.apache.org/repos/asf/spark/blob/9646018b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
index e8fa829..9035321 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
@@ -28,6 +28,32 @@ import org.apache.spark.sql.execution.stat._
 final class DataFrameStatFunctions private[sql](df: DataFrame) {
 
   /**
+   * Calculates the correlation of two columns of a DataFrame. Currently only supports the Pearson
+   * Correlation Coefficient. For Spearman Correlation, consider using RDD methods found in 
+   * MLlib's Statistics.
+   *
+   * @param col1 the name of the column
+   * @param col2 the name of the column to calculate the correlation against
+   * @return The Pearson Correlation Coefficient as a Double.
+   */
+  def corr(col1: String, col2: String, method: String): Double = {
+    require(method == "pearson", "Currently only the calculation of the Pearson Correlation " +
+      "coefficient is supported.")
+    StatFunctions.pearsonCorrelation(df, Seq(col1, col2))
+  }
+
+  /**
+   * Calculates the Pearson Correlation Coefficient of two columns of a DataFrame.
+   *
+   * @param col1 the name of the column
+   * @param col2 the name of the column to calculate the correlation against
+   * @return The Pearson Correlation Coefficient as a Double.
+   */
+  def corr(col1: String, col2: String): Double = {
+    corr(col1, col2, "pearson")
+  }
+
+  /**
    * Finding frequent items for columns, possibly with false positives. Using the
    * frequent element count algorithm described in
    * [[http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou]].

http://git-wip-us.apache.org/repos/asf/spark/blob/9646018b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
index d4a94c2..67b48e5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
@@ -23,29 +23,43 @@ import org.apache.spark.sql.types.{DoubleType, NumericType}
 
 private[sql] object StatFunctions {
   
+  /** Calculate the Pearson Correlation Coefficient for the given columns */
+  private[sql] def pearsonCorrelation(df: DataFrame, cols: Seq[String]): Double = {
+    val counts = collectStatisticalData(df, cols)
+    counts.Ck / math.sqrt(counts.MkX * counts.MkY)
+  }
+
   /** Helper class to simplify tracking and merging counts. */
   private class CovarianceCounter extends Serializable {
-    var xAvg = 0.0
-    var yAvg = 0.0
-    var Ck = 0.0
-    var count = 0L
+    var xAvg = 0.0 // the mean of all examples seen so far in col1
+    var yAvg = 0.0 // the mean of all examples seen so far in col2
+    var Ck = 0.0 // the co-moment after k examples
+    var MkX = 0.0 // sum of squares of differences from the (current) mean for col1
+    var MkY = 0.0 // sum of squares of differences from the (current) mean for col1
+    var count = 0L // count of observed examples
     // add an example to the calculation
     def add(x: Double, y: Double): this.type = {
-      val oldX = xAvg
+      val deltaX = x - xAvg
+      val deltaY = y - yAvg
       count += 1
-      xAvg += (x - xAvg) / count
-      yAvg += (y - yAvg) / count
-      Ck += (y - yAvg) * (x - oldX)
+      xAvg += deltaX / count
+      yAvg += deltaY / count
+      Ck += deltaX * (y - yAvg)
+      MkX += deltaX * (x - xAvg)
+      MkY += deltaY * (y - yAvg)
       this
     }
     // merge counters from other partitions. Formula can be found at:
-    // http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Covariance
+    // http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
     def merge(other: CovarianceCounter): this.type = {
       val totalCount = count + other.count
-      Ck += other.Ck + 
-        (xAvg - other.xAvg) * (yAvg - other.yAvg) * count / totalCount * other.count
+      val deltaX = xAvg - other.xAvg
+      val deltaY = yAvg - other.yAvg
+      Ck += other.Ck + deltaX * deltaY * count / totalCount * other.count
       xAvg = (xAvg * count + other.xAvg * other.count) / totalCount
       yAvg = (yAvg * count + other.yAvg * other.count) / totalCount
+      MkX += other.MkX + deltaX * deltaX * count / totalCount * other.count
+      MkY += other.MkY + deltaY * deltaY * count / totalCount * other.count
       count = totalCount
       this
     }
@@ -53,13 +67,7 @@ private[sql] object StatFunctions {
     def cov: Double = Ck / (count - 1)
   }
 
-  /**
-   * Calculate the covariance of two numerical columns of a DataFrame.
-   * @param df The DataFrame
-   * @param cols the column names
-   * @return the covariance of the two columns.
-   */
-  private[sql] def calculateCov(df: DataFrame, cols: Seq[String]): Double = {
+  private def collectStatisticalData(df: DataFrame, cols: Seq[String]): CovarianceCounter = {
     require(cols.length == 2, "Currently cov supports calculating the covariance " +
       "between two columns.")
     cols.map(name => (name, df.schema.fields.find(_.name == name))).foreach { case (name, data) =>
@@ -68,13 +76,23 @@ private[sql] object StatFunctions {
         s"with dataType ${data.get.dataType} not supported.")
     }
     val columns = cols.map(n => Column(Cast(Column(n).expr, DoubleType)))
-    val counts = df.select(columns:_*).rdd.aggregate(new CovarianceCounter)(
+    df.select(columns: _*).rdd.aggregate(new CovarianceCounter)(
       seqOp = (counter, row) => {
         counter.add(row.getDouble(0), row.getDouble(1))
       },
       combOp = (baseCounter, other) => {
         baseCounter.merge(other)
-      })
+    })
+  }
+
+  /**
+   * Calculate the covariance of two numerical columns of a DataFrame.
+   * @param df The DataFrame
+   * @param cols the column names
+   * @return the covariance of the two columns.
+   */
+  private[sql] def calculateCov(df: DataFrame, cols: Seq[String]): Double = {
+    val counts = collectStatisticalData(df, cols)
     counts.cov
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/9646018b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
----------------------------------------------------------------------
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
index 96fe66d..78e8472 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
@@ -188,6 +188,13 @@ public class JavaDataFrameSuite {
   }
 
   @Test
+  public void testCorrelation() {
+    DataFrame df = context.table("testData2");
+    Double pearsonCorr = df.stat().corr("a", "b", "pearson");
+    Assert.assertTrue(Math.abs(pearsonCorr) < 1e-6);
+  }
+
+  @Test
   public void testCovariance() {
     DataFrame df = context.table("testData2");
     Double result = df.stat().cov("a", "b");

http://git-wip-us.apache.org/repos/asf/spark/blob/9646018b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
index 4f5a2ff..06764d2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
@@ -30,10 +30,10 @@ class DataFrameStatSuite extends FunSuite  {
   def toLetter(i: Int): String = (i + 97).toChar.toString
   
   test("Frequent Items") {
-    val rows = Array.tabulate(1000) { i =>
+    val rows = Seq.tabulate(1000) { i =>
       if (i % 3 == 0) (1, toLetter(1), -1.0) else (i, toLetter(i), i * -1.0)
     }
-    val df = sqlCtx.sparkContext.parallelize(rows).toDF("numbers", "letters", "negDoubles")
+    val df = rows.toDF("numbers", "letters", "negDoubles")
 
     val results = df.stat.freqItems(Array("numbers", "letters"), 0.1)
     val items = results.collect().head
@@ -43,19 +43,40 @@ class DataFrameStatSuite extends FunSuite  {
     val singleColResults = df.stat.freqItems(Array("negDoubles"), 0.1)
     val items2 = singleColResults.collect().head
     items2.getSeq[Double](0) should contain (-1.0)
+  }
 
+  test("pearson correlation") {
+    val df = Seq.tabulate(10)(i => (i, 2 * i, i * -1.0)).toDF("a", "b", "c")
+    val corr1 = df.stat.corr("a", "b", "pearson")
+    assert(math.abs(corr1 - 1.0) < 1e-12)
+    val corr2 = df.stat.corr("a", "c", "pearson")
+    assert(math.abs(corr2 + 1.0) < 1e-12)
+    // non-trivial example. To reproduce in python, use:
+    // >>> from scipy.stats import pearsonr
+    // >>> import numpy as np
+    // >>> a = np.array(range(20))
+    // >>> b = np.array([x * x - 2 * x + 3.5 for x in range(20)])
+    // >>> pearsonr(a, b)
+    // (0.95723391394758572, 3.8902121417802199e-11)
+    // In R, use:
+    // > a <- 0:19
+    // > b <- mapply(function(x) x * x - 2 * x + 3.5, a)
+    // > cor(a, b)
+    // [1] 0.957233913947585835
+    val df2 = Seq.tabulate(20)(x => (x, x * x - 2 * x + 3.5)).toDF("a", "b")
+    val corr3 = df2.stat.corr("a", "b", "pearson")
+    assert(math.abs(corr3 - 0.95723391394758572) < 1e-12)
   }
 
   test("covariance") {
-    val rows = Array.tabulate(10)(i => (i, 2.0 * i, toLetter(i)))
-    val df = sqlCtx.sparkContext.parallelize(rows).toDF("singles", "doubles", "letters")
+    val df = Seq.tabulate(10)(i => (i, 2.0 * i, toLetter(i))).toDF("singles", "doubles", "letters")
 
     val results = df.stat.cov("singles", "doubles")
-    assert(math.abs(results - 55.0 / 3) < 1e-6)
+    assert(math.abs(results - 55.0 / 3) < 1e-12)
     intercept[IllegalArgumentException] {
       df.stat.cov("singles", "letters") // doesn't accept non-numerical dataTypes
     }
     val decimalRes = decimalData.stat.cov("a", "b")
-    assert(math.abs(decimalRes) < 1e-6)
+    assert(math.abs(decimalRes) < 1e-12)
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org