You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2015/05/05 02:02:55 UTC

spark git commit: [SPARK-7243][SQL] Contingency Tables for DataFrames

Repository: spark
Updated Branches:
  refs/heads/master fc8b58195 -> 805541117


[SPARK-7243][SQL] Contingency Tables for DataFrames

Computes a pair-wise frequency table of the given columns. Also known as cross-tabulation.
cc mengxr rxin

Author: Burak Yavuz <br...@gmail.com>

Closes #5842 from brkyvz/df-cont and squashes the following commits:

a07c01e [Burak Yavuz] addressed comments v4.1
ae9e01d [Burak Yavuz] fix test
9106585 [Burak Yavuz] addressed comments v4.0
bced829 [Burak Yavuz] fix merge conflicts
a63ad00 [Burak Yavuz] addressed comments v3.0
a0cad97 [Burak Yavuz] addressed comments v3.0
6805df8 [Burak Yavuz] addressed comments and fixed test
939b7c4 [Burak Yavuz] lint python
7f098bc [Burak Yavuz] add crosstab pyTest
fd53b00 [Burak Yavuz] added python support for crosstab
27a5a81 [Burak Yavuz] implemented crosstab


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/80554111
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/80554111
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/80554111

Branch: refs/heads/master
Commit: 80554111703c08e2bedbe303e04ecd162ec119e1
Parents: fc8b581
Author: Burak Yavuz <br...@gmail.com>
Authored: Mon May 4 17:02:49 2015 -0700
Committer: Reynold Xin <rx...@databricks.com>
Committed: Mon May 4 17:02:49 2015 -0700

----------------------------------------------------------------------
 python/pyspark/sql/dataframe.py                 | 25 +++++++++
 python/pyspark/sql/tests.py                     |  9 ++++
 .../spark/sql/DataFrameStatFunctions.scala      | 37 +++++++++----
 .../sql/execution/stat/StatFunctions.scala      | 37 +++++++++++--
 .../apache/spark/sql/JavaDataFrameSuite.java    | 28 ++++++++++
 .../apache/spark/sql/DataFrameStatSuite.scala   | 55 +++++++++++++-------
 6 files changed, 160 insertions(+), 31 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/80554111/python/pyspark/sql/dataframe.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 22762c5..f30a92d 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -931,6 +931,26 @@ class DataFrame(object):
             raise ValueError("col2 should be a string.")
         return self._jdf.stat().cov(col1, col2)
 
