You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by "dtenedor (via GitHub)" <gi...@apache.org> on 2023/11/06 20:09:34 UTC

[PR] [SPARK-45810][Python] Create Python UDTF API to stop consuming rows from the input table [spark]

dtenedor opened a new pull request, #43682:
URL: https://github.com/apache/spark/pull/43682

   ### What changes were proposed in this pull request?
   
   This PR creates a Python UDTF API to stop consuming rows from the input table.
   
   If the UDTF raises a `StopIteration` exception in the `eval` method, then the UDTF stops consuming rows from the input table for that input partition, and finally calls the `terminate` method (if any).
   
   For example:
   
   ```
   @udtf
   class TestUDTF:
       def __init__(self):
   	self._total = 0
   
       @staticmethod
       def analyze(_):
   	return AnalyzeResult(
   	    schema=StructType().add("total", IntegerType()), withSinglePartition=True
   	)
   
       def eval(self, _: Row):
   	self._total += 1
   	if self._total >= 3:
   	    raise StopIteration("StopIteration at self._total >= 3")
   
       def terminate(self):
   	yield self._total,
   ```
   
   ### Why are the changes needed?
   
   This is useful when the UDTF logic knows that we don't have to scan the input table anymore, and skip the rest of the I/O for that case.
   
   ### Does this PR introduce _any_ user-facing change?
   
   Yes, see above.
   
   ### How was this patch tested?
   
   This PR adds test coverage.
   
   ### Was this patch authored or co-authored using generative AI tooling?
   
   No


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Re: [PR] [SPARK-45810][Python] Create Python UDTF API to stop consuming rows from the input table [spark]

Posted by "dtenedor (via GitHub)" <gi...@apache.org>.
dtenedor commented on code in PR #43682:
URL: https://github.com/apache/spark/pull/43682#discussion_r1394630877


##########
python/pyspark/sql/tests/test_udtf.py:
##########
@@ -2482,6 +2533,7 @@ def tearDownClass(cls):
             super(UDTFTests, cls).tearDownClass()
 
 
+'''

Review Comment:
   My mistake on this, reverted.



##########
python/pyspark/sql/tests/test_udtf.py:
##########
@@ -2775,7 +2827,7 @@ def tearDownClass(cls):
             cls.spark.conf.unset("spark.sql.execution.pythonUDTF.arrow.enabled")
         finally:
             super(UDTFArrowTests, cls).tearDownClass()
-
+'''

Review Comment:
   My mistake on this, reverted.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Re: [PR] [SPARK-45810][Python] Create Python UDTF API to stop consuming rows from the input table [spark]

Posted by "ueshin (via GitHub)" <gi...@apache.org>.
ueshin commented on code in PR #43682:
URL: https://github.com/apache/spark/pull/43682#discussion_r1391800182


##########
python/docs/source/user_guide/sql/python_udtf.rst:
##########
@@ -101,20 +101,29 @@ To implement a Python UDTF, you first need to define a class implementing the me
                 partitionBy: Sequence[PartitioningColumn] = field(default_factory=tuple)
                 orderBy: Sequence[OrderingColumn] = field(default_factory=tuple)
 
+            Notes
+            -----
+            - It is possible for the `analyze` method to accept the exact arguments expected,
+              mapping 1:1 with the arguments provided to the UDTF call.
+            - The `analyze` method can instead choose ot accept positional arguments if desired
+              (using `*args`) or keyword arguments (using `**kwargs`).
+
             Examples
             --------
-            analyze implementation that returns one output column for each word in the input string
-            argument.
+            This is an `analyze` implementation that returns one output column for each word in the
+            input string argument.
 
-            >>> def analyze(self, text: str) -> AnalyzeResult:
+            >>> @staticmethod
+            ... def analyze(text: str) -> AnalyzeResult:

Review Comment:
   Thanks for the fix! 👍 



##########
python/docs/source/user_guide/sql/python_udtf.rst:
##########
@@ -101,20 +101,29 @@ To implement a Python UDTF, you first need to define a class implementing the me
                 partitionBy: Sequence[PartitioningColumn] = field(default_factory=tuple)
                 orderBy: Sequence[OrderingColumn] = field(default_factory=tuple)
 
