You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by GitBox <gi...@apache.org> on 2022/09/15 04:10:59 UTC

[GitHub] [spark] HeartSaVioR opened a new pull request, #37893: [DRAFT][DO-NOT-MERGE][SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

HeartSaVioR opened a new pull request, #37893:
URL: https://github.com/apache/spark/pull/37893

   ...TBD...
   
   ### What changes were proposed in this pull request?
   <!--
   Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. 
   If possible, please consider writing useful notes for better and faster reviews in your PR. See the examples below.
     1. If you refactor some codes with changing classes, showing the class hierarchy will help reviewers.
     2. If you fix some SQL features, you can provide some references of other DBMSes.
     3. If there is design documentation, please add the link.
     4. If there is a discussion in the mailing list, please add the link.
   -->
   
   
   ### Why are the changes needed?
   <!--
   Please clarify why the changes are needed. For instance,
     1. If you propose a new API, clarify the use case for a new API.
     2. If you fix a bug, you can clarify why it is a bug.
   -->
   
   
   ### Does this PR introduce _any_ user-facing change?
   
   Yes. We are exposing new public API in PySpark which performs arbitrary stateful processing.
   
   ### How was this patch tested?
   <!--
   If tests were added, say they were added here. Please make sure to add some test cases that check the changes thoroughly including negative and positive cases if possible.
   If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future.
   If tests were not added, please describe why they were not added and/or why it was difficult to add.
   If benchmark tests were added, please run the benchmarks in GitHub Actions for the consistent environment, and the instructions could accord to: https://spark.apache.org/developer-tools.html#github-workflow-benchmarks.
   -->
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #37893: [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #37893:
URL: https://github.com/apache/spark/pull/37893#discussion_r974858058


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala:
##########
@@ -0,0 +1,201 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import java.io._
+
+import scala.collection.JavaConverters._
+
+import org.apache.arrow.vector.VectorSchemaRoot
+import org.apache.arrow.vector.ipc.ArrowStreamWriter
+import org.json4s._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.api.python._
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.api.python.PythonSQLUtils
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.execution.python.ApplyInPandasWithStatePythonRunner.{InType, OutType, OutTypeForState, STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER}
+import org.apache.spark.sql.execution.python.ApplyInPandasWithStateWriter.STATE_METADATA_SCHEMA
+import org.apache.spark.sql.execution.streaming.GroupStateImpl
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch}
+
+
+/**
+ * A variant implementation of [[ArrowPythonRunner]] to serve the operation
+ * applyInPandasWithState.
+ */
+class ApplyInPandasWithStatePythonRunner(
+    funcs: Seq[ChainedPythonFunctions],
+    evalType: Int,
+    argOffsets: Array[Array[Int]],
+    inputSchema: StructType,
+    override protected val timeZoneId: String,
+    initialWorkerConf: Map[String, String],
+    stateEncoder: ExpressionEncoder[Row],
+    keySchema: StructType,
+    valueSchema: StructType,
+    stateValueSchema: StructType,
+    softLimitBytesPerBatch: Long,
+    minDataCountForSample: Int,
+    softTimeoutMillsPurgeBatch: Long)
+  extends BasePythonRunner[InType, OutType](funcs, evalType, argOffsets)
+  with PythonArrowInput[InType]
+  with PythonArrowOutput[OutType] {
+
+  override protected val schema: StructType = inputSchema.add("__state", STATE_METADATA_SCHEMA)
+
+  override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback
+
+  override val bufferSize: Int = SQLConf.get.pandasUDFBufferSize
+  require(
+    bufferSize >= 4,

Review Comment:
   Never mind. I just talked with @HyukjinKwon and understood how buffer works (I misunderstood) - it's more about how many Arrow RecordBatches to buffer and flush at once for efficiency. An Arrow RecordBatch bigger than buffer will be still considered as a single Arrow RecordBatch.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #37893: [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #37893:
URL: https://github.com/apache/spark/pull/37893#discussion_r973868967


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala:
##########
@@ -2705,6 +2705,44 @@ object SQLConf {
       .booleanConf
       .createWithDefault(false)
 
+  val MAP_PANDAS_UDF_WITH_STATE_SOFT_LIMIT_SIZE_PER_BATCH =
+    buildConf("spark.sql.execution.applyInPandasWithState.softLimitSizePerBatch")
+      .internal()
+      .doc("When using applyInPandasWithState, set a soft limit of the accumulated size of " +
+        "records that can be written to a single ArrowRecordBatch in memory. This is used to " +
+        "restrict the amount of memory being used to materialize the data in both executor and " +
+        "Python worker. The accumulated size of records are calculated via sampling a set of " +
+        "records. Splitting the ArrowRecordBatch is performed per record, so unless a record " +
+        "is quite huge, the size of constructed ArrowRecordBatch will be around the " +
+        "configured value.")
+      .version("3.4.0")
+      .bytesConf(ByteUnit.BYTE)
+      .createWithDefaultString("64MB")

Review Comment:
   Ah, SPARK-23258 is about restricting arrow record batch to size, seems similar with what we propose in this PR. It's still questionable if we calculate in every addition of row (accurate but would be super bad on performance) or sample as we do here (cannot be accurate and err might be non-trivial with variable-length columns).



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on pull request #37893: [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on PR #37893:
URL: https://github.com/apache/spark/pull/37893#issuecomment-1253325944

   @HyukjinKwon @alex-balikov 
   Please go with another round of review, thanks in advance!


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #37893: [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #37893:
URL: https://github.com/apache/spark/pull/37893#discussion_r974906420


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala:
##########
@@ -311,6 +323,56 @@ object UnsupportedOperationChecker extends Logging {
             }
           }
 
+        // applyInPandasWithState

Review Comment:
   NOTE: this is just a copy and paste of flatMapGroupsWithState since the characteristics are same for both.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #37893: [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on code in PR #37893:
URL: https://github.com/apache/spark/pull/37893#discussion_r973840041


##########
python/pyspark/sql/pandas/_typing/__init__.pyi:
##########
@@ -256,6 +258,10 @@ PandasGroupedMapFunction = Union[
     Callable[[Any, DataFrameLike], DataFrameLike],
 ]
 
+PandasGroupedMapFunctionWithState = Callable[
+    [Any, Iterable[DataFrameLike], GroupStateImpl], Iterable[DataFrameLike]

Review Comment:
   One concern is that if we happen to have a different implementation of `GroupState` in the far future. But the type is dynamic anyway so I don't worry too much.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #37893: [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #37893:
URL: https://github.com/apache/spark/pull/37893#discussion_r975800900


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala:
##########
@@ -0,0 +1,240 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import scala.collection.JavaConverters._
+
+import org.apache.arrow.vector.{FieldVector, VectorSchemaRoot}
+import org.apache.arrow.vector.ipc.ArrowStreamWriter
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.api.python.PythonSQLUtils
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow}
+import org.apache.spark.sql.execution.arrow.ArrowWriter
+import org.apache.spark.sql.execution.arrow.ArrowWriter.createFieldWriter
+import org.apache.spark.sql.execution.streaming.GroupStateImpl
+import org.apache.spark.sql.types.{BinaryType, BooleanType, IntegerType, StringType, StructField, StructType}
+import org.apache.spark.unsafe.types.UTF8String
+
+/**
+ * This class abstracts the complexity on constructing Arrow RecordBatches for data and state with
+ * bin-packing and chunking. The caller only need to call the proper public methods of this class
+ * `startNewGroup`, `writeRow`, `finalizeGroup`, `finalizeData` and this class will write the data
+ * and state into Arrow RecordBatches with performing bin-pack and chunk internally.
+ *
+ * This class requires that the parameter `root` has been initialized with the Arrow schema like
+ * below:
+ * - data fields
+ * - state field
+ *   - nested schema (Refer ApplyInPandasWithStateWriter.STATE_METADATA_SCHEMA)
+ *
+ * Please refer the code comment in the implementation to see how the writes of data and state
+ * against Arrow RecordBatch work with consideration of bin-packing and chunking.
+ */
+class ApplyInPandasWithStateWriter(
+    root: VectorSchemaRoot,
+    writer: ArrowStreamWriter,
+    arrowMaxRecordsPerBatch: Int) {
+
+  import ApplyInPandasWithStateWriter._
+
+  // Unlike applyInPandas (and other PySpark operators), applyInPandasWithState requires to produce
+  // the additional data `state`, along with the input data.
+  //
+  // ArrowStreamWriter supports only single VectorSchemaRoot, which means all Arrow RecordBatches
+  // being sent out from ArrowStreamWriter should have same schema. That said, we have to construct
+  // "an" Arrow schema to contain both types of data, and also construct Arrow RecordBatches to
+  // contain both data.
+  //
+  // To achieve this, we extend the schema for input data to have a column for state at the end.
+  // But also, we logically group the columns by family (data vs state) and initialize writer
+  // separately, since it's lot more easier and probably performant to write the row directly
+  // rather than projecting the row to match up with the overall schema.
+  //
+  // Although Arrow RecordBatch enables to write the data as columnar, we figure out it gives
+  // strange outputs if we don't ensure that all columns have the same number of values. Since
+  // there are at least one data for a grouping key (we ensure this for the case of handling timed
+  // out state as well) whereas there is only one state for a grouping key, we have to fill up the
+  // empty rows in state side to ensure both have the same number of rows.
+  private val arrowWriterForData = createArrowWriter(
+    root.getFieldVectors.asScala.toSeq.dropRight(1))
+  private val arrowWriterForState = createArrowWriter(
+    root.getFieldVectors.asScala.toSeq.takeRight(1))
+
+  // - Bin-packing
+  //
+  // We apply bin-packing the data from multiple groups into one Arrow RecordBatch to
+  // gain the performance. In many cases, the amount of data per grouping key is quite
+  // small, which does not seem to maximize the benefits of using Arrow.
+  //
+  // We have to split the record batch down to each group in Python worker to convert the
+  // data for group to Pandas, but hopefully, Arrow RecordBatch provides the way to split
+  // the range of data and give a view, say, "zero-copy". To help splitting the range for
+  // data, we provide the "start offset" and the "number of data" in the state metadata.
+  //
+  // We don't bin-pack all groups into a single record batch - we have a limit on the number
+  // of rows in the current Arrow RecordBatch to stop adding next group.
+  //
+  // - Chunking
+  //
+  // We also chunk the data from single group into multiple Arrow RecordBatch to ensure
+  // scalability. Note that we don't know the volume (number of rows, overall size) of data for
+  // specific group key before we read the entire data. The easiest approach to address both
+  // bin-pack and chunk is to check the number of rows in the current Arrow RecordBatch for each
+  // write of row.
+  //
+  // - Consideration
+  //
+  // Since the number of rows in Arrow RecordBatch does not represent the actual size (bytes),
+  // the limit should be set very conservatively. Using a small number of limit does not introduce
+  // correctness issues.
+
+  private var numRowsForCurGroup = 0
+  private var startOffsetForCurGroup = 0
+  private var totalNumRowsForBatch = 0
+  private var totalNumStatesForBatch = 0
+
+  private var currentGroupKeyRow: UnsafeRow = _
+  private var currentGroupState: GroupStateImpl[Row] = _
+
+  /**
+   * Indicates writer to start with new grouping key.
+   *
+   * @param keyRow The grouping key row for current group.
+   * @param groupState The instance of GroupStateImpl for current group.
+   */
+  def startNewGroup(keyRow: UnsafeRow, groupState: GroupStateImpl[Row]): Unit = {
+    currentGroupKeyRow = keyRow
+    currentGroupState = groupState
+  }
+
+  /**
+   * Indicates writer to write a row in the current group.
+   *
+   * @param dataRow The row to write in the current group.
+   */
+  def writeRow(dataRow: InternalRow): Unit = {
+    // If it exceeds the condition of batch (number of records) and there is more data for the
+    // same group, finalize and construct a new batch.
+
+    if (totalNumRowsForBatch >= arrowMaxRecordsPerBatch) {
+      // Provide state metadata row as intermediate
+      val stateInfoRow = buildStateInfoRow(currentGroupKeyRow, currentGroupState,
+        startOffsetForCurGroup, numRowsForCurGroup, isLastChunk = false)
+      arrowWriterForState.write(stateInfoRow)
+      totalNumStatesForBatch += 1
+
+      finalizeCurrentArrowBatch()
+    }
+
+    arrowWriterForData.write(dataRow)
+    numRowsForCurGroup += 1
+    totalNumRowsForBatch += 1
+  }
+
+  /**
+   * Indicates writer that current group has finalized and there will be no further row bound to
+   * the current group.
+   */
+  def finalizeGroup(): Unit = {
+    // Provide state metadata row
+    val stateInfoRow = buildStateInfoRow(currentGroupKeyRow, currentGroupState,
+      startOffsetForCurGroup, numRowsForCurGroup, isLastChunk = true)
+    arrowWriterForState.write(stateInfoRow)
+    totalNumStatesForBatch += 1
+
+    // The start offset for next group would be same as the total number of rows for batch,
+    // unless the next group starts with new batch.
+    startOffsetForCurGroup = totalNumRowsForBatch
+  }
+
+  /**
+   * Indicates writer that all groups have been processed.
+   */
+  def finalizeData(): Unit = {
+    if (numRowsForCurGroup > 0) {
+      // We still have some rows in the current record batch. Need to finalize them as well.
+      finalizeCurrentArrowBatch()
+    }
+  }
+
+  private def createArrowWriter(fieldVectors: Seq[FieldVector]): ArrowWriter = {
+    val children = fieldVectors.map { vector =>
+      vector.allocateNew()
+      createFieldWriter(vector)
+    }
+
+    new ArrowWriter(root, children.toArray)
+  }
+
+  private def buildStateInfoRow(
+      keyRow: UnsafeRow,
+      groupState: GroupStateImpl[Row],
+      startOffset: Int,
+      numRows: Int,
+      isLastChunk: Boolean): InternalRow = {
+    // NOTE: see ApplyInPandasWithStateWriter.STATE_METADATA_SCHEMA
+    val stateUnderlyingRow = new GenericInternalRow(
+      Array[Any](
+        UTF8String.fromString(groupState.json()),
+        keyRow.getBytes,
+        groupState.getOption.map(PythonSQLUtils.toPyRow).orNull,
+        startOffset,
+        numRows,
+        isLastChunk
+      )
+    )
+    new GenericInternalRow(Array[Any](stateUnderlyingRow))
+  }
+
+  private def finalizeCurrentArrowBatch(): Unit = {
+    val remainingEmptyStateRows = totalNumRowsForBatch - totalNumStatesForBatch
+    (0 until remainingEmptyStateRows).foreach { _ =>
+      arrowWriterForState.write(EMPTY_STATE_METADATA_ROW)
+    }
+
+    arrowWriterForState.finish()
+    arrowWriterForData.finish()
+    writer.writeBatch()
+    arrowWriterForState.reset()
+    arrowWriterForData.reset()
+
+    startOffsetForCurGroup = 0
+    numRowsForCurGroup = 0
+    totalNumRowsForBatch = 0
+    totalNumStatesForBatch = 0
+  }
+}
+
+object ApplyInPandasWithStateWriter {
+  val STATE_METADATA_SCHEMA: StructType = StructType(

Review Comment:
   Done. Additionally explained why the state metadata has the metadata of chunk as well.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on pull request #37893: [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on PR #37893:
URL: https://github.com/apache/spark/pull/37893#issuecomment-1249226867

   Will take a close look next Monday in KST.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #37893: [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #37893:
URL: https://github.com/apache/spark/pull/37893#discussion_r974862222


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala:
##########
@@ -0,0 +1,201 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import java.io._
+
+import scala.collection.JavaConverters._
+
+import org.apache.arrow.vector.VectorSchemaRoot
+import org.apache.arrow.vector.ipc.ArrowStreamWriter
+import org.json4s._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.api.python._
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.api.python.PythonSQLUtils
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.execution.python.ApplyInPandasWithStatePythonRunner.{InType, OutType, OutTypeForState, STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER}
+import org.apache.spark.sql.execution.python.ApplyInPandasWithStateWriter.STATE_METADATA_SCHEMA
+import org.apache.spark.sql.execution.streaming.GroupStateImpl
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch}
+
+
+/**
+ * A variant implementation of [[ArrowPythonRunner]] to serve the operation
+ * applyInPandasWithState.
+ */
+class ApplyInPandasWithStatePythonRunner(
+    funcs: Seq[ChainedPythonFunctions],
+    evalType: Int,
+    argOffsets: Array[Array[Int]],
+    inputSchema: StructType,
+    override protected val timeZoneId: String,
+    initialWorkerConf: Map[String, String],
+    stateEncoder: ExpressionEncoder[Row],
+    keySchema: StructType,
+    valueSchema: StructType,
+    stateValueSchema: StructType,
+    softLimitBytesPerBatch: Long,
+    minDataCountForSample: Int,
+    softTimeoutMillsPurgeBatch: Long)
+  extends BasePythonRunner[InType, OutType](funcs, evalType, argOffsets)
+  with PythonArrowInput[InType]
+  with PythonArrowOutput[OutType] {
+
+  override protected val schema: StructType = inputSchema.add("__state", STATE_METADATA_SCHEMA)
+
+  override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback
+
+  override val bufferSize: Int = SQLConf.get.pandasUDFBufferSize
+  require(
+    bufferSize >= 4,

Review Comment:
   We can do both, setting to 4 if it's less than 4, with warning log that they're encouraged to set it higher.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] alex-balikov commented on a diff in pull request #37893: [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

Posted by GitBox <gi...@apache.org>.
alex-balikov commented on code in PR #37893:
URL: https://github.com/apache/spark/pull/37893#discussion_r973478091


##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -216,6 +218,105 @@ def applyInPandas(
         jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
         return DataFrame(jdf, self.session)
 
+    def applyInPandasWithState(
+        self,
+        func: "PandasGroupedMapFunctionWithState",
+        outputStructType: Union[StructType, str],
+        stateStructType: Union[StructType, str],
+        outputMode: str,
+        timeoutConf: str,
+    ) -> DataFrame:
+        """
+        Applies the given function to each group of data, while maintaining a user-defined
+        per-group state. The result Dataset will represent the flattened record returned by the
+        function.
+
+        For a streaming Dataset, the function will be invoked for each group repeatedly in every
+        trigger, and updates to each group's state will be saved across invocations. The function
+        will also be invoked for each timed-out state repeatedly. The sequence of the invocation
+        will be input data -> state timeout. When the function is invoked for state timeout, there
+        will be no data being presented.
+
+        The function should takes parameters (key, Iterator[`pandas.DataFrame`], state) and
+        returns another Iterator[`pandas.DataFrame`]. The grouping key(s) will be passed as a tuple
+        of numpy data types, e.g., `numpy.int32` and `numpy.float64`. The state will be passed as
+        :class:`pyspark.sql.streaming.state.GroupStateImpl`.
+
+        For each group, all columns are passed together as `pandas.DataFrame` to the user-function,
+        and the returned `pandas.DataFrame` across all invocations are combined as a
+        :class:`DataFrame`. Note that the user function should loop through and process all
+        elements in the iterator. The user function should not make a guess of the number of
+        elements in the iterator.
+
+        The `outputStructType` should be a :class:`StructType` describing the schema of all
+        elements in returned value, `pandas.DataFrame`. The column labels of all elements in
+        returned value, `pandas.DataFrame` must either match the field names in the defined
+        schema if specified as strings, or match the field data types by position if not strings,
+        e.g. integer indices.
+
+        The `stateStructType` should be :class:`StructType` describing the schema of user-defined
+        state. The value of state will be presented as a tuple, as well as the update should be
+        performed with the tuple. User defined types e.g. native Python class types are not
+        supported. Alternatively, you can pickle the data and produce the data as BinaryType, but

Review Comment:
   'Alternatively, you can pickle the data ...' - instead say
   
   'For such cases, the user should pickle the data into BinaryType. Note that this approach may be sensitive to backwards and forward compatibility issues of Python picks and Spark can not guarantee compatibility.
   
   though I think you could drop the note as that is orthogonal to Spark.



##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -216,6 +218,105 @@ def applyInPandas(
         jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
         return DataFrame(jdf, self.session)
 
+    def applyInPandasWithState(
+        self,
+        func: "PandasGroupedMapFunctionWithState",
+        outputStructType: Union[StructType, str],
+        stateStructType: Union[StructType, str],
+        outputMode: str,
+        timeoutConf: str,
+    ) -> DataFrame:
+        """
+        Applies the given function to each group of data, while maintaining a user-defined
+        per-group state. The result Dataset will represent the flattened record returned by the
+        function.
+
+        For a streaming Dataset, the function will be invoked for each group repeatedly in every
+        trigger, and updates to each group's state will be saved across invocations. The function
+        will also be invoked for each timed-out state repeatedly. The sequence of the invocation
+        will be input data -> state timeout. When the function is invoked for state timeout, there
+        will be no data being presented.
+
+        The function should takes parameters (key, Iterator[`pandas.DataFrame`], state) and
+        returns another Iterator[`pandas.DataFrame`]. The grouping key(s) will be passed as a tuple
+        of numpy data types, e.g., `numpy.int32` and `numpy.float64`. The state will be passed as
+        :class:`pyspark.sql.streaming.state.GroupStateImpl`.
+
+        For each group, all columns are passed together as `pandas.DataFrame` to the user-function,
+        and the returned `pandas.DataFrame` across all invocations are combined as a
+        :class:`DataFrame`. Note that the user function should loop through and process all
+        elements in the iterator. The user function should not make a guess of the number of
+        elements in the iterator.
+
+        The `outputStructType` should be a :class:`StructType` describing the schema of all
+        elements in returned value, `pandas.DataFrame`. The column labels of all elements in
+        returned value, `pandas.DataFrame` must either match the field names in the defined
+        schema if specified as strings, or match the field data types by position if not strings,
+        e.g. integer indices.
+
+        The `stateStructType` should be :class:`StructType` describing the schema of user-defined
+        state. The value of state will be presented as a tuple, as well as the update should be
+        performed with the tuple. User defined types e.g. native Python class types are not
+        supported. Alternatively, you can pickle the data and produce the data as BinaryType, but
+        it is tied to the backward and forward compatibility of pickle in Python, and Spark itself
+        does not guarantee the compatibility.
+
+        The length of each element in both input and returned value, `pandas.DataFrame`, can be

Review Comment:
   'The size of each DataFrame in both the input and output ...'
   
   'The number of DataFrames in both the input and output can also be arbitrary.'



##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -216,6 +218,105 @@ def applyInPandas(
         jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
         return DataFrame(jdf, self.session)
 
+    def applyInPandasWithState(
+        self,
+        func: "PandasGroupedMapFunctionWithState",
+        outputStructType: Union[StructType, str],
+        stateStructType: Union[StructType, str],
+        outputMode: str,
+        timeoutConf: str,
+    ) -> DataFrame:
+        """
+        Applies the given function to each group of data, while maintaining a user-defined
+        per-group state. The result Dataset will represent the flattened record returned by the
+        function.
+
+        For a streaming Dataset, the function will be invoked for each group repeatedly in every
+        trigger, and updates to each group's state will be saved across invocations. The function
+        will also be invoked for each timed-out state repeatedly. The sequence of the invocation
+        will be input data -> state timeout. When the function is invoked for state timeout, there

Review Comment:
   I do not like '->' here - this is supposed to be text. How about:
   
   'The function will be invoked first for all input groups and then for al timed out states where the input data will be null.



##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -216,6 +218,105 @@ def applyInPandas(
         jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
         return DataFrame(jdf, self.session)
 
+    def applyInPandasWithState(
+        self,
+        func: "PandasGroupedMapFunctionWithState",
+        outputStructType: Union[StructType, str],
+        stateStructType: Union[StructType, str],
+        outputMode: str,
+        timeoutConf: str,
+    ) -> DataFrame:
+        """
+        Applies the given function to each group of data, while maintaining a user-defined
+        per-group state. The result Dataset will represent the flattened record returned by the
+        function.
+
+        For a streaming Dataset, the function will be invoked for each group repeatedly in every
+        trigger, and updates to each group's state will be saved across invocations. The function
+        will also be invoked for each timed-out state repeatedly. The sequence of the invocation
+        will be input data -> state timeout. When the function is invoked for state timeout, there
+        will be no data being presented.
+
+        The function should takes parameters (key, Iterator[`pandas.DataFrame`], state) and
+        returns another Iterator[`pandas.DataFrame`]. The grouping key(s) will be passed as a tuple
+        of numpy data types, e.g., `numpy.int32` and `numpy.float64`. The state will be passed as
+        :class:`pyspark.sql.streaming.state.GroupStateImpl`.
+
+        For each group, all columns are passed together as `pandas.DataFrame` to the user-function,

Review Comment:
   'all columns are passed together as `pandas.DataFrame` ...' - this is confusing - of course all columns will be passed together. How about:
   
   Each group is passed as one or more pandas.DataFrame group of records with all columns packed into the DataFrame.



##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -216,6 +218,105 @@ def applyInPandas(
         jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
         return DataFrame(jdf, self.session)
 
+    def applyInPandasWithState(
+        self,
+        func: "PandasGroupedMapFunctionWithState",
+        outputStructType: Union[StructType, str],
+        stateStructType: Union[StructType, str],
+        outputMode: str,
+        timeoutConf: str,
+    ) -> DataFrame:
+        """
+        Applies the given function to each group of data, while maintaining a user-defined
+        per-group state. The result Dataset will represent the flattened record returned by the
+        function.
+
+        For a streaming Dataset, the function will be invoked for each group repeatedly in every
+        trigger, and updates to each group's state will be saved across invocations. The function
+        will also be invoked for each timed-out state repeatedly. The sequence of the invocation
+        will be input data -> state timeout. When the function is invoked for state timeout, there
+        will be no data being presented.
+
+        The function should takes parameters (key, Iterator[`pandas.DataFrame`], state) and
+        returns another Iterator[`pandas.DataFrame`]. The grouping key(s) will be passed as a tuple
+        of numpy data types, e.g., `numpy.int32` and `numpy.float64`. The state will be passed as
+        :class:`pyspark.sql.streaming.state.GroupStateImpl`.
+
+        For each group, all columns are passed together as `pandas.DataFrame` to the user-function,
+        and the returned `pandas.DataFrame` across all invocations are combined as a
+        :class:`DataFrame`. Note that the user function should loop through and process all

Review Comment:
   'Note that the user function should loop through and process all elements in the iterator. The user function should not make a guess of the number of  elements in the iterator.' 
   
   - Why? This sounds like the use *must* process all iterator entries or otherwise something bad would happen. I would reword this to indicate that the grouped data could be split into multiple entries - 
   
   'Note that the group data may be split as multiple Iterator records and the user function should not assume that it receives a single record.'
   
   I would still suggest we have a design discussion about splitting groups unnecessary as I believe we should not do this.
   
   



##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -216,6 +218,105 @@ def applyInPandas(
         jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
         return DataFrame(jdf, self.session)
 
+    def applyInPandasWithState(
+        self,
+        func: "PandasGroupedMapFunctionWithState",
+        outputStructType: Union[StructType, str],
+        stateStructType: Union[StructType, str],
+        outputMode: str,
+        timeoutConf: str,
+    ) -> DataFrame:
+        """
+        Applies the given function to each group of data, while maintaining a user-defined
+        per-group state. The result Dataset will represent the flattened record returned by the
+        function.
+
+        For a streaming Dataset, the function will be invoked for each group repeatedly in every

Review Comment:
   remove 'repeatedly' - f'for each group' implies that.



##########
python/pyspark/sql/pandas/_typing/__init__.pyi:
##########
@@ -256,6 +258,10 @@ PandasGroupedMapFunction = Union[
     Callable[[Any, DataFrameLike], DataFrameLike],
 ]
 
+PandasGroupedMapFunctionWithState = Callable[
+    [Any, Iterable[DataFrameLike], GroupStateImpl], Iterable[DataFrameLike]

Review Comment:
   Can the type be GroupState without the 'Impl' - looks bad in public api.



##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -216,6 +218,105 @@ def applyInPandas(
         jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
         return DataFrame(jdf, self.session)
 
+    def applyInPandasWithState(
+        self,
+        func: "PandasGroupedMapFunctionWithState",
+        outputStructType: Union[StructType, str],
+        stateStructType: Union[StructType, str],
+        outputMode: str,
+        timeoutConf: str,
+    ) -> DataFrame:
+        """
+        Applies the given function to each group of data, while maintaining a user-defined
+        per-group state. The result Dataset will represent the flattened record returned by the
+        function.
+
+        For a streaming Dataset, the function will be invoked for each group repeatedly in every
+        trigger, and updates to each group's state will be saved across invocations. The function
+        will also be invoked for each timed-out state repeatedly. The sequence of the invocation
+        will be input data -> state timeout. When the function is invoked for state timeout, there
+        will be no data being presented.
+
+        The function should takes parameters (key, Iterator[`pandas.DataFrame`], state) and
+        returns another Iterator[`pandas.DataFrame`]. The grouping key(s) will be passed as a tuple
+        of numpy data types, e.g., `numpy.int32` and `numpy.float64`. The state will be passed as
+        :class:`pyspark.sql.streaming.state.GroupStateImpl`.
+
+        For each group, all columns are passed together as `pandas.DataFrame` to the user-function,
+        and the returned `pandas.DataFrame` across all invocations are combined as a
+        :class:`DataFrame`. Note that the user function should loop through and process all
+        elements in the iterator. The user function should not make a guess of the number of
+        elements in the iterator.
+
+        The `outputStructType` should be a :class:`StructType` describing the schema of all
+        elements in returned value, `pandas.DataFrame`. The column labels of all elements in
+        returned value, `pandas.DataFrame` must either match the field names in the defined
+        schema if specified as strings, or match the field data types by position if not strings,
+        e.g. integer indices.
+
+        The `stateStructType` should be :class:`StructType` describing the schema of user-defined
+        state. The value of state will be presented as a tuple, as well as the update should be
+        performed with the tuple. User defined types e.g. native Python class types are not

Review Comment:
   Not StructType types, e.g. user-defined or native Python types are not supported.



##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -216,6 +218,105 @@ def applyInPandas(
         jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
         return DataFrame(jdf, self.session)
 
+    def applyInPandasWithState(
+        self,
+        func: "PandasGroupedMapFunctionWithState",
+        outputStructType: Union[StructType, str],
+        stateStructType: Union[StructType, str],
+        outputMode: str,
+        timeoutConf: str,
+    ) -> DataFrame:
+        """
+        Applies the given function to each group of data, while maintaining a user-defined
+        per-group state. The result Dataset will represent the flattened record returned by the
+        function.
+
+        For a streaming Dataset, the function will be invoked for each group repeatedly in every
+        trigger, and updates to each group's state will be saved across invocations. The function
+        will also be invoked for each timed-out state repeatedly. The sequence of the invocation
+        will be input data -> state timeout. When the function is invoked for state timeout, there
+        will be no data being presented.
+
+        The function should takes parameters (key, Iterator[`pandas.DataFrame`], state) and
+        returns another Iterator[`pandas.DataFrame`]. The grouping key(s) will be passed as a tuple
+        of numpy data types, e.g., `numpy.int32` and `numpy.float64`. The state will be passed as
+        :class:`pyspark.sql.streaming.state.GroupStateImpl`.
+
+        For each group, all columns are passed together as `pandas.DataFrame` to the user-function,
+        and the returned `pandas.DataFrame` across all invocations are combined as a
+        :class:`DataFrame`. Note that the user function should loop through and process all
+        elements in the iterator. The user function should not make a guess of the number of
+        elements in the iterator.
+
+        The `outputStructType` should be a :class:`StructType` describing the schema of all
+        elements in returned value, `pandas.DataFrame`. The column labels of all elements in
+        returned value, `pandas.DataFrame` must either match the field names in the defined
+        schema if specified as strings, or match the field data types by position if not strings,
+        e.g. integer indices.
+
+        The `stateStructType` should be :class:`StructType` describing the schema of user-defined
+        state. The value of state will be presented as a tuple, as well as the update should be
+        performed with the tuple. User defined types e.g. native Python class types are not
+        supported. Alternatively, you can pickle the data and produce the data as BinaryType, but
+        it is tied to the backward and forward compatibility of pickle in Python, and Spark itself
+        does not guarantee the compatibility.
+
+        The length of each element in both input and returned value, `pandas.DataFrame`, can be
+        arbitrary. The length of iterator in both input and returned value can be also arbitrary.
+
+        .. versionadded:: 3.4.0
+
+        Parameters
+        ----------
+        func : function
+            a Python native function to be called on every group. It should takes parameters
+            (key, Iterator[`pandas.DataFrame`], state) and returns Iterator[`pandas.DataFrame`].
+            Note that the type of key is tuple, and the type of state is
+            :class:`pyspark.sql.streaming.state.GroupStateImpl`.
+        outputStructType : :class:`pyspark.sql.types.DataType` or str
+            the type of the output records. The value can be either a
+            :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
+        stateStructType : :class:`pyspark.sql.types.DataType` or str
+            the type of the user-defined state. The value can be either a
+            :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.

Review Comment:
   same - can you provide an example of the string



##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -216,6 +218,105 @@ def applyInPandas(
         jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
         return DataFrame(jdf, self.session)
 
+    def applyInPandasWithState(
+        self,
+        func: "PandasGroupedMapFunctionWithState",
+        outputStructType: Union[StructType, str],
+        stateStructType: Union[StructType, str],
+        outputMode: str,
+        timeoutConf: str,
+    ) -> DataFrame:
+        """
+        Applies the given function to each group of data, while maintaining a user-defined
+        per-group state. The result Dataset will represent the flattened record returned by the
+        function.
+
+        For a streaming Dataset, the function will be invoked for each group repeatedly in every
+        trigger, and updates to each group's state will be saved across invocations. The function
+        will also be invoked for each timed-out state repeatedly. The sequence of the invocation
+        will be input data -> state timeout. When the function is invoked for state timeout, there
+        will be no data being presented.
+
+        The function should takes parameters (key, Iterator[`pandas.DataFrame`], state) and
+        returns another Iterator[`pandas.DataFrame`]. The grouping key(s) will be passed as a tuple
+        of numpy data types, e.g., `numpy.int32` and `numpy.float64`. The state will be passed as
+        :class:`pyspark.sql.streaming.state.GroupStateImpl`.
+
+        For each group, all columns are passed together as `pandas.DataFrame` to the user-function,
+        and the returned `pandas.DataFrame` across all invocations are combined as a
+        :class:`DataFrame`. Note that the user function should loop through and process all
+        elements in the iterator. The user function should not make a guess of the number of
+        elements in the iterator.
+
+        The `outputStructType` should be a :class:`StructType` describing the schema of all
+        elements in returned value, `pandas.DataFrame`. The column labels of all elements in
+        returned value, `pandas.DataFrame` must either match the field names in the defined
+        schema if specified as strings, or match the field data types by position if not strings,
+        e.g. integer indices.
+
+        The `stateStructType` should be :class:`StructType` describing the schema of user-defined
+        state. The value of state will be presented as a tuple, as well as the update should be
+        performed with the tuple. User defined types e.g. native Python class types are not
+        supported. Alternatively, you can pickle the data and produce the data as BinaryType, but
+        it is tied to the backward and forward compatibility of pickle in Python, and Spark itself
+        does not guarantee the compatibility.
+
+        The length of each element in both input and returned value, `pandas.DataFrame`, can be
+        arbitrary. The length of iterator in both input and returned value can be also arbitrary.
+
+        .. versionadded:: 3.4.0
+
+        Parameters
+        ----------
+        func : function
+            a Python native function to be called on every group. It should takes parameters
+            (key, Iterator[`pandas.DataFrame`], state) and returns Iterator[`pandas.DataFrame`].
+            Note that the type of key is tuple, and the type of state is
+            :class:`pyspark.sql.streaming.state.GroupStateImpl`.
+        outputStructType : :class:`pyspark.sql.types.DataType` or str
+            the type of the output records. The value can be either a
+            :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.

Review Comment:
   can you provide an example here of the string?



##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -216,6 +218,105 @@ def applyInPandas(
         jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
         return DataFrame(jdf, self.session)
 
+    def applyInPandasWithState(
+        self,
+        func: "PandasGroupedMapFunctionWithState",
+        outputStructType: Union[StructType, str],
+        stateStructType: Union[StructType, str],
+        outputMode: str,
+        timeoutConf: str,
+    ) -> DataFrame:
+        """
+        Applies the given function to each group of data, while maintaining a user-defined
+        per-group state. The result Dataset will represent the flattened record returned by the
+        function.
+
+        For a streaming Dataset, the function will be invoked for each group repeatedly in every
+        trigger, and updates to each group's state will be saved across invocations. The function
+        will also be invoked for each timed-out state repeatedly. The sequence of the invocation

Review Comment:
   again, I think 'repeatedly' is unnecessary.



##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -216,6 +218,105 @@ def applyInPandas(
         jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
         return DataFrame(jdf, self.session)
 
+    def applyInPandasWithState(
+        self,
+        func: "PandasGroupedMapFunctionWithState",
+        outputStructType: Union[StructType, str],
+        stateStructType: Union[StructType, str],
+        outputMode: str,
+        timeoutConf: str,
+    ) -> DataFrame:
+        """
+        Applies the given function to each group of data, while maintaining a user-defined
+        per-group state. The result Dataset will represent the flattened record returned by the
+        function.
+
+        For a streaming Dataset, the function will be invoked for each group repeatedly in every
+        trigger, and updates to each group's state will be saved across invocations. The function
+        will also be invoked for each timed-out state repeatedly. The sequence of the invocation
+        will be input data -> state timeout. When the function is invoked for state timeout, there
+        will be no data being presented.
+
+        The function should takes parameters (key, Iterator[`pandas.DataFrame`], state) and
+        returns another Iterator[`pandas.DataFrame`]. The grouping key(s) will be passed as a tuple
+        of numpy data types, e.g., `numpy.int32` and `numpy.float64`. The state will be passed as
+        :class:`pyspark.sql.streaming.state.GroupStateImpl`.
+
+        For each group, all columns are passed together as `pandas.DataFrame` to the user-function,
+        and the returned `pandas.DataFrame` across all invocations are combined as a
+        :class:`DataFrame`. Note that the user function should loop through and process all
+        elements in the iterator. The user function should not make a guess of the number of
+        elements in the iterator.
+
+        The `outputStructType` should be a :class:`StructType` describing the schema of all
+        elements in returned value, `pandas.DataFrame`. The column labels of all elements in

Review Comment:
   in the returned value



##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -216,6 +218,105 @@ def applyInPandas(
         jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
         return DataFrame(jdf, self.session)
 
+    def applyInPandasWithState(
+        self,
+        func: "PandasGroupedMapFunctionWithState",
+        outputStructType: Union[StructType, str],
+        stateStructType: Union[StructType, str],
+        outputMode: str,
+        timeoutConf: str,
+    ) -> DataFrame:
+        """
+        Applies the given function to each group of data, while maintaining a user-defined
+        per-group state. The result Dataset will represent the flattened record returned by the
+        function.
+
+        For a streaming Dataset, the function will be invoked for each group repeatedly in every
+        trigger, and updates to each group's state will be saved across invocations. The function
+        will also be invoked for each timed-out state repeatedly. The sequence of the invocation
+        will be input data -> state timeout. When the function is invoked for state timeout, there
+        will be no data being presented.
+
+        The function should takes parameters (key, Iterator[`pandas.DataFrame`], state) and
+        returns another Iterator[`pandas.DataFrame`]. The grouping key(s) will be passed as a tuple
+        of numpy data types, e.g., `numpy.int32` and `numpy.float64`. The state will be passed as
+        :class:`pyspark.sql.streaming.state.GroupStateImpl`.
+
+        For each group, all columns are passed together as `pandas.DataFrame` to the user-function,
+        and the returned `pandas.DataFrame` across all invocations are combined as a
+        :class:`DataFrame`. Note that the user function should loop through and process all
+        elements in the iterator. The user function should not make a guess of the number of
+        elements in the iterator.
+
+        The `outputStructType` should be a :class:`StructType` describing the schema of all
+        elements in returned value, `pandas.DataFrame`. The column labels of all elements in
+        returned value, `pandas.DataFrame` must either match the field names in the defined
+        schema if specified as strings, or match the field data types by position if not strings,
+        e.g. integer indices.
+
+        The `stateStructType` should be :class:`StructType` describing the schema of user-defined
+        state. The value of state will be presented as a tuple, as well as the update should be
+        performed with the tuple. User defined types e.g. native Python class types are not
+        supported. Alternatively, you can pickle the data and produce the data as BinaryType, but
+        it is tied to the backward and forward compatibility of pickle in Python, and Spark itself
+        does not guarantee the compatibility.
+
+        The length of each element in both input and returned value, `pandas.DataFrame`, can be
+        arbitrary. The length of iterator in both input and returned value can be also arbitrary.
+
+        .. versionadded:: 3.4.0
+
+        Parameters
+        ----------
+        func : function
+            a Python native function to be called on every group. It should takes parameters

Review Comment:
   it should *take* parameters. ... and return Iterator...



##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -216,6 +218,105 @@ def applyInPandas(
         jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
         return DataFrame(jdf, self.session)
 
+    def applyInPandasWithState(
+        self,
+        func: "PandasGroupedMapFunctionWithState",
+        outputStructType: Union[StructType, str],
+        stateStructType: Union[StructType, str],
+        outputMode: str,
+        timeoutConf: str,
+    ) -> DataFrame:
+        """
+        Applies the given function to each group of data, while maintaining a user-defined
+        per-group state. The result Dataset will represent the flattened record returned by the
+        function.
+
+        For a streaming Dataset, the function will be invoked for each group repeatedly in every
+        trigger, and updates to each group's state will be saved across invocations. The function
+        will also be invoked for each timed-out state repeatedly. The sequence of the invocation
+        will be input data -> state timeout. When the function is invoked for state timeout, there
+        will be no data being presented.
+
+        The function should takes parameters (key, Iterator[`pandas.DataFrame`], state) and
+        returns another Iterator[`pandas.DataFrame`]. The grouping key(s) will be passed as a tuple
+        of numpy data types, e.g., `numpy.int32` and `numpy.float64`. The state will be passed as
+        :class:`pyspark.sql.streaming.state.GroupStateImpl`.
+
+        For each group, all columns are passed together as `pandas.DataFrame` to the user-function,
+        and the returned `pandas.DataFrame` across all invocations are combined as a
+        :class:`DataFrame`. Note that the user function should loop through and process all
+        elements in the iterator. The user function should not make a guess of the number of
+        elements in the iterator.
+
+        The `outputStructType` should be a :class:`StructType` describing the schema of all
+        elements in returned value, `pandas.DataFrame`. The column labels of all elements in
+        returned value, `pandas.DataFrame` must either match the field names in the defined
+        schema if specified as strings, or match the field data types by position if not strings,
+        e.g. integer indices.
+
+        The `stateStructType` should be :class:`StructType` describing the schema of user-defined

Review Comment:
   ... describing the schema of *the* user-defined state. The value of *the* state will be presented as a tuple and the update should be performed with a tuple.



##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -216,6 +218,105 @@ def applyInPandas(
         jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
         return DataFrame(jdf, self.session)
 
+    def applyInPandasWithState(
+        self,
+        func: "PandasGroupedMapFunctionWithState",
+        outputStructType: Union[StructType, str],
+        stateStructType: Union[StructType, str],
+        outputMode: str,
+        timeoutConf: str,
+    ) -> DataFrame:
+        """
+        Applies the given function to each group of data, while maintaining a user-defined
+        per-group state. The result Dataset will represent the flattened record returned by the
+        function.
+
+        For a streaming Dataset, the function will be invoked for each group repeatedly in every
+        trigger, and updates to each group's state will be saved across invocations. The function
+        will also be invoked for each timed-out state repeatedly. The sequence of the invocation
+        will be input data -> state timeout. When the function is invoked for state timeout, there
+        will be no data being presented.
+
+        The function should takes parameters (key, Iterator[`pandas.DataFrame`], state) and
+        returns another Iterator[`pandas.DataFrame`]. The grouping key(s) will be passed as a tuple
+        of numpy data types, e.g., `numpy.int32` and `numpy.float64`. The state will be passed as
+        :class:`pyspark.sql.streaming.state.GroupStateImpl`.

Review Comment:
   again, can we drop the Impl from the state class?



##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -216,6 +218,105 @@ def applyInPandas(
         jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
         return DataFrame(jdf, self.session)
 
+    def applyInPandasWithState(
+        self,
+        func: "PandasGroupedMapFunctionWithState",
+        outputStructType: Union[StructType, str],
+        stateStructType: Union[StructType, str],
+        outputMode: str,
+        timeoutConf: str,
+    ) -> DataFrame:
+        """
+        Applies the given function to each group of data, while maintaining a user-defined
+        per-group state. The result Dataset will represent the flattened record returned by the
+        function.
+
+        For a streaming Dataset, the function will be invoked for each group repeatedly in every
+        trigger, and updates to each group's state will be saved across invocations. The function
+        will also be invoked for each timed-out state repeatedly. The sequence of the invocation
+        will be input data -> state timeout. When the function is invoked for state timeout, there
+        will be no data being presented.
+
+        The function should takes parameters (key, Iterator[`pandas.DataFrame`], state) and

Review Comment:
   The function takes parameters ... and returns Iterator[...]



##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -216,6 +218,105 @@ def applyInPandas(
         jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
         return DataFrame(jdf, self.session)
 
+    def applyInPandasWithState(
+        self,
+        func: "PandasGroupedMapFunctionWithState",
+        outputStructType: Union[StructType, str],
+        stateStructType: Union[StructType, str],
+        outputMode: str,
+        timeoutConf: str,
+    ) -> DataFrame:
+        """
+        Applies the given function to each group of data, while maintaining a user-defined
+        per-group state. The result Dataset will represent the flattened record returned by the
+        function.
+
+        For a streaming Dataset, the function will be invoked for each group repeatedly in every
+        trigger, and updates to each group's state will be saved across invocations. The function
+        will also be invoked for each timed-out state repeatedly. The sequence of the invocation
+        will be input data -> state timeout. When the function is invoked for state timeout, there
+        will be no data being presented.
+
+        The function should takes parameters (key, Iterator[`pandas.DataFrame`], state) and
+        returns another Iterator[`pandas.DataFrame`]. The grouping key(s) will be passed as a tuple
+        of numpy data types, e.g., `numpy.int32` and `numpy.float64`. The state will be passed as
+        :class:`pyspark.sql.streaming.state.GroupStateImpl`.
+
+        For each group, all columns are passed together as `pandas.DataFrame` to the user-function,
+        and the returned `pandas.DataFrame` across all invocations are combined as a
+        :class:`DataFrame`. Note that the user function should loop through and process all
+        elements in the iterator. The user function should not make a guess of the number of
+        elements in the iterator.
+
+        The `outputStructType` should be a :class:`StructType` describing the schema of all
+        elements in returned value, `pandas.DataFrame`. The column labels of all elements in
+        returned value, `pandas.DataFrame` must either match the field names in the defined
+        schema if specified as strings, or match the field data types by position if not strings,
+        e.g. integer indices.
+
+        The `stateStructType` should be :class:`StructType` describing the schema of user-defined
+        state. The value of state will be presented as a tuple, as well as the update should be
+        performed with the tuple. User defined types e.g. native Python class types are not
+        supported. Alternatively, you can pickle the data and produce the data as BinaryType, but
+        it is tied to the backward and forward compatibility of pickle in Python, and Spark itself
+        does not guarantee the compatibility.
+
+        The length of each element in both input and returned value, `pandas.DataFrame`, can be
+        arbitrary. The length of iterator in both input and returned value can be also arbitrary.
+
+        .. versionadded:: 3.4.0
+
+        Parameters
+        ----------
+        func : function
+            a Python native function to be called on every group. It should takes parameters
+            (key, Iterator[`pandas.DataFrame`], state) and returns Iterator[`pandas.DataFrame`].
+            Note that the type of key is tuple, and the type of state is

Review Comment:
   Note that the type of *the* key is tuple and the type of *the* state is ...



##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -216,6 +218,105 @@ def applyInPandas(
         jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
         return DataFrame(jdf, self.session)
 
+    def applyInPandasWithState(
+        self,
+        func: "PandasGroupedMapFunctionWithState",
+        outputStructType: Union[StructType, str],
+        stateStructType: Union[StructType, str],
+        outputMode: str,
+        timeoutConf: str,
+    ) -> DataFrame:
+        """
+        Applies the given function to each group of data, while maintaining a user-defined
+        per-group state. The result Dataset will represent the flattened record returned by the
+        function.
+
+        For a streaming Dataset, the function will be invoked for each group repeatedly in every
+        trigger, and updates to each group's state will be saved across invocations. The function
+        will also be invoked for each timed-out state repeatedly. The sequence of the invocation
+        will be input data -> state timeout. When the function is invoked for state timeout, there
+        will be no data being presented.
+
+        The function should takes parameters (key, Iterator[`pandas.DataFrame`], state) and
+        returns another Iterator[`pandas.DataFrame`]. The grouping key(s) will be passed as a tuple
+        of numpy data types, e.g., `numpy.int32` and `numpy.float64`. The state will be passed as
+        :class:`pyspark.sql.streaming.state.GroupStateImpl`.
+
+        For each group, all columns are passed together as `pandas.DataFrame` to the user-function,
+        and the returned `pandas.DataFrame` across all invocations are combined as a
+        :class:`DataFrame`. Note that the user function should loop through and process all
+        elements in the iterator. The user function should not make a guess of the number of
+        elements in the iterator.
+
+        The `outputStructType` should be a :class:`StructType` describing the schema of all
+        elements in returned value, `pandas.DataFrame`. The column labels of all elements in
+        returned value, `pandas.DataFrame` must either match the field names in the defined

Review Comment:
   returned pandas.DataFrame must ...



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #37893: [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #37893:
URL: https://github.com/apache/spark/pull/37893#discussion_r975770234


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala:
##########
@@ -0,0 +1,214 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.python
+
+import org.apache.spark.TaskContext
+import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical.{EventTimeTimeout, ProcessingTimeTimeout}
+import org.apache.spark.sql.catalyst.plans.physical.Distribution
+import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode}
+import org.apache.spark.sql.execution.python.PandasGroupUtils.resolveArgOffsets
+import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP
+import org.apache.spark.sql.execution.streaming.state.FlatMapGroupsWithStateExecHelper.StateData
+import org.apache.spark.sql.execution.streaming.state.StateStore
+import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.ArrowUtils
+import org.apache.spark.util.CompletionIterator
+
+/**
+ * Physical operator for executing
+ * [[org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsInPandasWithState]]
+ *
+ * @param functionExpr function called on each group
+ * @param groupingAttributes used to group the data
+ * @param outAttributes used to define the output rows
+ * @param stateType used to serialize/deserialize state before calling `functionExpr`
+ * @param stateInfo `StatefulOperatorStateInfo` to identify the state store for a given operator.
+ * @param stateFormatVersion the version of state format.
+ * @param outputMode the output mode of `functionExpr`
+ * @param timeoutConf used to timeout groups that have not received data in a while
+ * @param batchTimestampMs processing timestamp of the current batch.
+ * @param eventTimeWatermark event time watermark for the current batch
+ * @param child logical plan of the underlying data
+ */
+case class FlatMapGroupsInPandasWithStateExec(

Review Comment:
   We always have a separate exec implementation for Scala/Java vs Python since the constructor parameters are different. (We are leveraging case class so difference of the constructor parameters warrant a new class.) So this is intentional. As a compromise we did the refactor to have FlatMapGroupsWithStateExecBase as a base class.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #37893: [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #37893:
URL: https://github.com/apache/spark/pull/37893#discussion_r974798726


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala:
##########
@@ -0,0 +1,201 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import java.io._
+
+import scala.collection.JavaConverters._
+
+import org.apache.arrow.vector.VectorSchemaRoot
+import org.apache.arrow.vector.ipc.ArrowStreamWriter
+import org.json4s._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.api.python._
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.api.python.PythonSQLUtils
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.execution.python.ApplyInPandasWithStatePythonRunner.{InType, OutType, OutTypeForState, STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER}
+import org.apache.spark.sql.execution.python.ApplyInPandasWithStateWriter.STATE_METADATA_SCHEMA
+import org.apache.spark.sql.execution.streaming.GroupStateImpl
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch}
+
+
+/**
+ * A variant implementation of [[ArrowPythonRunner]] to serve the operation
+ * applyInPandasWithState.
+ */
+class ApplyInPandasWithStatePythonRunner(
+    funcs: Seq[ChainedPythonFunctions],
+    evalType: Int,
+    argOffsets: Array[Array[Int]],
+    inputSchema: StructType,
+    override protected val timeZoneId: String,
+    initialWorkerConf: Map[String, String],

Review Comment:
   We can't use `workerConf` here since `override protected val workerConf`. So it's something like `_workerConf` vs `initialWorkerConf` and then it doesn't sound too bad to have prefix rather than `_`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #37893: [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #37893:
URL: https://github.com/apache/spark/pull/37893#discussion_r975902646


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala:
##########
@@ -311,6 +323,56 @@ object UnsupportedOperationChecker extends Logging {
             }
           }
 
+        // applyInPandasWithState
+        case m: FlatMapGroupsInPandasWithState if m.isStreaming =>
+          // Check compatibility with output modes and aggregations in query
+          val aggsInQuery = collectStreamingAggregates(plan)
+
+          if (aggsInQuery.isEmpty) {
+            // applyInPandasWithState without aggregation: operation's output mode must

Review Comment:
   Now I can imagine the case which current requirement of providing separate output mode prevents the unintentional behavior:
   
   - They implemented the user function for flatMapGroupsWithState with append mode.
   - They ran the query with append mode.
   - After that, they changed the output mode to update mode for some reason.
   - The user function is unchanged to account the change of update mode.
   
   We haven't allowed the query to run as of now, and we are going to allow the query to run if we drop the parameter.
   
   PS. I'm not a believer that end users can implement their user function accordingly based on output mode, but that is a fundamental API design issue of original flatMapGroupsWithState which is separate one.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #37893: [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #37893:
URL: https://github.com/apache/spark/pull/37893#discussion_r974072051


##########
python/pyspark/sql/pandas/_typing/__init__.pyi:
##########
@@ -256,6 +258,10 @@ PandasGroupedMapFunction = Union[
     Callable[[Any, DataFrameLike], DataFrameLike],
 ]
 
+PandasGroupedMapFunctionWithState = Callable[
+    [Any, Iterable[DataFrameLike], GroupStateImpl], Iterable[DataFrameLike]

Review Comment:
   Thanks, I just renamed GroupStateImpl to GroupState. Once we find the necessity we can use the same name to become interface and move out the implementation (I guess this is what @HyukjinKwon said the type is dynamic so no problem.)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #37893: [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #37893:
URL: https://github.com/apache/spark/pull/37893#discussion_r974803979


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala:
##########
@@ -0,0 +1,201 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import java.io._
+
+import scala.collection.JavaConverters._
+
+import org.apache.arrow.vector.VectorSchemaRoot
+import org.apache.arrow.vector.ipc.ArrowStreamWriter
+import org.json4s._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.api.python._
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.api.python.PythonSQLUtils
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.execution.python.ApplyInPandasWithStatePythonRunner.{InType, OutType, OutTypeForState, STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER}
+import org.apache.spark.sql.execution.python.ApplyInPandasWithStateWriter.STATE_METADATA_SCHEMA
+import org.apache.spark.sql.execution.streaming.GroupStateImpl
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch}
+
+
+/**
+ * A variant implementation of [[ArrowPythonRunner]] to serve the operation
+ * applyInPandasWithState.
+ */
+class ApplyInPandasWithStatePythonRunner(
+    funcs: Seq[ChainedPythonFunctions],
+    evalType: Int,
+    argOffsets: Array[Array[Int]],
+    inputSchema: StructType,
+    override protected val timeZoneId: String,
+    initialWorkerConf: Map[String, String],
+    stateEncoder: ExpressionEncoder[Row],
+    keySchema: StructType,
+    valueSchema: StructType,
+    stateValueSchema: StructType,
+    softLimitBytesPerBatch: Long,
+    minDataCountForSample: Int,
+    softTimeoutMillsPurgeBatch: Long)
+  extends BasePythonRunner[InType, OutType](funcs, evalType, argOffsets)
+  with PythonArrowInput[InType]
+  with PythonArrowOutput[OutType] {
+
+  override protected val schema: StructType = inputSchema.add("__state", STATE_METADATA_SCHEMA)
+
+  override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback
+
+  override val bufferSize: Int = SQLConf.get.pandasUDFBufferSize
+  require(
+    bufferSize >= 4,

Review Comment:
   This is borrowed from ArrowPythonRunner. Btw, I realized we should not allow Arrow RecordBatch in this runner to be split down due to buffer size - this runner has to have a full control of Arrow RecordBatch. We'll have to set this be constant something like Long.MAX_VALUE.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #37893: [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #37893:
URL: https://github.com/apache/spark/pull/37893#discussion_r973841670


##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -216,6 +218,105 @@ def applyInPandas(
         jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
         return DataFrame(jdf, self.session)
 
+    def applyInPandasWithState(
+        self,
+        func: "PandasGroupedMapFunctionWithState",
+        outputStructType: Union[StructType, str],
+        stateStructType: Union[StructType, str],
+        outputMode: str,
+        timeoutConf: str,
+    ) -> DataFrame:
+        """
+        Applies the given function to each group of data, while maintaining a user-defined
+        per-group state. The result Dataset will represent the flattened record returned by the
+        function.
+
+        For a streaming Dataset, the function will be invoked for each group repeatedly in every
+        trigger, and updates to each group's state will be saved across invocations. The function
+        will also be invoked for each timed-out state repeatedly. The sequence of the invocation
+        will be input data -> state timeout. When the function is invoked for state timeout, there
+        will be no data being presented.
+
+        The function should takes parameters (key, Iterator[`pandas.DataFrame`], state) and
+        returns another Iterator[`pandas.DataFrame`]. The grouping key(s) will be passed as a tuple
+        of numpy data types, e.g., `numpy.int32` and `numpy.float64`. The state will be passed as
+        :class:`pyspark.sql.streaming.state.GroupStateImpl`.
+
+        For each group, all columns are passed together as `pandas.DataFrame` to the user-function,

Review Comment:
   Same here, this follows the existing method doc in applyInPandas.
   
   I'm OK with change it though as I agree it's not mandatory to call out all columns will be passed. Neither user function nor public API specify columns, which is implicitly expected to all columns.
   
   Probably worth to discuss a bit more and change altogether in both function? cc. @HyukjinKwon 



##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -216,6 +218,105 @@ def applyInPandas(
         jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
         return DataFrame(jdf, self.session)
 
+    def applyInPandasWithState(
+        self,
+        func: "PandasGroupedMapFunctionWithState",
+        outputStructType: Union[StructType, str],
+        stateStructType: Union[StructType, str],
+        outputMode: str,
+        timeoutConf: str,
+    ) -> DataFrame:
+        """
+        Applies the given function to each group of data, while maintaining a user-defined
+        per-group state. The result Dataset will represent the flattened record returned by the
+        function.
+
+        For a streaming Dataset, the function will be invoked for each group repeatedly in every
+        trigger, and updates to each group's state will be saved across invocations. The function
+        will also be invoked for each timed-out state repeatedly. The sequence of the invocation
+        will be input data -> state timeout. When the function is invoked for state timeout, there
+        will be no data being presented.
+
+        The function should takes parameters (key, Iterator[`pandas.DataFrame`], state) and
+        returns another Iterator[`pandas.DataFrame`]. The grouping key(s) will be passed as a tuple
+        of numpy data types, e.g., `numpy.int32` and `numpy.float64`. The state will be passed as
+        :class:`pyspark.sql.streaming.state.GroupStateImpl`.
+
+        For each group, all columns are passed together as `pandas.DataFrame` to the user-function,
+        and the returned `pandas.DataFrame` across all invocations are combined as a
+        :class:`DataFrame`. Note that the user function should loop through and process all
+        elements in the iterator. The user function should not make a guess of the number of
+        elements in the iterator.
+
+        The `outputStructType` should be a :class:`StructType` describing the schema of all
+        elements in returned value, `pandas.DataFrame`. The column labels of all elements in
+        returned value, `pandas.DataFrame` must either match the field names in the defined
+        schema if specified as strings, or match the field data types by position if not strings,
+        e.g. integer indices.
+
+        The `stateStructType` should be :class:`StructType` describing the schema of user-defined
+        state. The value of state will be presented as a tuple, as well as the update should be
+        performed with the tuple. User defined types e.g. native Python class types are not
+        supported. Alternatively, you can pickle the data and produce the data as BinaryType, but
+        it is tied to the backward and forward compatibility of pickle in Python, and Spark itself
+        does not guarantee the compatibility.
+
+        The length of each element in both input and returned value, `pandas.DataFrame`, can be
+        arbitrary. The length of iterator in both input and returned value can be also arbitrary.
+
+        .. versionadded:: 3.4.0
+
+        Parameters
+        ----------
+        func : function
+            a Python native function to be called on every group. It should takes parameters
+            (key, Iterator[`pandas.DataFrame`], state) and returns Iterator[`pandas.DataFrame`].
+            Note that the type of key is tuple, and the type of state is
+            :class:`pyspark.sql.streaming.state.GroupStateImpl`.
+        outputStructType : :class:`pyspark.sql.types.DataType` or str
+            the type of the output records. The value can be either a
+            :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.

Review Comment:
   In the doc or here? All other PySpark method docs do not have example of this string.
   
   Maybe we could have examples like other APIs do and provide DDL-formatted type string to compensate.



##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -216,6 +218,105 @@ def applyInPandas(
         jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
         return DataFrame(jdf, self.session)
 
+    def applyInPandasWithState(
+        self,
+        func: "PandasGroupedMapFunctionWithState",
+        outputStructType: Union[StructType, str],
+        stateStructType: Union[StructType, str],
+        outputMode: str,
+        timeoutConf: str,
+    ) -> DataFrame:
+        """
+        Applies the given function to each group of data, while maintaining a user-defined
+        per-group state. The result Dataset will represent the flattened record returned by the
+        function.
+
+        For a streaming Dataset, the function will be invoked for each group repeatedly in every
+        trigger, and updates to each group's state will be saved across invocations. The function
+        will also be invoked for each timed-out state repeatedly. The sequence of the invocation
+        will be input data -> state timeout. When the function is invoked for state timeout, there
+        will be no data being presented.
+
+        The function should takes parameters (key, Iterator[`pandas.DataFrame`], state) and
+        returns another Iterator[`pandas.DataFrame`]. The grouping key(s) will be passed as a tuple
+        of numpy data types, e.g., `numpy.int32` and `numpy.float64`. The state will be passed as
+        :class:`pyspark.sql.streaming.state.GroupStateImpl`.

Review Comment:
   Let me handle it altogether once the direction has made.



##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -216,6 +218,105 @@ def applyInPandas(
         jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
         return DataFrame(jdf, self.session)
 
+    def applyInPandasWithState(
+        self,
+        func: "PandasGroupedMapFunctionWithState",
+        outputStructType: Union[StructType, str],
+        stateStructType: Union[StructType, str],
+        outputMode: str,
+        timeoutConf: str,
+    ) -> DataFrame:
+        """
+        Applies the given function to each group of data, while maintaining a user-defined
+        per-group state. The result Dataset will represent the flattened record returned by the
+        function.
+
+        For a streaming Dataset, the function will be invoked for each group repeatedly in every
+        trigger, and updates to each group's state will be saved across invocations. The function
+        will also be invoked for each timed-out state repeatedly. The sequence of the invocation
+        will be input data -> state timeout. When the function is invoked for state timeout, there
+        will be no data being presented.
+
+        The function should takes parameters (key, Iterator[`pandas.DataFrame`], state) and

Review Comment:
   This follows the existing method doc in applyInPandas.
   
   The "function" here refers to user function end users will provide, not a function Spark provides as public API, so using `should` here does not seem to be wrong. The mood is something like "you `should` construct an user function blabla...". "s" should be removed though.



##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -216,6 +218,105 @@ def applyInPandas(
         jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
         return DataFrame(jdf, self.session)
 
+    def applyInPandasWithState(
+        self,
+        func: "PandasGroupedMapFunctionWithState",
+        outputStructType: Union[StructType, str],
+        stateStructType: Union[StructType, str],
+        outputMode: str,
+        timeoutConf: str,
+    ) -> DataFrame:
+        """
+        Applies the given function to each group of data, while maintaining a user-defined
+        per-group state. The result Dataset will represent the flattened record returned by the
+        function.
+
+        For a streaming Dataset, the function will be invoked for each group repeatedly in every
+        trigger, and updates to each group's state will be saved across invocations. The function
+        will also be invoked for each timed-out state repeatedly. The sequence of the invocation
+        will be input data -> state timeout. When the function is invoked for state timeout, there
+        will be no data being presented.
+
+        The function should takes parameters (key, Iterator[`pandas.DataFrame`], state) and
+        returns another Iterator[`pandas.DataFrame`]. The grouping key(s) will be passed as a tuple
+        of numpy data types, e.g., `numpy.int32` and `numpy.float64`. The state will be passed as
+        :class:`pyspark.sql.streaming.state.GroupStateImpl`.
+
+        For each group, all columns are passed together as `pandas.DataFrame` to the user-function,
+        and the returned `pandas.DataFrame` across all invocations are combined as a
+        :class:`DataFrame`. Note that the user function should loop through and process all
+        elements in the iterator. The user function should not make a guess of the number of
+        elements in the iterator.
+
+        The `outputStructType` should be a :class:`StructType` describing the schema of all
+        elements in returned value, `pandas.DataFrame`. The column labels of all elements in
+        returned value, `pandas.DataFrame` must either match the field names in the defined
+        schema if specified as strings, or match the field data types by position if not strings,
+        e.g. integer indices.
+
+        The `stateStructType` should be :class:`StructType` describing the schema of user-defined
+        state. The value of state will be presented as a tuple, as well as the update should be
+        performed with the tuple. User defined types e.g. native Python class types are not

Review Comment:
   It's a bit tricky - native Python types contain int, float, str, ... and of course they are supported. Probably the clear definition is "python types are supported as long as the default encoder can convert to the Spark SQL type". Not sure we have a clear documentation describing the matrix of compatibility. cc. @HyukjinKwon could you please help us make this clear?



##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -216,6 +218,105 @@ def applyInPandas(
         jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
         return DataFrame(jdf, self.session)
 
+    def applyInPandasWithState(
+        self,
+        func: "PandasGroupedMapFunctionWithState",
+        outputStructType: Union[StructType, str],
+        stateStructType: Union[StructType, str],
+        outputMode: str,
+        timeoutConf: str,
+    ) -> DataFrame:
+        """
+        Applies the given function to each group of data, while maintaining a user-defined
+        per-group state. The result Dataset will represent the flattened record returned by the
+        function.
+
+        For a streaming Dataset, the function will be invoked for each group repeatedly in every
+        trigger, and updates to each group's state will be saved across invocations. The function
+        will also be invoked for each timed-out state repeatedly. The sequence of the invocation
+        will be input data -> state timeout. When the function is invoked for state timeout, there
+        will be no data being presented.
+
+        The function should takes parameters (key, Iterator[`pandas.DataFrame`], state) and
+        returns another Iterator[`pandas.DataFrame`]. The grouping key(s) will be passed as a tuple
+        of numpy data types, e.g., `numpy.int32` and `numpy.float64`. The state will be passed as
+        :class:`pyspark.sql.streaming.state.GroupStateImpl`.
+
+        For each group, all columns are passed together as `pandas.DataFrame` to the user-function,
+        and the returned `pandas.DataFrame` across all invocations are combined as a
+        :class:`DataFrame`. Note that the user function should loop through and process all
+        elements in the iterator. The user function should not make a guess of the number of
+        elements in the iterator.
+
+        The `outputStructType` should be a :class:`StructType` describing the schema of all
+        elements in returned value, `pandas.DataFrame`. The column labels of all elements in
+        returned value, `pandas.DataFrame` must either match the field names in the defined
+        schema if specified as strings, or match the field data types by position if not strings,
+        e.g. integer indices.
+
+        The `stateStructType` should be :class:`StructType` describing the schema of user-defined
+        state. The value of state will be presented as a tuple, as well as the update should be
+        performed with the tuple. User defined types e.g. native Python class types are not
+        supported. Alternatively, you can pickle the data and produce the data as BinaryType, but

Review Comment:
   Let's simply just remove the suggestion.



##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -216,6 +218,105 @@ def applyInPandas(
         jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
         return DataFrame(jdf, self.session)
 
+    def applyInPandasWithState(
+        self,
+        func: "PandasGroupedMapFunctionWithState",
+        outputStructType: Union[StructType, str],
+        stateStructType: Union[StructType, str],
+        outputMode: str,
+        timeoutConf: str,
+    ) -> DataFrame:
+        """
+        Applies the given function to each group of data, while maintaining a user-defined
+        per-group state. The result Dataset will represent the flattened record returned by the
+        function.
+
+        For a streaming Dataset, the function will be invoked for each group repeatedly in every
+        trigger, and updates to each group's state will be saved across invocations. The function
+        will also be invoked for each timed-out state repeatedly. The sequence of the invocation
+        will be input data -> state timeout. When the function is invoked for state timeout, there
+        will be no data being presented.
+
+        The function should takes parameters (key, Iterator[`pandas.DataFrame`], state) and
+        returns another Iterator[`pandas.DataFrame`]. The grouping key(s) will be passed as a tuple
+        of numpy data types, e.g., `numpy.int32` and `numpy.float64`. The state will be passed as
+        :class:`pyspark.sql.streaming.state.GroupStateImpl`.
+
+        For each group, all columns are passed together as `pandas.DataFrame` to the user-function,
+        and the returned `pandas.DataFrame` across all invocations are combined as a
+        :class:`DataFrame`. Note that the user function should loop through and process all
+        elements in the iterator. The user function should not make a guess of the number of
+        elements in the iterator.
+
+        The `outputStructType` should be a :class:`StructType` describing the schema of all
+        elements in returned value, `pandas.DataFrame`. The column labels of all elements in
+        returned value, `pandas.DataFrame` must either match the field names in the defined
+        schema if specified as strings, or match the field data types by position if not strings,
+        e.g. integer indices.
+
+        The `stateStructType` should be :class:`StructType` describing the schema of user-defined
+        state. The value of state will be presented as a tuple, as well as the update should be
+        performed with the tuple. User defined types e.g. native Python class types are not
+        supported. Alternatively, you can pickle the data and produce the data as BinaryType, but
+        it is tied to the backward and forward compatibility of pickle in Python, and Spark itself
+        does not guarantee the compatibility.
+
+        The length of each element in both input and returned value, `pandas.DataFrame`, can be
+        arbitrary. The length of iterator in both input and returned value can be also arbitrary.
+
+        .. versionadded:: 3.4.0
+
+        Parameters
+        ----------
+        func : function
+            a Python native function to be called on every group. It should takes parameters
+            (key, Iterator[`pandas.DataFrame`], state) and returns Iterator[`pandas.DataFrame`].
+            Note that the type of key is tuple, and the type of state is
+            :class:`pyspark.sql.streaming.state.GroupStateImpl`.
+        outputStructType : :class:`pyspark.sql.types.DataType` or str
+            the type of the output records. The value can be either a
+            :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
+        stateStructType : :class:`pyspark.sql.types.DataType` or str
+            the type of the user-defined state. The value can be either a
+            :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.

Review Comment:
   same.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #37893: [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #37893:
URL: https://github.com/apache/spark/pull/37893#discussion_r973863915


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala:
##########
@@ -2705,6 +2705,44 @@ object SQLConf {
       .booleanConf
       .createWithDefault(false)
 
+  val MAP_PANDAS_UDF_WITH_STATE_SOFT_LIMIT_SIZE_PER_BATCH =
+    buildConf("spark.sql.execution.applyInPandasWithState.softLimitSizePerBatch")
+      .internal()
+      .doc("When using applyInPandasWithState, set a soft limit of the accumulated size of " +
+        "records that can be written to a single ArrowRecordBatch in memory. This is used to " +
+        "restrict the amount of memory being used to materialize the data in both executor and " +
+        "Python worker. The accumulated size of records are calculated via sampling a set of " +
+        "records. Splitting the ArrowRecordBatch is performed per record, so unless a record " +
+        "is quite huge, the size of constructed ArrowRecordBatch will be around the " +
+        "configured value.")
+      .version("3.4.0")
+      .bytesConf(ByteUnit.BYTE)
+      .createWithDefaultString("64MB")
+
+  val MAP_PANDAS_UDF_WITH_STATE_MIN_DATA_COUNT_FOR_SAMPLE =
+    buildConf("spark.sql.execution.applyInPandasWithState.minDataCountForSample")
+      .internal()
+      .doc("When using applyInPandasWithState, specify the minimum number of records to sample " +
+        "the size of record. The size being retrieved from sampling will be used to estimate " +
+        "the accumulated size of records. Note that limiting by size does not work if the " +
+        "number of records are less than the configured value. For such case, ArrowRecordBatch " +
+        "will only be split for soft timeout.")
+      .version("3.4.0")
+      .intConf
+      .createWithDefault(100)
+
+  val MAP_PANDAS_UDF_WITH_STATE_SOFT_TIMEOUT_PURGE_BATCH =
+    buildConf("spark.sql.execution.applyInPandasWithState.softTimeoutPurgeBatch")
+      .internal()
+      .doc("When using applyInPandasWithState, specify the soft timeout for purging the " +
+        "ArrowRecordBatch. If batching records exceeds the timeout, Spark will force splitting " +
+        "the ArrowRecordBatch regardless of estimated size. This config ensures the receiver " +
+        "of data (both executor and Python worker) to not wait indefinitely for sender to " +
+        "complete the ArrowRecordBatch, which may hurt both throughput and latency.")
+      .version("3.4.0")
+      .timeConf(TimeUnit.MILLISECONDS)
+      .createWithDefaultString("100ms")

Review Comment:
   I'm not 100% clear how `spark.sql.execution.pandas.udf.buffer.size` works. Current logic won't work if this config is able to split an arrow record batch further down to multiple, as we rely on offset and the number of rows to split the range of data from overall arrow record batch. It relies on the fact that the logic has full control of constructing arrow record batch.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #37893: [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #37893:
URL: https://github.com/apache/spark/pull/37893#discussion_r973805753


##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -216,6 +218,105 @@ def applyInPandas(
         jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
         return DataFrame(jdf, self.session)
 
+    def applyInPandasWithState(
+        self,
+        func: "PandasGroupedMapFunctionWithState",
+        outputStructType: Union[StructType, str],
+        stateStructType: Union[StructType, str],
+        outputMode: str,
+        timeoutConf: str,
+    ) -> DataFrame:
+        """
+        Applies the given function to each group of data, while maintaining a user-defined
+        per-group state. The result Dataset will represent the flattened record returned by the
+        function.
+
+        For a streaming Dataset, the function will be invoked for each group repeatedly in every
+        trigger, and updates to each group's state will be saved across invocations. The function
+        will also be invoked for each timed-out state repeatedly. The sequence of the invocation
+        will be input data -> state timeout. When the function is invoked for state timeout, there
+        will be no data being presented.
+
+        The function should takes parameters (key, Iterator[`pandas.DataFrame`], state) and
+        returns another Iterator[`pandas.DataFrame`]. The grouping key(s) will be passed as a tuple
+        of numpy data types, e.g., `numpy.int32` and `numpy.float64`. The state will be passed as
+        :class:`pyspark.sql.streaming.state.GroupStateImpl`.
+
+        For each group, all columns are passed together as `pandas.DataFrame` to the user-function,
+        and the returned `pandas.DataFrame` across all invocations are combined as a
+        :class:`DataFrame`. Note that the user function should loop through and process all

Review Comment:
   >> Note that the user function should loop through and process all elements in the iterator. 
   
   > Why? This sounds like the use must process all iterator entries or otherwise something bad would happen. I would reword this to indicate that the grouped data could be split into multiple entries -
   
   I agree this is too conservative and we can remove that once there is technically no issue. I don't think we never have such a test for even existing flatMapGroupsWithState so we actually don't clearly know what happens if we pull a part of data from group.
   
   >>  The user function should not make a guess of the number of elements in the iterator.
   
   > I would still suggest we have a design discussion about splitting groups unnecessary as I believe we should not do this.
   
   I think there is a room for discussion on how to split group with in mind we also binpack in terms of performance, but I really doubt this has to be an interface contract. For former, it's not a first class concern and we shouldn't block this PR. For latter, I really want to see what is the real use case which leverages the interface contract, and how much it will be harder to implement for the same if we do not guarantee such contract. 
   
   Stricter interface contract can be loose without breaking anything, looser interface contract can never be stricter without breaking compatibility. Why not we go with conservative till we are very clear there is a clear use case?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #37893: [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on code in PR #37893:
URL: https://github.com/apache/spark/pull/37893#discussion_r973839939


##########
python/pyspark/sql/pandas/_typing/__init__.pyi:
##########
@@ -256,6 +258,10 @@ PandasGroupedMapFunction = Union[
     Callable[[Any, DataFrameLike], DataFrameLike],
 ]
 
+PandasGroupedMapFunctionWithState = Callable[
+    [Any, Iterable[DataFrameLike], GroupStateImpl], Iterable[DataFrameLike]

Review Comment:
   I am fine either way too. Users aren't able to create this instance directly anyway.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] alex-balikov commented on a diff in pull request #37893: [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

Posted by GitBox <gi...@apache.org>.
alex-balikov commented on code in PR #37893:
URL: https://github.com/apache/spark/pull/37893#discussion_r974672806


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala:
##########
@@ -2705,6 +2705,44 @@ object SQLConf {
       .booleanConf
       .createWithDefault(false)
 
+  val MAP_PANDAS_UDF_WITH_STATE_SOFT_LIMIT_SIZE_PER_BATCH =
+    buildConf("spark.sql.execution.applyInPandasWithState.softLimitSizePerBatch")
+      .internal()
+      .doc("When using applyInPandasWithState, set a soft limit of the accumulated size of " +
+        "records that can be written to a single ArrowRecordBatch in memory. This is used to " +
+        "restrict the amount of memory being used to materialize the data in both executor and " +
+        "Python worker. The accumulated size of records are calculated via sampling a set of " +
+        "records. Splitting the ArrowRecordBatch is performed per record, so unless a record " +
+        "is quite huge, the size of constructed ArrowRecordBatch will be around the " +
+        "configured value.")
+      .version("3.4.0")
+      .bytesConf(ByteUnit.BYTE)
+      .createWithDefaultString("64MB")
+
+  val MAP_PANDAS_UDF_WITH_STATE_MIN_DATA_COUNT_FOR_SAMPLE =

Review Comment:
   I wonder if we really care to have this param. Ultimately if the sizing estimate works badly, the users can just set a lower value for the batch size limit. I do not think it is useful to let them tune this parameter.



##########
sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala:
##########
@@ -620,6 +622,35 @@ class RelationalGroupedDataset protected[sql](
     Dataset.ofRows(df.sparkSession, plan)
   }
 
+  private[sql] def applyInPandasWithState(

Review Comment:
   method level comment



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala:
##########
@@ -2705,6 +2705,44 @@ object SQLConf {
       .booleanConf
       .createWithDefault(false)
 
+  val MAP_PANDAS_UDF_WITH_STATE_SOFT_LIMIT_SIZE_PER_BATCH =
+    buildConf("spark.sql.execution.applyInPandasWithState.softLimitSizePerBatch")
+      .internal()
+      .doc("When using applyInPandasWithState, set a soft limit of the accumulated size of " +
+        "records that can be written to a single ArrowRecordBatch in memory. This is used to " +
+        "restrict the amount of memory being used to materialize the data in both executor and " +
+        "Python worker. The accumulated size of records are calculated via sampling a set of " +
+        "records. Splitting the ArrowRecordBatch is performed per record, so unless a record " +
+        "is quite huge, the size of constructed ArrowRecordBatch will be around the " +
+        "configured value.")
+      .version("3.4.0")
+      .bytesConf(ByteUnit.BYTE)
+      .createWithDefaultString("64MB")

Review Comment:
   I agree that expressing the limit in terms of bytes is more meaningful that records. However we estimate the bytes size efficiently. Specifically here I would rename 'softLimitSizePerBatch' by removing 'soft' - we can clarify in the comment about that and also including 'Bytes' - 'batchSizeLimitBytes' . Also wonder if we should have the property specific to applyInPandasWithState or should we make it general - remove the applyInPandasWithState scoping even if we do not support this limit initially, seems like generally meaningful and we can follow up fixing the other places as a bug.



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala:
##########
@@ -2705,6 +2705,44 @@ object SQLConf {
       .booleanConf
       .createWithDefault(false)
 
+  val MAP_PANDAS_UDF_WITH_STATE_SOFT_LIMIT_SIZE_PER_BATCH =
+    buildConf("spark.sql.execution.applyInPandasWithState.softLimitSizePerBatch")
+      .internal()
+      .doc("When using applyInPandasWithState, set a soft limit of the accumulated size of " +
+        "records that can be written to a single ArrowRecordBatch in memory. This is used to " +
+        "restrict the amount of memory being used to materialize the data in both executor and " +
+        "Python worker. The accumulated size of records are calculated via sampling a set of " +
+        "records. Splitting the ArrowRecordBatch is performed per record, so unless a record " +
+        "is quite huge, the size of constructed ArrowRecordBatch will be around the " +
+        "configured value.")
+      .version("3.4.0")
+      .bytesConf(ByteUnit.BYTE)
+      .createWithDefaultString("64MB")
+
+  val MAP_PANDAS_UDF_WITH_STATE_MIN_DATA_COUNT_FOR_SAMPLE =
+    buildConf("spark.sql.execution.applyInPandasWithState.minDataCountForSample")
+      .internal()
+      .doc("When using applyInPandasWithState, specify the minimum number of records to sample " +
+        "the size of record. The size being retrieved from sampling will be used to estimate " +
+        "the accumulated size of records. Note that limiting by size does not work if the " +
+        "number of records are less than the configured value. For such case, ArrowRecordBatch " +
+        "will only be split for soft timeout.")
+      .version("3.4.0")
+      .intConf
+      .createWithDefault(100)
+
+  val MAP_PANDAS_UDF_WITH_STATE_SOFT_TIMEOUT_PURGE_BATCH =

Review Comment:
   again, should we really expose this? Lets have a reasonable const value to start with and not expose a config. It is impossible to understand what this means unless you intimately know the implementation.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #37893: [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #37893:
URL: https://github.com/apache/spark/pull/37893#discussion_r974854298


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala:
##########
@@ -2705,6 +2705,44 @@ object SQLConf {
       .booleanConf
       .createWithDefault(false)
 
+  val MAP_PANDAS_UDF_WITH_STATE_SOFT_LIMIT_SIZE_PER_BATCH =
+    buildConf("spark.sql.execution.applyInPandasWithState.softLimitSizePerBatch")
+      .internal()
+      .doc("When using applyInPandasWithState, set a soft limit of the accumulated size of " +
+        "records that can be written to a single ArrowRecordBatch in memory. This is used to " +
+        "restrict the amount of memory being used to materialize the data in both executor and " +
+        "Python worker. The accumulated size of records are calculated via sampling a set of " +
+        "records. Splitting the ArrowRecordBatch is performed per record, so unless a record " +
+        "is quite huge, the size of constructed ArrowRecordBatch will be around the " +
+        "configured value.")
+      .version("3.4.0")
+      .bytesConf(ByteUnit.BYTE)
+      .createWithDefaultString("64MB")

Review Comment:
   (closing the loop) We decided to simply use the number of rows for the condition of constructing Arrow RecordBatch. This will remove all new configs being introduced here, as well as reduce lots of complexity.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon closed pull request #37893: [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

Posted by GitBox <gi...@apache.org>.
HyukjinKwon closed pull request #37893: [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark
URL: https://github.com/apache/spark/pull/37893


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #37893: [DRAFT][DO-NOT-MERGE][SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #37893:
URL: https://github.com/apache/spark/pull/37893#discussion_r971651375


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala:
##########
@@ -793,6 +812,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
           initialStateGroupAttrs, data, initialStateDataAttrs, output, timeout,
           hasInitialState, planLater(initialState), planLater(child)
         ) :: Nil
+      case _: FlatMapGroupsInPandasWithState =>
+        throw new UnsupportedOperationException(

Review Comment:
   Updated.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala:
##########
@@ -0,0 +1,217 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.python
+
+import org.apache.spark.TaskContext
+import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical.{EventTimeTimeout, ProcessingTimeTimeout}
+import org.apache.spark.sql.catalyst.plans.physical.Distribution
+import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode}
+import org.apache.spark.sql.execution.python.PandasGroupUtils.resolveArgOffsets
+import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP
+import org.apache.spark.sql.execution.streaming.state.FlatMapGroupsWithStateExecHelper.StateData
+import org.apache.spark.sql.execution.streaming.state.StateStore
+import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.ArrowUtils
+import org.apache.spark.util.CompletionIterator
+
+/**
+ * Physical operator for executing
+ * [[org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsInPandasWithState]]
+ *
+ * @param functionExpr function called on each group
+ * @param groupingAttributes used to group the data
+ * @param outAttributes used to define the output rows
+ * @param stateType used to serialize/deserialize state before calling `functionExpr`
+ * @param stateInfo `StatefulOperatorStateInfo` to identify the state store for a given operator.
+ * @param stateFormatVersion the version of state format.
+ * @param outputMode the output mode of `functionExpr`
+ * @param timeoutConf used to timeout groups that have not received data in a while
+ * @param batchTimestampMs processing timestamp of the current batch.
+ * @param eventTimeWatermark event time watermark for the current batch
+ * @param child logical plan of the underlying data
+ */
+case class FlatMapGroupsInPandasWithStateExec(
+    functionExpr: Expression,
+    groupingAttributes: Seq[Attribute],
+    outAttributes: Seq[Attribute],
+    stateType: StructType,
+    stateInfo: Option[StatefulOperatorStateInfo],
+    stateFormatVersion: Int,
+    outputMode: OutputMode,
+    timeoutConf: GroupStateTimeout,
+    batchTimestampMs: Option[Long],
+    eventTimeWatermark: Option[Long],
+    child: SparkPlan) extends UnaryExecNode with FlatMapGroupsWithStateExecBase {
+
+  // TODO(SPARK-XXXXX): Add the support of initial state.

Review Comment:
   Updated.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #37893: [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #37893:
URL: https://github.com/apache/spark/pull/37893#discussion_r974072051


##########
python/pyspark/sql/pandas/_typing/__init__.pyi:
##########
@@ -256,6 +258,10 @@ PandasGroupedMapFunction = Union[
     Callable[[Any, DataFrameLike], DataFrameLike],
 ]
 
+PandasGroupedMapFunctionWithState = Callable[
+    [Any, Iterable[DataFrameLike], GroupStateImpl], Iterable[DataFrameLike]

Review Comment:
   Thanks, I just renamed GroupStateImpl to GroupState. Once we find the necessity we can use the same name to become interface and move out the implementation (I guess this is what @HyukjinKwon said the type is dynamic but please let me know if I miss something.)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] alex-balikov commented on a diff in pull request #37893: [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

Posted by GitBox <gi...@apache.org>.
alex-balikov commented on code in PR #37893:
URL: https://github.com/apache/spark/pull/37893#discussion_r976868040


##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -216,6 +218,125 @@ def applyInPandas(
         jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
         return DataFrame(jdf, self.session)
 
+    def applyInPandasWithState(
+        self,
+        func: "PandasGroupedMapFunctionWithState",
+        outputStructType: Union[StructType, str],
+        stateStructType: Union[StructType, str],
+        outputMode: str,
+        timeoutConf: str,
+    ) -> DataFrame:
+        """
+        Applies the given function to each group of data, while maintaining a user-defined
+        per-group state. The result Dataset will represent the flattened record returned by the
+        function.
+
+        For a streaming Dataset, the function will be invoked first for all input groups and then
+        for all timed out states where the input data is set to be empty. Updates to each group's
+        state will be saved across invocations.
+
+        The function should take parameters (key, Iterator[`pandas.DataFrame`], state) and
+        returns another Iterator[`pandas.DataFrame`]. The grouping key(s) will be passed as a tuple

Review Comment:
   return another ...



##########
python/pyspark/worker.py:
##########
@@ -207,6 +209,89 @@ def wrapped(key_series, value_series):
     return lambda k, v: [(wrapped(k, v), to_arrow_type(return_type))]
 
 
+def wrap_grouped_map_pandas_udf_with_state(f, return_type):
+    """
+    Provides a new lambda instance wrapping user function of applyInPandasWithState.
+
+    The lambda instance receives (key series, iterator of value series, state) and performs
+    some conversion to be adapted with the signature of user function.
+
+    See the function doc of inner function `wrapped` for more details on what adapter does.
+    See the function doc of `mapper` function for
+    `eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE` for more details on
+    the input parameters of lambda function.
+
+    Along with the returned iterator, the lambda instance will also produce the return_type as
+    converted to the arrow schema.
+    """
+
+    def wrapped(key_series, value_series_gen, state):
+        """
+        Provide an adapter of the user function performing below:
+
+        - Extract the first value of all columns in key series and produce as a tuple.
+        - If the state has timed out, call the user function with empty pandas DataFrame.
+        - If not, construct a new generator which converts each element of value series to
+          pandas DataFrame (lazy evaluation), and call the user function with the generator
+        - Verify each element of returned iterator to check the schema of pandas DataFrame.
+        """
+        import pandas as pd
+
+        key = tuple(s[0] for s in key_series)
+
+        if state.hasTimedOut:
+            # Timeout processing pass empty iterator. Here we return an empty DataFrame instead.
+            values = [
+                pd.DataFrame(columns=pd.concat(next(value_series_gen), axis=1).columns),
+            ]
+        else:
+            values = (pd.concat(x, axis=1) for x in value_series_gen)
+
+        result_iter = f(key, values, state)
+
+        def verify_element(result):
+            if not isinstance(result, pd.DataFrame):
+                raise TypeError(
+                    "The type of element in return iterator of the user-defined function "
+                    "should be pandas.DataFrame, but is {}".format(type(result))
+                )
+            # the number of columns of result have to match the return type
+            # but it is fine for result to have no columns at all if it is empty
+            if not (
+                len(result.columns) == len(return_type) or len(result.columns) == 0 and result.empty

Review Comment:
   may be it is just me but I would suggest adding parentheses so we do not rely on and/or priority 



##########
python/pyspark/sql/pandas/serializers.py:
##########
@@ -371,3 +373,354 @@ def load_stream(self, stream):
                 raise ValueError(
                     "Invalid number of pandas.DataFrames in group {0}".format(dataframes_in_group)
                 )
+
+
+class ApplyInPandasWithStateSerializer(ArrowStreamPandasUDFSerializer):
+    """
+    Serializer used by Python worker to evaluate UDF for applyInPandasWithState.
+
+    Parameters
+    ----------
+    timezone : str
+        A timezone to respect when handling timestamp values
+    safecheck : bool
+        If True, conversion from Arrow to Pandas checks for overflow/truncation
+    assign_cols_by_name : bool
+        If True, then Pandas DataFrames will get columns by name
+    state_object_schema : StructType
+        The type of state object represented as Spark SQL type
+    arrow_max_records_per_batch : int
+        Limit of the number of records that can be written to a single ArrowRecordBatch in memory.
+    """
+
+    def __init__(
+        self,
+        timezone,
+        safecheck,
+        assign_cols_by_name,
+        state_object_schema,
+        arrow_max_records_per_batch,
+    ):
+        super(ApplyInPandasWithStateSerializer, self).__init__(
+            timezone, safecheck, assign_cols_by_name
+        )
+        self.pickleSer = CPickleSerializer()
+        self.utf8_deserializer = UTF8Deserializer()
+        self.state_object_schema = state_object_schema
+
+        self.result_state_df_type = StructType(
+            [
+                StructField("properties", StringType()),
+                StructField("keyRowAsUnsafe", BinaryType()),
+                StructField("object", BinaryType()),
+                StructField("oldTimeoutTimestamp", LongType()),
+            ]
+        )
+
+        self.result_state_pdf_arrow_type = to_arrow_type(self.result_state_df_type)
+        self.arrow_max_records_per_batch = arrow_max_records_per_batch
+
+    def load_stream(self, stream):
+        """
+        Read ArrowRecordBatches from stream, deserialize them to populate a list of pair
+        (data chunk, state), and convert the data into a list of pandas.Series.
+
+        Please refer the doc of inner function `gen_data_and_state` for more details how
+        this function works in overall.
+
+        In addition, this function further groups the return of `gen_data_and_state` by the state
+        instance (same semantic as grouping by grouping key) and produces an iterator of data
+        chunks for each group, so that the caller can lazily materialize the data chunk.
+        """
+
+        import pyarrow as pa
+        import json
+        from itertools import groupby
+        from pyspark.sql.streaming.state import GroupState
+
+        def construct_state(state_info_col):
+            """
+            Construct state instance from the value of state information column.
+            """
+
+            state_info_col_properties = state_info_col["properties"]
+            state_info_col_key_row = state_info_col["keyRowAsUnsafe"]
+            state_info_col_object = state_info_col["object"]
+
+            state_properties = json.loads(state_info_col_properties)
+            if state_info_col_object:
+                state_object = self.pickleSer.loads(state_info_col_object)
+            else:
+                state_object = None
+            state_properties["optionalValue"] = state_object
+
+            return GroupState(
+                keyAsUnsafe=state_info_col_key_row,
+                valueSchema=self.state_object_schema,
+                **state_properties,
+            )
+
+        def gen_data_and_state(batches):
+            """
+            Deserialize ArrowRecordBatches and return a generator of
+            `(a list of pandas.Series, state)`.
+
+            The logic on deserialization is following:
+
+            1. Read the entire data part from Arrow RecordBatch.
+            2. Read the entire state information part from Arrow RecordBatch.
+            3. Loop through each state information:
+               3.A. Extract the data out from entire data via the information of data range.
+               3.B. Construct a new state instance if the state information is the first occurrence
+                    for the current grouping key.
+               3.C. Leverage existing new state instance if the state instance is already available

Review Comment:
   Leverage the existing state instance if it is already available for the current grouping key...



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #37893: [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #37893:
URL: https://github.com/apache/spark/pull/37893#discussion_r976963424


##########
python/pyspark/worker.py:
##########
@@ -207,6 +209,89 @@ def wrapped(key_series, value_series):
     return lambda k, v: [(wrapped(k, v), to_arrow_type(return_type))]
 
 
+def wrap_grouped_map_pandas_udf_with_state(f, return_type):
+    """
+    Provides a new lambda instance wrapping user function of applyInPandasWithState.
+
+    The lambda instance receives (key series, iterator of value series, state) and performs
+    some conversion to be adapted with the signature of user function.
+
+    See the function doc of inner function `wrapped` for more details on what adapter does.
+    See the function doc of `mapper` function for
+    `eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE` for more details on
+    the input parameters of lambda function.
+
+    Along with the returned iterator, the lambda instance will also produce the return_type as
+    converted to the arrow schema.
+    """
+
+    def wrapped(key_series, value_series_gen, state):
+        """
+        Provide an adapter of the user function performing below:
+
+        - Extract the first value of all columns in key series and produce as a tuple.
+        - If the state has timed out, call the user function with empty pandas DataFrame.
+        - If not, construct a new generator which converts each element of value series to
+          pandas DataFrame (lazy evaluation), and call the user function with the generator
+        - Verify each element of returned iterator to check the schema of pandas DataFrame.
+        """
+        import pandas as pd
+
+        key = tuple(s[0] for s in key_series)
+
+        if state.hasTimedOut:
+            # Timeout processing pass empty iterator. Here we return an empty DataFrame instead.
+            values = [
+                pd.DataFrame(columns=pd.concat(next(value_series_gen), axis=1).columns),
+            ]
+        else:
+            values = (pd.concat(x, axis=1) for x in value_series_gen)
+
+        result_iter = f(key, values, state)
+
+        def verify_element(result):
+            if not isinstance(result, pd.DataFrame):
+                raise TypeError(
+                    "The type of element in return iterator of the user-defined function "
+                    "should be pandas.DataFrame, but is {}".format(type(result))
+                )
+            # the number of columns of result have to match the return type
+            # but it is fine for result to have no columns at all if it is empty
+            if not (
+                len(result.columns) == len(return_type) or len(result.columns) == 0 and result.empty

Review Comment:
   No it's not just you. I planned it but forgot it. Thanks for the pointer.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #37893: [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #37893:
URL: https://github.com/apache/spark/pull/37893#discussion_r975770234


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala:
##########
@@ -0,0 +1,214 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.python
+
+import org.apache.spark.TaskContext
+import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical.{EventTimeTimeout, ProcessingTimeTimeout}
+import org.apache.spark.sql.catalyst.plans.physical.Distribution
+import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode}
+import org.apache.spark.sql.execution.python.PandasGroupUtils.resolveArgOffsets
+import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP
+import org.apache.spark.sql.execution.streaming.state.FlatMapGroupsWithStateExecHelper.StateData
+import org.apache.spark.sql.execution.streaming.state.StateStore
+import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.ArrowUtils
+import org.apache.spark.util.CompletionIterator
+
+/**
+ * Physical operator for executing
+ * [[org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsInPandasWithState]]
+ *
+ * @param functionExpr function called on each group
+ * @param groupingAttributes used to group the data
+ * @param outAttributes used to define the output rows
+ * @param stateType used to serialize/deserialize state before calling `functionExpr`
+ * @param stateInfo `StatefulOperatorStateInfo` to identify the state store for a given operator.
+ * @param stateFormatVersion the version of state format.
+ * @param outputMode the output mode of `functionExpr`
+ * @param timeoutConf used to timeout groups that have not received data in a while
+ * @param batchTimestampMs processing timestamp of the current batch.
+ * @param eventTimeWatermark event time watermark for the current batch
+ * @param child logical plan of the underlying data
+ */
+case class FlatMapGroupsInPandasWithStateExec(

Review Comment:
   We always have a separate exec implementation for Scala/Java vs Python since the constructor parameters are different. (We are leveraging case class so difference of the constructor parameters warrants a new class.) So this is intentional. As a compromise we did the refactor to have FlatMapGroupsWithStateExecBase as a base class.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #37893: [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #37893:
URL: https://github.com/apache/spark/pull/37893#discussion_r975939807


##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -216,6 +218,104 @@ def applyInPandas(
         jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
         return DataFrame(jdf, self.session)
 
+    def applyInPandasWithState(
+        self,
+        func: "PandasGroupedMapFunctionWithState",
+        outputStructType: Union[StructType, str],
+        stateStructType: Union[StructType, str],
+        outputMode: str,
+        timeoutConf: str,
+    ) -> DataFrame:
+        """
+        Applies the given function to each group of data, while maintaining a user-defined
+        per-group state. The result Dataset will represent the flattened record returned by the
+        function.
+
+        For a streaming Dataset, the function will be invoked first for all input groups and then
+        for all timed out states where the input data is set to be empty. Updates to each group's
+        state will be saved across invocations.
+
+        The function should take parameters (key, Iterator[`pandas.DataFrame`], state) and
+        returns another Iterator[`pandas.DataFrame`]. The grouping key(s) will be passed as a tuple
+        of numpy data types, e.g., `numpy.int32` and `numpy.float64`. The state will be passed as
+        :class:`pyspark.sql.streaming.state.GroupState`.
+
+        For each group, all columns are passed together as `pandas.DataFrame` to the user-function,
+        and the returned `pandas.DataFrame` across all invocations are combined as a
+        :class:`DataFrame`. Note that the user function should loop through and process all
+        elements in the iterator. The user function should not make a guess of the number of
+        elements in the iterator.
+
+        The `outputStructType` should be a :class:`StructType` describing the schema of all
+        elements in the returned value, `pandas.DataFrame`. The column labels of all elements in
+        returned `pandas.DataFrame` must either match the field names in the defined schema if
+        specified as strings, or match the field data types by position if not strings,
+        e.g. integer indices.
+
+        The `stateStructType` should be :class:`StructType` describing the schema of the
+        user-defined state. The value of the state will be presented as a tuple, as well as the
+        update should be performed with the tuple. The corresponding Python types for
+        :class:DataType are supported. Please refer to the page
+        https://spark.apache.org/docs/latest/sql-ref-datatypes.html (python tab).
+
+        The size of each DataFrame in both the input and output can be arbitrary. The number of
+        DataFrames in both the input and output can also be arbitrary.
+
+        .. versionadded:: 3.4.0
+
+        Parameters
+        ----------
+        func : function
+            a Python native function to be called on every group. It should take parameters
+            (key, Iterator[`pandas.DataFrame`], state) and return Iterator[`pandas.DataFrame`].
+            Note that the type of the key is tuple and the type of the state is
+            :class:`pyspark.sql.streaming.state.GroupState`.
+        outputStructType : :class:`pyspark.sql.types.DataType` or str
+            the type of the output records. The value can be either a
+            :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
+        stateStructType : :class:`pyspark.sql.types.DataType` or str
+            the type of the user-defined state. The value can be either a
+            :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
+        outputMode : str
+            the output mode of the function.
+        timeoutConf : str
+            timeout configuration for groups that do not receive data for a while. valid values
+            are defined in :class:`pyspark.sql.streaming.state.GroupStateTimeout`.
+
+        # TODO: Examples

Review Comment:
   https://issues.apache.org/jira/browse/SPARK-40509



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #37893: [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #37893:
URL: https://github.com/apache/spark/pull/37893#discussion_r975301743


##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -216,6 +218,104 @@ def applyInPandas(
         jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
         return DataFrame(jdf, self.session)
 
+    def applyInPandasWithState(
+        self,
+        func: "PandasGroupedMapFunctionWithState",
+        outputStructType: Union[StructType, str],
+        stateStructType: Union[StructType, str],
+        outputMode: str,
+        timeoutConf: str,
+    ) -> DataFrame:
+        """
+        Applies the given function to each group of data, while maintaining a user-defined
+        per-group state. The result Dataset will represent the flattened record returned by the
+        function.
+
+        For a streaming Dataset, the function will be invoked first for all input groups and then
+        for all timed out states where the input data is set to be empty. Updates to each group's
+        state will be saved across invocations.
+
+        The function should take parameters (key, Iterator[`pandas.DataFrame`], state) and
+        returns another Iterator[`pandas.DataFrame`]. The grouping key(s) will be passed as a tuple
+        of numpy data types, e.g., `numpy.int32` and `numpy.float64`. The state will be passed as
+        :class:`pyspark.sql.streaming.state.GroupState`.
+
+        For each group, all columns are passed together as `pandas.DataFrame` to the user-function,
+        and the returned `pandas.DataFrame` across all invocations are combined as a
+        :class:`DataFrame`. Note that the user function should loop through and process all
+        elements in the iterator. The user function should not make a guess of the number of
+        elements in the iterator.
+
+        The `outputStructType` should be a :class:`StructType` describing the schema of all
+        elements in the returned value, `pandas.DataFrame`. The column labels of all elements in
+        returned `pandas.DataFrame` must either match the field names in the defined schema if
+        specified as strings, or match the field data types by position if not strings,
+        e.g. integer indices.
+
+        The `stateStructType` should be :class:`StructType` describing the schema of the
+        user-defined state. The value of the state will be presented as a tuple, as well as the
+        update should be performed with the tuple. The corresponding Python types for
+        :class:DataType are supported. Please refer to the page
+        https://spark.apache.org/docs/latest/sql-ref-datatypes.html (python tab).
+
+        The size of each DataFrame in both the input and output can be arbitrary. The number of
+        DataFrames in both the input and output can also be arbitrary.
+
+        .. versionadded:: 3.4.0
+
+        Parameters
+        ----------
+        func : function
+            a Python native function to be called on every group. It should take parameters
+            (key, Iterator[`pandas.DataFrame`], state) and return Iterator[`pandas.DataFrame`].
+            Note that the type of the key is tuple and the type of the state is
+            :class:`pyspark.sql.streaming.state.GroupState`.
+        outputStructType : :class:`pyspark.sql.types.DataType` or str
+            the type of the output records. The value can be either a
+            :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
+        stateStructType : :class:`pyspark.sql.types.DataType` or str
+            the type of the user-defined state. The value can be either a
+            :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
+        outputMode : str
+            the output mode of the function.
+        timeoutConf : str
+            timeout configuration for groups that do not receive data for a while. valid values
+            are defined in :class:`pyspark.sql.streaming.state.GroupStateTimeout`.
+
+        # TODO: Examples

Review Comment:
   This is something I still need to do - let me come up with some examples. I guess we probably can't run automated test from the example section though.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #37893: [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #37893:
URL: https://github.com/apache/spark/pull/37893#discussion_r975902646


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala:
##########
@@ -311,6 +323,56 @@ object UnsupportedOperationChecker extends Logging {
             }
           }
 
+        // applyInPandasWithState
+        case m: FlatMapGroupsInPandasWithState if m.isStreaming =>
+          // Check compatibility with output modes and aggregations in query
+          val aggsInQuery = collectStreamingAggregates(plan)
+
+          if (aggsInQuery.isEmpty) {
+            // applyInPandasWithState without aggregation: operation's output mode must

Review Comment:
   Now I can imagine the case which current requirement of providing separate output mode prevents the unintentional behavior:
   
   - They implemented the user function for flatMapGroupsWithState with append mode.
   - They ran the query with append mode.
   - After that, they changed the output mode for the query to update mode for some reason.
   - The user function is unchanged to account the change of update mode.
   
   We haven't allowed the query to run as of now, and we are going to allow the query to run if we drop the parameter.
   
   PS. I'm not a believer that end users can implement their user function accordingly based on output mode, but that is a fundamental API design issue of original flatMapGroupsWithState which is separate one.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #37893: [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #37893:
URL: https://github.com/apache/spark/pull/37893#discussion_r975902646


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala:
##########
@@ -311,6 +323,56 @@ object UnsupportedOperationChecker extends Logging {
             }
           }
 
+        // applyInPandasWithState
+        case m: FlatMapGroupsInPandasWithState if m.isStreaming =>
+          // Check compatibility with output modes and aggregations in query
+          val aggsInQuery = collectStreamingAggregates(plan)
+
+          if (aggsInQuery.isEmpty) {
+            // applyInPandasWithState without aggregation: operation's output mode must

Review Comment:
   Now I can imagine the case which can prevent the unintentional behavior:
   
   - They implemented the user function for flatMapGroupsWithState with append mode.
   - They ran the query with append mode.
   - After that, they changed the output mode to update mode for some reason.
   - The user function is unchanged to account the change of update mode.
   
   We haven't allowed the query to run as of now, and we are going to allow the query to run if we drop the parameter.
   
   PS. I'm not a believer that end users can implement their user function accordingly based on output mode, but that is a fundamental API design issue which is separate.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #37893: [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #37893:
URL: https://github.com/apache/spark/pull/37893#discussion_r973864734


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala:
##########
@@ -2705,6 +2705,44 @@ object SQLConf {
       .booleanConf
       .createWithDefault(false)
 
+  val MAP_PANDAS_UDF_WITH_STATE_SOFT_LIMIT_SIZE_PER_BATCH =
+    buildConf("spark.sql.execution.applyInPandasWithState.softLimitSizePerBatch")
+      .internal()
+      .doc("When using applyInPandasWithState, set a soft limit of the accumulated size of " +
+        "records that can be written to a single ArrowRecordBatch in memory. This is used to " +
+        "restrict the amount of memory being used to materialize the data in both executor and " +
+        "Python worker. The accumulated size of records are calculated via sampling a set of " +
+        "records. Splitting the ArrowRecordBatch is performed per record, so unless a record " +
+        "is quite huge, the size of constructed ArrowRecordBatch will be around the " +
+        "configured value.")
+      .version("3.4.0")
+      .bytesConf(ByteUnit.BYTE)
+      .createWithDefaultString("64MB")
+
+  val MAP_PANDAS_UDF_WITH_STATE_MIN_DATA_COUNT_FOR_SAMPLE =
+    buildConf("spark.sql.execution.applyInPandasWithState.minDataCountForSample")
+      .internal()
+      .doc("When using applyInPandasWithState, specify the minimum number of records to sample " +
+        "the size of record. The size being retrieved from sampling will be used to estimate " +
+        "the accumulated size of records. Note that limiting by size does not work if the " +
+        "number of records are less than the configured value. For such case, ArrowRecordBatch " +
+        "will only be split for soft timeout.")
+      .version("3.4.0")
+      .intConf
+      .createWithDefault(100)
+
+  val MAP_PANDAS_UDF_WITH_STATE_SOFT_TIMEOUT_PURGE_BATCH =
+    buildConf("spark.sql.execution.applyInPandasWithState.softTimeoutPurgeBatch")
+      .internal()
+      .doc("When using applyInPandasWithState, specify the soft timeout for purging the " +
+        "ArrowRecordBatch. If batching records exceeds the timeout, Spark will force splitting " +
+        "the ArrowRecordBatch regardless of estimated size. This config ensures the receiver " +
+        "of data (both executor and Python worker) to not wait indefinitely for sender to " +
+        "complete the ArrowRecordBatch, which may hurt both throughput and latency.")
+      .version("3.4.0")
+      .timeConf(TimeUnit.MILLISECONDS)
+      .createWithDefaultString("100ms")

Review Comment:
   This config is to have two different aspects of closing the arrow record batch, 1) size 2) time on batching.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #37893: [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #37893:
URL: https://github.com/apache/spark/pull/37893#discussion_r973862455


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala:
##########
@@ -2705,6 +2705,44 @@ object SQLConf {
       .booleanConf
       .createWithDefault(false)
 
+  val MAP_PANDAS_UDF_WITH_STATE_SOFT_LIMIT_SIZE_PER_BATCH =
+    buildConf("spark.sql.execution.applyInPandasWithState.softLimitSizePerBatch")
+      .internal()
+      .doc("When using applyInPandasWithState, set a soft limit of the accumulated size of " +
+        "records that can be written to a single ArrowRecordBatch in memory. This is used to " +
+        "restrict the amount of memory being used to materialize the data in both executor and " +
+        "Python worker. The accumulated size of records are calculated via sampling a set of " +
+        "records. Splitting the ArrowRecordBatch is performed per record, so unless a record " +
+        "is quite huge, the size of constructed ArrowRecordBatch will be around the " +
+        "configured value.")
+      .version("3.4.0")
+      .bytesConf(ByteUnit.BYTE)
+      .createWithDefaultString("64MB")

Review Comment:
   Batching has multiple purposes - here we do this for scalability, meaning it'd be closer to the purpose if we can batch with size rather than the number of rows. I'm OK with changing the condition on cutting out arrow batch to the number of rows, as it's configurable hence users can adjust it to smaller if they encounter the memory issue in any way.
   
   cc. @alex-balikov Does this make sense to you?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #37893: [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #37893:
URL: https://github.com/apache/spark/pull/37893#discussion_r974856402


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala:
##########
@@ -0,0 +1,246 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import scala.collection.JavaConverters._
+
+import org.apache.arrow.vector.{FieldVector, VectorSchemaRoot}
+import org.apache.arrow.vector.ipc.ArrowStreamWriter
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.api.python.PythonSQLUtils
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow}
+import org.apache.spark.sql.execution.arrow.ArrowWriter
+import org.apache.spark.sql.execution.arrow.ArrowWriter.createFieldWriter
+import org.apache.spark.sql.execution.streaming.GroupStateImpl
+import org.apache.spark.sql.types.{BinaryType, BooleanType, IntegerType, StringType, StructField, StructType}
+import org.apache.spark.unsafe.types.UTF8String
+
+/**
+ * This class abstracts the complexity on constructing Arrow RecordBatches for data and state with
+ * bin-packing and chunking. The caller only need to call the proper public methods of this class
+ * `startNewGroup`, `writeRow`, `finalizeGroup`, `finalizeData` and this class will write the data
+ * and state into Arrow RecordBatches with performing bin-pack and chunk internally.
+ *
+ * This class requires that the parameter `root` has initialized with the Arrow schema like below:
+ * - data fields
+ * - state field
+ *   - nested schema (Refer ApplyInPandasWithStateWriter.STATE_METADATA_SCHEMA)
+ *
+ * Please refer the code comment in the implementation to see how the writes of data and state
+ * against Arrow RecordBatch work with consideration of bin-packing and chunking.
+ */
+class ApplyInPandasWithStateWriter(
+    root: VectorSchemaRoot,
+    writer: ArrowStreamWriter,
+    softLimitBytesPerBatch: Long,
+    minDataCountForSample: Int,
+    softTimeoutMillsPurgeBatch: Long) {
+
+  import ApplyInPandasWithStateWriter._
+
+  // We logically group the columns by family (data vs state) and initialize writer separately,
+  // since it's lot more easier and probably performant to write the row directly rather than
+  // projecting the row to match up with the overall schema.
+  //
+  // The number of data rows and state metadata rows can be different which could be problematic

Review Comment:
   We use a single Arrow RecordBatch for both data and state - I'll mention this and also the rationalization explicitly in the code comment.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] alex-balikov commented on a diff in pull request #37893: [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

Posted by GitBox <gi...@apache.org>.
alex-balikov commented on code in PR #37893:
URL: https://github.com/apache/spark/pull/37893#discussion_r974517188


##########
python/pyspark/worker.py:
##########
@@ -361,6 +429,32 @@ def read_udfs(pickleSer, infile, eval_type):
 
         if eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
             ser = CogroupUDFSerializer(timezone, safecheck, assign_cols_by_name)
+        elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE:
+            soft_limit_bytes_per_batch = runner_conf.get(
+                "spark.sql.execution.applyInPandasWithState.softLimitSizePerBatch",
+                (64 * 1024 * 1024),
+            )
+            soft_limit_bytes_per_batch = int(soft_limit_bytes_per_batch)
+
+            min_data_count_for_sample = runner_conf.get(
+                "spark.sql.execution.applyInPandasWithState.minDataCountForSample", 100

Review Comment:
   similar comment about the property names and default values here and everywhere else - can they be defined in a more prominent place 



##########
python/pyspark/worker.py:
##########
@@ -361,6 +429,32 @@ def read_udfs(pickleSer, infile, eval_type):
 
         if eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
             ser = CogroupUDFSerializer(timezone, safecheck, assign_cols_by_name)
+        elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE:
+            soft_limit_bytes_per_batch = runner_conf.get(
+                "spark.sql.execution.applyInPandasWithState.softLimitSizePerBatch",
+                (64 * 1024 * 1024),

Review Comment:
   can the default be value be defined in some more prominent place? Also the property names.



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala:
##########
@@ -311,6 +323,56 @@ object UnsupportedOperationChecker extends Logging {
             }
           }
 
+        // applyInPandasWithState
+        case m: FlatMapGroupsInPandasWithState if m.isStreaming =>
+          // Check compatibility with output modes and aggregations in query
+          val aggsInQuery = collectStreamingAggregates(plan)
+
+          if (aggsInQuery.isEmpty) {
+            // applyInPandasWithState without aggregation: operation's output mode must

Review Comment:
   why do we even have operation output mode. We are defining a new api, can we just drop this parameter from the api if we are going to be enforcing for it t match the output mode?



##########
python/pyspark/worker.py:
##########
@@ -361,6 +429,32 @@ def read_udfs(pickleSer, infile, eval_type):
 
         if eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
             ser = CogroupUDFSerializer(timezone, safecheck, assign_cols_by_name)
+        elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE:
+            soft_limit_bytes_per_batch = runner_conf.get(
+                "spark.sql.execution.applyInPandasWithState.softLimitSizePerBatch",
+                (64 * 1024 * 1024),
+            )
+            soft_limit_bytes_per_batch = int(soft_limit_bytes_per_batch)
+
+            min_data_count_for_sample = runner_conf.get(
+                "spark.sql.execution.applyInPandasWithState.minDataCountForSample", 100
+            )
+            min_data_count_for_sample = int(min_data_count_for_sample)
+
+            soft_timeout_millis_purge_batch = runner_conf.get(
+                "spark.sql.execution.applyInPandasWithState.softTimeoutPurgeBatch", 100

Review Comment:
   same



##########
python/pyspark/worker.py:
##########
@@ -207,6 +209,65 @@ def wrapped(key_series, value_series):
     return lambda k, v: [(wrapped(k, v), to_arrow_type(return_type))]
 
 
+def wrap_grouped_map_pandas_udf_with_state(f, return_type):
+    def wrapped(key_series, value_series_gen, state):
+        import pandas as pd
+
+        key = tuple(s[0] for s in key_series)
+
+        if state.hasTimedOut:
+            # Timeout processing pass empty iterator. Here we return an empty DataFrame instead.
+            values = [
+                pd.DataFrame(columns=pd.concat(next(value_series_gen), axis=1).columns),
+            ]
+        else:
+            values = (pd.concat(x, axis=1) for x in value_series_gen)
+
+        result_iter = f(key, values, state)
+
+        def verify_element(result):
+            if not isinstance(result, pd.DataFrame):
+                raise TypeError(
+                    "The type of element in return iterator of the user-defined function "
+                    "should be pandas.DataFrame, but is {}".format(type(result))
+                )
+            # the number of columns of result have to match the return type
+            # but it is fine for result to have no columns at all if it is empty
+            if not (

Review Comment:
   if not ... ?



##########
python/pyspark/worker.py:
##########
@@ -486,6 +580,35 @@ def mapper(a):
             vals = [a[o] for o in parsed_offsets[0][1]]
             return f(keys, vals)
 
+    elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE:
+        # We assume there is only one UDF here because grouped map doesn't
+        # support combining multiple UDFs.
+        assert num_udfs == 1
+
+        # See FlatMapGroupsInPandas(WithState)Exec for how arg_offsets are used to
+        # distinguish between grouping attributes and data attributes
+        arg_offsets, f = read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index=0)
+        parsed_offsets = extract_key_value_indexes(arg_offsets)
+
+        def mapper(a):

Review Comment:
   method level comment with param types and semantics.



##########
python/pyspark/worker.py:
##########
@@ -361,6 +429,32 @@ def read_udfs(pickleSer, infile, eval_type):
 
         if eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
             ser = CogroupUDFSerializer(timezone, safecheck, assign_cols_by_name)
+        elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE:
+            soft_limit_bytes_per_batch = runner_conf.get(
+                "spark.sql.execution.applyInPandasWithState.softLimitSizePerBatch",

Review Comment:
   I do not think 'soft' is necessary i the parameter name. Leave that for the comment describing that this is a soft limit.



##########
python/pyspark/worker.py:
##########
@@ -207,6 +209,65 @@ def wrapped(key_series, value_series):
     return lambda k, v: [(wrapped(k, v), to_arrow_type(return_type))]
 
 
+def wrap_grouped_map_pandas_udf_with_state(f, return_type):

Review Comment:
   method level comments with parameter types and semantics.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #37893: [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #37893:
URL: https://github.com/apache/spark/pull/37893#discussion_r974784396


##########
python/pyspark/sql/pandas/serializers.py:
##########
@@ -371,3 +375,292 @@ def load_stream(self, stream):
                 raise ValueError(
                     "Invalid number of pandas.DataFrames in group {0}".format(dataframes_in_group)
                 )
+
+
+class ApplyInPandasWithStateSerializer(ArrowStreamPandasUDFSerializer):

Review Comment:
   Done.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #37893: [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #37893:
URL: https://github.com/apache/spark/pull/37893#discussion_r975780795


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala:
##########
@@ -0,0 +1,214 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.python
+
+import org.apache.spark.TaskContext
+import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical.{EventTimeTimeout, ProcessingTimeTimeout}
+import org.apache.spark.sql.catalyst.plans.physical.Distribution
+import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode}
+import org.apache.spark.sql.execution.python.PandasGroupUtils.resolveArgOffsets
+import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP
+import org.apache.spark.sql.execution.streaming.state.FlatMapGroupsWithStateExecHelper.StateData
+import org.apache.spark.sql.execution.streaming.state.StateStore
+import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.ArrowUtils
+import org.apache.spark.util.CompletionIterator
+
+/**
+ * Physical operator for executing
+ * [[org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsInPandasWithState]]
+ *
+ * @param functionExpr function called on each group
+ * @param groupingAttributes used to group the data
+ * @param outAttributes used to define the output rows
+ * @param stateType used to serialize/deserialize state before calling `functionExpr`
+ * @param stateInfo `StatefulOperatorStateInfo` to identify the state store for a given operator.
+ * @param stateFormatVersion the version of state format.
+ * @param outputMode the output mode of `functionExpr`
+ * @param timeoutConf used to timeout groups that have not received data in a while
+ * @param batchTimestampMs processing timestamp of the current batch.
+ * @param eventTimeWatermark event time watermark for the current batch
+ * @param child logical plan of the underlying data
+ */
+case class FlatMapGroupsInPandasWithStateExec(
+    functionExpr: Expression,
+    groupingAttributes: Seq[Attribute],
+    outAttributes: Seq[Attribute],
+    stateType: StructType,
+    stateInfo: Option[StatefulOperatorStateInfo],
+    stateFormatVersion: Int,
+    outputMode: OutputMode,
+    timeoutConf: GroupStateTimeout,
+    batchTimestampMs: Option[Long],
+    eventTimeWatermark: Option[Long],
+    child: SparkPlan) extends UnaryExecNode with FlatMapGroupsWithStateExecBase {
+
+  // TODO(SPARK-40444): Add the support of initial state.
+  override protected val initialStateDeserializer: Expression = null
+  override protected val initialStateGroupAttrs: Seq[Attribute] = null
+  override protected val initialStateDataAttrs: Seq[Attribute] = null
+  override protected val initialState: SparkPlan = null
+  override protected val hasInitialState: Boolean = false
+
+  override protected val stateEncoder: ExpressionEncoder[Any] =
+    RowEncoder(stateType).resolveAndBind().asInstanceOf[ExpressionEncoder[Any]]
+
+  override def output: Seq[Attribute] = outAttributes
+
+  private val sessionLocalTimeZone = conf.sessionLocalTimeZone
+  private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
+
+  private val pythonFunction = functionExpr.asInstanceOf[PythonUDF].func
+  private val chainedFunc = Seq(ChainedPythonFunctions(Seq(pythonFunction)))
+  private lazy val (dedupAttributes, argOffsets) = resolveArgOffsets(
+    groupingAttributes ++ child.output, groupingAttributes)
+  private lazy val unsafeProj = UnsafeProjection.create(dedupAttributes, child.output)
+
+  override def requiredChildDistribution: Seq[Distribution] =
+    StatefulOperatorPartitioning.getCompatibleDistribution(
+      groupingAttributes, getStateInfo, conf) :: Nil
+
+  override def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq(
+    groupingAttributes.map(SortOrder(_, Ascending)))
+
+  override def shortName: String = "applyInPandasWithState"
+
+  override protected def withNewChildInternal(
+      newChild: SparkPlan): FlatMapGroupsInPandasWithStateExec = copy(child = newChild)
+
+  override def createInputProcessor(
+      store: StateStore): InputProcessor = new InputProcessor(store: StateStore) {
+
+    override def processNewData(dataIter: Iterator[InternalRow]): Iterator[InternalRow] = {
+      val groupedIter = GroupedIterator(dataIter, groupingAttributes, child.output)
+      val processIter = groupedIter.map { case (keyRow, valueRowIter) =>
+        val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow]
+        val stateData = stateManager.getState(store, keyUnsafeRow)
+        (keyUnsafeRow, stateData, valueRowIter.map(unsafeProj))
+      }
+
+      process(processIter, hasTimedOut = false)
+    }
+
+    override def processNewDataWithInitialState(
+        childDataIter: Iterator[InternalRow],
+        initStateIter: Iterator[InternalRow]): Iterator[InternalRow] = {
+      throw new UnsupportedOperationException("Should not reach here!")
+    }
+
+    override def processTimedOutState(): Iterator[InternalRow] = {
+      if (isTimeoutEnabled) {
+        val timeoutThreshold = timeoutConf match {
+          case ProcessingTimeTimeout => batchTimestampMs.get
+          case EventTimeTimeout => eventTimeWatermark.get
+          case _ =>
+            throw new IllegalStateException(
+              s"Cannot filter timed out keys for $timeoutConf")
+        }
+        val timingOutPairs = stateManager.getAllState(store).filter { state =>
+          state.timeoutTimestamp != NO_TIMESTAMP && state.timeoutTimestamp < timeoutThreshold
+        }
+
+        val processIter = timingOutPairs.map { stateData =>
+          val joinedKeyRow = unsafeProj(
+            new JoinedRow(
+              stateData.keyRow,
+              new GenericInternalRow(Array.fill(dedupAttributes.length)(null: Any))))
+
+          (stateData.keyRow, stateData, Iterator.single(joinedKeyRow))
+        }
+
+        process(processIter, hasTimedOut = true)
+      } else Iterator.empty
+    }
+
+    private def process(
+        iter: Iterator[(UnsafeRow, StateData, Iterator[InternalRow])],
+        hasTimedOut: Boolean): Iterator[InternalRow] = {
+      val runner = new ApplyInPandasWithStatePythonRunner(
+        chainedFunc,
+        PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE,
+        Array(argOffsets),
+        StructType.fromAttributes(dedupAttributes),
+        sessionLocalTimeZone,
+        pythonRunnerConf,
+        stateEncoder.asInstanceOf[ExpressionEncoder[Row]],
+        groupingAttributes.toStructType,
+        child.output.toStructType,
+        stateType)
+
+      val context = TaskContext.get()
+
+      val processIter = iter.map { case (keyRow, stateData, valueIter) =>
+        val groupedState = GroupStateImpl.createForStreaming(
+          Option(stateData.stateObj).map { r => assert(r.isInstanceOf[Row]); r },
+          batchTimestampMs.getOrElse(NO_TIMESTAMP),
+          eventTimeWatermark.getOrElse(NO_TIMESTAMP),
+          timeoutConf,
+          hasTimedOut = hasTimedOut,
+          watermarkPresent).asInstanceOf[GroupStateImpl[Row]]
+        (keyRow, groupedState, valueIter)
+      }
+      runner.compute(processIter, context.partitionId(), context).flatMap {
+        case (stateIter, outputIter) =>
+          // When the iterator is consumed, then write changes to state.
+          // state does not affect each others, hence when to update does not affect to the result.
+          def onIteratorCompletion: Unit = {
+            stateIter.foreach { case (keyRow, newGroupState, oldTimeoutTimestamp) =>
+              if (newGroupState.isRemoved && !newGroupState.getTimeoutTimestampMs.isPresent()) {
+                stateManager.removeState(store, keyRow)
+                numRemovedStateRows += 1
+              } else {
+                val currentTimeoutTimestamp = newGroupState.getTimeoutTimestampMs
+                  .orElse(NO_TIMESTAMP)
+                val hasTimeoutChanged = currentTimeoutTimestamp != oldTimeoutTimestamp
+                val shouldWriteState = newGroupState.isUpdated || newGroupState.isRemoved ||
+                  hasTimeoutChanged
+
+                if (shouldWriteState) {

Review Comment:
   > basically if the state was removed but there is still timeout set? Will you keep the user state object around till the timeout fires?
   
   I'm not 100% understanding the intention of the original codebase, but it seems so.
   
   Here the removal of state is removal of "value object" of the state. We don't allow users to set "null" on value object, hence removal of state is the only way to clear the value object. In the meanwhile, we seem to still allow setting the timeout with state having undefined value object.
   
   The status of the state would be the same when you start with new state and only set the timeout without setting the value object. Given we allow this, above case probably has to be allowed as well.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #37893: [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #37893:
URL: https://github.com/apache/spark/pull/37893#discussion_r975770234


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala:
##########
@@ -0,0 +1,214 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.python
+
+import org.apache.spark.TaskContext
+import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical.{EventTimeTimeout, ProcessingTimeTimeout}
+import org.apache.spark.sql.catalyst.plans.physical.Distribution
+import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode}
+import org.apache.spark.sql.execution.python.PandasGroupUtils.resolveArgOffsets
+import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP
+import org.apache.spark.sql.execution.streaming.state.FlatMapGroupsWithStateExecHelper.StateData
+import org.apache.spark.sql.execution.streaming.state.StateStore
+import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.ArrowUtils
+import org.apache.spark.util.CompletionIterator
+
+/**
+ * Physical operator for executing
+ * [[org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsInPandasWithState]]
+ *
+ * @param functionExpr function called on each group
+ * @param groupingAttributes used to group the data
+ * @param outAttributes used to define the output rows
+ * @param stateType used to serialize/deserialize state before calling `functionExpr`
+ * @param stateInfo `StatefulOperatorStateInfo` to identify the state store for a given operator.
+ * @param stateFormatVersion the version of state format.
+ * @param outputMode the output mode of `functionExpr`
+ * @param timeoutConf used to timeout groups that have not received data in a while
+ * @param batchTimestampMs processing timestamp of the current batch.
+ * @param eventTimeWatermark event time watermark for the current batch
+ * @param child logical plan of the underlying data
+ */
+case class FlatMapGroupsInPandasWithStateExec(

Review Comment:
   We always have a separate exec implementation for Scala/Java vs Python since the constructor parameters are different. (We are leveraging case class for logical/physical plan, so difference of the constructor parameters warrants a new class.) So this is intentional. As a compromise we did the refactor to have FlatMapGroupsWithStateExecBase as a base class.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #37893: [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #37893:
URL: https://github.com/apache/spark/pull/37893#discussion_r975902646


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala:
##########
@@ -311,6 +323,56 @@ object UnsupportedOperationChecker extends Logging {
             }
           }
 
+        // applyInPandasWithState
+        case m: FlatMapGroupsInPandasWithState if m.isStreaming =>
+          // Check compatibility with output modes and aggregations in query
+          val aggsInQuery = collectStreamingAggregates(plan)
+
+          if (aggsInQuery.isEmpty) {
+            // applyInPandasWithState without aggregation: operation's output mode must

Review Comment:
   Now I can imagine the case which can prevent the unintentional behavior:
   
   - They implemented the user function for flatMapGroupsWithState with append mode.
   - They ran the query with append mode.
   - After that, they changed the output mode to update mode for some reason.
   - The user function is unchanged to account the change of update mode.
   
   We haven't allowed the query to run as of now, and we are going to allow the query to run if we drop the parameter.
   
   PS. I'm not a believer that end users can implement their user function accordingly based on output mode, but that is a fundamental API design issue of original flatMapGroupsWithState which is separate one.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #37893: [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #37893:
URL: https://github.com/apache/spark/pull/37893#discussion_r974905452


##########
python/pyspark/worker.py:
##########
@@ -207,6 +209,65 @@ def wrapped(key_series, value_series):
     return lambda k, v: [(wrapped(k, v), to_arrow_type(return_type))]
 
 
+def wrap_grouped_map_pandas_udf_with_state(f, return_type):
+    def wrapped(key_series, value_series_gen, state):
+        import pandas as pd
+
+        key = tuple(s[0] for s in key_series)
+
+        if state.hasTimedOut:
+            # Timeout processing pass empty iterator. Here we return an empty DataFrame instead.
+            values = [
+                pd.DataFrame(columns=pd.concat(next(value_series_gen), axis=1).columns),
+            ]
+        else:
+            values = (pd.concat(x, axis=1) for x in value_series_gen)
+
+        result_iter = f(key, values, state)
+
+        def verify_element(result):
+            if not isinstance(result, pd.DataFrame):
+                raise TypeError(
+                    "The type of element in return iterator of the user-defined function "
+                    "should be pandas.DataFrame, but is {}".format(type(result))
+                )
+            # the number of columns of result have to match the return type
+            # but it is fine for result to have no columns at all if it is empty
+            if not (

Review Comment:
   This is borrowed from above function - I think we took `if not` here because it's more intuitive and easier to think of "valid" case and apply "not" to reverse, rather than manually convert the conditions to be the contraposition.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on pull request #37893: [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on PR #37893:
URL: https://github.com/apache/spark/pull/37893#issuecomment-1254478824

   Thanks @HyukjinKwon and @alex-balikov for thoughtful reviewing and merging!
   I'll handle the latest comments as a follow-up PR.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on pull request #37893: [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on PR #37893:
URL: https://github.com/apache/spark/pull/37893#issuecomment-1254477957

   My comments are just nits. I will merge this in first to move forward.
   
   
   Merged to master.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #37893: [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #37893:
URL: https://github.com/apache/spark/pull/37893#discussion_r973864021


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala:
##########
@@ -0,0 +1,197 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import java.io._
+
+import scala.collection.JavaConverters._
+
+import org.apache.arrow.vector.VectorSchemaRoot
+import org.apache.arrow.vector.ipc.ArrowStreamWriter
+import org.json4s._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.api.python._
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.api.python.PythonSQLUtils
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.execution.python.ApplyInPandasWithStatePythonRunner.{InType, OutType, OutTypeForState, STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER}
+import org.apache.spark.sql.execution.python.ApplyInPandasWithStateWriter.STATE_METADATA_SCHEMA
+import org.apache.spark.sql.execution.streaming.GroupStateImpl
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch}
+
+
+/**
+ * [[ArrowPythonRunner]] with [[org.apache.spark.sql.streaming.GroupState]].
+ */
+class ApplyInPandasWithStatePythonRunner(
+    funcs: Seq[ChainedPythonFunctions],
+    evalType: Int,
+    argOffsets: Array[Array[Int]],
+    inputSchema: StructType,
+    override protected val timeZoneId: String,
+    initialWorkerConf: Map[String, String],
+    stateEncoder: ExpressionEncoder[Row],
+    keySchema: StructType,
+    valueSchema: StructType,
+    stateValueSchema: StructType,
+    softLimitBytesPerBatch: Long,
+    minDataCountForSample: Int,
+    softTimeoutMillsPurgeBatch: Long)
+  extends BasePythonRunner[InType, OutType](funcs, evalType, argOffsets)
+  with PythonArrowInput[InType]
+  with PythonArrowOutput[OutType] {
+
+  override protected val schema: StructType = inputSchema.add("!__state__!", STATE_METADATA_SCHEMA)

Review Comment:
   Yeah that's a good point. Will remove `!` in left and right, and probably also remove `__` in right.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #37893: [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #37893:
URL: https://github.com/apache/spark/pull/37893#discussion_r973868967


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala:
##########
@@ -2705,6 +2705,44 @@ object SQLConf {
       .booleanConf
       .createWithDefault(false)
 
+  val MAP_PANDAS_UDF_WITH_STATE_SOFT_LIMIT_SIZE_PER_BATCH =
+    buildConf("spark.sql.execution.applyInPandasWithState.softLimitSizePerBatch")
+      .internal()
+      .doc("When using applyInPandasWithState, set a soft limit of the accumulated size of " +
+        "records that can be written to a single ArrowRecordBatch in memory. This is used to " +
+        "restrict the amount of memory being used to materialize the data in both executor and " +
+        "Python worker. The accumulated size of records are calculated via sampling a set of " +
+        "records. Splitting the ArrowRecordBatch is performed per record, so unless a record " +
+        "is quite huge, the size of constructed ArrowRecordBatch will be around the " +
+        "configured value.")
+      .version("3.4.0")
+      .bytesConf(ByteUnit.BYTE)
+      .createWithDefaultString("64MB")

Review Comment:
   Ah, SPARK-23258 is about restricting arrow record batch to size, seems similar with what we propose in this PR. It's still questionable if we calculate in every addition of row (accurate but would be super bad on performance) or do sampling as we do here (cannot be accurate and err might be non-trivial with variable-length columns).



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #37893: [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #37893:
URL: https://github.com/apache/spark/pull/37893#discussion_r974858058


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala:
##########
@@ -0,0 +1,201 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import java.io._
+
+import scala.collection.JavaConverters._
+
+import org.apache.arrow.vector.VectorSchemaRoot
+import org.apache.arrow.vector.ipc.ArrowStreamWriter
+import org.json4s._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.api.python._
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.api.python.PythonSQLUtils
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.execution.python.ApplyInPandasWithStatePythonRunner.{InType, OutType, OutTypeForState, STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER}
+import org.apache.spark.sql.execution.python.ApplyInPandasWithStateWriter.STATE_METADATA_SCHEMA
+import org.apache.spark.sql.execution.streaming.GroupStateImpl
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch}
+
+
+/**
+ * A variant implementation of [[ArrowPythonRunner]] to serve the operation
+ * applyInPandasWithState.
+ */
+class ApplyInPandasWithStatePythonRunner(
+    funcs: Seq[ChainedPythonFunctions],
+    evalType: Int,
+    argOffsets: Array[Array[Int]],
+    inputSchema: StructType,
+    override protected val timeZoneId: String,
+    initialWorkerConf: Map[String, String],
+    stateEncoder: ExpressionEncoder[Row],
+    keySchema: StructType,
+    valueSchema: StructType,
+    stateValueSchema: StructType,
+    softLimitBytesPerBatch: Long,
+    minDataCountForSample: Int,
+    softTimeoutMillsPurgeBatch: Long)
+  extends BasePythonRunner[InType, OutType](funcs, evalType, argOffsets)
+  with PythonArrowInput[InType]
+  with PythonArrowOutput[OutType] {
+
+  override protected val schema: StructType = inputSchema.add("__state", STATE_METADATA_SCHEMA)
+
+  override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback
+
+  override val bufferSize: Int = SQLConf.get.pandasUDFBufferSize
+  require(
+    bufferSize >= 4,

Review Comment:
   Never mind. I just talked with @HyukjinKwon and understood how buffer works (I misunderstood) - it's more about how many small Arrow RecordBatches to buffer and flush at once for efficiency. An Arrow RecordBatch bigger than buffer will be still considered as a single Arrow RecordBatch.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] alex-balikov commented on a diff in pull request #37893: [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

Posted by GitBox <gi...@apache.org>.
alex-balikov commented on code in PR #37893:
URL: https://github.com/apache/spark/pull/37893#discussion_r975687838


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala:
##########
@@ -0,0 +1,214 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.python
+
+import org.apache.spark.TaskContext
+import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical.{EventTimeTimeout, ProcessingTimeTimeout}
+import org.apache.spark.sql.catalyst.plans.physical.Distribution
+import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode}
+import org.apache.spark.sql.execution.python.PandasGroupUtils.resolveArgOffsets
+import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP
+import org.apache.spark.sql.execution.streaming.state.FlatMapGroupsWithStateExecHelper.StateData
+import org.apache.spark.sql.execution.streaming.state.StateStore
+import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.ArrowUtils
+import org.apache.spark.util.CompletionIterator
+
+/**
+ * Physical operator for executing
+ * [[org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsInPandasWithState]]
+ *
+ * @param functionExpr function called on each group
+ * @param groupingAttributes used to group the data
+ * @param outAttributes used to define the output rows
+ * @param stateType used to serialize/deserialize state before calling `functionExpr`
+ * @param stateInfo `StatefulOperatorStateInfo` to identify the state store for a given operator.
+ * @param stateFormatVersion the version of state format.
+ * @param outputMode the output mode of `functionExpr`
+ * @param timeoutConf used to timeout groups that have not received data in a while
+ * @param batchTimestampMs processing timestamp of the current batch.
+ * @param eventTimeWatermark event time watermark for the current batch
+ * @param child logical plan of the underlying data
+ */
+case class FlatMapGroupsInPandasWithStateExec(

Review Comment:
   I wonder if this can be merged with the regular FlatMapGroupsWithStateExec. Maybe as a followup cleanup.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala:
##########
@@ -0,0 +1,214 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.python
+
+import org.apache.spark.TaskContext
+import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical.{EventTimeTimeout, ProcessingTimeTimeout}
+import org.apache.spark.sql.catalyst.plans.physical.Distribution
+import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode}
+import org.apache.spark.sql.execution.python.PandasGroupUtils.resolveArgOffsets
+import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP
+import org.apache.spark.sql.execution.streaming.state.FlatMapGroupsWithStateExecHelper.StateData
+import org.apache.spark.sql.execution.streaming.state.StateStore
+import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.ArrowUtils
+import org.apache.spark.util.CompletionIterator
+
+/**
+ * Physical operator for executing
+ * [[org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsInPandasWithState]]
+ *
+ * @param functionExpr function called on each group
+ * @param groupingAttributes used to group the data
+ * @param outAttributes used to define the output rows
+ * @param stateType used to serialize/deserialize state before calling `functionExpr`
+ * @param stateInfo `StatefulOperatorStateInfo` to identify the state store for a given operator.
+ * @param stateFormatVersion the version of state format.
+ * @param outputMode the output mode of `functionExpr`
+ * @param timeoutConf used to timeout groups that have not received data in a while
+ * @param batchTimestampMs processing timestamp of the current batch.
+ * @param eventTimeWatermark event time watermark for the current batch
+ * @param child logical plan of the underlying data
+ */
+case class FlatMapGroupsInPandasWithStateExec(
+    functionExpr: Expression,
+    groupingAttributes: Seq[Attribute],
+    outAttributes: Seq[Attribute],
+    stateType: StructType,
+    stateInfo: Option[StatefulOperatorStateInfo],
+    stateFormatVersion: Int,
+    outputMode: OutputMode,
+    timeoutConf: GroupStateTimeout,
+    batchTimestampMs: Option[Long],
+    eventTimeWatermark: Option[Long],
+    child: SparkPlan) extends UnaryExecNode with FlatMapGroupsWithStateExecBase {
+
+  // TODO(SPARK-40444): Add the support of initial state.
+  override protected val initialStateDeserializer: Expression = null
+  override protected val initialStateGroupAttrs: Seq[Attribute] = null
+  override protected val initialStateDataAttrs: Seq[Attribute] = null
+  override protected val initialState: SparkPlan = null
+  override protected val hasInitialState: Boolean = false
+
+  override protected val stateEncoder: ExpressionEncoder[Any] =
+    RowEncoder(stateType).resolveAndBind().asInstanceOf[ExpressionEncoder[Any]]
+
+  override def output: Seq[Attribute] = outAttributes
+
+  private val sessionLocalTimeZone = conf.sessionLocalTimeZone
+  private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
+
+  private val pythonFunction = functionExpr.asInstanceOf[PythonUDF].func
+  private val chainedFunc = Seq(ChainedPythonFunctions(Seq(pythonFunction)))
+  private lazy val (dedupAttributes, argOffsets) = resolveArgOffsets(
+    groupingAttributes ++ child.output, groupingAttributes)
+  private lazy val unsafeProj = UnsafeProjection.create(dedupAttributes, child.output)
+
+  override def requiredChildDistribution: Seq[Distribution] =
+    StatefulOperatorPartitioning.getCompatibleDistribution(
+      groupingAttributes, getStateInfo, conf) :: Nil
+
+  override def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq(
+    groupingAttributes.map(SortOrder(_, Ascending)))
+
+  override def shortName: String = "applyInPandasWithState"
+
+  override protected def withNewChildInternal(
+      newChild: SparkPlan): FlatMapGroupsInPandasWithStateExec = copy(child = newChild)
+
+  override def createInputProcessor(
+      store: StateStore): InputProcessor = new InputProcessor(store: StateStore) {
+
+    override def processNewData(dataIter: Iterator[InternalRow]): Iterator[InternalRow] = {
+      val groupedIter = GroupedIterator(dataIter, groupingAttributes, child.output)
+      val processIter = groupedIter.map { case (keyRow, valueRowIter) =>
+        val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow]
+        val stateData = stateManager.getState(store, keyUnsafeRow)
+        (keyUnsafeRow, stateData, valueRowIter.map(unsafeProj))
+      }
+
+      process(processIter, hasTimedOut = false)
+    }
+
+    override def processNewDataWithInitialState(
+        childDataIter: Iterator[InternalRow],
+        initStateIter: Iterator[InternalRow]): Iterator[InternalRow] = {
+      throw new UnsupportedOperationException("Should not reach here!")
+    }
+
+    override def processTimedOutState(): Iterator[InternalRow] = {
+      if (isTimeoutEnabled) {
+        val timeoutThreshold = timeoutConf match {
+          case ProcessingTimeTimeout => batchTimestampMs.get
+          case EventTimeTimeout => eventTimeWatermark.get
+          case _ =>
+            throw new IllegalStateException(
+              s"Cannot filter timed out keys for $timeoutConf")
+        }
+        val timingOutPairs = stateManager.getAllState(store).filter { state =>
+          state.timeoutTimestamp != NO_TIMESTAMP && state.timeoutTimestamp < timeoutThreshold
+        }
+
+        val processIter = timingOutPairs.map { stateData =>
+          val joinedKeyRow = unsafeProj(
+            new JoinedRow(
+              stateData.keyRow,
+              new GenericInternalRow(Array.fill(dedupAttributes.length)(null: Any))))
+
+          (stateData.keyRow, stateData, Iterator.single(joinedKeyRow))
+        }
+
+        process(processIter, hasTimedOut = true)
+      } else Iterator.empty
+    }
+
+    private def process(
+        iter: Iterator[(UnsafeRow, StateData, Iterator[InternalRow])],
+        hasTimedOut: Boolean): Iterator[InternalRow] = {
+      val runner = new ApplyInPandasWithStatePythonRunner(
+        chainedFunc,
+        PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE,
+        Array(argOffsets),
+        StructType.fromAttributes(dedupAttributes),
+        sessionLocalTimeZone,
+        pythonRunnerConf,
+        stateEncoder.asInstanceOf[ExpressionEncoder[Row]],
+        groupingAttributes.toStructType,
+        child.output.toStructType,
+        stateType)
+
+      val context = TaskContext.get()
+
+      val processIter = iter.map { case (keyRow, stateData, valueIter) =>
+        val groupedState = GroupStateImpl.createForStreaming(
+          Option(stateData.stateObj).map { r => assert(r.isInstanceOf[Row]); r },
+          batchTimestampMs.getOrElse(NO_TIMESTAMP),
+          eventTimeWatermark.getOrElse(NO_TIMESTAMP),
+          timeoutConf,
+          hasTimedOut = hasTimedOut,
+          watermarkPresent).asInstanceOf[GroupStateImpl[Row]]
+        (keyRow, groupedState, valueIter)
+      }
+      runner.compute(processIter, context.partitionId(), context).flatMap {
+        case (stateIter, outputIter) =>
+          // When the iterator is consumed, then write changes to state.
+          // state does not affect each others, hence when to update does not affect to the result.
+          def onIteratorCompletion: Unit = {
+            stateIter.foreach { case (keyRow, newGroupState, oldTimeoutTimestamp) =>
+              if (newGroupState.isRemoved && !newGroupState.getTimeoutTimestampMs.isPresent()) {
+                stateManager.removeState(store, keyRow)
+                numRemovedStateRows += 1
+              } else {
+                val currentTimeoutTimestamp = newGroupState.getTimeoutTimestampMs
+                  .orElse(NO_TIMESTAMP)
+                val hasTimeoutChanged = currentTimeoutTimestamp != oldTimeoutTimestamp
+                val shouldWriteState = newGroupState.isUpdated || newGroupState.isRemoved ||
+                  hasTimeoutChanged
+
+                if (shouldWriteState) {

Review Comment:
   what happens if 
   
   newGroupState.isRemoved && newGroupState.getTimeoutTimestampMs.isPresent()
   
   - basically if the state was removed but there is still timeout set? Will you keep the user state object around till the timeout fires?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala:
##########
@@ -0,0 +1,240 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import scala.collection.JavaConverters._
+
+import org.apache.arrow.vector.{FieldVector, VectorSchemaRoot}
+import org.apache.arrow.vector.ipc.ArrowStreamWriter
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.api.python.PythonSQLUtils
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow}
+import org.apache.spark.sql.execution.arrow.ArrowWriter
+import org.apache.spark.sql.execution.arrow.ArrowWriter.createFieldWriter
+import org.apache.spark.sql.execution.streaming.GroupStateImpl
+import org.apache.spark.sql.types.{BinaryType, BooleanType, IntegerType, StringType, StructField, StructType}
+import org.apache.spark.unsafe.types.UTF8String
+
+/**
+ * This class abstracts the complexity on constructing Arrow RecordBatches for data and state with
+ * bin-packing and chunking. The caller only need to call the proper public methods of this class
+ * `startNewGroup`, `writeRow`, `finalizeGroup`, `finalizeData` and this class will write the data
+ * and state into Arrow RecordBatches with performing bin-pack and chunk internally.
+ *
+ * This class requires that the parameter `root` has been initialized with the Arrow schema like
+ * below:
+ * - data fields
+ * - state field
+ *   - nested schema (Refer ApplyInPandasWithStateWriter.STATE_METADATA_SCHEMA)
+ *
+ * Please refer the code comment in the implementation to see how the writes of data and state
+ * against Arrow RecordBatch work with consideration of bin-packing and chunking.
+ */
+class ApplyInPandasWithStateWriter(
+    root: VectorSchemaRoot,
+    writer: ArrowStreamWriter,
+    arrowMaxRecordsPerBatch: Int) {
+
+  import ApplyInPandasWithStateWriter._
+
+  // Unlike applyInPandas (and other PySpark operators), applyInPandasWithState requires to produce
+  // the additional data `state`, along with the input data.
+  //
+  // ArrowStreamWriter supports only single VectorSchemaRoot, which means all Arrow RecordBatches
+  // being sent out from ArrowStreamWriter should have same schema. That said, we have to construct
+  // "an" Arrow schema to contain both types of data, and also construct Arrow RecordBatches to

Review Comment:
   to contain both data and state, and also construct ArrowBatches to contain both data and state.



##########
python/pyspark/worker.py:
##########
@@ -207,6 +209,65 @@ def wrapped(key_series, value_series):
     return lambda k, v: [(wrapped(k, v), to_arrow_type(return_type))]
 
 
+def wrap_grouped_map_pandas_udf_with_state(f, return_type):
+    def wrapped(key_series, value_series_gen, state):
+        import pandas as pd
+
+        key = tuple(s[0] for s in key_series)
+
+        if state.hasTimedOut:
+            # Timeout processing pass empty iterator. Here we return an empty DataFrame instead.
+            values = [
+                pd.DataFrame(columns=pd.concat(next(value_series_gen), axis=1).columns),
+            ]
+        else:
+            values = (pd.concat(x, axis=1) for x in value_series_gen)
+
+        result_iter = f(key, values, state)
+
+        def verify_element(result):
+            if not isinstance(result, pd.DataFrame):
+                raise TypeError(
+                    "The type of element in return iterator of the user-defined function "
+                    "should be pandas.DataFrame, but is {}".format(type(result))
+                )
+            # the number of columns of result have to match the return type
+            # but it is fine for result to have no columns at all if it is empty
+            if not (

Review Comment:
   ah, nevermind, I just misread the code.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala:
##########
@@ -0,0 +1,240 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import scala.collection.JavaConverters._
+
+import org.apache.arrow.vector.{FieldVector, VectorSchemaRoot}
+import org.apache.arrow.vector.ipc.ArrowStreamWriter
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.api.python.PythonSQLUtils
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow}
+import org.apache.spark.sql.execution.arrow.ArrowWriter
+import org.apache.spark.sql.execution.arrow.ArrowWriter.createFieldWriter
+import org.apache.spark.sql.execution.streaming.GroupStateImpl
+import org.apache.spark.sql.types.{BinaryType, BooleanType, IntegerType, StringType, StructField, StructType}
+import org.apache.spark.unsafe.types.UTF8String
+
+/**
+ * This class abstracts the complexity on constructing Arrow RecordBatches for data and state with
+ * bin-packing and chunking. The caller only need to call the proper public methods of this class
+ * `startNewGroup`, `writeRow`, `finalizeGroup`, `finalizeData` and this class will write the data
+ * and state into Arrow RecordBatches with performing bin-pack and chunk internally.
+ *
+ * This class requires that the parameter `root` has been initialized with the Arrow schema like
+ * below:
+ * - data fields
+ * - state field
+ *   - nested schema (Refer ApplyInPandasWithStateWriter.STATE_METADATA_SCHEMA)
+ *
+ * Please refer the code comment in the implementation to see how the writes of data and state
+ * against Arrow RecordBatch work with consideration of bin-packing and chunking.
+ */
+class ApplyInPandasWithStateWriter(
+    root: VectorSchemaRoot,
+    writer: ArrowStreamWriter,
+    arrowMaxRecordsPerBatch: Int) {
+
+  import ApplyInPandasWithStateWriter._
+
+  // Unlike applyInPandas (and other PySpark operators), applyInPandasWithState requires to produce
+  // the additional data `state`, along with the input data.
+  //
+  // ArrowStreamWriter supports only single VectorSchemaRoot, which means all Arrow RecordBatches
+  // being sent out from ArrowStreamWriter should have same schema. That said, we have to construct
+  // "an" Arrow schema to contain both types of data, and also construct Arrow RecordBatches to
+  // contain both data.
+  //
+  // To achieve this, we extend the schema for input data to have a column for state at the end.
+  // But also, we logically group the columns by family (data vs state) and initialize writer
+  // separately, since it's lot more easier and probably performant to write the row directly
+  // rather than projecting the row to match up with the overall schema.
+  //
+  // Although Arrow RecordBatch enables to write the data as columnar, we figure out it gives
+  // strange outputs if we don't ensure that all columns have the same number of values. Since
+  // there are at least one data for a grouping key (we ensure this for the case of handling timed
+  // out state as well) whereas there is only one state for a grouping key, we have to fill up the
+  // empty rows in state side to ensure both have the same number of rows.
+  private val arrowWriterForData = createArrowWriter(
+    root.getFieldVectors.asScala.toSeq.dropRight(1))
+  private val arrowWriterForState = createArrowWriter(
+    root.getFieldVectors.asScala.toSeq.takeRight(1))
+
+  // - Bin-packing
+  //
+  // We apply bin-packing the data from multiple groups into one Arrow RecordBatch to
+  // gain the performance. In many cases, the amount of data per grouping key is quite
+  // small, which does not seem to maximize the benefits of using Arrow.
+  //
+  // We have to split the record batch down to each group in Python worker to convert the
+  // data for group to Pandas, but hopefully, Arrow RecordBatch provides the way to split
+  // the range of data and give a view, say, "zero-copy". To help splitting the range for
+  // data, we provide the "start offset" and the "number of data" in the state metadata.
+  //
+  // We don't bin-pack all groups into a single record batch - we have a limit on the number
+  // of rows in the current Arrow RecordBatch to stop adding next group.
+  //
+  // - Chunking
+  //
+  // We also chunk the data from single group into multiple Arrow RecordBatch to ensure
+  // scalability. Note that we don't know the volume (number of rows, overall size) of data for
+  // specific group key before we read the entire data. The easiest approach to address both
+  // bin-pack and chunk is to check the number of rows in the current Arrow RecordBatch for each
+  // write of row.
+  //
+  // - Consideration
+  //
+  // Since the number of rows in Arrow RecordBatch does not represent the actual size (bytes),
+  // the limit should be set very conservatively. Using a small number of limit does not introduce
+  // correctness issues.
+
+  private var numRowsForCurGroup = 0
+  private var startOffsetForCurGroup = 0
+  private var totalNumRowsForBatch = 0
+  private var totalNumStatesForBatch = 0
+
+  private var currentGroupKeyRow: UnsafeRow = _
+  private var currentGroupState: GroupStateImpl[Row] = _
+
+  /**
+   * Indicates writer to start with new grouping key.
+   *
+   * @param keyRow The grouping key row for current group.
+   * @param groupState The instance of GroupStateImpl for current group.
+   */
+  def startNewGroup(keyRow: UnsafeRow, groupState: GroupStateImpl[Row]): Unit = {
+    currentGroupKeyRow = keyRow
+    currentGroupState = groupState
+  }
+
+  /**
+   * Indicates writer to write a row in the current group.
+   *
+   * @param dataRow The row to write in the current group.
+   */
+  def writeRow(dataRow: InternalRow): Unit = {
+    // If it exceeds the condition of batch (number of records) and there is more data for the
+    // same group, finalize and construct a new batch.
+
+    if (totalNumRowsForBatch >= arrowMaxRecordsPerBatch) {
+      // Provide state metadata row as intermediate
+      val stateInfoRow = buildStateInfoRow(currentGroupKeyRow, currentGroupState,
+        startOffsetForCurGroup, numRowsForCurGroup, isLastChunk = false)
+      arrowWriterForState.write(stateInfoRow)
+      totalNumStatesForBatch += 1
+
+      finalizeCurrentArrowBatch()
+    }
+
+    arrowWriterForData.write(dataRow)
+    numRowsForCurGroup += 1
+    totalNumRowsForBatch += 1
+  }
+
+  /**
+   * Indicates writer that current group has finalized and there will be no further row bound to
+   * the current group.
+   */
+  def finalizeGroup(): Unit = {
+    // Provide state metadata row
+    val stateInfoRow = buildStateInfoRow(currentGroupKeyRow, currentGroupState,
+      startOffsetForCurGroup, numRowsForCurGroup, isLastChunk = true)
+    arrowWriterForState.write(stateInfoRow)
+    totalNumStatesForBatch += 1
+
+    // The start offset for next group would be same as the total number of rows for batch,
+    // unless the next group starts with new batch.
+    startOffsetForCurGroup = totalNumRowsForBatch
+  }
+
+  /**
+   * Indicates writer that all groups have been processed.
+   */
+  def finalizeData(): Unit = {
+    if (numRowsForCurGroup > 0) {
+      // We still have some rows in the current record batch. Need to finalize them as well.
+      finalizeCurrentArrowBatch()
+    }
+  }
+
+  private def createArrowWriter(fieldVectors: Seq[FieldVector]): ArrowWriter = {
+    val children = fieldVectors.map { vector =>
+      vector.allocateNew()
+      createFieldWriter(vector)
+    }
+
+    new ArrowWriter(root, children.toArray)
+  }
+
+  private def buildStateInfoRow(
+      keyRow: UnsafeRow,
+      groupState: GroupStateImpl[Row],
+      startOffset: Int,
+      numRows: Int,
+      isLastChunk: Boolean): InternalRow = {
+    // NOTE: see ApplyInPandasWithStateWriter.STATE_METADATA_SCHEMA
+    val stateUnderlyingRow = new GenericInternalRow(
+      Array[Any](
+        UTF8String.fromString(groupState.json()),
+        keyRow.getBytes,
+        groupState.getOption.map(PythonSQLUtils.toPyRow).orNull,
+        startOffset,
+        numRows,
+        isLastChunk
+      )
+    )
+    new GenericInternalRow(Array[Any](stateUnderlyingRow))
+  }
+
+  private def finalizeCurrentArrowBatch(): Unit = {
+    val remainingEmptyStateRows = totalNumRowsForBatch - totalNumStatesForBatch
+    (0 until remainingEmptyStateRows).foreach { _ =>
+      arrowWriterForState.write(EMPTY_STATE_METADATA_ROW)
+    }
+
+    arrowWriterForState.finish()
+    arrowWriterForData.finish()
+    writer.writeBatch()
+    arrowWriterForState.reset()
+    arrowWriterForData.reset()
+
+    startOffsetForCurGroup = 0
+    numRowsForCurGroup = 0
+    totalNumRowsForBatch = 0
+    totalNumStatesForBatch = 0
+  }
+}
+
+object ApplyInPandasWithStateWriter {
+  val STATE_METADATA_SCHEMA: StructType = StructType(

Review Comment:
   please comment on the semantics of each column. Specifically isLastChunk is not obvious but important for the operation of the protocol.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #37893: [DRAFT][DO-NOT-MERGE][SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #37893:
URL: https://github.com/apache/spark/pull/37893#discussion_r971613565


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala:
##########
@@ -0,0 +1,217 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.python
+
+import org.apache.spark.TaskContext
+import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical.{EventTimeTimeout, ProcessingTimeTimeout}
+import org.apache.spark.sql.catalyst.plans.physical.Distribution
+import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode}
+import org.apache.spark.sql.execution.python.PandasGroupUtils.resolveArgOffsets
+import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP
+import org.apache.spark.sql.execution.streaming.state.FlatMapGroupsWithStateExecHelper.StateData
+import org.apache.spark.sql.execution.streaming.state.StateStore
+import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.ArrowUtils
+import org.apache.spark.util.CompletionIterator
+
+/**
+ * Physical operator for executing
+ * [[org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsInPandasWithState]]
+ *
+ * @param functionExpr function called on each group
+ * @param groupingAttributes used to group the data
+ * @param outAttributes used to define the output rows
+ * @param stateType used to serialize/deserialize state before calling `functionExpr`
+ * @param stateInfo `StatefulOperatorStateInfo` to identify the state store for a given operator.
+ * @param stateFormatVersion the version of state format.
+ * @param outputMode the output mode of `functionExpr`
+ * @param timeoutConf used to timeout groups that have not received data in a while
+ * @param batchTimestampMs processing timestamp of the current batch.
+ * @param eventTimeWatermark event time watermark for the current batch
+ * @param child logical plan of the underlying data
+ */
+case class FlatMapGroupsInPandasWithStateExec(
+    functionExpr: Expression,
+    groupingAttributes: Seq[Attribute],
+    outAttributes: Seq[Attribute],
+    stateType: StructType,
+    stateInfo: Option[StatefulOperatorStateInfo],
+    stateFormatVersion: Int,
+    outputMode: OutputMode,
+    timeoutConf: GroupStateTimeout,
+    batchTimestampMs: Option[Long],
+    eventTimeWatermark: Option[Long],
+    child: SparkPlan) extends UnaryExecNode with FlatMapGroupsWithStateExecBase {
+
+  // TODO(SPARK-XXXXX): Add the support of initial state.

Review Comment:
   self-comment: It does not seem to be one of first priority functionalities - I think we can deal with this later. I'll file a JIRA ticket and change the ticket number. 



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala:
##########
@@ -793,6 +812,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
           initialStateGroupAttrs, data, initialStateDataAttrs, output, timeout,
           hasInitialState, planLater(initialState), planLater(child)
         ) :: Nil
+      case _: FlatMapGroupsInPandasWithState =>
+        throw new UnsupportedOperationException(

Review Comment:
   In Scala/Java, flatMapGroupsWithState API uses different physical plan for batch query, which applies some wrap-up with user function to effectively match up with user function in MapGroupsExec.
   
   This seems to be non-trivial to deal with in PySpark, or at least seems to require major work. As there is another API for batch-friendly query (applyInPandas), I think we can leave this to TODO and guide users to use applyInPandas instead for now.
   
   self-comment: file a JIRA ticket and replace the number.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #37893: [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on code in PR #37893:
URL: https://github.com/apache/spark/pull/37893#discussion_r977150660


##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -216,6 +218,125 @@ def applyInPandas(
         jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
         return DataFrame(jdf, self.session)
 
+    def applyInPandasWithState(
+        self,
+        func: "PandasGroupedMapFunctionWithState",
+        outputStructType: Union[StructType, str],
+        stateStructType: Union[StructType, str],
+        outputMode: str,
+        timeoutConf: str,
+    ) -> DataFrame:
+        """
+        Applies the given function to each group of data, while maintaining a user-defined
+        per-group state. The result Dataset will represent the flattened record returned by the
+        function.
+
+        For a streaming Dataset, the function will be invoked first for all input groups and then
+        for all timed out states where the input data is set to be empty. Updates to each group's
+        state will be saved across invocations.

Review Comment:
   ```suggestion
           For a streaming :class:`DataFrame`, the function will be invoked first for all input groups
           and then for all timed out states where the input data is set to be empty. Updates to
           each group's state will be saved across invocations.
   ```



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala:
##########
@@ -98,6 +98,16 @@ class ArrowWriter(val root: VectorSchemaRoot, fields: Array[ArrowFieldWriter]) {
     count += 1
   }
 
+  def sizeInBytes(): Int = {

Review Comment:
   I think we don't need `sizeInBytes` and  `getSizeInBytes ` anymore



##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -216,6 +218,125 @@ def applyInPandas(
         jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
         return DataFrame(jdf, self.session)
 
+    def applyInPandasWithState(
+        self,
+        func: "PandasGroupedMapFunctionWithState",
+        outputStructType: Union[StructType, str],
+        stateStructType: Union[StructType, str],
+        outputMode: str,
+        timeoutConf: str,
+    ) -> DataFrame:
+        """
+        Applies the given function to each group of data, while maintaining a user-defined
+        per-group state. The result Dataset will represent the flattened record returned by the
+        function.
+
+        For a streaming Dataset, the function will be invoked first for all input groups and then
+        for all timed out states where the input data is set to be empty. Updates to each group's
+        state will be saved across invocations.
+
+        The function should take parameters (key, Iterator[`pandas.DataFrame`], state) and
+        return another Iterator[`pandas.DataFrame`]. The grouping key(s) will be passed as a tuple
+        of numpy data types, e.g., `numpy.int32` and `numpy.float64`. The state will be passed as
+        :class:`pyspark.sql.streaming.state.GroupState`.
+
+        For each group, all columns are passed together as `pandas.DataFrame` to the user-function,
+        and the returned `pandas.DataFrame` across all invocations are combined as a
+        :class:`DataFrame`. Note that the user function should not make a guess of the number of
+        elements in the iterator. To process all data, the user function needs to iterate all
+        elements and process them. On the other hand, the user function is not strictly required to
+        iterate through all elements in the iterator if it intends to read a part of data.
+
+        The `outputStructType` should be a :class:`StructType` describing the schema of all
+        elements in the returned value, `pandas.DataFrame`. The column labels of all elements in
+        returned `pandas.DataFrame` must either match the field names in the defined schema if
+        specified as strings, or match the field data types by position if not strings,
+        e.g. integer indices.
+
+        The `stateStructType` should be :class:`StructType` describing the schema of the
+        user-defined state. The value of the state will be presented as a tuple, as well as the
+        update should be performed with the tuple. The corresponding Python types for
+        :class:DataType are supported. Please refer to the page
+        https://spark.apache.org/docs/latest/sql-ref-datatypes.html (python tab).
+
+        The size of each DataFrame in both the input and output can be arbitrary. The number of
+        DataFrames in both the input and output can also be arbitrary.
+
+        .. versionadded:: 3.4.0
+
+        Parameters
+        ----------
+        func : function
+            a Python native function to be called on every group. It should take parameters
+            (key, Iterator[`pandas.DataFrame`], state) and return Iterator[`pandas.DataFrame`].
+            Note that the type of the key is tuple and the type of the state is
+            :class:`pyspark.sql.streaming.state.GroupState`.
+        outputStructType : :class:`pyspark.sql.types.DataType` or str
+            the type of the output records. The value can be either a
+            :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
+        stateStructType : :class:`pyspark.sql.types.DataType` or str
+            the type of the user-defined state. The value can be either a
+            :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
+        outputMode : str
+            the output mode of the function.
+        timeoutConf : str
+            timeout configuration for groups that do not receive data for a while. valid values
+            are defined in :class:`pyspark.sql.streaming.state.GroupStateTimeout`.
+
+        Examples
+        --------
+        >>> import pandas as pd  # doctest: +SKIP
+        >>> from pyspark.sql.streaming.state import GroupStateTimeout
+        >>> def count_fn(key, pdf_iter, state):
+        ...     assert isinstance(state, GroupStateImpl)
+        ...     total_len = 0
+        ...     for pdf in pdf_iter:
+        ...         total_len += len(pdf)
+        ...     state.update((total_len,))
+        ...     yield pd.DataFrame({"id": [key[0]], "countAsString": [str(total_len)]})

Review Comment:
   ```suggestion
           ...     yield pd.DataFrame({"id": [key[0]], "countAsString": [str(total_len)]})
           ...
   ```



##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -216,6 +218,125 @@ def applyInPandas(
         jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
         return DataFrame(jdf, self.session)
 
+    def applyInPandasWithState(
+        self,
+        func: "PandasGroupedMapFunctionWithState",
+        outputStructType: Union[StructType, str],
+        stateStructType: Union[StructType, str],
+        outputMode: str,
+        timeoutConf: str,
+    ) -> DataFrame:
+        """
+        Applies the given function to each group of data, while maintaining a user-defined
+        per-group state. The result Dataset will represent the flattened record returned by the
+        function.
+
+        For a streaming Dataset, the function will be invoked first for all input groups and then
+        for all timed out states where the input data is set to be empty. Updates to each group's
+        state will be saved across invocations.
+
+        The function should take parameters (key, Iterator[`pandas.DataFrame`], state) and
+        return another Iterator[`pandas.DataFrame`]. The grouping key(s) will be passed as a tuple
+        of numpy data types, e.g., `numpy.int32` and `numpy.float64`. The state will be passed as
+        :class:`pyspark.sql.streaming.state.GroupState`.
+
+        For each group, all columns are passed together as `pandas.DataFrame` to the user-function,
+        and the returned `pandas.DataFrame` across all invocations are combined as a
+        :class:`DataFrame`. Note that the user function should not make a guess of the number of
+        elements in the iterator. To process all data, the user function needs to iterate all
+        elements and process them. On the other hand, the user function is not strictly required to
+        iterate through all elements in the iterator if it intends to read a part of data.
+
+        The `outputStructType` should be a :class:`StructType` describing the schema of all
+        elements in the returned value, `pandas.DataFrame`. The column labels of all elements in
+        returned `pandas.DataFrame` must either match the field names in the defined schema if
+        specified as strings, or match the field data types by position if not strings,
+        e.g. integer indices.
+
+        The `stateStructType` should be :class:`StructType` describing the schema of the
+        user-defined state. The value of the state will be presented as a tuple, as well as the
+        update should be performed with the tuple. The corresponding Python types for
+        :class:DataType are supported. Please refer to the page
+        https://spark.apache.org/docs/latest/sql-ref-datatypes.html (python tab).
+
+        The size of each DataFrame in both the input and output can be arbitrary. The number of
+        DataFrames in both the input and output can also be arbitrary.

Review Comment:
   ```suggestion
           The size of each `pandas.DataFrame` in both the input and output can be arbitrary.
           The number of DataFrames in both the input and output can also be arbitrary.
   ```



##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -216,6 +218,125 @@ def applyInPandas(
         jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
         return DataFrame(jdf, self.session)
 
+    def applyInPandasWithState(
+        self,
+        func: "PandasGroupedMapFunctionWithState",
+        outputStructType: Union[StructType, str],
+        stateStructType: Union[StructType, str],
+        outputMode: str,
+        timeoutConf: str,
+    ) -> DataFrame:
+        """
+        Applies the given function to each group of data, while maintaining a user-defined
+        per-group state. The result Dataset will represent the flattened record returned by the
+        function.
+
+        For a streaming Dataset, the function will be invoked first for all input groups and then
+        for all timed out states where the input data is set to be empty. Updates to each group's
+        state will be saved across invocations.
+
+        The function should take parameters (key, Iterator[`pandas.DataFrame`], state) and
+        return another Iterator[`pandas.DataFrame`]. The grouping key(s) will be passed as a tuple
+        of numpy data types, e.g., `numpy.int32` and `numpy.float64`. The state will be passed as
+        :class:`pyspark.sql.streaming.state.GroupState`.
+
+        For each group, all columns are passed together as `pandas.DataFrame` to the user-function,
+        and the returned `pandas.DataFrame` across all invocations are combined as a
+        :class:`DataFrame`. Note that the user function should not make a guess of the number of
+        elements in the iterator. To process all data, the user function needs to iterate all
+        elements and process them. On the other hand, the user function is not strictly required to
+        iterate through all elements in the iterator if it intends to read a part of data.
+
+        The `outputStructType` should be a :class:`StructType` describing the schema of all
+        elements in the returned value, `pandas.DataFrame`. The column labels of all elements in
+        returned `pandas.DataFrame` must either match the field names in the defined schema if
+        specified as strings, or match the field data types by position if not strings,
+        e.g. integer indices.
+
+        The `stateStructType` should be :class:`StructType` describing the schema of the
+        user-defined state. The value of the state will be presented as a tuple, as well as the
+        update should be performed with the tuple. The corresponding Python types for
+        :class:DataType are supported. Please refer to the page
+        https://spark.apache.org/docs/latest/sql-ref-datatypes.html (python tab).
+
+        The size of each DataFrame in both the input and output can be arbitrary. The number of
+        DataFrames in both the input and output can also be arbitrary.

Review Comment:
   I think we can extract some notes from the description to `Notes` section. But no biggie.



##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -216,6 +218,125 @@ def applyInPandas(
         jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
         return DataFrame(jdf, self.session)
 
+    def applyInPandasWithState(
+        self,
+        func: "PandasGroupedMapFunctionWithState",
+        outputStructType: Union[StructType, str],
+        stateStructType: Union[StructType, str],
+        outputMode: str,
+        timeoutConf: str,
+    ) -> DataFrame:
+        """
+        Applies the given function to each group of data, while maintaining a user-defined
+        per-group state. The result Dataset will represent the flattened record returned by the
+        function.
+
+        For a streaming Dataset, the function will be invoked first for all input groups and then
+        for all timed out states where the input data is set to be empty. Updates to each group's
+        state will be saved across invocations.
+
+        The function should take parameters (key, Iterator[`pandas.DataFrame`], state) and
+        return another Iterator[`pandas.DataFrame`]. The grouping key(s) will be passed as a tuple
+        of numpy data types, e.g., `numpy.int32` and `numpy.float64`. The state will be passed as
+        :class:`pyspark.sql.streaming.state.GroupState`.
+
+        For each group, all columns are passed together as `pandas.DataFrame` to the user-function,
+        and the returned `pandas.DataFrame` across all invocations are combined as a
+        :class:`DataFrame`. Note that the user function should not make a guess of the number of
+        elements in the iterator. To process all data, the user function needs to iterate all
+        elements and process them. On the other hand, the user function is not strictly required to
+        iterate through all elements in the iterator if it intends to read a part of data.
+
+        The `outputStructType` should be a :class:`StructType` describing the schema of all
+        elements in the returned value, `pandas.DataFrame`. The column labels of all elements in
+        returned `pandas.DataFrame` must either match the field names in the defined schema if
+        specified as strings, or match the field data types by position if not strings,
+        e.g. integer indices.
+
+        The `stateStructType` should be :class:`StructType` describing the schema of the
+        user-defined state. The value of the state will be presented as a tuple, as well as the
+        update should be performed with the tuple. The corresponding Python types for
+        :class:DataType are supported. Please refer to the page
+        https://spark.apache.org/docs/latest/sql-ref-datatypes.html (python tab).

Review Comment:
   ```suggestion
           https://spark.apache.org/docs/latest/sql-ref-datatypes.html (Python tab).
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #37893: [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #37893:
URL: https://github.com/apache/spark/pull/37893#discussion_r973797335


##########
python/pyspark/sql/pandas/_typing/__init__.pyi:
##########
@@ -256,6 +258,10 @@ PandasGroupedMapFunction = Union[
     Callable[[Any, DataFrameLike], DataFrameLike],
 ]
 
+PandasGroupedMapFunctionWithState = Callable[
+    [Any, Iterable[DataFrameLike], GroupStateImpl], Iterable[DataFrameLike]

Review Comment:
   Either we can split out interface and implementation, or just change the name. I'm fine with any direction.
   cc. @HyukjinKwon What'd be the best practice of such case?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #37893: [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on code in PR #37893:
URL: https://github.com/apache/spark/pull/37893#discussion_r974870249


##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -216,6 +218,105 @@ def applyInPandas(
         jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
         return DataFrame(jdf, self.session)
 
+    def applyInPandasWithState(
+        self,
+        func: "PandasGroupedMapFunctionWithState",
+        outputStructType: Union[StructType, str],
+        stateStructType: Union[StructType, str],
+        outputMode: str,
+        timeoutConf: str,
+    ) -> DataFrame:
+        """
+        Applies the given function to each group of data, while maintaining a user-defined
+        per-group state. The result Dataset will represent the flattened record returned by the
+        function.
+
+        For a streaming Dataset, the function will be invoked for each group repeatedly in every
+        trigger, and updates to each group's state will be saved across invocations. The function
+        will also be invoked for each timed-out state repeatedly. The sequence of the invocation
+        will be input data -> state timeout. When the function is invoked for state timeout, there
+        will be no data being presented.
+
+        The function should takes parameters (key, Iterator[`pandas.DataFrame`], state) and
+        returns another Iterator[`pandas.DataFrame`]. The grouping key(s) will be passed as a tuple
+        of numpy data types, e.g., `numpy.int32` and `numpy.float64`. The state will be passed as
+        :class:`pyspark.sql.streaming.state.GroupStateImpl`.
+
+        For each group, all columns are passed together as `pandas.DataFrame` to the user-function,
+        and the returned `pandas.DataFrame` across all invocations are combined as a
+        :class:`DataFrame`. Note that the user function should loop through and process all
+        elements in the iterator. The user function should not make a guess of the number of
+        elements in the iterator.
+
+        The `outputStructType` should be a :class:`StructType` describing the schema of all
+        elements in returned value, `pandas.DataFrame`. The column labels of all elements in
+        returned value, `pandas.DataFrame` must either match the field names in the defined
+        schema if specified as strings, or match the field data types by position if not strings,
+        e.g. integer indices.
+
+        The `stateStructType` should be :class:`StructType` describing the schema of user-defined
+        state. The value of state will be presented as a tuple, as well as the update should be
+        performed with the tuple. User defined types e.g. native Python class types are not

Review Comment:
   I think we can just say that "the corresponding Python types for :class:`DataType` are supported".  Documented here https://spark.apache.org/docs/latest/sql-ref-datatypes.html (click python tab)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] alex-balikov commented on a diff in pull request #37893: [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

Posted by GitBox <gi...@apache.org>.
alex-balikov commented on code in PR #37893:
URL: https://github.com/apache/spark/pull/37893#discussion_r974707164


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala:
##########
@@ -0,0 +1,201 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import java.io._
+
+import scala.collection.JavaConverters._
+
+import org.apache.arrow.vector.VectorSchemaRoot
+import org.apache.arrow.vector.ipc.ArrowStreamWriter
+import org.json4s._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.api.python._
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.api.python.PythonSQLUtils
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.execution.python.ApplyInPandasWithStatePythonRunner.{InType, OutType, OutTypeForState, STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER}
+import org.apache.spark.sql.execution.python.ApplyInPandasWithStateWriter.STATE_METADATA_SCHEMA
+import org.apache.spark.sql.execution.streaming.GroupStateImpl
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch}
+
+
+/**
+ * A variant implementation of [[ArrowPythonRunner]] to serve the operation
+ * applyInPandasWithState.
+ */
+class ApplyInPandasWithStatePythonRunner(
+    funcs: Seq[ChainedPythonFunctions],
+    evalType: Int,
+    argOffsets: Array[Array[Int]],
+    inputSchema: StructType,
+    override protected val timeZoneId: String,
+    initialWorkerConf: Map[String, String],
+    stateEncoder: ExpressionEncoder[Row],
+    keySchema: StructType,
+    valueSchema: StructType,
+    stateValueSchema: StructType,
+    softLimitBytesPerBatch: Long,
+    minDataCountForSample: Int,
+    softTimeoutMillsPurgeBatch: Long)
+  extends BasePythonRunner[InType, OutType](funcs, evalType, argOffsets)
+  with PythonArrowInput[InType]
+  with PythonArrowOutput[OutType] {
+
+  override protected val schema: StructType = inputSchema.add("__state", STATE_METADATA_SCHEMA)
+
+  override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback
+
+  override val bufferSize: Int = SQLConf.get.pandasUDFBufferSize
+  require(
+    bufferSize >= 4,
+    "Pandas execution requires more than 4 bytes. Please set higher buffer. " +
+      s"Please change '${SQLConf.PANDAS_UDF_BUFFER_SIZE.key}'.")
+
+  // applyInPandasWithState has its own mechanism to construct the Arrow RecordBatch instance.
+  // Configurations are both applied to executor and Python worker, set them to the worker conf
+  // to let Python worker read the config properly.
+  override protected val workerConf: Map[String, String] = initialWorkerConf +
+    (SQLConf.MAP_PANDAS_UDF_WITH_STATE_SOFT_LIMIT_SIZE_PER_BATCH.key ->
+      softLimitBytesPerBatch.toString) +
+    (SQLConf.MAP_PANDAS_UDF_WITH_STATE_MIN_DATA_COUNT_FOR_SAMPLE.key ->
+      minDataCountForSample.toString) +
+    (SQLConf.MAP_PANDAS_UDF_WITH_STATE_SOFT_TIMEOUT_PURGE_BATCH.key ->
+      softTimeoutMillsPurgeBatch.toString)
+
+  private val stateRowDeserializer = stateEncoder.createDeserializer()
+
+  override protected def handleMetadataBeforeExec(stream: DataOutputStream): Unit = {
+    super.handleMetadataBeforeExec(stream)
+    // Also write the schema for state value
+    PythonRDD.writeUTF(stateValueSchema.json, stream)
+  }
+
+  protected def writeIteratorToArrowStream(
+      root: VectorSchemaRoot,
+      writer: ArrowStreamWriter,
+      dataOut: DataOutputStream,
+      inputIterator: Iterator[InType]): Unit = {
+    val w = new ApplyInPandasWithStateWriter(root, writer, softLimitBytesPerBatch,
+      minDataCountForSample, softTimeoutMillsPurgeBatch)
+
+    while (inputIterator.hasNext) {
+      val (keyRow, groupState, dataIter) = inputIterator.next()
+      assert(dataIter.hasNext, "should have at least one data row!")
+      w.startNewGroup(keyRow, groupState)
+
+      while (dataIter.hasNext) {
+        val dataRow = dataIter.next()
+        w.writeRow(dataRow)
+      }
+
+      w.finalizeGroup()
+    }
+
+    w.finalizeData()
+  }
+
+  protected def deserializeColumnarBatch(batch: ColumnarBatch, schema: StructType): OutType = {

Review Comment:
   method level comments



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala:
##########
@@ -0,0 +1,246 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import scala.collection.JavaConverters._
+
+import org.apache.arrow.vector.{FieldVector, VectorSchemaRoot}
+import org.apache.arrow.vector.ipc.ArrowStreamWriter
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.api.python.PythonSQLUtils
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow}
+import org.apache.spark.sql.execution.arrow.ArrowWriter
+import org.apache.spark.sql.execution.arrow.ArrowWriter.createFieldWriter
+import org.apache.spark.sql.execution.streaming.GroupStateImpl
+import org.apache.spark.sql.types.{BinaryType, BooleanType, IntegerType, StringType, StructField, StructType}
+import org.apache.spark.unsafe.types.UTF8String
+
+/**
+ * This class abstracts the complexity on constructing Arrow RecordBatches for data and state with
+ * bin-packing and chunking. The caller only need to call the proper public methods of this class
+ * `startNewGroup`, `writeRow`, `finalizeGroup`, `finalizeData` and this class will write the data
+ * and state into Arrow RecordBatches with performing bin-pack and chunk internally.
+ *
+ * This class requires that the parameter `root` has initialized with the Arrow schema like below:

Review Comment:
   the parameter 'root' has *been* initialized



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala:
##########
@@ -0,0 +1,246 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import scala.collection.JavaConverters._
+
+import org.apache.arrow.vector.{FieldVector, VectorSchemaRoot}
+import org.apache.arrow.vector.ipc.ArrowStreamWriter
+
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.api.python.PythonSQLUtils
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow}
+import org.apache.spark.sql.execution.arrow.ArrowWriter
+import org.apache.spark.sql.execution.arrow.ArrowWriter.createFieldWriter
+import org.apache.spark.sql.execution.streaming.GroupStateImpl
+import org.apache.spark.sql.types.{BinaryType, BooleanType, IntegerType, StringType, StructField, StructType}
+import org.apache.spark.unsafe.types.UTF8String
+
+/**
+ * This class abstracts the complexity on constructing Arrow RecordBatches for data and state with
+ * bin-packing and chunking. The caller only need to call the proper public methods of this class
+ * `startNewGroup`, `writeRow`, `finalizeGroup`, `finalizeData` and this class will write the data
+ * and state into Arrow RecordBatches with performing bin-pack and chunk internally.
+ *
+ * This class requires that the parameter `root` has initialized with the Arrow schema like below:
+ * - data fields
+ * - state field
+ *   - nested schema (Refer ApplyInPandasWithStateWriter.STATE_METADATA_SCHEMA)
+ *
+ * Please refer the code comment in the implementation to see how the writes of data and state
+ * against Arrow RecordBatch work with consideration of bin-packing and chunking.
+ */
+class ApplyInPandasWithStateWriter(
+    root: VectorSchemaRoot,
+    writer: ArrowStreamWriter,
+    softLimitBytesPerBatch: Long,
+    minDataCountForSample: Int,
+    softTimeoutMillsPurgeBatch: Long) {
+
+  import ApplyInPandasWithStateWriter._
+
+  // We logically group the columns by family (data vs state) and initialize writer separately,
+  // since it's lot more easier and probably performant to write the row directly rather than
+  // projecting the row to match up with the overall schema.
+  //
+  // The number of data rows and state metadata rows can be different which could be problematic

Review Comment:
   it is confusing what can be problematic. It we are maintaining separate batches for data and state then it is not clear why these batches need to have same number of rows. If we packing everything into the same batch, which we should explain why we are doing this, then it makes sense that some rows may contain only data and no state.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #37893: [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #37893:
URL: https://github.com/apache/spark/pull/37893#discussion_r975910639


##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -216,6 +218,104 @@ def applyInPandas(
         jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
         return DataFrame(jdf, self.session)
 
+    def applyInPandasWithState(
+        self,
+        func: "PandasGroupedMapFunctionWithState",
+        outputStructType: Union[StructType, str],
+        stateStructType: Union[StructType, str],
+        outputMode: str,
+        timeoutConf: str,
+    ) -> DataFrame:
+        """
+        Applies the given function to each group of data, while maintaining a user-defined
+        per-group state. The result Dataset will represent the flattened record returned by the
+        function.
+
+        For a streaming Dataset, the function will be invoked first for all input groups and then
+        for all timed out states where the input data is set to be empty. Updates to each group's
+        state will be saved across invocations.
+
+        The function should take parameters (key, Iterator[`pandas.DataFrame`], state) and
+        returns another Iterator[`pandas.DataFrame`]. The grouping key(s) will be passed as a tuple
+        of numpy data types, e.g., `numpy.int32` and `numpy.float64`. The state will be passed as
+        :class:`pyspark.sql.streaming.state.GroupState`.
+
+        For each group, all columns are passed together as `pandas.DataFrame` to the user-function,
+        and the returned `pandas.DataFrame` across all invocations are combined as a
+        :class:`DataFrame`. Note that the user function should loop through and process all
+        elements in the iterator. The user function should not make a guess of the number of
+        elements in the iterator.
+
+        The `outputStructType` should be a :class:`StructType` describing the schema of all
+        elements in the returned value, `pandas.DataFrame`. The column labels of all elements in
+        returned `pandas.DataFrame` must either match the field names in the defined schema if
+        specified as strings, or match the field data types by position if not strings,
+        e.g. integer indices.
+
+        The `stateStructType` should be :class:`StructType` describing the schema of the
+        user-defined state. The value of the state will be presented as a tuple, as well as the
+        update should be performed with the tuple. The corresponding Python types for
+        :class:DataType are supported. Please refer to the page
+        https://spark.apache.org/docs/latest/sql-ref-datatypes.html (python tab).
+
+        The size of each DataFrame in both the input and output can be arbitrary. The number of
+        DataFrames in both the input and output can also be arbitrary.
+
+        .. versionadded:: 3.4.0
+
+        Parameters
+        ----------
+        func : function
+            a Python native function to be called on every group. It should take parameters
+            (key, Iterator[`pandas.DataFrame`], state) and return Iterator[`pandas.DataFrame`].
+            Note that the type of the key is tuple and the type of the state is
+            :class:`pyspark.sql.streaming.state.GroupState`.
+        outputStructType : :class:`pyspark.sql.types.DataType` or str
+            the type of the output records. The value can be either a
+            :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
+        stateStructType : :class:`pyspark.sql.types.DataType` or str
+            the type of the user-defined state. The value can be either a
+            :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
+        outputMode : str
+            the output mode of the function.
+        timeoutConf : str
+            timeout configuration for groups that do not receive data for a while. valid values
+            are defined in :class:`pyspark.sql.streaming.state.GroupStateTimeout`.
+
+        # TODO: Examples

Review Comment:
   I just added a simple example - let me come up with full example code in examples directory. I'll file a new JIRA ticket for this.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on pull request #37893: [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on PR #37893:
URL: https://github.com/apache/spark/pull/37893#issuecomment-1254210194

   https://github.com/HeartSaVioR/spark/actions/runs/3098349789/jobs/5019498380
   
   ```
   BasicSchedulerIntegrationSuite.super simple job
   org.scalatest.exceptions.TestFailedException: Map() did not equal Map(0 -> 42, 5 -> 42, 1 -> 42, 6 -> 42, 9 -> 42, 2 -> 42, 7 -> 42, 3 -> 42, 8 -> 42, 4 -> 42)
    [Check failure on line 210 in YarnClusterSuite](https://github.com/HeartSaVioR/spark/commit/f1000487960fa19aff9979211db68e63ec4384e0#annotation_4680752580) 
   
   YarnClusterSuite.run Spark in yarn-client mode with different configurations, ensuring redaction
   org.scalatest.exceptions.TestFailedDueToTimeoutException: The code passed to eventually never returned normally. Attempted 190 times over 3.001220665 minutes. Last failure message: handle.getState().isFinal() was false.
    [Check failure on line 210 in YarnClusterSuite](https://github.com/HeartSaVioR/spark/commit/f1000487960fa19aff9979211db68e63ec4384e0#annotation_4680752582) 
   
   YarnClusterSuite.run Spark in yarn-cluster mode with different configurations, ensuring redaction
   org.scalatest.exceptions.TestFailedDueToTimeoutException: The code passed to eventually never returned normally. Attempted 190 times over 3.001046460266666 minutes. Last failure message: handle.getState().isFinal() was false.
    [Check failure on line 210 in YarnClusterSuite](https://github.com/HeartSaVioR/spark/commit/f1000487960fa19aff9979211db68e63ec4384e0#annotation_4680752584) 
   
   YarnClusterSuite.yarn-cluster should respect conf overrides in SparkHadoopUtil (SPARK-16414, SPARK-23630)
   org.scalatest.exceptions.TestFailedDueToTimeoutException: The code passed to eventually never returned normally. Attempted 190 times over 3.0008878969166664 minutes. Last failure message: handle.getState().isFinal() was false.
    [Check failure on line 210 in YarnClusterSuite](https://github.com/HeartSaVioR/spark/commit/f1000487960fa19aff9979211db68e63ec4384e0#annotation_4680752586) 
   
   
   YarnClusterSuite.SPARK-35672: run Spark in yarn-client mode with additional jar using URI scheme 'local'
   org.scalatest.exceptions.TestFailedDueToTimeoutException: The code passed to eventually never returned normally. Attempted 190 times over 3.0009406496 minutes. Last failure message: handle.getState().isFinal() was false.
   ```
   
   None of test failures is related to the change of this PR. Since we update the PR again let's see the build.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #37893: [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #37893:
URL: https://github.com/apache/spark/pull/37893#discussion_r973843912


##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -216,6 +218,105 @@ def applyInPandas(
         jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
         return DataFrame(jdf, self.session)
 
+    def applyInPandasWithState(
+        self,
+        func: "PandasGroupedMapFunctionWithState",
+        outputStructType: Union[StructType, str],
+        stateStructType: Union[StructType, str],
+        outputMode: str,
+        timeoutConf: str,
+    ) -> DataFrame:
+        """
+        Applies the given function to each group of data, while maintaining a user-defined
+        per-group state. The result Dataset will represent the flattened record returned by the
+        function.
+
+        For a streaming Dataset, the function will be invoked for each group repeatedly in every
+        trigger, and updates to each group's state will be saved across invocations. The function
+        will also be invoked for each timed-out state repeatedly. The sequence of the invocation
+        will be input data -> state timeout. When the function is invoked for state timeout, there
+        will be no data being presented.
+
+        The function should takes parameters (key, Iterator[`pandas.DataFrame`], state) and
+        returns another Iterator[`pandas.DataFrame`]. The grouping key(s) will be passed as a tuple
+        of numpy data types, e.g., `numpy.int32` and `numpy.float64`. The state will be passed as
+        :class:`pyspark.sql.streaming.state.GroupStateImpl`.
+
+        For each group, all columns are passed together as `pandas.DataFrame` to the user-function,
+        and the returned `pandas.DataFrame` across all invocations are combined as a
+        :class:`DataFrame`. Note that the user function should loop through and process all
+        elements in the iterator. The user function should not make a guess of the number of
+        elements in the iterator.
+
+        The `outputStructType` should be a :class:`StructType` describing the schema of all
+        elements in returned value, `pandas.DataFrame`. The column labels of all elements in
+        returned value, `pandas.DataFrame` must either match the field names in the defined
+        schema if specified as strings, or match the field data types by position if not strings,
+        e.g. integer indices.
+
+        The `stateStructType` should be :class:`StructType` describing the schema of user-defined
+        state. The value of state will be presented as a tuple, as well as the update should be
+        performed with the tuple. User defined types e.g. native Python class types are not

Review Comment:
   It's a bit tricky - native Python types contain int, float, str, ... and of course they are supported. Probably the clear definition is "python types are supported as long as the default encoder can convert to the Spark SQL type". Not sure we have a clear documentation describing the matrix of compatibility. 
   
   cc. @HyukjinKwon could you please help us make this clear?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on pull request #37893: [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on PR #37893:
URL: https://github.com/apache/spark/pull/37893#issuecomment-1248815511

   cc. @viirya @HyukjinKwon 
   Please take a look into this. Thanks. I understand this is huge and a bit complicated in some part, logic around binpack/chunk. Please feel free to leave comments if the code comment isn't sufficient to understand, I'll try my best to cover it.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] alex-balikov commented on a diff in pull request #37893: [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

Posted by GitBox <gi...@apache.org>.
alex-balikov commented on code in PR #37893:
URL: https://github.com/apache/spark/pull/37893#discussion_r973844357


##########
python/pyspark/sql/pandas/serializers.py:
##########
@@ -371,3 +375,292 @@ def load_stream(self, stream):
                 raise ValueError(
                     "Invalid number of pandas.DataFrames in group {0}".format(dataframes_in_group)
                 )
+
+
+class ApplyInPandasWithStateSerializer(ArrowStreamPandasUDFSerializer):

Review Comment:
   can we have some class and method level comments here.



##########
python/pyspark/sql/pandas/serializers.py:
##########
@@ -371,3 +375,292 @@ def load_stream(self, stream):
                 raise ValueError(
                     "Invalid number of pandas.DataFrames in group {0}".format(dataframes_in_group)
                 )
+
+
+class ApplyInPandasWithStateSerializer(ArrowStreamPandasUDFSerializer):
+    def __init__(
+        self,
+        timezone,
+        safecheck,
+        assign_cols_by_name,
+        state_object_schema,
+        soft_limit_bytes_per_batch,
+        min_data_count_for_sample,
+        soft_timeout_millis_purge_batch,
+    ):
+        super(ApplyInPandasWithStateSerializer, self).__init__(
+            timezone, safecheck, assign_cols_by_name
+        )
+        self.pickleSer = CPickleSerializer()
+        self.utf8_deserializer = UTF8Deserializer()
+        self.state_object_schema = state_object_schema
+
+        self.result_state_df_type = StructType(
+            [
+                StructField("properties", StringType()),
+                StructField("keyRowAsUnsafe", BinaryType()),
+                StructField("object", BinaryType()),
+                StructField("oldTimeoutTimestamp", LongType()),
+            ]
+        )
+
+        self.result_state_pdf_arrow_type = to_arrow_type(self.result_state_df_type)
+        self.soft_limit_bytes_per_batch = soft_limit_bytes_per_batch
+        self.min_data_count_for_sample = min_data_count_for_sample
+        self.soft_timeout_millis_purge_batch = soft_timeout_millis_purge_batch
+
+    def load_stream(self, stream):

Review Comment:
   method level comments here and everywhere else. specifically the parameters being untyped, a comment about the parameter type will be helpful.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HyukjinKwon commented on a diff in pull request #37893: [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

Posted by GitBox <gi...@apache.org>.
HyukjinKwon commented on code in PR #37893:
URL: https://github.com/apache/spark/pull/37893#discussion_r973844153


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala:
##########
@@ -2705,6 +2705,44 @@ object SQLConf {
       .booleanConf
       .createWithDefault(false)
 
+  val MAP_PANDAS_UDF_WITH_STATE_SOFT_LIMIT_SIZE_PER_BATCH =
+    buildConf("spark.sql.execution.applyInPandasWithState.softLimitSizePerBatch")
+      .internal()
+      .doc("When using applyInPandasWithState, set a soft limit of the accumulated size of " +
+        "records that can be written to a single ArrowRecordBatch in memory. This is used to " +
+        "restrict the amount of memory being used to materialize the data in both executor and " +
+        "Python worker. The accumulated size of records are calculated via sampling a set of " +
+        "records. Splitting the ArrowRecordBatch is performed per record, so unless a record " +
+        "is quite huge, the size of constructed ArrowRecordBatch will be around the " +
+        "configured value.")
+      .version("3.4.0")
+      .bytesConf(ByteUnit.BYTE)
+      .createWithDefaultString("64MB")
+
+  val MAP_PANDAS_UDF_WITH_STATE_MIN_DATA_COUNT_FOR_SAMPLE =
+    buildConf("spark.sql.execution.applyInPandasWithState.minDataCountForSample")
+      .internal()
+      .doc("When using applyInPandasWithState, specify the minimum number of records to sample " +
+        "the size of record. The size being retrieved from sampling will be used to estimate " +
+        "the accumulated size of records. Note that limiting by size does not work if the " +
+        "number of records are less than the configured value. For such case, ArrowRecordBatch " +
+        "will only be split for soft timeout.")
+      .version("3.4.0")
+      .intConf
+      .createWithDefault(100)
+
+  val MAP_PANDAS_UDF_WITH_STATE_SOFT_TIMEOUT_PURGE_BATCH =
+    buildConf("spark.sql.execution.applyInPandasWithState.softTimeoutPurgeBatch")
+      .internal()
+      .doc("When using applyInPandasWithState, specify the soft timeout for purging the " +
+        "ArrowRecordBatch. If batching records exceeds the timeout, Spark will force splitting " +
+        "the ArrowRecordBatch regardless of estimated size. This config ensures the receiver " +
+        "of data (both executor and Python worker) to not wait indefinitely for sender to " +
+        "complete the ArrowRecordBatch, which may hurt both throughput and latency.")
+      .version("3.4.0")
+      .timeConf(TimeUnit.MILLISECONDS)
+      .createWithDefaultString("100ms")

Review Comment:
   For this, can we just leverage `spark.sql.execution.pandas.udf.buffer.size` (the feature this PR adds already respects it) if the flush time matters? That configuration is for the purpose.



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala:
##########
@@ -142,6 +143,17 @@ object UnsupportedOperationChecker extends Logging {
           " or the output mode is not append on a streaming DataFrames/Datasets")(plan)
     }
 
+    val applyInPandasWithStates = plan.collect {
+      case f: FlatMapGroupsInPandasWithState if f.isStreaming => f
+    }
+
+    // Disallow multiple `applyInPandasWithState`s.
+    if (applyInPandasWithStates.size >= 2) {

Review Comment:
   no biggie but .. 
   
   ```suggestion
       if (applyInPandasWithStates.size > 1) {
   ```
   
   



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala:
##########
@@ -2705,6 +2705,44 @@ object SQLConf {
       .booleanConf
       .createWithDefault(false)
 
+  val MAP_PANDAS_UDF_WITH_STATE_SOFT_LIMIT_SIZE_PER_BATCH =
+    buildConf("spark.sql.execution.applyInPandasWithState.softLimitSizePerBatch")
+      .internal()
+      .doc("When using applyInPandasWithState, set a soft limit of the accumulated size of " +
+        "records that can be written to a single ArrowRecordBatch in memory. This is used to " +
+        "restrict the amount of memory being used to materialize the data in both executor and " +
+        "Python worker. The accumulated size of records are calculated via sampling a set of " +
+        "records. Splitting the ArrowRecordBatch is performed per record, so unless a record " +
+        "is quite huge, the size of constructed ArrowRecordBatch will be around the " +
+        "configured value.")
+      .version("3.4.0")
+      .bytesConf(ByteUnit.BYTE)
+      .createWithDefaultString("64MB")

Review Comment:
   I think we should have a general configuration for this later that applies to all Arrow batch (SPARK-23258). I think we should reuse `spark.sql.execution.arrow.maxRecordsPerBatch` for the time being.
   



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala:
##########
@@ -0,0 +1,197 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import java.io._
+
+import scala.collection.JavaConverters._
+
+import org.apache.arrow.vector.VectorSchemaRoot
+import org.apache.arrow.vector.ipc.ArrowStreamWriter
+import org.json4s._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.api.python._
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.api.python.PythonSQLUtils
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.execution.python.ApplyInPandasWithStatePythonRunner.{InType, OutType, OutTypeForState, STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER}
+import org.apache.spark.sql.execution.python.ApplyInPandasWithStateWriter.STATE_METADATA_SCHEMA
+import org.apache.spark.sql.execution.streaming.GroupStateImpl
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch}
+
+
+/**
+ * [[ArrowPythonRunner]] with [[org.apache.spark.sql.streaming.GroupState]].
+ */
+class ApplyInPandasWithStatePythonRunner(
+    funcs: Seq[ChainedPythonFunctions],
+    evalType: Int,
+    argOffsets: Array[Array[Int]],
+    inputSchema: StructType,
+    override protected val timeZoneId: String,
+    initialWorkerConf: Map[String, String],
+    stateEncoder: ExpressionEncoder[Row],
+    keySchema: StructType,
+    valueSchema: StructType,
+    stateValueSchema: StructType,
+    softLimitBytesPerBatch: Long,
+    minDataCountForSample: Int,
+    softTimeoutMillsPurgeBatch: Long)
+  extends BasePythonRunner[InType, OutType](funcs, evalType, argOffsets)
+  with PythonArrowInput[InType]
+  with PythonArrowOutput[OutType] {
+
+  override protected val schema: StructType = inputSchema.add("!__state__!", STATE_METADATA_SCHEMA)

Review Comment:
   I suspect it's using `!` here because `!` cannot be an identifier in Spark SQL (?). To be absolutely strict, such column names are allowed in some places of DataFrame API (e.g, `spark.range(1).toDF("!__state__!")`).
   
   I believe we use internal column names such `__grouping__id`, `__file_source_metadata_col`, `__metadata_col`  and `_groupingexpression` in general. We're retrieving them positionally in Python worker side so I assume this is fine to have a duplicate name  ...



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala:
##########
@@ -0,0 +1,217 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.python
+
+import org.apache.spark.TaskContext
+import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical.{EventTimeTimeout, ProcessingTimeTimeout}
+import org.apache.spark.sql.catalyst.plans.physical.Distribution
+import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode}
+import org.apache.spark.sql.execution.python.PandasGroupUtils.resolveArgOffsets
+import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP
+import org.apache.spark.sql.execution.streaming.state.FlatMapGroupsWithStateExecHelper.StateData
+import org.apache.spark.sql.execution.streaming.state.StateStore
+import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.ArrowUtils
+import org.apache.spark.util.CompletionIterator
+
+/**
+ * Physical operator for executing
+ * [[org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsInPandasWithState]]
+ *
+ * @param functionExpr function called on each group
+ * @param groupingAttributes used to group the data
+ * @param outAttributes used to define the output rows
+ * @param stateType used to serialize/deserialize state before calling `functionExpr`
+ * @param stateInfo `StatefulOperatorStateInfo` to identify the state store for a given operator.
+ * @param stateFormatVersion the version of state format.
+ * @param outputMode the output mode of `functionExpr`
+ * @param timeoutConf used to timeout groups that have not received data in a while
+ * @param batchTimestampMs processing timestamp of the current batch.
+ * @param eventTimeWatermark event time watermark for the current batch
+ * @param child logical plan of the underlying data
+ */
+case class FlatMapGroupsInPandasWithStateExec(
+    functionExpr: Expression,
+    groupingAttributes: Seq[Attribute],
+    outAttributes: Seq[Attribute],
+    stateType: StructType,
+    stateInfo: Option[StatefulOperatorStateInfo],
+    stateFormatVersion: Int,
+    outputMode: OutputMode,
+    timeoutConf: GroupStateTimeout,
+    batchTimestampMs: Option[Long],
+    eventTimeWatermark: Option[Long],
+    child: SparkPlan) extends UnaryExecNode with FlatMapGroupsWithStateExecBase {
+
+  // TODO(SPARK-40444): Add the support of initial state.
+  override protected val initialStateDeserializer: Expression = null
+  override protected val initialStateGroupAttrs: Seq[Attribute] = null
+  override protected val initialStateDataAttrs: Seq[Attribute] = null
+  override protected val initialState: SparkPlan = null

Review Comment:
   ```suggestion
     override protected val initialStateDeserializer: Expression = _
     override protected val initialStateGroupAttrs: Seq[Attribute] = _
     override protected val initialStateDataAttrs: Seq[Attribute] = _
     override protected val initialState: SparkPlan = _
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on a diff in pull request #37893: [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on code in PR #37893:
URL: https://github.com/apache/spark/pull/37893#discussion_r974805894


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala:
##########
@@ -0,0 +1,201 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import java.io._
+
+import scala.collection.JavaConverters._
+
+import org.apache.arrow.vector.VectorSchemaRoot
+import org.apache.arrow.vector.ipc.ArrowStreamWriter
+import org.json4s._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.api.python._
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.api.python.PythonSQLUtils
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.execution.python.ApplyInPandasWithStatePythonRunner.{InType, OutType, OutTypeForState, STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER}
+import org.apache.spark.sql.execution.python.ApplyInPandasWithStateWriter.STATE_METADATA_SCHEMA
+import org.apache.spark.sql.execution.streaming.GroupStateImpl
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch}
+
+
+/**
+ * A variant implementation of [[ArrowPythonRunner]] to serve the operation
+ * applyInPandasWithState.
+ */
+class ApplyInPandasWithStatePythonRunner(
+    funcs: Seq[ChainedPythonFunctions],
+    evalType: Int,
+    argOffsets: Array[Array[Int]],
+    inputSchema: StructType,
+    override protected val timeZoneId: String,
+    initialWorkerConf: Map[String, String],
+    stateEncoder: ExpressionEncoder[Row],
+    keySchema: StructType,
+    valueSchema: StructType,
+    stateValueSchema: StructType,
+    softLimitBytesPerBatch: Long,
+    minDataCountForSample: Int,
+    softTimeoutMillsPurgeBatch: Long)
+  extends BasePythonRunner[InType, OutType](funcs, evalType, argOffsets)
+  with PythonArrowInput[InType]
+  with PythonArrowOutput[OutType] {
+
+  override protected val schema: StructType = inputSchema.add("__state", STATE_METADATA_SCHEMA)
+
+  override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback
+
+  override val bufferSize: Int = SQLConf.get.pandasUDFBufferSize
+  require(
+    bufferSize >= 4,
+    "Pandas execution requires more than 4 bytes. Please set higher buffer. " +
+      s"Please change '${SQLConf.PANDAS_UDF_BUFFER_SIZE.key}'.")
+
+  // applyInPandasWithState has its own mechanism to construct the Arrow RecordBatch instance.
+  // Configurations are both applied to executor and Python worker, set them to the worker conf
+  // to let Python worker read the config properly.
+  override protected val workerConf: Map[String, String] = initialWorkerConf +
+    (SQLConf.MAP_PANDAS_UDF_WITH_STATE_SOFT_LIMIT_SIZE_PER_BATCH.key ->
+      softLimitBytesPerBatch.toString) +
+    (SQLConf.MAP_PANDAS_UDF_WITH_STATE_MIN_DATA_COUNT_FOR_SAMPLE.key ->
+      minDataCountForSample.toString) +
+    (SQLConf.MAP_PANDAS_UDF_WITH_STATE_SOFT_TIMEOUT_PURGE_BATCH.key ->
+      softTimeoutMillsPurgeBatch.toString)
+
+  private val stateRowDeserializer = stateEncoder.createDeserializer()
+
+  override protected def handleMetadataBeforeExec(stream: DataOutputStream): Unit = {
+    super.handleMetadataBeforeExec(stream)
+    // Also write the schema for state value
+    PythonRDD.writeUTF(stateValueSchema.json, stream)
+  }
+
+  protected def writeIteratorToArrowStream(

Review Comment:
   I agree the order of parameters didn't strictly follow some well-known best practice, but the change requires to change the base class instead of this. May need a follow-up JIRA ticket / PR to address this in general.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] HeartSaVioR commented on pull request #37893: [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

Posted by GitBox <gi...@apache.org>.
HeartSaVioR commented on PR #37893:
URL: https://github.com/apache/spark/pull/37893#issuecomment-1252277251

   I tried to add method level doc as many as possible, except the case I think it's unnecessary (I might still miss some pieces). 
   
   I don't go with the approach trying to explain all of the parameters with types though, for reasons:
   
   - For Python code, it'd be really hard to maintain the doc in sync with the code.
   - For Scala/Java we have been omitting the explanation of parameters or even methods if it's obvious from the name.
   
   In both languages, we strongly encourage to have method doc and parameter explanation for public APIs. Here we technically add only one public method in group_ops, and one public class GroupState in PySpark. Others are all internal and private.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] alex-balikov commented on a diff in pull request #37893: [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark

Posted by GitBox <gi...@apache.org>.
alex-balikov commented on code in PR #37893:
URL: https://github.com/apache/spark/pull/37893#discussion_r974680745


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala:
##########
@@ -0,0 +1,201 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import java.io._
+
+import scala.collection.JavaConverters._
+
+import org.apache.arrow.vector.VectorSchemaRoot
+import org.apache.arrow.vector.ipc.ArrowStreamWriter
+import org.json4s._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.api.python._
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.api.python.PythonSQLUtils
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.execution.python.ApplyInPandasWithStatePythonRunner.{InType, OutType, OutTypeForState, STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER}
+import org.apache.spark.sql.execution.python.ApplyInPandasWithStateWriter.STATE_METADATA_SCHEMA
+import org.apache.spark.sql.execution.streaming.GroupStateImpl
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch}
+
+
+/**
+ * A variant implementation of [[ArrowPythonRunner]] to serve the operation
+ * applyInPandasWithState.
+ */
+class ApplyInPandasWithStatePythonRunner(
+    funcs: Seq[ChainedPythonFunctions],
+    evalType: Int,
+    argOffsets: Array[Array[Int]],
+    inputSchema: StructType,
+    override protected val timeZoneId: String,
+    initialWorkerConf: Map[String, String],
+    stateEncoder: ExpressionEncoder[Row],
+    keySchema: StructType,
+    valueSchema: StructType,
+    stateValueSchema: StructType,
+    softLimitBytesPerBatch: Long,
+    minDataCountForSample: Int,
+    softTimeoutMillsPurgeBatch: Long)
+  extends BasePythonRunner[InType, OutType](funcs, evalType, argOffsets)
+  with PythonArrowInput[InType]
+  with PythonArrowOutput[OutType] {
+
+  override protected val schema: StructType = inputSchema.add("__state", STATE_METADATA_SCHEMA)
+
+  override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback
+
+  override val bufferSize: Int = SQLConf.get.pandasUDFBufferSize
+  require(
+    bufferSize >= 4,

Review Comment:
   why do we care to throw this exception. may just ensure that the buffer size if bigger:
   
   override val bufferSize: Int = max(4, SQLConf.get.pandasUDFBufferSize)



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala:
##########
@@ -0,0 +1,201 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import java.io._
+
+import scala.collection.JavaConverters._
+
+import org.apache.arrow.vector.VectorSchemaRoot
+import org.apache.arrow.vector.ipc.ArrowStreamWriter
+import org.json4s._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.api.python._
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.api.python.PythonSQLUtils
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.execution.python.ApplyInPandasWithStatePythonRunner.{InType, OutType, OutTypeForState, STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER}
+import org.apache.spark.sql.execution.python.ApplyInPandasWithStateWriter.STATE_METADATA_SCHEMA
+import org.apache.spark.sql.execution.streaming.GroupStateImpl
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch}
+
+
+/**
+ * A variant implementation of [[ArrowPythonRunner]] to serve the operation
+ * applyInPandasWithState.

Review Comment:
   may be add a bit more context why applyInPandasWithState requires a different runner.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala:
##########
@@ -0,0 +1,201 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import java.io._
+
+import scala.collection.JavaConverters._
+
+import org.apache.arrow.vector.VectorSchemaRoot
+import org.apache.arrow.vector.ipc.ArrowStreamWriter
+import org.json4s._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.api.python._
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.api.python.PythonSQLUtils
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.execution.python.ApplyInPandasWithStatePythonRunner.{InType, OutType, OutTypeForState, STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER}
+import org.apache.spark.sql.execution.python.ApplyInPandasWithStateWriter.STATE_METADATA_SCHEMA
+import org.apache.spark.sql.execution.streaming.GroupStateImpl
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch}
+
+
+/**
+ * A variant implementation of [[ArrowPythonRunner]] to serve the operation
+ * applyInPandasWithState.
+ */
+class ApplyInPandasWithStatePythonRunner(
+    funcs: Seq[ChainedPythonFunctions],
+    evalType: Int,
+    argOffsets: Array[Array[Int]],
+    inputSchema: StructType,
+    override protected val timeZoneId: String,
+    initialWorkerConf: Map[String, String],

Review Comment:
   'initial' seems redundant in the name. This is implementation detail unnecessary in the api.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala:
##########
@@ -0,0 +1,201 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import java.io._
+
+import scala.collection.JavaConverters._
+
+import org.apache.arrow.vector.VectorSchemaRoot
+import org.apache.arrow.vector.ipc.ArrowStreamWriter
+import org.json4s._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.api.python._
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.api.python.PythonSQLUtils
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.execution.python.ApplyInPandasWithStatePythonRunner.{InType, OutType, OutTypeForState, STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER}
+import org.apache.spark.sql.execution.python.ApplyInPandasWithStateWriter.STATE_METADATA_SCHEMA
+import org.apache.spark.sql.execution.streaming.GroupStateImpl
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch}
+
+
+/**
+ * A variant implementation of [[ArrowPythonRunner]] to serve the operation
+ * applyInPandasWithState.
+ */
+class ApplyInPandasWithStatePythonRunner(
+    funcs: Seq[ChainedPythonFunctions],
+    evalType: Int,
+    argOffsets: Array[Array[Int]],
+    inputSchema: StructType,
+    override protected val timeZoneId: String,
+    initialWorkerConf: Map[String, String],
+    stateEncoder: ExpressionEncoder[Row],
+    keySchema: StructType,
+    valueSchema: StructType,
+    stateValueSchema: StructType,
+    softLimitBytesPerBatch: Long,
+    minDataCountForSample: Int,
+    softTimeoutMillsPurgeBatch: Long)
+  extends BasePythonRunner[InType, OutType](funcs, evalType, argOffsets)
+  with PythonArrowInput[InType]
+  with PythonArrowOutput[OutType] {
+
+  override protected val schema: StructType = inputSchema.add("__state", STATE_METADATA_SCHEMA)
+
+  override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback
+
+  override val bufferSize: Int = SQLConf.get.pandasUDFBufferSize
+  require(
+    bufferSize >= 4,
+    "Pandas execution requires more than 4 bytes. Please set higher buffer. " +
+      s"Please change '${SQLConf.PANDAS_UDF_BUFFER_SIZE.key}'.")
+
+  // applyInPandasWithState has its own mechanism to construct the Arrow RecordBatch instance.
+  // Configurations are both applied to executor and Python worker, set them to the worker conf
+  // to let Python worker read the config properly.
+  override protected val workerConf: Map[String, String] = initialWorkerConf +
+    (SQLConf.MAP_PANDAS_UDF_WITH_STATE_SOFT_LIMIT_SIZE_PER_BATCH.key ->
+      softLimitBytesPerBatch.toString) +
+    (SQLConf.MAP_PANDAS_UDF_WITH_STATE_MIN_DATA_COUNT_FOR_SAMPLE.key ->
+      minDataCountForSample.toString) +
+    (SQLConf.MAP_PANDAS_UDF_WITH_STATE_SOFT_TIMEOUT_PURGE_BATCH.key ->
+      softTimeoutMillsPurgeBatch.toString)
+
+  private val stateRowDeserializer = stateEncoder.createDeserializer()
+
+  override protected def handleMetadataBeforeExec(stream: DataOutputStream): Unit = {
+    super.handleMetadataBeforeExec(stream)
+    // Also write the schema for state value
+    PythonRDD.writeUTF(stateValueSchema.json, stream)
+  }
+
+  protected def writeIteratorToArrowStream(

Review Comment:
   method level comment. also, usually out params are put after the in ones.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org