+    def crosstab(self, col1, col2):
+        """
+        Computes a pair-wise frequency table of the given columns. Also known as a contingency
+        table. The number of distinct values for each column should be less than 1e4. The first
+        column of each row will be the distinct values of `col1` and the column names will be the
+        distinct values of `col2`. The name of the first column will be `$col1_$col2`. Pairs that
+        have no occurrences will have `null` as their counts.
+        :func:`DataFrame.crosstab` and :func:`DataFrameStatFunctions.crosstab` are aliases.
+
+        :param col1: The name of the first column. Distinct items will make the first item of
+            each row.
+        :param col2: The name of the second column. Distinct items will make the column names
+            of the DataFrame.
+        """
+        if not isinstance(col1, str):
+            raise ValueError("col1 should be a string.")
+        if not isinstance(col2, str):
+            raise ValueError("col2 should be a string.")
+        return DataFrame(self._jdf.stat().crosstab(col1, col2), self.sql_ctx)
+
     def freqItems(self, cols, support=None):
         """
         Finding frequent items for columns, possibly with false positives. Using the
@@ -1423,6 +1443,11 @@ class DataFrameStatFunctions(object):
 
     cov.__doc__ = DataFrame.cov.__doc__
 
+    def crosstab(self, col1, col2):
+        return self.df.crosstab(col1, col2)
+
+    crosstab.__doc__ = DataFrame.crosstab.__doc__
+
     def freqItems(self, cols, support=None):
         return self.df.freqItems(cols, support)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/80554111/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index d652c30..7ea6656 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -405,6 +405,15 @@ class SQLTests(ReusedPySparkTestCase):
         cov = df.stat.cov("a", "b")
         self.assertTrue(abs(cov - 55.0 / 3) < 1e-6)
 
+    def test_crosstab(self):
+        df = self.sc.parallelize([Row(a=i % 3, b=i % 2) for i in range(1, 7)]).toDF()
+        ct = df.stat.crosstab("a", "b").collect()
+        ct = sorted(ct, key=lambda x: x[0])
+        for i, row in enumerate(ct):
+            self.assertEqual(row[0], str(i))
+            self.assertTrue(row[1], 1)
+            self.assertTrue(row[2], 1)
+
     def test_math_functions(self):
         df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF()
         from pyspark.sql import mathfunctions as functions

http://git-wip-us.apache.org/repos/asf/spark/blob/80554111/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
index 9035321..fcf21ca 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
@@ -28,6 +28,16 @@ import org.apache.spark.sql.execution.stat._
 final class DataFrameStatFunctions private[sql](df: DataFrame) {
 
   /**
+   * Calculate the sample covariance of two numerical columns of a DataFrame.
+   * @param col1 the name of the first column
+   * @param col2 the name of the second column
+   * @return the covariance of the two columns.
+   */
+  def cov(col1: String, col2: String): Double = {
+    StatFunctions.calculateCov(df, Seq(col1, col2))
+  }
+
+  /*
    * Calculates the correlation of two columns of a DataFrame. Currently only supports the Pearson
    * Correlation Coefficient. For Spearman Correlation, consider using RDD methods found in 
    * MLlib's Statistics.
@@ -54,6 +64,23 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
   }
 
   /**
+   * Computes a pair-wise frequency table of the given columns. Also known as a contingency table.
+   * The number of distinct values for each column should be less than 1e4. The first
+   * column of each row will be the distinct values of `col1` and the column names will be the
+   * distinct values of `col2`. The name of the first column will be `$col1_$col2`. Counts will be
+   * returned as `Long`s. Pairs that have no occurrences will have `null` as their counts.
+   *
+   * @param col1 The name of the first column. Distinct items will make the first item of
+   *             each row.
+   * @param col2 The name of the second column. Distinct items will make the column names
+   *             of the DataFrame.
+   * @return A Local DataFrame containing the table
+   */
+  def crosstab(col1: String, col2: String): DataFrame = {
+    StatFunctions.crossTabulate(df, col1, col2)
+  }
+
+  /**
    * Finding frequent items for columns, possibly with false positives. Using the
    * frequent element count algorithm described in
    * [[http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou]].
@@ -94,14 +121,4 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
   def freqItems(cols: Seq[String]): DataFrame = {
     FrequentItems.singlePassFreqItems(df, cols, 0.01)
   }
-
-  /**
-   * Calculate the sample covariance of two numerical columns of a DataFrame.
-   * @param col1 the name of the first column
-   * @param col2 the name of the second column
-   * @return the covariance of the two columns.
-   */
-  def cov(col1: String, col2: String): Double = {
-    StatFunctions.calculateCov(df, Seq(col1, col2))
-  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/80554111/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
index 67b48e5..b50f606 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
@@ -17,11 +17,14 @@
 
 package org.apache.spark.sql.execution.stat
 
-import org.apache.spark.sql.catalyst.expressions.Cast
+import org.apache.spark.Logging
 import org.apache.spark.sql.{Column, DataFrame}
-import org.apache.spark.sql.types.{DoubleType, NumericType}
+import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Cast}
+import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types._
 
-private[sql] object StatFunctions {
+private[sql] object StatFunctions extends Logging {
   
   /** Calculate the Pearson Correlation Coefficient for the given columns */
   private[sql] def pearsonCorrelation(df: DataFrame, cols: Seq[String]): Double = {
@@ -95,4 +98,32 @@ private[sql] object StatFunctions {
     val counts = collectStatisticalData(df, cols)
     counts.cov
   }
+
+  /** Generate a table of frequencies for the elements of two columns. */
+  private[sql] def crossTabulate(df: DataFrame, col1: String, col2: String): DataFrame = {
+    val tableName = s"${col1}_$col2"
+    val counts = df.groupBy(col1, col2).agg(col(col1), col(col2), count("*")).take(1e8.toInt)
+    if (counts.length == 1e8.toInt) {
+      logWarning("The maximum limit of 1e8 pairs have been collected, which may not be all of " +
+        "the pairs. Please try reducing the amount of distinct items in your columns.")
+    }
+    // get the distinct values of column 2, so that we can make them the column names
+    val distinctCol2 = counts.map(_.get(1)).distinct.zipWithIndex.toMap
+    val columnSize = distinctCol2.size
+    require(columnSize < 1e4, s"The number of distinct values for $col2, can't " +
+      s"exceed 1e4. Currently $columnSize")
+    val table = counts.groupBy(_.get(0)).map { case (col1Item, rows) =>
+      val countsRow = new GenericMutableRow(columnSize + 1)
+      rows.foreach { row =>
+        countsRow.setLong(distinctCol2.get(row.get(1)).get + 1, row.getLong(2))
+      }
+      // the value of col1 is the first value, the rest are the counts
+      countsRow.setString(0, col1Item.toString)
+      countsRow
+    }.toSeq
+    val headerNames = distinctCol2.map(r => StructField(r._1.toString, LongType)).toSeq
+    val schema = StructType(StructField(tableName, StringType) +: headerNames)
+
+    new DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table))
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/80554111/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
----------------------------------------------------------------------
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
index 78e8472..58cc8e5 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
@@ -34,6 +34,7 @@ import scala.collection.mutable.Buffer;
 
 import java.io.Serializable;
 import java.util.Arrays;
+import java.util.Comparator;
 import java.util.List;
 import java.util.Map;
 
@@ -178,6 +179,33 @@ public class JavaDataFrameSuite {
       Assert.assertEquals(bean.getD().get(i), d.apply(i));
     }
   }
+
+  private static Comparator<Row> CrosstabRowComparator = new Comparator<Row>() {
+    public int compare(Row row1, Row row2) {
+      String item1 = row1.getString(0);
+      String item2 = row2.getString(0);
+      return item1.compareTo(item2);
+    }
+  };
+
+  @Test
+  public void testCrosstab() {
+    DataFrame df = context.table("testData2");
+    DataFrame crosstab = df.stat().crosstab("a", "b");
+    String[] columnNames = crosstab.schema().fieldNames();
+    Assert.assertEquals(columnNames[0], "a_b");
+    Assert.assertEquals(columnNames[1], "1");
+    Assert.assertEquals(columnNames[2], "2");
+    Row[] rows = crosstab.collect();
+    Arrays.sort(rows, CrosstabRowComparator);
+    Integer count = 1;
+    for (Row row : rows) {
+      Assert.assertEquals(row.get(0).toString(), count.toString());
+      Assert.assertEquals(row.getLong(1), 1L);
+      Assert.assertEquals(row.getLong(2), 1L);
+      count++;
+    }
+  }
   
   @Test
   public void testFrequentItems() {

http://git-wip-us.apache.org/repos/asf/spark/blob/80554111/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
index 06764d2..46b1845 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
@@ -24,26 +24,9 @@ import org.apache.spark.sql.test.TestSQLContext
 import org.apache.spark.sql.test.TestSQLContext.implicits._
 
 class DataFrameStatSuite extends FunSuite  {
-
-  import TestData._
+  
   val sqlCtx = TestSQLContext
   def toLetter(i: Int): String = (i + 97).toChar.toString
-  
-  test("Frequent Items") {
-    val rows = Seq.tabulate(1000) { i =>
-      if (i % 3 == 0) (1, toLetter(1), -1.0) else (i, toLetter(i), i * -1.0)
-    }
-    val df = rows.toDF("numbers", "letters", "negDoubles")
-
-    val results = df.stat.freqItems(Array("numbers", "letters"), 0.1)
-    val items = results.collect().head
-    items.getSeq[Int](0) should contain (1)
-    items.getSeq[String](1) should contain (toLetter(1))
-
-    val singleColResults = df.stat.freqItems(Array("negDoubles"), 0.1)
-    val items2 = singleColResults.collect().head
-    items2.getSeq[Double](0) should contain (-1.0)
-  }
 
   test("pearson correlation") {
     val df = Seq.tabulate(10)(i => (i, 2 * i, i * -1.0)).toDF("a", "b", "c")
@@ -76,7 +59,43 @@ class DataFrameStatSuite extends FunSuite  {
     intercept[IllegalArgumentException] {
       df.stat.cov("singles", "letters") // doesn't accept non-numerical dataTypes
     }
+    val decimalData = Seq.tabulate(6)(i => (BigDecimal(i % 3), BigDecimal(i % 2))).toDF("a", "b")
     val decimalRes = decimalData.stat.cov("a", "b")
     assert(math.abs(decimalRes) < 1e-12)
   }
+
+  test("crosstab") {
+    val df = Seq((0, 0), (2, 1), (1, 0), (2, 0), (0, 0), (2, 0)).toDF("a", "b")
+    val crosstab = df.stat.crosstab("a", "b")
+    val columnNames = crosstab.schema.fieldNames
+    assert(columnNames(0) === "a_b")
+    assert(columnNames(1) === "0")
+    assert(columnNames(2) === "1")
+    val rows: Array[Row] = crosstab.collect().sortBy(_.getString(0))
+    assert(rows(0).get(0).toString === "0")
+    assert(rows(0).getLong(1) === 2L)
+    assert(rows(0).get(2) === null)
+    assert(rows(1).get(0).toString === "1")
+    assert(rows(1).getLong(1) === 1L)
+    assert(rows(1).get(2) === null)
+    assert(rows(2).get(0).toString === "2")
+    assert(rows(2).getLong(1) === 2L)
+    assert(rows(2).getLong(2) === 1L)
+  }
+
+  test("Frequent Items") {
+    val rows = Seq.tabulate(1000) { i =>
+      if (i % 3 == 0) (1, toLetter(1), -1.0) else (i, toLetter(i), i * -1.0)
+    }
+    val df = rows.toDF("numbers", "letters", "negDoubles")
+
+    val results = df.stat.freqItems(Array("numbers", "letters"), 0.1)
+    val items = results.collect().head
+    items.getSeq[Int](0) should contain (1)
+    items.getSeq[String](1) should contain (toLetter(1))
+
+    val singleColResults = df.stat.freqItems(Array("negDoubles"), 0.1)
+    val items2 = singleColResults.collect().head
+    items2.getSeq[Double](0) should contain (-1.0)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org