+            Notes
+            -----
+            - It is possible for the `analyze` method to accept the exact arguments expected,
+              mapping 1:1 with the arguments provided to the UDTF call.
+            - The `analyze` method can instead choose ot accept positional arguments if desired
+              (using `*args`) or keyword arguments (using `**kwargs`).
+
             Examples
             --------
-            analyze implementation that returns one output column for each word in the input string
-            argument.
+            This is an `analyze` implementation that returns one output column for each word in the
+            input string argument.
 
-            >>> def analyze(self, text: str) -> AnalyzeResult:
+            >>> @staticmethod
+            ... def analyze(text: str) -> AnalyzeResult:

Review Comment:
   Thanks for the fix! 👍 



##########
python/pyspark/sql/tests/test_udtf.py:
##########
@@ -2467,6 +2468,53 @@ def terminate(self):
             [Row(count=20, buffer="abc")],
         )
 
+    def test_udtf_with_skip_rest_of_input_table_exception(self):
+        @udtf(returnType="total: int")
+        class TestUDTF:
+            def __init__(self):
+                self._total = 0
+
+            def eval(self, _: Row):
+                self._total += 1
+                if self._total >= 4:
+                    raise SkipRestOfInputTableException("Stop at self._total >= 4")
+
+            def terminate(self):
+                yield self._total,
+
+        self.spark.udtf.register("test_udtf", TestUDTF)
+
+        # Run a test case including WITH SINGLE PARTITION on the UDTF call. The
+        # SkipRestOfInputTableException stops scanning rows after the fourth input row is consumed.
+        assertDataFrameEqual(
+            self.spark.sql(
+                """
+                WITH t AS (
+                  SELECT id FROM range(1, 21)
+                )
+                SELECT total
+                FROM test_udtf(TABLE(t) WITH SINGLE PARTITION)
+                """
+            ),
+            [Row(total=4)],
+        )
+        # Run a test case including WITH SINGLE PARTITION on the UDTF call. The
+        # SkipRestOfInputTableException stops scanning rows for each of the two partitions
+        # separately.
+        assertDataFrameEqual(
+            self.spark.sql(
+                """
+                WITH t AS (
+                  SELECT id FROM range(1, 21)
+                )
+                SELECT id / 10 AS id_divided_by_ten, total
+                FROM test_udtf(TABLE(t) PARTITION BY id / 10)
+                ORDER BY ALL
+                """
+            ),
+            [Row(id_divided_by_ten=0, total=4), Row(id_divided_by_ten=1, total=4)],

Review Comment:
   For this case, the expected results should be:
   
   ```py
   [Row(id_divided_by_ten=0, total=4), Row(id_divided_by_ten=1, total=4), Row(id_divided_by_ten=2, total=1)]
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Re: [PR] [SPARK-45810][Python] Create Python UDTF API to stop consuming rows from the input table [spark]

Posted by "ueshin (via GitHub)" <gi...@apache.org>.
ueshin commented on code in PR #43682:
URL: https://github.com/apache/spark/pull/43682#discussion_r1389997538


##########
python/pyspark/sql/tests/test_udtf.py:
##########
@@ -2467,6 +2468,41 @@ def terminate(self):
             [Row(count=20, buffer="abc")],
         )
 
+    def test_udtf_with_skip_rest_of_input_table_exception(self):
+        @udtf
+        class TestUDTF:
+            def __init__(self):
+                self._total = 0
+
+            @staticmethod
+            def analyze(_):
+                return AnalyzeResult(
+                    schema=StructType().add("total", IntegerType()), withSinglePartition=True
+                )
+
+            def eval(self, _: Row):
+                self._total += 1
+                if self._total >= 4:
+                    raise SkipRestOfInputTableException("Stop at self._total >= 4")
+
+            def terminate(self):
+                yield self._total,
+
+        self.spark.udtf.register("test_udtf", TestUDTF)
+
+        assertDataFrameEqual(
+            self.spark.sql(
+                """
+                WITH t AS (
+                  SELECT id FROM range(1, 21)
+                )
+                SELECT total
+                FROM test_udtf(TABLE(t))
+                """
+            ),
+            [Row(total=4)],
+        )

Review Comment:
   What happens with partition by? Could you add a test for the case?



##########
python/pyspark/sql/tests/test_udtf.py:
##########
@@ -2467,6 +2468,41 @@ def terminate(self):
             [Row(count=20, buffer="abc")],
         )
 
+    def test_udtf_with_skip_rest_of_input_table_exception(self):
+        @udtf
+        class TestUDTF:
+            def __init__(self):
+                self._total = 0
+
+            @staticmethod
+            def analyze(_):
+                return AnalyzeResult(
+                    schema=StructType().add("total", IntegerType()), withSinglePartition=True
+                )

Review Comment:
   nit: I guess we can use `@udf(returnType=...)` for the schema and `TABLE(t) WITH SINGLE PARTITION` to simplify the test.



##########
python/pyspark/sql/udtf.py:
##########
@@ -118,6 +125,13 @@ class AnalyzeResult:
     orderBy: Sequence[OrderingColumn] = field(default_factory=tuple)
 
 
+# This represents an exception that the 'eval' method may raise to indicate that it is done
+# consuming rows from the current partition of the input table. Then the UDTF's 'terminate' method
+# runs (if any).
+class SkipRestOfInputTableException(Exception):
+    pass

Review Comment:
   The comment should be in the class definition?
   
   ```py
   class SkipRestOfInputTableException(Exception):
       # This represents ...
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Re: [PR] [SPARK-45810][Python] Create Python UDTF API to stop consuming rows from the input table [spark]

Posted by "allisonwang-db (via GitHub)" <gi...@apache.org>.
allisonwang-db commented on code in PR #43682:
URL: https://github.com/apache/spark/pull/43682#discussion_r1384255664


##########
python/pyspark/worker.py:
##########
@@ -1057,6 +1059,9 @@ def mapper(_, it):
                     yield from eval(*[a[o] for o in args_kwargs_offsets])
                 if terminate is not None:
                     yield from terminate()
+            except StopIteration:
+                if terminate is not None:
+                    yield from terminate()

Review Comment:
   Ah the exception thrown in eval will be caught by this 
   ```
   except Exception as e:
       raise PySparkRuntimeError(
          error_class="UDTF_EXEC_ERROR",
   ```
   I wonder if we should just move the whole try ... except ... block to this mapper function instead of checking this for every single input row.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Re: [PR] [SPARK-45810][Python] Create Python UDTF API to stop consuming rows from the input table [spark]

Posted by "allisonwang-db (via GitHub)" <gi...@apache.org>.
allisonwang-db commented on code in PR #43682:
URL: https://github.com/apache/spark/pull/43682#discussion_r1384161755


##########
python/pyspark/worker.py:
##########
@@ -995,6 +995,8 @@ def verify_result(result):
             def func(*a: Any) -> Any:
                 try:
                     return f(*a)
+                except StopIteration:
+                    raise

Review Comment:
   Should we use something similar to `fail_on_stopiteration`?
   ```
           except StopIteration as exc:
               raise PySparkRuntimeError(
                   error_class="STOP_ITERATION_OCCURRED",
                   message_parameters={
                       "exc": str(exc),
                   },
               )
   
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Re: [PR] [SPARK-45810][Python] Create Python UDTF API to stop consuming rows from the input table [spark]

Posted by "ueshin (via GitHub)" <gi...@apache.org>.
ueshin closed pull request #43682: [SPARK-45810][Python] Create Python UDTF API to stop consuming rows from the input table
URL: https://github.com/apache/spark/pull/43682


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Re: [PR] [SPARK-45810][Python] Create Python UDTF API to stop consuming rows from the input table [spark]

Posted by "dtenedor (via GitHub)" <gi...@apache.org>.
dtenedor commented on code in PR #43682:
URL: https://github.com/apache/spark/pull/43682#discussion_r1393418903


##########
python/pyspark/sql/tests/test_udtf.py:
##########
@@ -2467,6 +2468,53 @@ def terminate(self):
             [Row(count=20, buffer="abc")],
         )
 
+    def test_udtf_with_skip_rest_of_input_table_exception(self):
+        @udtf(returnType="total: int")
+        class TestUDTF:
+            def __init__(self):
+                self._total = 0
+
+            def eval(self, _: Row):
+                self._total += 1
+                if self._total >= 4:
+                    raise SkipRestOfInputTableException("Stop at self._total >= 4")
+
+            def terminate(self):
+                yield self._total,
+
+        self.spark.udtf.register("test_udtf", TestUDTF)
+
+        # Run a test case including WITH SINGLE PARTITION on the UDTF call. The
+        # SkipRestOfInputTableException stops scanning rows after the fourth input row is consumed.
+        assertDataFrameEqual(
+            self.spark.sql(
+                """
+                WITH t AS (
+                  SELECT id FROM range(1, 21)
+                )
+                SELECT total
+                FROM test_udtf(TABLE(t) WITH SINGLE PARTITION)
+                """
+            ),
+            [Row(total=4)],
+        )
+        # Run a test case including WITH SINGLE PARTITION on the UDTF call. The
+        # SkipRestOfInputTableException stops scanning rows for each of the two partitions
+        # separately.
+        assertDataFrameEqual(
+            self.spark.sql(
+                """
+                WITH t AS (
+                  SELECT id FROM range(1, 21)
+                )
+                SELECT id / 10 AS id_divided_by_ten, total
+                FROM test_udtf(TABLE(t) PARTITION BY id / 10)
+                ORDER BY ALL
+                """
+            ),
+            [Row(id_divided_by_ten=0, total=4), Row(id_divided_by_ten=1, total=4)],

Review Comment:
   You're right, this was actually a bug. The `class UDTFWithPartitions` did not have support for this new exception type before. I added support for that now, and the test passes.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Re: [PR] [SPARK-45810][Python] Create Python UDTF API to stop consuming rows from the input table [spark]

Posted by "dtenedor (via GitHub)" <gi...@apache.org>.
dtenedor commented on code in PR #43682:
URL: https://github.com/apache/spark/pull/43682#discussion_r1391840056


##########
python/pyspark/sql/tests/test_udtf.py:
##########
@@ -2467,6 +2468,53 @@ def terminate(self):
             [Row(count=20, buffer="abc")],
         )
 
+    def test_udtf_with_skip_rest_of_input_table_exception(self):
+        @udtf(returnType="total: int")
+        class TestUDTF:
+            def __init__(self):
+                self._total = 0
+
+            def eval(self, _: Row):
+                self._total += 1
+                if self._total >= 4:
+                    raise SkipRestOfInputTableException("Stop at self._total >= 4")
+
+            def terminate(self):
+                yield self._total,
+
+        self.spark.udtf.register("test_udtf", TestUDTF)
+
+        # Run a test case including WITH SINGLE PARTITION on the UDTF call. The
+        # SkipRestOfInputTableException stops scanning rows after the fourth input row is consumed.
+        assertDataFrameEqual(
+            self.spark.sql(
+                """
+                WITH t AS (
+                  SELECT id FROM range(1, 21)
+                )
+                SELECT total
+                FROM test_udtf(TABLE(t) WITH SINGLE PARTITION)
+                """
+            ),
+            [Row(total=4)],
+        )
+        # Run a test case including WITH SINGLE PARTITION on the UDTF call. The
+        # SkipRestOfInputTableException stops scanning rows for each of the two partitions
+        # separately.
+        assertDataFrameEqual(
+            self.spark.sql(
+                """
+                WITH t AS (
+                  SELECT id FROM range(1, 21)
+                )
+                SELECT id / 10 AS id_divided_by_ten, total
+                FROM test_udtf(TABLE(t) PARTITION BY id / 10)
+                ORDER BY ALL
+                """
+            ),
+            [Row(id_divided_by_ten=0, total=4), Row(id_divided_by_ten=1, total=4)],

Review Comment:
   I thought so as well, but apparently the `range` function accepts the second argument for its upper bound (exclusive) :) 
   
   ```
   > SELECT id FROM range(1, 21)
   
   id
   1
   2
   3
   4
   5
   6
   7
   8
   9
   10
   11
   12
   13
   14
   15
   16
   17
   18
   19
   20
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Re: [PR] [SPARK-45810][Python] Create Python UDTF API to stop consuming rows from the input table [spark]

Posted by "ueshin (via GitHub)" <gi...@apache.org>.
ueshin commented on code in PR #43682:
URL: https://github.com/apache/spark/pull/43682#discussion_r1391845246


##########
python/pyspark/sql/tests/test_udtf.py:
##########
@@ -2467,6 +2468,53 @@ def terminate(self):
             [Row(count=20, buffer="abc")],
         )
 
