You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ho...@apache.org on 2017/02/01 22:11:42 UTC

spark git commit: [SPARK-14352][SQL] approxQuantile should support multi columns

Repository: spark
Updated Branches:
  refs/heads/master c5fcb7f68 -> b0985764f


[SPARK-14352][SQL] approxQuantile should support multi columns

## What changes were proposed in this pull request?

1, add the multi-cols support based on current private api
2, add the multi-cols support to pyspark
## How was this patch tested?

unit tests

Author: Zheng RuiFeng <ru...@foxmail.com>
Author: Ruifeng Zheng <ru...@foxmail.com>

Closes #12135 from zhengruifeng/quantile4multicols.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b0985764
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b0985764
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b0985764

Branch: refs/heads/master
Commit: b0985764f00acea97df7399a6b337262fc97f5ee
Parents: c5fcb7f
Author: Zheng RuiFeng <ru...@foxmail.com>
Authored: Wed Feb 1 14:11:28 2017 -0800
Committer: Holden Karau <ho...@us.ibm.com>
Committed: Wed Feb 1 14:11:28 2017 -0800

----------------------------------------------------------------------
 python/pyspark/sql/dataframe.py                 | 37 ++++++++++++++++----
 python/pyspark/sql/tests.py                     | 23 +++++++++++-
 .../spark/sql/DataFrameStatFunctions.scala      | 37 ++++++++++++++++++--
 .../apache/spark/sql/DataFrameStatSuite.scala   | 15 ++++++++
 4 files changed, 101 insertions(+), 11 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b0985764/python/pyspark/sql/dataframe.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 10e42d0..50373b8 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -16,7 +16,6 @@
 #
 
 import sys
-import warnings
 import random
 
 if sys.version >= '3':
@@ -1348,7 +1347,7 @@ class DataFrame(object):
     @since(2.0)
     def approxQuantile(self, col, probabilities, relativeError):
         """
-        Calculates the approximate quantiles of a numerical column of a
+        Calculates the approximate quantiles of numerical columns of a
         DataFrame.
 
         The result of this algorithm has the following deterministic bound:
@@ -1365,7 +1364,10 @@ class DataFrame(object):
         Space-efficient Online Computation of Quantile Summaries]]
         by Greenwald and Khanna.
 
-        :param col: the name of the numerical column
+        Note that rows containing any null values will be removed before calculation.
+
+        :param col: str, list.
+          Can be a single column name, or a list of names for multiple columns.
         :param probabilities: a list of quantile probabilities
           Each number must belong to [0, 1].
           For example 0 is the minimum, 0.5 is the median, 1 is the maximum.
@@ -1373,10 +1375,30 @@ class DataFrame(object):
           (>= 0). If set to zero, the exact quantiles are computed, which
           could be very expensive. Note that values greater than 1 are
           accepted but give the same result as 1.
-        :return:  the approximate quantiles at the given probabilities
+        :return:  the approximate quantiles at the given probabilities. If
+          the input `col` is a string, the output is a list of floats. If the
+          input `col` is a list or tuple of strings, the output is also a
+          list, but each element in it is a list of floats, i.e., the output
+          is a list of list of floats.
+
+        .. versionchanged:: 2.2
+           Added support for multiple columns.
         """
-        if not isinstance(col, str):
-            raise ValueError("col should be a string.")
+
+        if not isinstance(col, (str, list, tuple)):
+            raise ValueError("col should be a string, list or tuple, but got %r" % type(col))
+
+        isStr = isinstance(col, str)
+
+        if isinstance(col, tuple):
+            col = list(col)
+        elif isinstance(col, str):
+            col = [col]
+
+        for c in col:
+            if not isinstance(c, str):
+                raise ValueError("columns should be strings, but got %r" % type(c))
+        col = _to_list(self._sc, col)
 
         if not isinstance(probabilities, (list, tuple)):
             raise ValueError("probabilities should be a list or tuple")
@@ -1392,7 +1414,8 @@ class DataFrame(object):
         relativeError = float(relativeError)
 
         jaq = self._jdf.stat().approxQuantile(col, probabilities, relativeError)
-        return list(jaq)
+        jaq_list = [list(j) for j in jaq]
+        return jaq_list[0] if isStr else jaq_list
 
     @since(1.4)
     def corr(self, col1, col2, method=None):

http://git-wip-us.apache.org/repos/asf/spark/blob/b0985764/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 2fea4ac..86cad4b 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -895,11 +895,32 @@ class SQLTests(ReusedPySparkTestCase):
         self.assertEqual([Row(a=None, b=1, c=None, d=98)], df3.collect())
 
     def test_approxQuantile(self):
-        df = self.sc.parallelize([Row(a=i) for i in range(10)]).toDF()
+        df = self.sc.parallelize([Row(a=i, b=i+10) for i in range(10)]).toDF()
         aq = df.stat.approxQuantile("a", [0.1, 0.5, 0.9], 0.1)
         self.assertTrue(isinstance(aq, list))
         self.assertEqual(len(aq), 3)
         self.assertTrue(all(isinstance(q, float) for q in aq))
