You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ka...@apache.org on 2022/11/29 07:13:47 UTC

[spark] branch master updated: [SPARK-41260][PYTHON][SS] Cast NumPy instances to Python primitive types in GroupState update

This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 7bc71910f9b [SPARK-41260][PYTHON][SS] Cast NumPy instances to Python primitive types in GroupState update
7bc71910f9b is described below

commit 7bc71910f9b08183d1e0572eef880e996892fa7d
Author: Hyukjin Kwon <gu...@apache.org>
AuthorDate: Tue Nov 29 16:13:32 2022 +0900

    [SPARK-41260][PYTHON][SS] Cast NumPy instances to Python primitive types in GroupState update
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to cast NumPy instances in `GroupState.update`.  Previously, if we pass a NumPy instance to `GroupState.update`, then it failed with an exception as below:
    
    ```
    net.razorvine.pickle.PickleException: expected zero arguments for construction of ClassDict (for numpy.dtype). This happens when an unsupported/unregistered class is being unpickled that requires construction arguments. Fix it by registering a custom IObjectConstructor for this class.
            at net.razorvine.pickle.objects.ClassDictConstructor.construct(ClassDictConstructor.java:23)
            at net.razorvine.pickle.Unpickler.load_reduce(Unpickler.java:759)
            at net.razorvine.pickle.Unpickler.dispatch(Unpickler.java:199)
            at net.razorvine.pickle.Unpickler.load(Unpickler.java:109)
            at net.razorvine.pickle.Unpickler.loads(Unpickler.java:122)
            at org.apache.spark.sql.api.python.PythonSQLUtils$.$anonfun$toJVMRow$1(PythonSQLUtils.scala:125)
    ```
    
    ### Why are the changes needed?
    
    `applyInPandasWithState` uses pandas instances so it is very common to extract a NumPy instance from the pandas and set it to the group state.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No because this new API is not released yet.
    
    ### How was this patch tested?
    
    Manually tested, and unittest was added.
    
    Closes #38796 from HyukjinKwon/SPARK-41260.
    
    Authored-by: Hyukjin Kwon <gu...@apache.org>
    Signed-off-by: Jungtaek Lim <ka...@gmail.com>
---
 python/pyspark/sql/streaming/state.py              | 21 +++++++-
 .../FlatMapGroupsInPandasWithStateSuite.scala      | 63 ++++++++++++++++++++++
 2 files changed, 83 insertions(+), 1 deletion(-)

diff --git a/python/pyspark/sql/streaming/state.py b/python/pyspark/sql/streaming/state.py
index 66b225e1b10..f0ac427cbea 100644
--- a/python/pyspark/sql/streaming/state.py
+++ b/python/pyspark/sql/streaming/state.py
@@ -19,6 +19,7 @@ import json
 from typing import Tuple, Optional
 
 from pyspark.sql.types import DateType, Row, StructType
+from pyspark.sql.utils import has_numpy
 
 __all__ = ["GroupState", "GroupStateTimeout"]
 
@@ -130,7 +131,25 @@ class GroupState:
         if newValue is None:
             raise ValueError("'None' is not a valid state value")
 
-        self._value = Row(*newValue)
+        converted = []
+        if has_numpy:
+            import numpy as np
+
+            # In order to convert NumPy types to Python primitive types.
+            for v in newValue:
+                if isinstance(v, np.generic):
+                    converted.append(v.tolist())
+                # Address a couple of pandas dtypes too.
+                elif hasattr(v, "to_pytimedelta"):
+                    converted.append(v.to_pytimedelta())
+                elif hasattr(v, "to_pydatetime"):
+                    converted.append(v.to_pydatetime())
+                else:
+                    converted.append(v)
+        else:
+            converted = list(newValue)
+
+        self._value = Row(*converted)
         self._defined = True
         self._updated = True
         self._removed = False
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala
index a83f7cce4c1..ca738b805eb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala
@@ -886,4 +886,67 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest {
         total = Seq(1), updated = Seq(1), droppedByWatermark = Seq(0), removed = Some(Seq(1)))
     )
   }
+
+  test("SPARK-41260: applyInPandasWithState - NumPy instances to JVM rows in state") {
+    assume(shouldTestPandasUDFs)
+
+    val pythonScript =
+      """
+        |import pandas as pd
+        |import numpy as np
+        |import datetime
+        |from pyspark.sql.types import StructType, StructField, StringType
+        |
+        |tpe = StructType([StructField("key", StringType())])
+        |
+        |def func(key, pdf_iter, state):
+        |    pdf = pd.DataFrame({
+        |        'int32': [np.int32(1)],
+        |        'int64': [np.int64(1)],
+        |        'float32': [np.float32(1)],
+        |        'float64': [np.float64(1)],
+        |        'bool': [True],
+        |        'datetime': [datetime.datetime(1990, 1, 1, 0, 0, 0)],
+        |        'timedelta': [datetime.timedelta(1)]
+        |    })
+        |
+        |    state.update(tuple(pdf.iloc[0]))
+        |    # Assert on Python primitive type comparison.
+        |    assert state.get == (
+        |        1, 1, 1.0, 1.0, True,
+        |        datetime.datetime(1990, 1, 1, 0, 0, 0), datetime.timedelta(1)
+        |    )
+        |    yield pd.DataFrame({'key': [key[0]]})
+        |""".stripMargin
+    val pythonFunc = TestGroupedMapPandasUDFWithState(
+      name = "pandas_grouped_map_with_state", pythonScript = pythonScript)
+
+    val inputData = MemoryStream[String]
+    val outputStructType = StructType(Seq(StructField("key", StringType)))
+    val stateStructType = StructType(Seq(
+      StructField("int32", IntegerType),
+      StructField("int64", LongType),
+      StructField("float32", FloatType),
+      StructField("float64", DoubleType),
+      StructField("bool", BooleanType),
+      StructField("datetime", DateType),
+      StructField("timedelta", DayTimeIntervalType())
+    ))
+    val inputDataDS = inputData.toDS()
+    val result =
+      inputDataDS
+        .groupBy("value")
+        .applyInPandasWithState(
+          pythonFunc(inputDataDS("value")).expr.asInstanceOf[PythonUDF],
+          outputStructType,
+          stateStructType,
+          "Update",
+          "NoTimeout")
+
+    testStream(result, Update)(
+      AddData(inputData, "a"),
+      CheckNewAnswer("a"),
+      assertNumStateRows(total = 1, updated = 1)
+    )
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org