+    def test_udtf_with_skip_rest_of_input_table_exception(self):
+        @udtf(returnType="total: int")
+        class TestUDTF:
+            def __init__(self):
+                self._total = 0
+
+            def eval(self, _: Row):
+                self._total += 1
+                if self._total >= 4:
+                    raise SkipRestOfInputTableException("Stop at self._total >= 4")
+
+            def terminate(self):
+                yield self._total,
+
+        self.spark.udtf.register("test_udtf", TestUDTF)
+
+        # Run a test case including WITH SINGLE PARTITION on the UDTF call. The
+        # SkipRestOfInputTableException stops scanning rows after the fourth input row is consumed.
+        assertDataFrameEqual(
+            self.spark.sql(
+                """
+                WITH t AS (
+                  SELECT id FROM range(1, 21)
+                )
+                SELECT total
+                FROM test_udtf(TABLE(t) WITH SINGLE PARTITION)
+                """
+            ),
+            [Row(total=4)],
+        )
+        # Run a test case including WITH SINGLE PARTITION on the UDTF call. The
+        # SkipRestOfInputTableException stops scanning rows for each of the two partitions
+        # separately.
+        assertDataFrameEqual(
+            self.spark.sql(
+                """
+                WITH t AS (
+                  SELECT id FROM range(1, 21)
+                )
+                SELECT id / 10 AS id_divided_by_ten, total
+                FROM test_udtf(TABLE(t) PARTITION BY id / 10)
+                ORDER BY ALL
+                """
+            ),
+            [Row(id_divided_by_ten=0, total=4), Row(id_divided_by_ten=1, total=4)],

Review Comment:
   so `20` should be in another group?
   
   If the stop is not here, there will be 3 partitions including 9, 10, and 1 rows.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Re: [PR] [SPARK-45810][Python] Create Python UDTF API to stop consuming rows from the input table [spark]

Posted by "ueshin (via GitHub)" <gi...@apache.org>.
ueshin commented on code in PR #43682:
URL: https://github.com/apache/spark/pull/43682#discussion_r1391845246


##########
python/pyspark/sql/tests/test_udtf.py:
##########
@@ -2467,6 +2468,53 @@ def terminate(self):
             [Row(count=20, buffer="abc")],
         )
 
