You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ka...@apache.org on 2022/11/29 07:13:47 UTC
[spark] branch master updated: [SPARK-41260][PYTHON][SS] Cast NumPy instances to Python primitive types in GroupState update
This is an automated email from the ASF dual-hosted git repository.
kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 7bc71910f9b [SPARK-41260][PYTHON][SS] Cast NumPy instances to Python primitive types in GroupState update
7bc71910f9b is described below
commit 7bc71910f9b08183d1e0572eef880e996892fa7d
Author: Hyukjin Kwon <gu...@apache.org>
AuthorDate: Tue Nov 29 16:13:32 2022 +0900
[SPARK-41260][PYTHON][SS] Cast NumPy instances to Python primitive types in GroupState update
### What changes were proposed in this pull request?
This PR proposes to cast NumPy instances in `GroupState.update`. Previously, if we pass a NumPy instance to `GroupState.update`, then it failed with an exception as below:
```
net.razorvine.pickle.PickleException: expected zero arguments for construction of ClassDict (for numpy.dtype). This happens when an unsupported/unregistered class is being unpickled that requires construction arguments. Fix it by registering a custom IObjectConstructor for this class.
at net.razorvine.pickle.objects.ClassDictConstructor.construct(ClassDictConstructor.java:23)
at net.razorvine.pickle.Unpickler.load_reduce(Unpickler.java:759)
at net.razorvine.pickle.Unpickler.dispatch(Unpickler.java:199)
at net.razorvine.pickle.Unpickler.load(Unpickler.java:109)
at net.razorvine.pickle.Unpickler.loads(Unpickler.java:122)
at org.apache.spark.sql.api.python.PythonSQLUtils$.$anonfun$toJVMRow$1(PythonSQLUtils.scala:125)
```
### Why are the changes needed?
`applyInPandasWithState` uses pandas instances so it is very common to extract a NumPy instance from the pandas and set it to the group state.
### Does this PR introduce _any_ user-facing change?
No because this new API is not released yet.
### How was this patch tested?
Manually tested, and unittest was added.
Closes #38796 from HyukjinKwon/SPARK-41260.
Authored-by: Hyukjin Kwon <gu...@apache.org>
Signed-off-by: Jungtaek Lim <ka...@gmail.com>
---
python/pyspark/sql/streaming/state.py | 21 +++++++-
.../FlatMapGroupsInPandasWithStateSuite.scala | 63 ++++++++++++++++++++++
2 files changed, 83 insertions(+), 1 deletion(-)
diff --git a/python/pyspark/sql/streaming/state.py b/python/pyspark/sql/streaming/state.py
index 66b225e1b10..f0ac427cbea 100644
--- a/python/pyspark/sql/streaming/state.py
+++ b/python/pyspark/sql/streaming/state.py
@@ -19,6 +19,7 @@ import json
from typing import Tuple, Optional
from pyspark.sql.types import DateType, Row, StructType
+from pyspark.sql.utils import has_numpy
__all__ = ["GroupState", "GroupStateTimeout"]
@@ -130,7 +131,25 @@ class GroupState:
if newValue is None:
raise ValueError("'None' is not a valid state value")
- self._value = Row(*newValue)
+ converted = []
+ if has_numpy:
+ import numpy as np
+
+ # In order to convert NumPy types to Python primitive types.
+ for v in newValue:
+ if isinstance(v, np.generic):
+ converted.append(v.tolist())
+ # Address a couple of pandas dtypes too.
+ elif hasattr(v, "to_pytimedelta"):
+ converted.append(v.to_pytimedelta())
+ elif hasattr(v, "to_pydatetime"):
+ converted.append(v.to_pydatetime())
+ else:
+ converted.append(v)
+ else:
+ converted = list(newValue)
+
+ self._value = Row(*converted)
self._defined = True
self._updated = True
self._removed = False
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala
index a83f7cce4c1..ca738b805eb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsInPandasWithStateSuite.scala
@@ -886,4 +886,67 @@ class FlatMapGroupsInPandasWithStateSuite extends StateStoreMetricsTest {
total = Seq(1), updated = Seq(1), droppedByWatermark = Seq(0), removed = Some(Seq(1)))
)
}
+
+ test("SPARK-41260: applyInPandasWithState - NumPy instances to JVM rows in state") {
+ assume(shouldTestPandasUDFs)
+
+ val pythonScript =
+ """
+ |import pandas as pd
+ |import numpy as np
+ |import datetime
+ |from pyspark.sql.types import StructType, StructField, StringType
+ |
+ |tpe = StructType([StructField("key", StringType())])
+ |
+ |def func(key, pdf_iter, state):
+ | pdf = pd.DataFrame({
+ | 'int32': [np.int32(1)],
+ | 'int64': [np.int64(1)],
+ | 'float32': [np.float32(1)],
+ | 'float64': [np.float64(1)],
+ | 'bool': [True],
+ | 'datetime': [datetime.datetime(1990, 1, 1, 0, 0, 0)],
+ | 'timedelta': [datetime.timedelta(1)]
+ | })
+ |
+ | state.update(tuple(pdf.iloc[0]))
+ | # Assert on Python primitive type comparison.
+ | assert state.get == (
+ | 1, 1, 1.0, 1.0, True,
+ | datetime.datetime(1990, 1, 1, 0, 0, 0), datetime.timedelta(1)
+ | )
+ | yield pd.DataFrame({'key': [key[0]]})
+ |""".stripMargin
+ val pythonFunc = TestGroupedMapPandasUDFWithState(
+ name = "pandas_grouped_map_with_state", pythonScript = pythonScript)
+
+ val inputData = MemoryStream[String]
+ val outputStructType = StructType(Seq(StructField("key", StringType)))
+ val stateStructType = StructType(Seq(
+ StructField("int32", IntegerType),
+ StructField("int64", LongType),
+ StructField("float32", FloatType),
+ StructField("float64", DoubleType),
+ StructField("bool", BooleanType),
+ StructField("datetime", DateType),
+ StructField("timedelta", DayTimeIntervalType())
+ ))
+ val inputDataDS = inputData.toDS()
+ val result =
+ inputDataDS
+ .groupBy("value")
+ .applyInPandasWithState(
+ pythonFunc(inputDataDS("value")).expr.asInstanceOf[PythonUDF],
+ outputStructType,
+ stateStructType,
+ "Update",
+ "NoTimeout")
+
+ testStream(result, Update)(
+ AddData(inputData, "a"),
+ CheckNewAnswer("a"),
+ assertNumStateRows(total = 1, updated = 1)
+ )
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org