You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by GitBox <gi...@apache.org> on 2018/12/18 01:19:09 UTC

[GitHub] asfgit closed pull request #22305: [SPARK-24561][SQL][Python] User-defined window aggregation functions with Pandas UDF (bounded window)

asfgit closed pull request #22305: [SPARK-24561][SQL][Python] User-defined window aggregation functions with Pandas UDF (bounded window)
URL: https://github.com/apache/spark/pull/22305
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index f98e550e39da8..d188de39e21c7 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -2982,8 +2982,7 @@ def pandas_udf(f=None, returnType=None, functionType=None):
        |  2|        6.0|
        +---+-----------+
 
-       This example shows using grouped aggregated UDFs as window functions. Note that only
-       unbounded window frame is supported at the moment:
+       This example shows using grouped aggregated UDFs as window functions.
 
        >>> from pyspark.sql.functions import pandas_udf, PandasUDFType
        >>> from pyspark.sql import Window
@@ -2993,20 +2992,24 @@ def pandas_udf(f=None, returnType=None, functionType=None):
        >>> @pandas_udf("double", PandasUDFType.GROUPED_AGG)  # doctest: +SKIP
        ... def mean_udf(v):
        ...     return v.mean()
-       >>> w = Window \\
-       ...     .partitionBy('id') \\
-       ...     .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
+       >>> w = (Window.partitionBy('id')
+       ...            .orderBy('v')
+       ...            .rowsBetween(-1, 0))
        >>> df.withColumn('mean_v', mean_udf(df['v']).over(w)).show()  # doctest: +SKIP
        +---+----+------+
        | id|   v|mean_v|
        +---+----+------+
-       |  1| 1.0|   1.5|
+       |  1| 1.0|   1.0|
        |  1| 2.0|   1.5|
-       |  2| 3.0|   6.0|
-       |  2| 5.0|   6.0|
-       |  2|10.0|   6.0|
+       |  2| 3.0|   3.0|
+       |  2| 5.0|   4.0|
+       |  2|10.0|   7.5|
        +---+----+------+
 
+       .. note:: For performance reasons, the input series to window functions are not copied.
+            Therefore, mutating the input series is not allowed and will cause incorrect results.
+            For the same reason, users should also not rely on the index of the input series.
+
        .. seealso:: :meth:`pyspark.sql.GroupedData.agg` and :class:`pyspark.sql.Window`
 
     .. note:: The user-defined functions are considered deterministic by default. Due to
diff --git a/python/pyspark/sql/tests/test_pandas_udf_window.py b/python/pyspark/sql/tests/test_pandas_udf_window.py
index f0e6d2696df62..1b7df6797e9e6 100644
--- a/python/pyspark/sql/tests/test_pandas_udf_window.py
+++ b/python/pyspark/sql/tests/test_pandas_udf_window.py
@@ -47,6 +47,15 @@ def pandas_scalar_time_two(self):
         from pyspark.sql.functions import pandas_udf
         return pandas_udf(lambda v: v * 2, 'double')
 
+    @property
+    def pandas_agg_count_udf(self):
+        from pyspark.sql.functions import pandas_udf, PandasUDFType
+
+        @pandas_udf('long', PandasUDFType.GROUPED_AGG)
+        def count(v):
+            return len(v)
+        return count
+
     @property
     def pandas_agg_mean_udf(self):
         from pyspark.sql.functions import pandas_udf, PandasUDFType
@@ -77,7 +86,7 @@ def min(v):
     @property
     def unbounded_window(self):
         return Window.partitionBy('id') \
-            .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
+            .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing).orderBy('v')
 
     @property
     def ordered_window(self):
@@ -87,6 +96,32 @@ def ordered_window(self):
     def unpartitioned_window(self):
         return Window.partitionBy()
 
+    @property
+    def sliding_row_window(self):
+        return Window.partitionBy('id').orderBy('v').rowsBetween(-2, 1)
+
+    @property
+    def sliding_range_window(self):
+        return Window.partitionBy('id').orderBy('v').rangeBetween(-2, 4)
+
+    @property
+    def growing_row_window(self):
+        return Window.partitionBy('id').orderBy('v').rowsBetween(Window.unboundedPreceding, 3)
+
+    @property
+    def growing_range_window(self):
+        return Window.partitionBy('id').orderBy('v') \
+            .rangeBetween(Window.unboundedPreceding, 4)
+
+    @property
+    def shrinking_row_window(self):
+        return Window.partitionBy('id').orderBy('v').rowsBetween(-2, Window.unboundedFollowing)
+
+    @property
+    def shrinking_range_window(self):
+        return Window.partitionBy('id').orderBy('v') \
+            .rangeBetween(-3, Window.unboundedFollowing)
+
     def test_simple(self):
         from pyspark.sql.functions import mean
 
@@ -111,12 +146,12 @@ def test_multiple_udfs(self):
         w = self.unbounded_window
 
         result1 = df.withColumn('mean_v', self.pandas_agg_mean_udf(df['v']).over(w)) \
-                    .withColumn('max_v', self.pandas_agg_max_udf(df['v']).over(w)) \
-                    .withColumn('min_w', self.pandas_agg_min_udf(df['w']).over(w))
+            .withColumn('max_v', self.pandas_agg_max_udf(df['v']).over(w)) \
+            .withColumn('min_w', self.pandas_agg_min_udf(df['w']).over(w))
 
         expected1 = df.withColumn('mean_v', mean(df['v']).over(w)) \
-                      .withColumn('max_v', max(df['v']).over(w)) \
-                      .withColumn('min_w', min(df['w']).over(w))
+            .withColumn('max_v', max(df['v']).over(w)) \
+            .withColumn('min_w', min(df['w']).over(w))
 
         self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
 
@@ -204,16 +239,16 @@ def test_mixed_sql_and_udf(self):
 
         # Test chaining sql aggregate function and udf
         result3 = df.withColumn('max_v', max_udf(df['v']).over(w)) \
-                    .withColumn('min_v', min(df['v']).over(w)) \
-                    .withColumn('v_diff', col('max_v') - col('min_v')) \
-                    .drop('max_v', 'min_v')
+            .withColumn('min_v', min(df['v']).over(w)) \
+            .withColumn('v_diff', col('max_v') - col('min_v')) \
+            .drop('max_v', 'min_v')
         expected3 = expected1
 
         # Test mixing sql window function and udf
         result4 = df.withColumn('max_v', max_udf(df['v']).over(w)) \
-                    .withColumn('rank', rank().over(ow))
+            .withColumn('rank', rank().over(ow))
         expected4 = df.withColumn('max_v', max(df['v']).over(w)) \
-                      .withColumn('rank', rank().over(ow))
+            .withColumn('rank', rank().over(ow))
 
         self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
         self.assertPandasEqual(expected2.toPandas(), result2.toPandas())
@@ -235,8 +270,6 @@ def test_invalid_args(self):
 
         df = self.data
         w = self.unbounded_window