+    def test_udtf_with_skip_rest_of_input_table_exception(self):
+        @udtf(returnType="total: int")
+        class TestUDTF:
+            def __init__(self):
+                self._total = 0
+
+            def eval(self, _: Row):
+                self._total += 1
+                if self._total >= 4:
+                    raise SkipRestOfInputTableException("Stop at self._total >= 4")
+
+            def terminate(self):
+                yield self._total,
+
+        self.spark.udtf.register("test_udtf", TestUDTF)
+
+        # Run a test case including WITH SINGLE PARTITION on the UDTF call. The
+        # SkipRestOfInputTableException stops scanning rows after the fourth input row is consumed.
+        assertDataFrameEqual(
+            self.spark.sql(
+                """
+                WITH t AS (
+                  SELECT id FROM range(1, 21)
+                )
+                SELECT total
+                FROM test_udtf(TABLE(t) WITH SINGLE PARTITION)
+                """
+            ),
+            [Row(total=4)],
+        )
+        # Run a test case including WITH SINGLE PARTITION on the UDTF call. The
+        # SkipRestOfInputTableException stops scanning rows for each of the two partitions
+        # separately.
+        assertDataFrameEqual(
+            self.spark.sql(
+                """
+                WITH t AS (
+                  SELECT id FROM range(1, 21)
+                )
+                SELECT id / 10 AS id_divided_by_ten, total
+                FROM test_udtf(TABLE(t) PARTITION BY id / 10)
+                ORDER BY ALL
+                """
+            ),
+            [Row(id_divided_by_ten=0, total=4), Row(id_divided_by_ten=1, total=4)],

Review Comment:
   so `20` should be in another group?
   
   If the stop is not here, there will be 3 groups including 9, 10, and 1 rows.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Re: [PR] [SPARK-45810][Python] Create Python UDTF API to stop consuming rows from the input table [spark]

Posted by "dtenedor (via GitHub)" <gi...@apache.org>.
dtenedor commented on PR #43682:
URL: https://github.com/apache/spark/pull/43682#issuecomment-1796329902

   cc @ueshin @allisonwang-db 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Re: [PR] [SPARK-45810][Python] Create Python UDTF API to stop consuming rows from the input table [spark]

Posted by "dtenedor (via GitHub)" <gi...@apache.org>.
dtenedor commented on code in PR #43682:
URL: https://github.com/apache/spark/pull/43682#discussion_r1391771813


##########
python/pyspark/sql/tests/test_udtf.py:
##########
@@ -2467,6 +2468,41 @@ def terminate(self):
             [Row(count=20, buffer="abc")],
         )
 
