You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hv...@apache.org on 2016/10/11 05:33:24 UTC

spark git commit: [SPARK-17844] Simplify DataFrame API for defining frame boundaries in window functions

Repository: spark
Updated Branches:
  refs/heads/master 0c0ad436a -> b515768f2


[SPARK-17844] Simplify DataFrame API for defining frame boundaries in window functions

## What changes were proposed in this pull request?
When I was creating the example code for SPARK-10496, I realized it was pretty convoluted to define the frame boundaries for window functions when there is no partition column or ordering column. The reason is that we don't provide a way to create a WindowSpec directly with the frame boundaries. We can trivially improve this by adding rowsBetween and rangeBetween to Window object.

As an example, to compute cumulative sum using the natural ordering, before this pr:
```
df.select('key, sum("value").over(Window.partitionBy(lit(1)).rowsBetween(Long.MinValue, 0)))
```

After this pr:
```
df.select('key, sum("value").over(Window.rowsBetween(Long.MinValue, 0)))
```

Note that you could argue there is no point specifying a window frame without partitionBy/orderBy -- but it is strange that only rowsBetween and rangeBetween are not the only two APIs not available.

This also fixes https://issues.apache.org/jira/browse/SPARK-17656 (removing _root_.scala).

## How was this patch tested?
Added test cases to compute cumulative sum in DataFrameWindowSuite for Scala/Java and tests.py for Python.

Author: Reynold Xin <rx...@databricks.com>

Closes #15412 from rxin/SPARK-17844.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b515768f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b515768f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b515768f

Branch: refs/heads/master
Commit: b515768f2668749ad37a3bdf9d265ce45ec447b1
Parents: 0c0ad43
Author: Reynold Xin <rx...@databricks.com>
Authored: Mon Oct 10 22:33:20 2016 -0700
Committer: Herman van Hovell <hv...@databricks.com>
Committed: Mon Oct 10 22:33:20 2016 -0700

----------------------------------------------------------------------
 python/pyspark/sql/tests.py                     |  9 ++++
 python/pyspark/sql/window.py                    | 48 ++++++++++++++++++++
 .../apache/spark/sql/expressions/Window.scala   | 46 +++++++++++++++++--
 .../spark/sql/expressions/WindowSpec.scala      | 10 ++--
 .../org/apache/spark/sql/expressions/udaf.scala |  4 +-
 .../apache/spark/sql/DataFrameWindowSuite.scala | 12 +++++
 6 files changed, 119 insertions(+), 10 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b515768f/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index a9e4555..7b6f9f0 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -1859,6 +1859,15 @@ class HiveContextSQLTests(ReusedPySparkTestCase):
         for r, ex in zip(rs, expected):
             self.assertEqual(tuple(r), ex[:len(r)])
 
+    def test_window_functions_cumulative_sum(self):
+        df = self.spark.createDataFrame([("one", 1), ("two", 2)], ["key", "value"])
+        from pyspark.sql import functions as F
+        sel = df.select(df.key, F.sum(df.value).over(Window.rowsBetween(-sys.maxsize, 0)))
+        rs = sorted(sel.collect())
+        expected = [("one", 1), ("two", 3)]
+        for r, ex in zip(rs, expected):
+            self.assertEqual(tuple(r), ex[:len(r)])
+
     def test_collect_functions(self):
         df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
         from pyspark.sql import functions

http://git-wip-us.apache.org/repos/asf/spark/blob/b515768f/python/pyspark/sql/window.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/window.py b/python/pyspark/sql/window.py
index 46663f6..87e9a98 100644
--- a/python/pyspark/sql/window.py
+++ b/python/pyspark/sql/window.py
@@ -66,6 +66,54 @@ class Window(object):
         jspec = sc._jvm.org.apache.spark.sql.expressions.Window.orderBy(_to_java_cols(cols))
         return WindowSpec(jspec)
 