-        ow = self.ordered_window
-        mean_udf = self.pandas_agg_mean_udf
 
         with QuietTest(self.sc):
             with self.assertRaisesRegexp(
@@ -245,11 +278,101 @@ def test_invalid_args(self):
                 foo_udf = pandas_udf(lambda x: x, 'v double', PandasUDFType.GROUPED_MAP)
                 df.withColumn('v2', foo_udf(df['v']).over(w))
 
-        with QuietTest(self.sc):
-            with self.assertRaisesRegexp(
-                    AnalysisException,
-                    '.*Only unbounded window frame is supported.*'):
-                df.withColumn('mean_v', mean_udf(df['v']).over(ow))
+    def test_bounded_simple(self):
+        from pyspark.sql.functions import mean, max, min, count
+
+        df = self.data
+        w1 = self.sliding_row_window
+        w2 = self.shrinking_range_window
+
+        plus_one = self.python_plus_one
+        count_udf = self.pandas_agg_count_udf
+        mean_udf = self.pandas_agg_mean_udf
+        max_udf = self.pandas_agg_max_udf
+        min_udf = self.pandas_agg_min_udf
+
+        result1 = df.withColumn('mean_v', mean_udf(plus_one(df['v'])).over(w1)) \
+            .withColumn('count_v', count_udf(df['v']).over(w2)) \
+            .withColumn('max_v',  max_udf(df['v']).over(w2)) \
+            .withColumn('min_v', min_udf(df['v']).over(w1))
+
+        expected1 = df.withColumn('mean_v', mean(plus_one(df['v'])).over(w1)) \
+            .withColumn('count_v', count(df['v']).over(w2)) \
+            .withColumn('max_v', max(df['v']).over(w2)) \
+            .withColumn('min_v', min(df['v']).over(w1))
+
+        self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
+
+    def test_growing_window(self):
+        from pyspark.sql.functions import mean
+
+        df = self.data
+        w1 = self.growing_row_window
+        w2 = self.growing_range_window
+
+        mean_udf = self.pandas_agg_mean_udf
+
+        result1 = df.withColumn('m1', mean_udf(df['v']).over(w1)) \
+            .withColumn('m2', mean_udf(df['v']).over(w2))
+
+        expected1 = df.withColumn('m1', mean(df['v']).over(w1)) \
+            .withColumn('m2', mean(df['v']).over(w2))
+
+        self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
+
+    def test_sliding_window(self):
+        from pyspark.sql.functions import mean
+
+        df = self.data
+        w1 = self.sliding_row_window
+        w2 = self.sliding_range_window
+
+        mean_udf = self.pandas_agg_mean_udf
+
+        result1 = df.withColumn('m1', mean_udf(df['v']).over(w1)) \
+            .withColumn('m2', mean_udf(df['v']).over(w2))
+
+        expected1 = df.withColumn('m1', mean(df['v']).over(w1)) \
+            .withColumn('m2', mean(df['v']).over(w2))
+
+        self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
+
+    def test_shrinking_window(self):
+        from pyspark.sql.functions import mean
+
+        df = self.data
+        w1 = self.shrinking_row_window
+        w2 = self.shrinking_range_window
+
+        mean_udf = self.pandas_agg_mean_udf
+
+        result1 = df.withColumn('m1', mean_udf(df['v']).over(w1)) \
+            .withColumn('m2', mean_udf(df['v']).over(w2))
+
+        expected1 = df.withColumn('m1', mean(df['v']).over(w1)) \
+            .withColumn('m2', mean(df['v']).over(w2))
+
+        self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
+
+    def test_bounded_mixed(self):
+        from pyspark.sql.functions import mean, max
+
+        df = self.data
+        w1 = self.sliding_row_window
+        w2 = self.unbounded_window
+
+        mean_udf = self.pandas_agg_mean_udf
+        max_udf = self.pandas_agg_max_udf
+
+        result1 = df.withColumn('mean_v', mean_udf(df['v']).over(w1)) \
+            .withColumn('max_v', max_udf(df['v']).over(w2)) \
+            .withColumn('mean_unbounded_v', mean_udf(df['v']).over(w1))
+
+        expected1 = df.withColumn('mean_v', mean(df['v']).over(w1)) \
+            .withColumn('max_v', max(df['v']).over(w2)) \
+            .withColumn('mean_unbounded_v', mean(df['v']).over(w1))
+
+        self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
 
 
 if __name__ == "__main__":
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 953b468e96519..bf007b0c62d8d 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -145,7 +145,18 @@ def wrapped(*series):
     return lambda *a: (wrapped(*a), arrow_return_type)
 
 
-def wrap_window_agg_pandas_udf(f, return_type):
+def wrap_window_agg_pandas_udf(f, return_type, runner_conf, udf_index):
+    window_bound_types_str = runner_conf.get('pandas_window_bound_types')
+    window_bound_type = [t.strip().lower() for t in window_bound_types_str.split(',')][udf_index]
+    if window_bound_type == 'bounded':
+        return wrap_bounded_window_agg_pandas_udf(f, return_type)
+    elif window_bound_type == 'unbounded':
+        return wrap_unbounded_window_agg_pandas_udf(f, return_type)
+    else:
+        raise RuntimeError("Invalid window bound type: {} ".format(window_bound_type))
+
+
+def wrap_unbounded_window_agg_pandas_udf(f, return_type):
     # This is similar to grouped_agg_pandas_udf, the only difference
     # is that window_agg_pandas_udf needs to repeat the return value
     # to match window length, where grouped_agg_pandas_udf just returns
@@ -160,7 +171,41 @@ def wrapped(*series):
     return lambda *a: (wrapped(*a), arrow_return_type)
 
 
-def read_single_udf(pickleSer, infile, eval_type, runner_conf):
+def wrap_bounded_window_agg_pandas_udf(f, return_type):
+    arrow_return_type = to_arrow_type(return_type)
+
+    def wrapped(begin_index, end_index, *series):
+        import pandas as pd
+        result = []
+
+        # Index operation is faster on np.ndarray,
+        # So we turn the index series into np array
+        # here for performance
+        begin_array = begin_index.values
+        end_array = end_index.values
+
+        for i in range(len(begin_array)):
+            # Note: Create a slice from a series for each window is
+            #       actually pretty expensive. However, there
+            #       is no easy way to reduce cost here.
+            # Note: s.iloc[i : j] is about 30% faster than s[i: j], with
+            #       the caveat that the created slices shares the same
+            #       memory with s. Therefore, user are not allowed to
+            #       change the value of input series inside the window
+            #       function. It is rare that user needs to modify the
+            #       input series in the window function, and therefore,
+            #       it is be a reasonable restriction.
+            # Note: Calling reset_index on the slices will increase the cost
+            #       of creating slices by about 100%. Therefore, for performance
+            #       reasons we don't do it here.
+            series_slices = [s.iloc[begin_array[i]: end_array[i]] for s in series]
+            result.append(f(*series_slices))
+        return pd.Series(result)
+
+    return lambda *a: (wrapped(*a), arrow_return_type)
+
+
+def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index):
     num_arg = read_int(infile)
     arg_offsets = [read_int(infile) for i in range(num_arg)]
     row_func = None
@@ -184,7 +229,7 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf):
     elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
         return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type)
     elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF:
-        return arg_offsets, wrap_window_agg_pandas_udf(func, return_type)
+        return arg_offsets, wrap_window_agg_pandas_udf(func, return_type, runner_conf, udf_index)
     elif eval_type == PythonEvalType.SQL_BATCHED_UDF:
         return arg_offsets, wrap_udf(func, return_type)
     else:
@@ -226,7 +271,8 @@ def read_udfs(pickleSer, infile, eval_type):
 
         # See FlatMapGroupsInPandasExec for how arg_offsets are used to
         # distinguish between grouping attributes and data attributes