+    def test_udtf_with_skip_rest_of_input_table_exception(self):
+        @udtf
+        class TestUDTF:
+            def __init__(self):
+                self._total = 0
+
+            @staticmethod
+            def analyze(_):
+                return AnalyzeResult(
+                    schema=StructType().add("total", IntegerType()), withSinglePartition=True
+                )
+
+            def eval(self, _: Row):
+                self._total += 1
+                if self._total >= 4:
+                    raise SkipRestOfInputTableException("Stop at self._total >= 4")
+
+            def terminate(self):
+                yield self._total,
+
+        self.spark.udtf.register("test_udtf", TestUDTF)
+
+        assertDataFrameEqual(
+            self.spark.sql(
+                """
+                WITH t AS (
+                  SELECT id FROM range(1, 21)
+                )
+                SELECT total
+                FROM test_udtf(TABLE(t))
+                """
+            ),
+            [Row(total=4)],
+        )

Review Comment:
   The `SkipRestOfInputTableException` stops scanning rows for just the current partition. I added a test case for this as well.



##########
python/pyspark/sql/udtf.py:
##########
@@ -118,6 +125,13 @@ class AnalyzeResult:
     orderBy: Sequence[OrderingColumn] = field(default_factory=tuple)
 
 
+# This represents an exception that the 'eval' method may raise to indicate that it is done
+# consuming rows from the current partition of the input table. Then the UDTF's 'terminate' method
+# runs (if any).
+class SkipRestOfInputTableException(Exception):
+    pass