+        aqs = df.stat.approxQuantile(["a", "b"], [0.1, 0.5, 0.9], 0.1)
+        self.assertTrue(isinstance(aqs, list))
+        self.assertEqual(len(aqs), 2)
+        self.assertTrue(isinstance(aqs[0], list))
+        self.assertEqual(len(aqs[0]), 3)
+        self.assertTrue(all(isinstance(q, float) for q in aqs[0]))
+        self.assertTrue(isinstance(aqs[1], list))
+        self.assertEqual(len(aqs[1]), 3)
+        self.assertTrue(all(isinstance(q, float) for q in aqs[1]))
+        aqt = df.stat.approxQuantile(("a", "b"), [0.1, 0.5, 0.9], 0.1)
+        self.assertTrue(isinstance(aqt, list))
+        self.assertEqual(len(aqt), 2)
+        self.assertTrue(isinstance(aqt[0], list))
+        self.assertEqual(len(aqt[0]), 3)
+        self.assertTrue(all(isinstance(q, float) for q in aqt[0]))
+        self.assertTrue(isinstance(aqt[1], list))
+        self.assertEqual(len(aqt[1]), 3)
+        self.assertTrue(all(isinstance(q, float) for q in aqt[1]))
+        self.assertRaises(ValueError, lambda: df.stat.approxQuantile(123, [0.1, 0.9], 0.1))
+        self.assertRaises(ValueError, lambda: df.stat.approxQuantile(("a", 123), [0.1, 0.9], 0.1))
+        self.assertRaises(ValueError, lambda: df.stat.approxQuantile(["a", 123], [0.1, 0.9], 0.1))
 
     def test_corr(self):
         import math

http://git-wip-us.apache.org/repos/asf/spark/blob/b0985764/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
index 7294532..2b782fd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
@@ -24,6 +24,7 @@ import scala.collection.JavaConverters._
 import org.apache.spark.annotation.InterfaceStability
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.execution.stat._
+import org.apache.spark.sql.functions.col
 import org.apache.spark.sql.types._
 import org.apache.spark.util.sketch.{BloomFilter, CountMinSketch}
 
@@ -75,13 +76,43 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
   }
 
   /**
+   * Calculates the approximate quantiles of numerical columns of a DataFrame.
+   * @see [[DataFrameStatsFunctions.approxQuantile(col:Str* approxQuantile]] for
+   *     detailed description.
+   *
+   * Note that rows containing any null or NaN values values will be removed before
+   * calculation.
+   * @param cols the names of the numerical columns
+   * @param probabilities a list of quantile probabilities
+   *   Each number must belong to [0, 1].
+   *   For example 0 is the minimum, 0.5 is the median, 1 is the maximum.
+   * @param relativeError The relative target precision to achieve (>= 0).
+   *   If set to zero, the exact quantiles are computed, which could be very expensive.
+   *   Note that values greater than 1 are accepted but give the same result as 1.
+   * @return the approximate quantiles at the given probabilities of each column
+   *
+   * @note Rows containing any NaN values will be removed before calculation
+   *
+   * @since 2.2.0
+   */
+  def approxQuantile(
+      cols: Array[String],
+      probabilities: Array[Double],
+      relativeError: Double): Array[Array[Double]] = {
+    StatFunctions.multipleApproxQuantiles(df.select(cols.map(col): _*).na.drop(), cols,
+      probabilities, relativeError).map(_.toArray).toArray
+  }
+
+
+  /**
    * Python-friendly version of [[approxQuantile()]]
    */
   private[spark] def approxQuantile(
-      col: String,
+      cols: List[String],
       probabilities: List[Double],
-      relativeError: Double): java.util.List[Double] = {
-    approxQuantile(col, probabilities.toArray, relativeError).toList.asJava
+      relativeError: Double): java.util.List[java.util.List[Double]] = {
+    approxQuantile(cols.toArray, probabilities.toArray, relativeError)
+        .map(_.toList.asJava).toList.asJava
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/b0985764/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
index 1383208..f52b18e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
@@ -149,11 +149,26 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext {
       assert(math.abs(s2 - q2 * n) < error_single)
       assert(math.abs(d1 - 2 * q1 * n) < error_double)
       assert(math.abs(d2 - 2 * q2 * n) < error_double)
+
+      // Multiple columns
+      val Array(Array(ms1, ms2), Array(md1, md2)) =
+        df.stat.approxQuantile(Array("singles", "doubles"), Array(q1, q2), epsilon)
+
+      assert(math.abs(ms1 - q1 * n) < error_single)
+      assert(math.abs(ms2 - q2 * n) < error_single)
+      assert(math.abs(md1 - 2 * q1 * n) < error_double)
+      assert(math.abs(md2 - 2 * q2 * n) < error_double)
     }
     // test approxQuantile on NaN values
     val dfNaN = Seq(Double.NaN, 1.0, Double.NaN, Double.NaN).toDF("input")
     val resNaN = dfNaN.stat.approxQuantile("input", Array(q1, q2), epsilons.head)
     assert(resNaN.count(_.isNaN) === 0)
+    // test approxQuantile on multi-column NaN values
+    val dfNaN2 = Seq((Double.NaN, 1.0), (1.0, 1.0), (-1.0, Double.NaN), (Double.NaN, Double.NaN))
+      .toDF("input1", "input2")
+    val resNaN2 = dfNaN2.stat.approxQuantile(Array("input1", "input2"),
+      Array(q1, q2), epsilons.head)
+    assert(resNaN2.flatten.count(_.isNaN) === 0)
   }
 
   test("crosstab") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org