+    @staticmethod
+    @since(2.1)
+    def rowsBetween(start, end):
+        """
+        Creates a :class:`WindowSpec` with the frame boundaries defined,
+        from `start` (inclusive) to `end` (inclusive).
+
+        Both `start` and `end` are relative positions from the current row.
+        For example, "0" means "current row", while "-1" means the row before
+        the current row, and "5" means the fifth row after the current row.
+
+        :param start: boundary start, inclusive.
+                      The frame is unbounded if this is ``-sys.maxsize`` (or lower).
+        :param end: boundary end, inclusive.
+                    The frame is unbounded if this is ``sys.maxsize`` (or higher).
+        """
+        if start <= -sys.maxsize:
+            start = WindowSpec._JAVA_MIN_LONG
+        if end >= sys.maxsize:
+            end = WindowSpec._JAVA_MAX_LONG
+        sc = SparkContext._active_spark_context
+        jspec = sc._jvm.org.apache.spark.sql.expressions.Window.rowsBetween(start, end)
+        return WindowSpec(jspec)
+
+    @staticmethod
+    @since(2.1)
+    def rangeBetween(start, end):
+        """
+        Creates a :class:`WindowSpec` with the frame boundaries defined,
+        from `start` (inclusive) to `end` (inclusive).
+
+        Both `start` and `end` are relative from the current row. For example,
+        "0" means "current row", while "-1" means one off before the current row,
+        and "5" means the five off after the current row.
+
+        :param start: boundary start, inclusive.
+                      The frame is unbounded if this is ``-sys.maxsize`` (or lower).
+        :param end: boundary end, inclusive.
+                    The frame is unbounded if this is ``sys.maxsize`` (or higher).
+        """
+        if start <= -sys.maxsize:
+            start = WindowSpec._JAVA_MIN_LONG
+        if end >= sys.maxsize:
+            end = WindowSpec._JAVA_MAX_LONG
+        sc = SparkContext._active_spark_context
+        jspec = sc._jvm.org.apache.spark.sql.expressions.Window.rangeBetween(start, end)
+        return WindowSpec(jspec)
+
 
 class WindowSpec(object):
     """

http://git-wip-us.apache.org/repos/asf/spark/blob/b515768f/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala
index c29ec6f..e8a0c5f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala
@@ -42,7 +42,7 @@ object Window {
    * Creates a [[WindowSpec]] with the partitioning defined.
    * @since 1.4.0
    */
-  @_root_.scala.annotation.varargs
+  @scala.annotation.varargs
   def partitionBy(colName: String, colNames: String*): WindowSpec = {
     spec.partitionBy(colName, colNames : _*)
   }
@@ -51,7 +51,7 @@ object Window {
    * Creates a [[WindowSpec]] with the partitioning defined.
    * @since 1.4.0
    */
-  @_root_.scala.annotation.varargs
+  @scala.annotation.varargs
   def partitionBy(cols: Column*): WindowSpec = {
     spec.partitionBy(cols : _*)
   }
@@ -60,7 +60,7 @@ object Window {
    * Creates a [[WindowSpec]] with the ordering defined.
    * @since 1.4.0
    */
-  @_root_.scala.annotation.varargs
+  @scala.annotation.varargs
   def orderBy(colName: String, colNames: String*): WindowSpec = {
     spec.orderBy(colName, colNames : _*)
   }
@@ -69,11 +69,49 @@ object Window {
    * Creates a [[WindowSpec]] with the ordering defined.
    * @since 1.4.0
    */
-  @_root_.scala.annotation.varargs
+  @scala.annotation.varargs
   def orderBy(cols: Column*): WindowSpec = {
     spec.orderBy(cols : _*)
   }
 
+  /**
+   * Creates a [[WindowSpec]] with the frame boundaries defined,
+   * from `start` (inclusive) to `end` (inclusive).
+   *
+   * Both `start` and `end` are relative positions from the current row. For example, "0" means
+   * "current row", while "-1" means the row before the current row, and "5" means the fifth row
+   * after the current row.
+   *
+   * @param start boundary start, inclusive.
+   *              The frame is unbounded if this is the minimum long value.
+   * @param end boundary end, inclusive.
+   *            The frame is unbounded if this is the maximum long value.
+   * @since 2.1.0
+   */
+  // Note: when updating the doc for this method, also update WindowSpec.rowsBetween.
+  def rowsBetween(start: Long, end: Long): WindowSpec = {
+    spec.rowsBetween(start, end)
+  }
+
+  /**
+   * Creates a [[WindowSpec]] with the frame boundaries defined,
+   * from `start` (inclusive) to `end` (inclusive).
+   *
+   * Both `start` and `end` are relative from the current row. For example, "0" means "current row",
+   * while "-1" means one off before the current row, and "5" means the five off after the
+   * current row.
+   *
+   * @param start boundary start, inclusive.
+   *              The frame is unbounded if this is the minimum long value.
+   * @param end boundary end, inclusive.
+   *            The frame is unbounded if this is the maximum long value.
+   * @since 2.1.0
+   */
+  // Note: when updating the doc for this method, also update WindowSpec.rangeBetween.
+  def rangeBetween(start: Long, end: Long): WindowSpec = {
+    spec.rangeBetween(start, end)
+  }
+
   private[sql] def spec: WindowSpec = {
     new WindowSpec(Seq.empty, Seq.empty, UnspecifiedFrame)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/b515768f/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala
index d716da2..82bc8f1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala
@@ -39,7 +39,7 @@ class WindowSpec private[sql](
    * Defines the partitioning columns in a [[WindowSpec]].
    * @since 1.4.0
    */
-  @_root_.scala.annotation.varargs
+  @scala.annotation.varargs
   def partitionBy(colName: String, colNames: String*): WindowSpec = {
     partitionBy((colName +: colNames).map(Column(_)): _*)
   }
@@ -48,7 +48,7 @@ class WindowSpec private[sql](
    * Defines the partitioning columns in a [[WindowSpec]].
    * @since 1.4.0
    */
-  @_root_.scala.annotation.varargs
+  @scala.annotation.varargs
   def partitionBy(cols: Column*): WindowSpec = {
     new WindowSpec(cols.map(_.expr), orderSpec, frame)
   }
@@ -57,7 +57,7 @@ class WindowSpec private[sql](
    * Defines the ordering columns in a [[WindowSpec]].
    * @since 1.4.0
    */
-  @_root_.scala.annotation.varargs
+  @scala.annotation.varargs
   def orderBy(colName: String, colNames: String*): WindowSpec = {
     orderBy((colName +: colNames).map(Column(_)): _*)
   }
@@ -66,7 +66,7 @@ class WindowSpec private[sql](
    * Defines the ordering columns in a [[WindowSpec]].
    * @since 1.4.0
    */
-  @_root_.scala.annotation.varargs
+  @scala.annotation.varargs
   def orderBy(cols: Column*): WindowSpec = {
     val sortOrder: Seq[SortOrder] = cols.map { col =>
       col.expr match {
@@ -92,6 +92,7 @@ class WindowSpec private[sql](
    *            The frame is unbounded if this is the maximum long value.
    * @since 1.4.0
    */
+  // Note: when updating the doc for this method, also update Window.rowsBetween.
   def rowsBetween(start: Long, end: Long): WindowSpec = {
     between(RowFrame, start, end)
   }
@@ -109,6 +110,7 @@ class WindowSpec private[sql](
    *            The frame is unbounded if this is the maximum long value.
    * @since 1.4.0
    */
+  // Note: when updating the doc for this method, also update Window.rangeBetween.
   def rangeBetween(start: Long, end: Long): WindowSpec = {
     between(RangeFrame, start, end)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/b515768f/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala
index eac658c..5417a0e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala
@@ -106,7 +106,7 @@ abstract class UserDefinedAggregateFunction extends Serializable {
   /**
    * Creates a [[Column]] for this UDAF using given [[Column]]s as input arguments.
    */
-  @_root_.scala.annotation.varargs
+  @scala.annotation.varargs
   def apply(exprs: Column*): Column = {
     val aggregateExpression =
       AggregateExpression(
@@ -120,7 +120,7 @@ abstract class UserDefinedAggregateFunction extends Serializable {
    * Creates a [[Column]] for this UDAF using the distinct values of the given
    * [[Column]]s as input arguments.
    */
-  @_root_.scala.annotation.varargs
+  @scala.annotation.varargs
   def distinct(exprs: Column*): Column = {
     val aggregateExpression =
       AggregateExpression(

http://git-wip-us.apache.org/repos/asf/spark/blob/b515768f/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala
index c2b47ca..5bc386f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala
@@ -22,6 +22,9 @@ import org.apache.spark.sql.functions._
 import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types.{DataType, LongType, StructType}
 
+/**
+ * Window function testing for DataFrame API.
+ */
 class DataFrameWindowSuite extends QueryTest with SharedSQLContext {
   import testImplicits._
 
@@ -47,6 +50,15 @@ class DataFrameWindowSuite extends QueryTest with SharedSQLContext {
       Row(1, "1") :: Row(2, "2") :: Row(null, null) :: Row(null, null) :: Nil)
   }
 
+  test("Window.rowsBetween") {
+    val df = Seq(("one", 1), ("two", 2)).toDF("key", "value")
+    // Running (cumulative) sum
+    checkAnswer(
+      df.select('key, sum("value").over(Window.rowsBetween(Long.MinValue, 0))),
+      Row("one", 1) :: Row("two", 3) :: Nil
+    )
+  }
+
   test("lead") {
     val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value")
     df.createOrReplaceTempView("window_table")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org