Review Comment:
   Good point, I moved this to a pydoc string inside the class definition itself.



##########
python/pyspark/sql/tests/test_udtf.py:
##########
@@ -2467,6 +2468,41 @@ def terminate(self):
             [Row(count=20, buffer="abc")],
         )
 
+    def test_udtf_with_skip_rest_of_input_table_exception(self):
+        @udtf
+        class TestUDTF:
+            def __init__(self):
+                self._total = 0
+
+            @staticmethod
+            def analyze(_):
+                return AnalyzeResult(
+                    schema=StructType().add("total", IntegerType()), withSinglePartition=True
+                )

Review Comment:
   Good idea, done!



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Re: [PR] [SPARK-45810][Python] Create Python UDTF API to stop consuming rows from the input table [spark]

Posted by "ueshin (via GitHub)" <gi...@apache.org>.
ueshin commented on code in PR #43682:
URL: https://github.com/apache/spark/pull/43682#discussion_r1391846444


##########
python/pyspark/sql/tests/test_udtf.py:
##########
@@ -2467,6 +2468,53 @@ def terminate(self):
             [Row(count=20, buffer="abc")],
         )
 
+    def test_udtf_with_skip_rest_of_input_table_exception(self):
+        @udtf(returnType="total: int")
+        class TestUDTF:
+            def __init__(self):
+                self._total = 0
+
+            def eval(self, _: Row):
+                self._total += 1
+                if self._total >= 4:
+                    raise SkipRestOfInputTableException("Stop at self._total >= 4")
+
+            def terminate(self):
+                yield self._total,
+
+        self.spark.udtf.register("test_udtf", TestUDTF)
+
+        # Run a test case including WITH SINGLE PARTITION on the UDTF call. The
+        # SkipRestOfInputTableException stops scanning rows after the fourth input row is consumed.
+        assertDataFrameEqual(
+            self.spark.sql(
+                """
+                WITH t AS (
+                  SELECT id FROM range(1, 21)
+                )
+                SELECT total
+                FROM test_udtf(TABLE(t) WITH SINGLE PARTITION)
+                """
+            ),
+            [Row(total=4)],
+        )
+        # Run a test case including WITH SINGLE PARTITION on the UDTF call. The
+        # SkipRestOfInputTableException stops scanning rows for each of the two partitions
+        # separately.
+        assertDataFrameEqual(
+            self.spark.sql(
+                """
+                WITH t AS (
+                  SELECT id FROM range(1, 21)
+                )
+                SELECT id / 10 AS id_divided_by_ten, total
+                FROM test_udtf(TABLE(t) PARTITION BY id / 10)

Review Comment:
   btw, `id / 10` will be `double`. I guess it should be `floor(id / 10)`?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Re: [PR] [SPARK-45810][Python] Create Python UDTF API to stop consuming rows from the input table [spark]

Posted by "dtenedor (via GitHub)" <gi...@apache.org>.
dtenedor commented on code in PR #43682:
URL: https://github.com/apache/spark/pull/43682#discussion_r1384198085


##########
python/pyspark/worker.py:
##########
@@ -995,6 +995,8 @@ def verify_result(result):
             def func(*a: Any) -> Any:
                 try:
                     return f(*a)
+                except StopIteration:
+                    raise

Review Comment:
   I looked at this, but unlike scalar pyspark UDFs, we want to apply custom treatment for UDTFs with respect to the `eval` vs. `terminate` vs. `cleanup` methods. The idea is that if `eval` invokes `raise StopIteration()`, we then call `terminate` as normal and return a successful result for the UDTF as a whole.
   
   To implement this behavior, if we used `fail_on_stopiteration` here, we'd have to later catch a `PySparkRuntimeException` and check if it's `error_class="STOP_ITERATION_OCCURRED"` manually, which seems pretty confusing. It seems simpler to just `raise` here and catch the `StopIteration` on L1062 and L1154 instead, and then call `terminate`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Re: [PR] [SPARK-45810][Python] Create Python UDTF API to stop consuming rows from the input table [spark]

Posted by "ueshin (via GitHub)" <gi...@apache.org>.
ueshin commented on PR #43682:
URL: https://github.com/apache/spark/pull/43682#issuecomment-1806541757

   btw, the example in the description should not use tabs? It breaks the code format.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Re: [PR] [SPARK-45810][Python] Create Python UDTF API to stop consuming rows from the input table [spark]

Posted by "ueshin (via GitHub)" <gi...@apache.org>.
ueshin commented on PR #43682:
URL: https://github.com/apache/spark/pull/43682#issuecomment-1813313871

   Thanks! merging to master.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Re: [PR] [SPARK-45810][Python] Create Python UDTF API to stop consuming rows from the input table [spark]

Posted by "ueshin (via GitHub)" <gi...@apache.org>.
ueshin commented on code in PR #43682:
URL: https://github.com/apache/spark/pull/43682#discussion_r1393589351


##########
python/pyspark/sql/tests/test_udtf.py:
##########
@@ -2482,6 +2533,7 @@ def tearDownClass(cls):
             super(UDTFTests, cls).tearDownClass()
 
 
+'''

Review Comment:
   Looks like a mistake? Could you revert this change?



##########
python/pyspark/sql/tests/test_udtf.py:
##########
@@ -2775,7 +2827,7 @@ def tearDownClass(cls):
             cls.spark.conf.unset("spark.sql.execution.pythonUDTF.arrow.enabled")
         finally:
             super(UDTFArrowTests, cls).tearDownClass()
-
+'''

Review Comment:
   ditto.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Re: [PR] [SPARK-45810][Python] Create Python UDTF API to stop consuming rows from the input table [spark]

Posted by "dtenedor (via GitHub)" <gi...@apache.org>.
dtenedor commented on code in PR #43682:
URL: https://github.com/apache/spark/pull/43682#discussion_r1393418317


##########
python/pyspark/sql/tests/test_udtf.py:
##########
@@ -2467,6 +2468,53 @@ def terminate(self):
             [Row(count=20, buffer="abc")],
         )
 
+    def test_udtf_with_skip_rest_of_input_table_exception(self):
+        @udtf(returnType="total: int")
+        class TestUDTF:
+            def __init__(self):
+                self._total = 0
+
+            def eval(self, _: Row):
+                self._total += 1
+                if self._total >= 4:
+                    raise SkipRestOfInputTableException("Stop at self._total >= 4")
+
+            def terminate(self):
+                yield self._total,
+
+        self.spark.udtf.register("test_udtf", TestUDTF)
+
+        # Run a test case including WITH SINGLE PARTITION on the UDTF call. The
+        # SkipRestOfInputTableException stops scanning rows after the fourth input row is consumed.
+        assertDataFrameEqual(
+            self.spark.sql(
+                """
+                WITH t AS (
+                  SELECT id FROM range(1, 21)
+                )
+                SELECT total
+                FROM test_udtf(TABLE(t) WITH SINGLE PARTITION)
+                """
+            ),
+            [Row(total=4)],
+        )
+        # Run a test case including WITH SINGLE PARTITION on the UDTF call. The
+        # SkipRestOfInputTableException stops scanning rows for each of the two partitions
+        # separately.
+        assertDataFrameEqual(
+            self.spark.sql(
+                """
+                WITH t AS (
+                  SELECT id FROM range(1, 21)
+                )
+                SELECT id / 10 AS id_divided_by_ten, total
+                FROM test_udtf(TABLE(t) PARTITION BY id / 10)

Review Comment:
   Sounds good, done.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Re: [PR] [SPARK-45810][Python] Create Python UDTF API to stop consuming rows from the input table [spark]

Posted by "dtenedor (via GitHub)" <gi...@apache.org>.
dtenedor commented on code in PR #43682:
URL: https://github.com/apache/spark/pull/43682#discussion_r1385390475


##########
python/pyspark/worker.py:
##########
@@ -1057,6 +1059,9 @@ def mapper(_, it):
                     yield from eval(*[a[o] for o in args_kwargs_offsets])
                 if terminate is not None:
                     yield from terminate()
+            except StopIteration:
+                if terminate is not None:
+                    yield from terminate()

Review Comment:
   I tried this, but then we lose the UDTF method name (i.e. `eval` or `terminate`) in the error message. It seems better to keep that. I will leave this alone for now, let me know if you would prefer to continue with this change anyway and lose this information in the error message and we can proceed with that.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Re: [PR] [SPARK-45810][Python] Create Python UDTF API to stop consuming rows from the input table [spark]

Posted by "allisonwang-db (via GitHub)" <gi...@apache.org>.
allisonwang-db commented on code in PR #43682:
URL: https://github.com/apache/spark/pull/43682#discussion_r1384244315


##########
python/pyspark/worker.py:
##########
@@ -1057,6 +1059,9 @@ def mapper(_, it):
                     yield from eval(*[a[o] for o in args_kwargs_offsets])
                 if terminate is not None:
                     yield from terminate()
+            except StopIteration:
+                if terminate is not None:
+                    yield from terminate()

Review Comment:
   If we are catching the `StopIteration` exception in mapper, do we still need the try ... catch block inside the `def func` and `def evaluate` below? 



##########
python/pyspark/worker.py:
##########
@@ -995,6 +995,8 @@ def verify_result(result):
             def func(*a: Any) -> Any:
                 try:
                     return f(*a)
+                except StopIteration:
+                    raise

Review Comment:
   Make sense!



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Re: [PR] [SPARK-45810][Python] Create Python UDTF API to stop consuming rows from the input table [spark]

Posted by "dtenedor (via GitHub)" <gi...@apache.org>.
dtenedor commented on PR #43682:
URL: https://github.com/apache/spark/pull/43682#issuecomment-1804852645

   cc @ueshin 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org