-        arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type, runner_conf)
+        arg_offsets, udf = read_single_udf(
+            pickleSer, infile, eval_type, runner_conf, udf_index=0)
         udfs['f'] = udf
         split_offset = arg_offsets[0] + 1
         arg0 = ["a[%d]" % o for o in arg_offsets[1: split_offset]]
@@ -238,7 +284,8 @@ def read_udfs(pickleSer, infile, eval_type):
         # In the special case of a single UDF this will return a single result rather
         # than a tuple of results; this is the format that the JVM side expects.
         for i in range(num_udfs):
-            arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type, runner_conf)
+            arg_offsets, udf = read_single_udf(
+                pickleSer, infile, eval_type, runner_conf, udf_index=i)
             udfs['f%d' % i] = udf
             args = ["a[%d]" % o for o in arg_offsets]
             call_udf.append("f%d(%s)" % (i, ", ".join(args)))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 6a91d556b2f3e..88d41e8824405 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -134,11 +134,6 @@ trait CheckAnalysis extends PredicateHelper {
             failAnalysis("An offset window function can only be evaluated in an ordered " +
               s"row-based window frame with a single offset: $w")
 
-          case _ @ WindowExpression(_: PythonUDF,
-            WindowSpecDefinition(_, _, frame: SpecifiedWindowFrame))
-              if !frame.isUnbounded =>
-            failAnalysis("Only unbounded window frame is supported with Pandas UDFs.")
-
           case w @ WindowExpression(e, s) =>
             // Only allow window functions with an aggregate expression or an offset window
             // function or a Pandas window UDF.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala
index 82973307feef3..1ce1215bfdd62 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala
@@ -27,17 +27,64 @@ import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode}
+import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning}
+import org.apache.spark.sql.execution.{ExternalAppendOnlyUnsafeRowArray, SparkPlan}
 import org.apache.spark.sql.execution.arrow.ArrowUtils
-import org.apache.spark.sql.types.{DataType, StructField, StructType}
+import org.apache.spark.sql.execution.window._
+import org.apache.spark.sql.types._
 import org.apache.spark.util.Utils
 
+/**
+ * This class calculates and outputs windowed aggregates over the rows in a single partition.
+ *
+ * This is similar to [[WindowExec]]. The main difference is that this node does not compute
+ * any window aggregation values. Instead, it computes the lower and upper bound for each window
+ * (i.e. window bounds) and pass the data and indices to Python worker to do the actual window
+ * aggregation.
+ *
+ * It currently materializes all data associated with the same partition key and passes them to
+ * Python worker. This is not strictly necessary for sliding windows and can be improved (by
+ * possibly slicing data into overlapping chunks and stitching them together).
+ *
+ * This class groups window expressions by their window boundaries so that window expressions
+ * with the same window boundaries can share the same window bounds. The window bounds are
+ * prepended to the data passed to the python worker.
+ *
+ * For example, if we have:
+ *     avg(v) over specifiedwindowframe(RowFrame, -5, 5),
+ *     avg(v) over specifiedwindowframe(RowFrame, UnboundedPreceding, UnboundedFollowing),
+ *     avg(v) over specifiedwindowframe(RowFrame, -3, 3),
+ *     max(v) over specifiedwindowframe(RowFrame, -3, 3)
+ *
+ * The python input will look like:
+ * (lower_bound_w1, upper_bound_w1, lower_bound_w3, upper_bound_w3, v)
+ *
+ * where w1 is specifiedwindowframe(RowFrame, -5, 5)
+ *       w2 is specifiedwindowframe(RowFrame, UnboundedPreceding, UnboundedFollowing)
+ *       w3 is specifiedwindowframe(RowFrame, -3, 3)
+ *
+ * Note that w2 doesn't have bound indices in the python input because it's unbounded window
+ * so it's bound indices will always be the same.
+ *
+ * Bounded window and Unbounded window are evaluated differently in Python worker:
+ * (1) Bounded window takes the window bound indices in addition to the input columns.
+ *     Unbounded window takes only input columns.
+ * (2) Bounded window evaluates the udf once per input row.
+ *     Unbounded window evaluates the udf once per window partition.
+ * This is controlled by Python runner conf "pandas_window_bound_types"
+ *
+ * The logic to compute window bounds is delegated to [[WindowFunctionFrame]] and shared with
+ * [[WindowExec]]
+ *
+ * Note this doesn't support partial aggregation and all aggregation is computed from the entire
+ * window.
+ */
 case class WindowInPandasExec(
     windowExpression: Seq[NamedExpression],
     partitionSpec: Seq[Expression],
     orderSpec: Seq[SortOrder],
-    child: SparkPlan) extends UnaryExecNode {
+    child: SparkPlan)
+  extends WindowExecBase(windowExpression, partitionSpec, orderSpec, child) {
 
   override def output: Seq[Attribute] =
     child.output ++ windowExpression.map(_.toAttribute)
@@ -60,6 +107,26 @@ case class WindowInPandasExec(
 
   override def outputPartitioning: Partitioning = child.outputPartitioning
 
+  /**
+   * Helper functions and data structures for window bounds
+   *
+   * It contains:
+   * (1) Total number of window bound indices in the python input row
+   * (2) Function from frame index to its lower bound column index in the python input row
+   * (3) Function from frame index to its upper bound column index in the python input row
+   * (4) Seq from frame index to its window bound type
+   */
+  private type WindowBoundHelpers = (Int, Int => Int, Int => Int, Seq[WindowBoundType])
+
+  /**
+   * Enum for window bound types. Used only inside this class.
+   */
+  private sealed case class WindowBoundType(value: String)
+  private object UnboundedWindow extends WindowBoundType("unbounded")
+  private object BoundedWindow extends WindowBoundType("bounded")
+
+  private val windowBoundTypeConf = "pandas_window_bound_types"
+
   private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = {
     udf.children match {
       case Seq(u: PythonUDF) =>
@@ -73,68 +140,150 @@ case class WindowInPandasExec(
   }
 
   /**
-   * Create the resulting projection.
-   *
-   * This method uses Code Generation. It can only be used on the executor side.
-   *
-   * @param expressions unbound ordered function expressions.
-   * @return the final resulting projection.
+   * See [[WindowBoundHelpers]] for details.
    */
-  private[this] def createResultProjection(expressions: Seq[Expression]): UnsafeProjection = {
-    val references = expressions.zipWithIndex.map { case (e, i) =>
-      // Results of window expressions will be on the right side of child's output
-      BoundReference(child.output.size + i, e.dataType, e.nullable)
+  private def computeWindowBoundHelpers(
+      factories: Seq[InternalRow => WindowFunctionFrame]
+  ): WindowBoundHelpers = {
+    val functionFrames = factories.map(_(EmptyRow))
+
+    val windowBoundTypes = functionFrames.map {
+      case _: UnboundedWindowFunctionFrame => UnboundedWindow
+      case _: UnboundedFollowingWindowFunctionFrame |
+        _: SlidingWindowFunctionFrame |
+        _: UnboundedPrecedingWindowFunctionFrame => BoundedWindow
+      // It should be impossible to get other types of window function frame here
+      case frame => throw new RuntimeException(s"Unexpected window function frame $frame.")
     }
-    val unboundToRefMap = expressions.zip(references).toMap
-    val patchedWindowExpression = windowExpression.map(_.transform(unboundToRefMap))
-    UnsafeProjection.create(
-      child.output ++ patchedWindowExpression,
-      child.output)
+
+    val requiredIndices = functionFrames.map {
+      case _: UnboundedWindowFunctionFrame => 0
+      case _ => 2
+    }
+
+    val upperBoundIndices = requiredIndices.scan(0)(_ + _).tail
+
+    val boundIndices = requiredIndices.zip(upperBoundIndices).map { case (num, upperBoundIndex) =>
+        if (num == 0) {
+          // Sentinel values for unbounded window
+          (-1, -1)
+        } else {
+          (upperBoundIndex - 2, upperBoundIndex - 1)
+        }
+    }
+
+    def lowerBoundIndex(frameIndex: Int) = boundIndices(frameIndex)._1
+    def upperBoundIndex(frameIndex: Int) = boundIndices(frameIndex)._2
+
+    (requiredIndices.sum, lowerBoundIndex, upperBoundIndex, windowBoundTypes)
   }
 
   protected override def doExecute(): RDD[InternalRow] = {
-    val inputRDD = child.execute()
+    // Unwrap the expressions and factories from the map.
+    val expressionsWithFrameIndex =
+      windowFrameExpressionFactoryPairs.map(_._1).zipWithIndex.flatMap {
+        case (buffer, frameIndex) => buffer.map(expr => (expr, frameIndex))
+      }
+
+    val expressions = expressionsWithFrameIndex.map(_._1)
+    val expressionIndexToFrameIndex =
+      expressionsWithFrameIndex.map(_._2).zipWithIndex.map(_.swap).toMap
+
+    val factories = windowFrameExpressionFactoryPairs.map(_._2).toArray
 
+    // Helper functions
+    val (numBoundIndices, lowerBoundIndex, upperBoundIndex, frameWindowBoundTypes) =
+      computeWindowBoundHelpers(factories)
+    val isBounded = { frameIndex: Int => lowerBoundIndex(frameIndex) >= 0 }
+    val numFrames = factories.length
+
+    val inMemoryThreshold = conf.windowExecBufferInMemoryThreshold
+    val spillThreshold = conf.windowExecBufferSpillThreshold
     val sessionLocalTimeZone = conf.sessionLocalTimeZone
-    val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
 
     // Extract window expressions and window functions
-    val expressions = windowExpression.flatMap(_.collect { case e: WindowExpression => e })
-
-    val udfExpressions = expressions.map(_.windowFunction.asInstanceOf[PythonUDF])
+    val windowExpressions = expressions.flatMap(_.collect { case e: WindowExpression => e })
+    val udfExpressions = windowExpressions.map(_.windowFunction.asInstanceOf[PythonUDF])
 
+    // We shouldn't be chaining anything here.
+    // All chained python functions should only contain one function.
     val (pyFuncs, inputs) = udfExpressions.map(collectFunctions).unzip
+    require(pyFuncs.length == expressions.length)
+
+    val udfWindowBoundTypes = pyFuncs.indices.map(i =>
+      frameWindowBoundTypes(expressionIndexToFrameIndex(i)))
+    val pythonRunnerConf: Map[String, String] = (ArrowUtils.getPythonRunnerConfMap(conf)
+      + (windowBoundTypeConf -> udfWindowBoundTypes.map(_.value).mkString(",")))
 
     // Filter child output attributes down to only those that are UDF inputs.
-    // Also eliminate duplicate UDF inputs.
-    val allInputs = new ArrayBuffer[Expression]
-    val dataTypes = new ArrayBuffer[DataType]
+    // Also eliminate duplicate UDF inputs. This is similar to how other Python UDF node
+    // handles UDF inputs.
+    val dataInputs = new ArrayBuffer[Expression]
+    val dataInputTypes = new ArrayBuffer[DataType]
     val argOffsets = inputs.map { input =>
       input.map { e =>
-        if (allInputs.exists(_.semanticEquals(e))) {
-          allInputs.indexWhere(_.semanticEquals(e))
+        if (dataInputs.exists(_.semanticEquals(e))) {
+          dataInputs.indexWhere(_.semanticEquals(e))
         } else {
-          allInputs += e
-          dataTypes += e.dataType
-          allInputs.length - 1
+          dataInputs += e
+          dataInputTypes += e.dataType
+          dataInputs.length - 1
         }
       }.toArray
     }.toArray
 
-    // Schema of input rows to the python runner
-    val windowInputSchema = StructType(dataTypes.zipWithIndex.map { case (dt, i) =>
-      StructField(s"_$i", dt)
-    })
+    // In addition to UDF inputs, we will prepend window bounds for each UDFs.
+    // For bounded windows, we prepend lower bound and upper bound. For unbounded windows,
+    // we no not add window bounds. (strictly speaking, we only need to lower or upper bound
+    // if the window is bounded only on one side, this can be improved in the future)
 
-    inputRDD.mapPartitionsInternal { iter =>
-      val context = TaskContext.get()
+    // Setting window bounds for each window frames. Each window frame has different bounds so
+    // each has its own window bound columns.
+    val windowBoundsInput = factories.indices.flatMap { frameIndex =>
+      if (isBounded(frameIndex)) {
+        Seq(
+          BoundReference(lowerBoundIndex(frameIndex), IntegerType, nullable = false),
+          BoundReference(upperBoundIndex(frameIndex), IntegerType, nullable = false)
+        )
+      } else {
+        Seq.empty
+      }
+    }
 
-      val grouped = if (partitionSpec.isEmpty) {
-        // Use an empty unsafe row as a place holder for the grouping key
-        Iterator((new UnsafeRow(), iter))
+    // Setting the window bounds argOffset for each UDF. For UDFs with bounded window, argOffset
+    // for the UDF is (lowerBoundOffet, upperBoundOffset, inputOffset1, inputOffset2, ...)
+    // For UDFs with unbounded window, argOffset is (inputOffset1, inputOffset2, ...)
+    pyFuncs.indices.foreach { exprIndex =>
+      val frameIndex = expressionIndexToFrameIndex(exprIndex)
+      if (isBounded(frameIndex)) {
+        argOffsets(exprIndex) =
+          Array(lowerBoundIndex(frameIndex), upperBoundIndex(frameIndex)) ++
+            argOffsets(exprIndex).map(_ + windowBoundsInput.length)
       } else {
-        GroupedIterator(iter, partitionSpec, child.output)
+        argOffsets(exprIndex) = argOffsets(exprIndex).map(_ + windowBoundsInput.length)
       }
+    }
+
+    val allInputs = windowBoundsInput ++ dataInputs
+    val allInputTypes = allInputs.map(_.dataType)
+
+    // Start processing.
+    child.execute().mapPartitions { iter =>
+      val context = TaskContext.get()
+
+      // Get all relevant projections.
+      val resultProj = createResultProjection(expressions)
+      val pythonInputProj = UnsafeProjection.create(
+        allInputs,
+        windowBoundsInput.map(ref =>
+          AttributeReference(s"i_${ref.ordinal}", ref.dataType)()) ++ child.output
+      )
+      val pythonInputSchema = StructType(
+        allInputTypes.zipWithIndex.map { case (dt, i) =>
+          StructField(s"_$i", dt)
+        }
+      )
+      val grouping = UnsafeProjection.create(partitionSpec, child.output)
 
       // The queue used to buffer input rows so we can drain it to
       // combine input with output from Python.
@@ -144,11 +293,94 @@ case class WindowInPandasExec(
         queue.close()
       }
 
-      val inputProj = UnsafeProjection.create(allInputs, child.output)
-      val pythonInput = grouped.map { case (_, rows) =>
-        rows.map { row =>
-          queue.add(row.asInstanceOf[UnsafeRow])
-          inputProj(row)
+      val stream = iter.map { row =>
+        queue.add(row.asInstanceOf[UnsafeRow])
+        row
+      }
+
+      val pythonInput = new Iterator[Iterator[UnsafeRow]] {
+
+        // Manage the stream and the grouping.
+        var nextRow: UnsafeRow = null
+        var nextGroup: UnsafeRow = null
+        var nextRowAvailable: Boolean = false
+        private[this] def fetchNextRow() {
+          nextRowAvailable = stream.hasNext
+          if (nextRowAvailable) {
+            nextRow = stream.next().asInstanceOf[UnsafeRow]
+            nextGroup = grouping(nextRow)
+          } else {
+            nextRow = null
+            nextGroup = null
+          }
+        }
+        fetchNextRow()
+
+        // Manage the current partition.
+        val buffer: ExternalAppendOnlyUnsafeRowArray =
+          new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold)
+        var bufferIterator: Iterator[UnsafeRow] = _
+
+        val indexRow = new SpecificInternalRow(Array.fill(numBoundIndices)(IntegerType))
+
+        val frames = factories.map(_(indexRow))
+
+        private[this] def fetchNextPartition() {
+          // Collect all the rows in the current partition.
+          // Before we start to fetch new input rows, make a copy of nextGroup.
+          val currentGroup = nextGroup.copy()
+
+          // clear last partition
+          buffer.clear()
+
+          while (nextRowAvailable && nextGroup == currentGroup) {
+            buffer.add(nextRow)
+            fetchNextRow()
+          }
+
+          // Setup the frames.
+          var i = 0
+          while (i < numFrames) {
+            frames(i).prepare(buffer)
+            i += 1
+          }
+
+          // Setup iteration
+          rowIndex = 0
+          bufferIterator = buffer.generateIterator()
+        }
+
+        // Iteration
+        var rowIndex = 0
+
+        override final def hasNext: Boolean =
+          (bufferIterator != null && bufferIterator.hasNext) || nextRowAvailable
+
+        override final def next(): Iterator[UnsafeRow] = {
+          // Load the next partition if we need to.
+          if ((bufferIterator == null || !bufferIterator.hasNext) && nextRowAvailable) {
+            fetchNextPartition()
+          }
+
+          val join = new JoinedRow
+
+          bufferIterator.zipWithIndex.map {
+            case (current, index) =>
+              var frameIndex = 0
+              while (frameIndex < numFrames) {
+                frames(frameIndex).write(index, current)
+                // If the window is unbounded we don't need to write out window bounds.
+                if (isBounded(frameIndex)) {
+                  indexRow.setInt(
+                    lowerBoundIndex(frameIndex), frames(frameIndex).currentLowerBound())
+                  indexRow.setInt(
+                    upperBoundIndex(frameIndex), frames(frameIndex).currentUpperBound())
+                }
+                frameIndex += 1
+              }
+
+              pythonInputProj(join(indexRow, current))
+          }
         }
       }
 
@@ -156,12 +388,11 @@ case class WindowInPandasExec(
         pyFuncs,
         PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF,
         argOffsets,
-        windowInputSchema,
+        pythonInputSchema,
         sessionLocalTimeZone,
         pythonRunnerConf).compute(pythonInput, context.partitionId(), context)
 
       val joined = new JoinedRow
-      val resultProj = createResultProjection(expressions)
 
       windowFunctionResult.flatMap(_.rowIterator.asScala).map { windowOutput =>
         val leftRow = queue.remove()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala
index 729b8bdb3dae8..89f6edda2ef57 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala
@@ -83,7 +83,7 @@ case class WindowExec(
     partitionSpec: Seq[Expression],
     orderSpec: Seq[SortOrder],
     child: SparkPlan)
-  extends UnaryExecNode {
+  extends WindowExecBase(windowExpression, partitionSpec, orderSpec, child) {
 
   override def output: Seq[Attribute] =
     child.output ++ windowExpression.map(_.toAttribute)
@@ -104,193 +104,6 @@ case class WindowExec(
 
   override def outputPartitioning: Partitioning = child.outputPartitioning
 
-  /**
-   * Create a bound ordering object for a given frame type and offset. A bound ordering object is
-   * used to determine which input row lies within the frame boundaries of an output row.
-   *
-   * This method uses Code Generation. It can only be used on the executor side.
-   *
-   * @param frame to evaluate. This can either be a Row or Range frame.
-   * @param bound with respect to the row.
-   * @param timeZone the session local timezone for time related calculations.
-   * @return a bound ordering object.
-   */
-  private[this] def createBoundOrdering(
-      frame: FrameType, bound: Expression, timeZone: String): BoundOrdering = {
-    (frame, bound) match {
-      case (RowFrame, CurrentRow) =>
-        RowBoundOrdering(0)
-
-      case (RowFrame, IntegerLiteral(offset)) =>
-        RowBoundOrdering(offset)
-
-      case (RangeFrame, CurrentRow) =>
-        val ordering = newOrdering(orderSpec, child.output)
-        RangeBoundOrdering(ordering, IdentityProjection, IdentityProjection)
-
-      case (RangeFrame, offset: Expression) if orderSpec.size == 1 =>
-        // Use only the first order expression when the offset is non-null.
-        val sortExpr = orderSpec.head
-        val expr = sortExpr.child
-
-        // Create the projection which returns the current 'value'.
-        val current = newMutableProjection(expr :: Nil, child.output)
-
-        // Flip the sign of the offset when processing the order is descending
-        val boundOffset = sortExpr.direction match {
-          case Descending => UnaryMinus(offset)
-          case Ascending => offset
-        }
-
-        // Create the projection which returns the current 'value' modified by adding the offset.
-        val boundExpr = (expr.dataType, boundOffset.dataType) match {
-          case (DateType, IntegerType) => DateAdd(expr, boundOffset)
-          case (TimestampType, CalendarIntervalType) =>
-            TimeAdd(expr, boundOffset, Some(timeZone))
-          case (a, b) if a== b => Add(expr, boundOffset)
-        }
-        val bound = newMutableProjection(boundExpr :: Nil, child.output)
-
-        // Construct the ordering. This is used to compare the result of current value projection
-        // to the result of bound value projection. This is done manually because we want to use
-        // Code Generation (if it is enabled).
-        val boundSortExprs = sortExpr.copy(BoundReference(0, expr.dataType, expr.nullable)) :: Nil
-        val ordering = newOrdering(boundSortExprs, Nil)
-        RangeBoundOrdering(ordering, current, bound)
-
-      case (RangeFrame, _) =>
-        sys.error("Non-Zero range offsets are not supported for windows " +
-          "with multiple order expressions.")
-    }
-  }
-
-  /**
-   * Collection containing an entry for each window frame to process. Each entry contains a frame's
-   * [[WindowExpression]]s and factory function for the WindowFrameFunction.
-   */
-  private[this] lazy val windowFrameExpressionFactoryPairs = {
-    type FrameKey = (String, FrameType, Expression, Expression)
-    type ExpressionBuffer = mutable.Buffer[Expression]
-    val framedFunctions = mutable.Map.empty[FrameKey, (ExpressionBuffer, ExpressionBuffer)]
-
-    // Add a function and its function to the map for a given frame.
-    def collect(tpe: String, fr: SpecifiedWindowFrame, e: Expression, fn: Expression): Unit = {
-      val key = (tpe, fr.frameType, fr.lower, fr.upper)
-      val (es, fns) = framedFunctions.getOrElseUpdate(
-        key, (ArrayBuffer.empty[Expression], ArrayBuffer.empty[Expression]))
-      es += e
-      fns += fn
-    }
-
-    // Collect all valid window functions and group them by their frame.
-    windowExpression.foreach { x =>
-      x.foreach {
-        case e @ WindowExpression(function, spec) =>
-          val frame = spec.frameSpecification.asInstanceOf[SpecifiedWindowFrame]
-          function match {
-            case AggregateExpression(f, _, _, _) => collect("AGGREGATE", frame, e, f)
-            case f: AggregateWindowFunction => collect("AGGREGATE", frame, e, f)
-            case f: OffsetWindowFunction => collect("OFFSET", frame, e, f)
-            case f => sys.error(s"Unsupported window function: $f")
-          }
-        case _ =>
-      }
-    }
-
-    // Map the groups to a (unbound) expression and frame factory pair.
-    var numExpressions = 0
-    val timeZone = conf.sessionLocalTimeZone
-    framedFunctions.toSeq.map {
-      case (key, (expressions, functionSeq)) =>
-        val ordinal = numExpressions
-        val functions = functionSeq.toArray
-
-        // Construct an aggregate processor if we need one.
-        def processor = AggregateProcessor(
-          functions,
-          ordinal,
-          child.output,
-          (expressions, schema) =>
-            newMutableProjection(expressions, schema, subexpressionEliminationEnabled))
-
-        // Create the factory
-        val factory = key match {
-          // Offset Frame
-          case ("OFFSET", _, IntegerLiteral(offset), _) =>
-            target: InternalRow =>
-              new OffsetWindowFunctionFrame(
-                target,
-                ordinal,
-                // OFFSET frame functions are guaranteed be OffsetWindowFunctions.
-                functions.map(_.asInstanceOf[OffsetWindowFunction]),
-                child.output,
-                (expressions, schema) =>
-                  newMutableProjection(expressions, schema, subexpressionEliminationEnabled),
-                offset)
-
-          // Entire Partition Frame.
-          case ("AGGREGATE", _, UnboundedPreceding, UnboundedFollowing) =>
-            target: InternalRow => {
-              new UnboundedWindowFunctionFrame(target, processor)
-            }
-
-          // Growing Frame.
-          case ("AGGREGATE", frameType, UnboundedPreceding, upper) =>
-            target: InternalRow => {
-              new UnboundedPrecedingWindowFunctionFrame(
-                target,
-                processor,
-                createBoundOrdering(frameType, upper, timeZone))
-            }
-
-          // Shrinking Frame.
-          case ("AGGREGATE", frameType, lower, UnboundedFollowing) =>
-            target: InternalRow => {
-              new UnboundedFollowingWindowFunctionFrame(
-                target,
-                processor,
-                createBoundOrdering(frameType, lower, timeZone))
-            }
-
-          // Moving Frame.
-          case ("AGGREGATE", frameType, lower, upper) =>
-            target: InternalRow => {
-              new SlidingWindowFunctionFrame(
-                target,
-                processor,
-                createBoundOrdering(frameType, lower, timeZone),
-                createBoundOrdering(frameType, upper, timeZone))
-            }
-        }
-
-        // Keep track of the number of expressions. This is a side-effect in a map...
-        numExpressions += expressions.size
-
-        // Create the Frame Expression - Factory pair.
-        (expressions, factory)
-    }
-  }
-
-  /**
-   * Create the resulting projection.
-   *
-   * This method uses Code Generation. It can only be used on the executor side.
-   *
-   * @param expressions unbound ordered function expressions.
-   * @return the final resulting projection.
-   */
-  private[this] def createResultProjection(expressions: Seq[Expression]): UnsafeProjection = {
-    val references = expressions.zipWithIndex.map{ case (e, i) =>
-      // Results of window expressions will be on the right side of child's output
-      BoundReference(child.output.size + i, e.dataType, e.nullable)
-    }
-    val unboundToRefMap = expressions.zip(references).toMap
-    val patchedWindowExpression = windowExpression.map(_.transform(unboundToRefMap))
-    UnsafeProjection.create(
-      child.output ++ patchedWindowExpression,
-      child.output)
-  }
-
   protected override def doExecute(): RDD[InternalRow] = {
     // Unwrap the expressions and factories from the map.
     val expressions = windowFrameExpressionFactoryPairs.flatMap(_._1)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala
new file mode 100644
index 0000000000000..dcb86f48bdf32
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala
@@ -0,0 +1,230 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.window
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
+import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
+import org.apache.spark.sql.types.{CalendarIntervalType, DateType, IntegerType, TimestampType}
+
+abstract class WindowExecBase(
+    windowExpression: Seq[NamedExpression],
+    partitionSpec: Seq[Expression],
+    orderSpec: Seq[SortOrder],
+    child: SparkPlan) extends UnaryExecNode {
+
+  /**
+   * Create the resulting projection.
+   *
+   * This method uses Code Generation. It can only be used on the executor side.
+   *
+   * @param expressions unbound ordered function expressions.
+   * @return the final resulting projection.
+   */
+  protected def createResultProjection(expressions: Seq[Expression]): UnsafeProjection = {
+    val references = expressions.zipWithIndex.map { case (e, i) =>
+      // Results of window expressions will be on the right side of child's output
+      BoundReference(child.output.size + i, e.dataType, e.nullable)
+    }
+    val unboundToRefMap = expressions.zip(references).toMap
+    val patchedWindowExpression = windowExpression.map(_.transform(unboundToRefMap))
+    UnsafeProjection.create(
+      child.output ++ patchedWindowExpression,
+      child.output)
+  }
+
+  /**
+   * Create a bound ordering object for a given frame type and offset. A bound ordering object is
+   * used to determine which input row lies within the frame boundaries of an output row.
+   *
+   * This method uses Code Generation. It can only be used on the executor side.
+   *
+   * @param frame to evaluate. This can either be a Row or Range frame.
+   * @param bound with respect to the row.
+   * @param timeZone the session local timezone for time related calculations.
+   * @return a bound ordering object.
+   */
+  private def createBoundOrdering(
+      frame: FrameType, bound: Expression, timeZone: String): BoundOrdering = {
+    (frame, bound) match {
+      case (RowFrame, CurrentRow) =>
+        RowBoundOrdering(0)
+
+      case (RowFrame, IntegerLiteral(offset)) =>
+        RowBoundOrdering(offset)
+
+      case (RangeFrame, CurrentRow) =>
+        val ordering = newOrdering(orderSpec, child.output)
+        RangeBoundOrdering(ordering, IdentityProjection, IdentityProjection)
+
+      case (RangeFrame, offset: Expression) if orderSpec.size == 1 =>
+        // Use only the first order expression when the offset is non-null.
+        val sortExpr = orderSpec.head
+        val expr = sortExpr.child
+
+        // Create the projection which returns the current 'value'.
+        val current = newMutableProjection(expr :: Nil, child.output)
+
+        // Flip the sign of the offset when processing the order is descending
+        val boundOffset = sortExpr.direction match {
+          case Descending => UnaryMinus(offset)
+          case Ascending => offset
+        }
+
+        // Create the projection which returns the current 'value' modified by adding the offset.
+        val boundExpr = (expr.dataType, boundOffset.dataType) match {
+          case (DateType, IntegerType) => DateAdd(expr, boundOffset)
+          case (TimestampType, CalendarIntervalType) =>
+            TimeAdd(expr, boundOffset, Some(timeZone))
+          case (a, b) if a == b => Add(expr, boundOffset)
+        }
+        val bound = newMutableProjection(boundExpr :: Nil, child.output)
+
+        // Construct the ordering. This is used to compare the result of current value projection
+        // to the result of bound value projection. This is done manually because we want to use
+        // Code Generation (if it is enabled).
+        val boundSortExprs = sortExpr.copy(BoundReference(0, expr.dataType, expr.nullable)) :: Nil
+        val ordering = newOrdering(boundSortExprs, Nil)
+        RangeBoundOrdering(ordering, current, bound)
+
+      case (RangeFrame, _) =>
+        sys.error("Non-Zero range offsets are not supported for windows " +
+          "with multiple order expressions.")
+    }
+  }
+
+  /**
+   * Collection containing an entry for each window frame to process. Each entry contains a frame's
+   * [[WindowExpression]]s and factory function for the WindowFrameFunction.
+   */
+  protected lazy val windowFrameExpressionFactoryPairs = {
+    type FrameKey = (String, FrameType, Expression, Expression)
+    type ExpressionBuffer = mutable.Buffer[Expression]
+    val framedFunctions = mutable.Map.empty[FrameKey, (ExpressionBuffer, ExpressionBuffer)]
+
+    // Add a function and its function to the map for a given frame.
+    def collect(tpe: String, fr: SpecifiedWindowFrame, e: Expression, fn: Expression): Unit = {
+      val key = (tpe, fr.frameType, fr.lower, fr.upper)
+      val (es, fns) = framedFunctions.getOrElseUpdate(
+        key, (ArrayBuffer.empty[Expression], ArrayBuffer.empty[Expression]))
+      es += e
+      fns += fn
+    }
+
+    // Collect all valid window functions and group them by their frame.
+    windowExpression.foreach { x =>
+      x.foreach {
+        case e @ WindowExpression(function, spec) =>
+          val frame = spec.frameSpecification.asInstanceOf[SpecifiedWindowFrame]
+          function match {
+            case AggregateExpression(f, _, _, _) => collect("AGGREGATE", frame, e, f)
+            case f: AggregateWindowFunction => collect("AGGREGATE", frame, e, f)
+            case f: OffsetWindowFunction => collect("OFFSET", frame, e, f)
+            case f: PythonUDF => collect("AGGREGATE", frame, e, f)
+            case f => sys.error(s"Unsupported window function: $f")
+          }
+        case _ =>
+      }
+    }
+
+    // Map the groups to a (unbound) expression and frame factory pair.
+    var numExpressions = 0
+    val timeZone = conf.sessionLocalTimeZone
+    framedFunctions.toSeq.map {
+      case (key, (expressions, functionSeq)) =>
+        val ordinal = numExpressions
+        val functions = functionSeq.toArray
+
+        // Construct an aggregate processor if we need one.
+        // Currently we don't allow mixing of Pandas UDF and SQL aggregation functions
+        // in a single Window physical node. Therefore, we can assume no SQL aggregation
+        // functions if Pandas UDF exists. In the future, we might mix Pandas UDF and SQL
+        // aggregation function in a single physical node.
+        def processor = if (functions.exists(_.isInstanceOf[PythonUDF])) {
+          null
+        } else {
+          AggregateProcessor(
+            functions,
+            ordinal,
+            child.output,
+            (expressions, schema) =>
+              newMutableProjection(expressions, schema, subexpressionEliminationEnabled))
+        }
+
+        // Create the factory
+        val factory = key match {
+          // Offset Frame
+          case ("OFFSET", _, IntegerLiteral(offset), _) =>
+            target: InternalRow =>
+              new OffsetWindowFunctionFrame(
+                target,
+                ordinal,
+                // OFFSET frame functions are guaranteed be OffsetWindowFunctions.
+                functions.map(_.asInstanceOf[OffsetWindowFunction]),
+                child.output,
+                (expressions, schema) =>
+                  newMutableProjection(expressions, schema, subexpressionEliminationEnabled),
+                offset)
+
+          // Entire Partition Frame.
+          case ("AGGREGATE", _, UnboundedPreceding, UnboundedFollowing) =>
+            target: InternalRow => {
+              new UnboundedWindowFunctionFrame(target, processor)
+            }
+
+          // Growing Frame.
+          case ("AGGREGATE", frameType, UnboundedPreceding, upper) =>
+            target: InternalRow => {
+              new UnboundedPrecedingWindowFunctionFrame(
+                target,
+                processor,
+                createBoundOrdering(frameType, upper, timeZone))
+            }
+
+          // Shrinking Frame.
+          case ("AGGREGATE", frameType, lower, UnboundedFollowing) =>
+            target: InternalRow => {
+              new UnboundedFollowingWindowFunctionFrame(
+                target,
+                processor,
+                createBoundOrdering(frameType, lower, timeZone))
+            }
+
+          // Moving Frame.
+          case ("AGGREGATE", frameType, lower, upper) =>
+            target: InternalRow => {
+              new SlidingWindowFunctionFrame(
+                target,
+                processor,
+                createBoundOrdering(frameType, lower, timeZone),
+                createBoundOrdering(frameType, upper, timeZone))
+            }
+        }
+
+        // Keep track of the number of expressions. This is a side-effect in a map...
+        numExpressions += expressions.size
+
+        // Create the Frame Expression - Factory pair.
+        (expressions, factory)
+    }
+  }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala
index 156002ef58fbe..a5601899ea2de 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala
@@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray
  * Before use a frame must be prepared by passing it all the rows in the current partition. After
  * preparation the update method can be called to fill the output rows.
  */
-private[window] abstract class WindowFunctionFrame {
+abstract class WindowFunctionFrame {
   /**
    * Prepare the frame for calculating the results for a partition.
    *
@@ -42,6 +42,20 @@ private[window] abstract class WindowFunctionFrame {
    * Write the current results to the target row.
    */
   def write(index: Int, current: InternalRow): Unit
+
+  /**
+   * The current lower window bound in the row array (inclusive).
+   *
+   * This should be called after the current row is updated via [[write]]
+   */
+  def currentLowerBound(): Int
+
+  /**
+   * The current row index of the upper window bound in the row array (exclusive)
+   *
+   * This should be called after the current row is updated via [[write]]
+   */
+  def currentUpperBound(): Int
 }
 
 object WindowFunctionFrame {
@@ -62,7 +76,7 @@ object WindowFunctionFrame {
  * @param newMutableProjection function used to create the projection.
  * @param offset by which rows get moved within a partition.
  */
-private[window] final class OffsetWindowFunctionFrame(
+final class OffsetWindowFunctionFrame(
     target: InternalRow,
     ordinal: Int,
     expressions: Array[OffsetWindowFunction],
@@ -137,6 +151,10 @@ private[window] final class OffsetWindowFunctionFrame(
     }
     inputIndex += 1
   }
+
+  override def currentLowerBound(): Int = throw new UnsupportedOperationException()
+
+  override def currentUpperBound(): Int = throw new UnsupportedOperationException()
 }
 
 /**
@@ -148,7 +166,7 @@ private[window] final class OffsetWindowFunctionFrame(
  * @param lbound comparator used to identify the lower bound of an output row.
  * @param ubound comparator used to identify the upper bound of an output row.
  */
-private[window] final class SlidingWindowFunctionFrame(
+final class SlidingWindowFunctionFrame(
     target: InternalRow,
     processor: AggregateProcessor,
     lbound: BoundOrdering,
@@ -170,24 +188,24 @@ private[window] final class SlidingWindowFunctionFrame(
   private[this] val buffer = new util.ArrayDeque[InternalRow]()
 
   /**
-   * Index of the first input row with a value greater than the upper bound of the current
-   * output row.
+   * Index of the first input row with a value equal to or greater than the lower bound of the
+   * current output row.
    */
-  private[this] var inputHighIndex = 0
+  private[this] var lowerBound = 0
 
   /**
-   * Index of the first input row with a value equal to or greater than the lower bound of the
-   * current output row.
+   * Index of the first input row with a value greater than the upper bound of the current
+   * output row.
    */
-  private[this] var inputLowIndex = 0
+  private[this] var upperBound = 0
 
   /** Prepare the frame for calculating a new partition. Reset all variables. */
   override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = {
     input = rows
     inputIterator = input.generateIterator()
     nextRow = WindowFunctionFrame.getNextOrNull(inputIterator)
-    inputHighIndex = 0
-    inputLowIndex = 0
+    lowerBound = 0
+    upperBound = 0
     buffer.clear()
   }
 
@@ -197,27 +215,27 @@ private[window] final class SlidingWindowFunctionFrame(
 
     // Drop all rows from the buffer for which the input row value is smaller than
     // the output row lower bound.
-    while (!buffer.isEmpty && lbound.compare(buffer.peek(), inputLowIndex, current, index) < 0) {
+    while (!buffer.isEmpty && lbound.compare(buffer.peek(), lowerBound, current, index) < 0) {
       buffer.remove()
-      inputLowIndex += 1
+      lowerBound += 1
       bufferUpdated = true
     }
 
     // Add all rows to the buffer for which the input row value is equal to or less than
     // the output row upper bound.
-    while (nextRow != null && ubound.compare(nextRow, inputHighIndex, current, index) <= 0) {
-      if (lbound.compare(nextRow, inputLowIndex, current, index) < 0) {
-        inputLowIndex += 1
+    while (nextRow != null && ubound.compare(nextRow, upperBound, current, index) <= 0) {
+      if (lbound.compare(nextRow, lowerBound, current, index) < 0) {
+        lowerBound += 1
       } else {
         buffer.add(nextRow.copy())
         bufferUpdated = true
       }
       nextRow = WindowFunctionFrame.getNextOrNull(inputIterator)
-      inputHighIndex += 1
+      upperBound += 1
     }
 
     // Only recalculate and update when the buffer changes.
-    if (bufferUpdated) {
+    if (processor != null && bufferUpdated) {
       processor.initialize(input.length)
       val iter = buffer.iterator()
       while (iter.hasNext) {
@@ -226,6 +244,10 @@ private[window] final class SlidingWindowFunctionFrame(
       processor.evaluate(target)
     }
   }
+
+  override def currentLowerBound(): Int = lowerBound
+
+  override def currentUpperBound(): Int = upperBound
 }
 
 /**
@@ -239,27 +261,39 @@ private[window] final class SlidingWindowFunctionFrame(
  * @param target to write results to.
  * @param processor to calculate the row values with.
  */
-private[window] final class UnboundedWindowFunctionFrame(
+final class UnboundedWindowFunctionFrame(
     target: InternalRow,
     processor: AggregateProcessor)
   extends WindowFunctionFrame {
 
+  val lowerBound: Int = 0
+  var upperBound: Int = 0
+
   /** Prepare the frame for calculating a new partition. Process all rows eagerly. */
   override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = {
-    processor.initialize(rows.length)
-
-    val iterator = rows.generateIterator()
-    while (iterator.hasNext) {
-      processor.update(iterator.next())
+    if (processor != null) {
+      processor.initialize(rows.length)
+      val iterator = rows.generateIterator()
+      while (iterator.hasNext) {
+        processor.update(iterator.next())
+      }
     }
+
+    upperBound = rows.length
   }
 
   /** Write the frame columns for the current row to the given target row. */
   override def write(index: Int, current: InternalRow): Unit = {
     // Unfortunately we cannot assume that evaluation is deterministic. So we need to re-evaluate
     // for each row.
-    processor.evaluate(target)
+    if (processor != null) {
+      processor.evaluate(target)
+    }
   }
+
+  override def currentLowerBound(): Int = lowerBound
+
+  override def currentUpperBound(): Int = upperBound
 }
 
 /**
@@ -276,7 +310,7 @@ private[window] final class UnboundedWindowFunctionFrame(
  * @param processor to calculate the row values with.
  * @param ubound comparator used to identify the upper bound of an output row.
  */
-private[window] final class UnboundedPrecedingWindowFunctionFrame(
+final class UnboundedPrecedingWindowFunctionFrame(
     target: InternalRow,
     processor: AggregateProcessor,
     ubound: BoundOrdering)
@@ -308,7 +342,9 @@ private[window] final class UnboundedPrecedingWindowFunctionFrame(
       nextRow = inputIterator.next()
     }
 
-    processor.initialize(input.length)
+    if (processor != null) {
+      processor.initialize(input.length)
+    }
   }
 
   /** Write the frame columns for the current row to the given target row. */
@@ -318,17 +354,23 @@ private[window] final class UnboundedPrecedingWindowFunctionFrame(
     // Add all rows to the aggregates for which the input row value is equal to or less than
     // the output row upper bound.
     while (nextRow != null && ubound.compare(nextRow, inputIndex, current, index) <= 0) {
-      processor.update(nextRow)
+      if (processor != null) {
+        processor.update(nextRow)
+      }
       nextRow = WindowFunctionFrame.getNextOrNull(inputIterator)
       inputIndex += 1
       bufferUpdated = true
     }
 
     // Only recalculate and update when the buffer changes.
-    if (bufferUpdated) {
+    if (processor != null && bufferUpdated) {
       processor.evaluate(target)
     }
   }
+
+  override def currentLowerBound(): Int = 0
+
+  override def currentUpperBound(): Int = inputIndex
 }
 
 /**
@@ -347,7 +389,7 @@ private[window] final class UnboundedPrecedingWindowFunctionFrame(
  * @param processor to calculate the row values with.
  * @param lbound comparator used to identify the lower bound of an output row.
  */
-private[window] final class UnboundedFollowingWindowFunctionFrame(
+final class UnboundedFollowingWindowFunctionFrame(
     target: InternalRow,
     processor: AggregateProcessor,
     lbound: BoundOrdering)
@@ -384,7 +426,7 @@ private[window] final class UnboundedFollowingWindowFunctionFrame(
     }
 
     // Only recalculate and update when the buffer changes.
-    if (bufferUpdated) {
+    if (processor != null && bufferUpdated) {
       processor.initialize(input.length)
       if (nextRow != null) {
         processor.update(nextRow)
@@ -395,4 +437,8 @@ private[window] final class UnboundedFollowingWindowFunctionFrame(
       processor.evaluate(target)
     }
   }
+
+  override def currentLowerBound(): Int = inputIndex
+
+  override def currentUpperBound(): Int = input.length
 }


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org