You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2023/06/21 01:34:19 UTC

[spark] branch master updated: [SPARK-42941][SS][CONNECT] 1/2] StreamingQueryListener - Event Serde in JSON format

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 6bfc01188e9 [SPARK-42941][SS][CONNECT] 1/2] StreamingQueryListener - Event Serde in JSON format
6bfc01188e9 is described below

commit 6bfc01188e96af065218e9f4574c3c0b8c87fde0
Author: Wei Liu <we...@databricks.com>
AuthorDate: Wed Jun 21 10:34:06 2023 +0900

    [SPARK-42941][SS][CONNECT] 1/2] StreamingQueryListener - Event Serde in JSON format
    
    ### What changes were proposed in this pull request?
    
    Following the discussion of `foreachBatch` implementation, we decide to implement connect StreamingQueryListener in a way that the server runs the listener code, rather than the client.
    
    Following this POC: https://github.com/apache/spark/pull/41096, this is going to be done in a way such that
    1. Client sends serialized python code to server
    2. Server initializes a Scala `StreamingQueryListener`, which initialize the python progress and run the python code. (Details of this step still depends on `foreachBatch` implementation.
    3. When a new StreamingQuery Event comes in, the jvm serialize it to JSON and send it to the python progress to process.
    
    This PR focus on step 3, the serialization and deserialization of the events.
    
    Also finishes a TODO to check exception in QueryTerminatedEvent
    
    ### Why are the changes needed?
    
    For implementing Connect StreamingQueryListener
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    New unit tests
    
    Closes #41540 from WweiL/SPARK-42941-listener-python-new-1.
    
    Authored-by: Wei Liu <we...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 python/pyspark/sql/streaming/listener.py           | 452 +++++++++++++++++----
 .../sql/tests/streaming/test_streaming_listener.py | 221 +++++++++-
 .../sql/streaming/StreamingQueryListener.scala     |  44 +-
 3 files changed, 618 insertions(+), 99 deletions(-)

diff --git a/python/pyspark/sql/streaming/listener.py b/python/pyspark/sql/streaming/listener.py
index 33482664a7b..198af0c9cbe 100644
--- a/python/pyspark/sql/streaming/listener.py
+++ b/python/pyspark/sql/streaming/listener.py
@@ -15,7 +15,8 @@
 # limitations under the License.
 #
 import uuid
-from typing import Optional, Dict, List
+import json
+from typing import Any, Dict, List, Optional
 from abc import ABC, abstractmethod
 
 from py4j.java_gateway import JavaObject
@@ -129,16 +130,16 @@ class JStreamingQueryListener:
         self.pylistener = pylistener
 
     def onQueryStarted(self, jevent: JavaObject) -> None:
-        self.pylistener.onQueryStarted(QueryStartedEvent(jevent))
+        self.pylistener.onQueryStarted(QueryStartedEvent.fromJObject(jevent))
 
     def onQueryProgress(self, jevent: JavaObject) -> None:
-        self.pylistener.onQueryProgress(QueryProgressEvent(jevent))
+        self.pylistener.onQueryProgress(QueryProgressEvent.fromJObject(jevent))
 
     def onQueryIdle(self, jevent: JavaObject) -> None:
-        self.pylistener.onQueryIdle(QueryIdleEvent(jevent))
+        self.pylistener.onQueryIdle(QueryIdleEvent.fromJObject(jevent))
 
     def onQueryTerminated(self, jevent: JavaObject) -> None:
-        self.pylistener.onQueryTerminated(QueryTerminatedEvent(jevent))
+        self.pylistener.onQueryTerminated(QueryTerminatedEvent.fromJObject(jevent))
 
     class Java:
         implements = ["org.apache.spark.sql.streaming.PythonStreamingQueryListener"]
@@ -155,11 +156,31 @@ class QueryStartedEvent:
     This API is evolving.
     """
 
-    def __init__(self, jevent: JavaObject) -> None:
-        self._id: uuid.UUID = uuid.UUID(jevent.id().toString())
-        self._runId: uuid.UUID = uuid.UUID(jevent.runId().toString())
-        self._name: Optional[str] = jevent.name()
-        self._timestamp: str = jevent.timestamp()
+    def __init__(
+        self, id: uuid.UUID, runId: uuid.UUID, name: Optional[str], timestamp: str
+    ) -> None:
+        self._id: uuid.UUID = id
+        self._runId: uuid.UUID = runId
+        self._name: Optional[str] = name
+        self._timestamp: str = timestamp
+
+    @classmethod
+    def fromJObject(cls, jevent: JavaObject) -> "QueryStartedEvent":
+        return cls(
+            id=uuid.UUID(jevent.id().toString()),
+            runId=uuid.UUID(jevent.runId().toString()),
+            name=jevent.name(),
+            timestamp=jevent.timestamp(),
+        )
+
+    @classmethod
+    def fromJson(cls, j: Dict[str, Any]) -> "QueryStartedEvent":
+        return cls(
+            id=uuid.UUID(j["id"]),
+            runId=uuid.UUID(j["runId"]),
+            name=j["name"],
+            timestamp=j["timestamp"],
+        )
 
     @property
     def id(self) -> uuid.UUID:
@@ -203,8 +224,16 @@ class QueryProgressEvent:
     This API is evolving.
     """
 
-    def __init__(self, jevent: JavaObject) -> None:
-        self._progress: StreamingQueryProgress = StreamingQueryProgress(jevent.progress())
+    def __init__(self, progress: "StreamingQueryProgress") -> None:
+        self._progress: StreamingQueryProgress = progress
+
+    @classmethod
+    def fromJObject(cls, jevent: JavaObject) -> "QueryProgressEvent":
+        return cls(progress=StreamingQueryProgress.fromJObject(jevent.progress()))
+
+    @classmethod
+    def fromJson(cls, j: Dict[str, Any]) -> "QueryProgressEvent":
+        return cls(progress=StreamingQueryProgress.fromJson(j["progress"]))
 
     @property
     def progress(self) -> "StreamingQueryProgress":
@@ -225,10 +254,22 @@ class QueryIdleEvent:
     This API is evolving.
     """
 
-    def __init__(self, jevent: JavaObject) -> None:
-        self._id: uuid.UUID = uuid.UUID(jevent.id().toString())
-        self._runId: uuid.UUID = uuid.UUID(jevent.runId().toString())
-        self._timestamp: str = jevent.timestamp()
+    def __init__(self, id: uuid.UUID, runId: uuid.UUID, timestamp: str) -> None:
+        self._id: uuid.UUID = id
+        self._runId: uuid.UUID = runId
+        self._timestamp: str = timestamp
+
+    @classmethod
+    def fromJObject(cls, jevent: JavaObject) -> "QueryIdleEvent":
+        return cls(
+            id=uuid.UUID(jevent.id().toString()),
+            runId=uuid.UUID(jevent.runId().toString()),
+            timestamp=jevent.timestamp(),
+        )
+
+    @classmethod
+    def fromJson(cls, j: Dict[str, Any]) -> "QueryIdleEvent":
+        return cls(id=uuid.UUID(j["id"]), runId=uuid.UUID(j["runId"]), timestamp=j["timestamp"])
 
     @property
     def id(self) -> uuid.UUID:
@@ -265,14 +306,36 @@ class QueryTerminatedEvent:
     This API is evolving.
     """
 
-    def __init__(self, jevent: JavaObject) -> None:
-        self._id: uuid.UUID = uuid.UUID(jevent.id().toString())
-        self._runId: uuid.UUID = uuid.UUID(jevent.runId().toString())
+    def __init__(
+        self,
+        id: uuid.UUID,
+        runId: uuid.UUID,
+        exception: Optional[str],
+        errorClassOnException: Optional[str],
+    ) -> None:
+        self._id: uuid.UUID = id
+        self._runId: uuid.UUID = runId
+        self._exception: Optional[str] = exception
+        self._errorClassOnException: Optional[str] = errorClassOnException
+
+    @classmethod
+    def fromJObject(cls, jevent: JavaObject) -> "QueryTerminatedEvent":
         jexception = jevent.exception()
-        self._exception: Optional[str] = jexception.get() if jexception.isDefined() else None
         jerrorclass = jevent.errorClassOnException()
-        self._errorClassOnException: Optional[str] = (
-            jerrorclass.get() if jerrorclass.isDefined() else None
+        return cls(
+            id=uuid.UUID(jevent.id().toString()),
+            runId=uuid.UUID(jevent.runId().toString()),
+            exception=jexception.get() if jexception.isDefined() else None,
+            errorClassOnException=jerrorclass.get() if jerrorclass.isDefined() else None,
+        )
+
+    @classmethod
+    def fromJson(cls, j: Dict[str, Any]) -> "QueryTerminatedEvent":
+        return cls(
+            id=uuid.UUID(j["id"]),
+            runId=uuid.UUID(j["runId"]),
+            exception=j["exception"],
+            errorClassOnException=j["errorClassOnException"],
         )
 
     @property
@@ -322,32 +385,97 @@ class StreamingQueryProgress:
     This API is evolving.
     """
 
-    def __init__(self, jprogress: JavaObject) -> None:
+    def __init__(
+        self,
+        id: uuid.UUID,
+        runId: uuid.UUID,
+        name: Optional[str],
+        timestamp: str,
+        batchId: int,
+        batchDuration: int,
+        durationMs: Dict[str, int],
+        eventTime: Dict[str, str],
+        stateOperators: List["StateOperatorProgress"],
+        sources: List["SourceProgress"],
+        sink: "SinkProgress",
+        numInputRows: int,
+        inputRowsPerSecond: float,
+        processedRowsPerSecond: float,
+        observedMetrics: Dict[str, Row],
+        jprogress: Optional[JavaObject] = None,
+        jdict: Optional[Dict[str, Any]] = None,
+    ):
+        self._jprogress: Optional[JavaObject] = jprogress
+        self._jdict: Optional[Dict[str, Any]] = jdict
+        self._id: uuid.UUID = id
+        self._runId: uuid.UUID = runId
+        self._name: Optional[str] = name
+        self._timestamp: str = timestamp
+        self._batchId: int = batchId
+        self._batchDuration: int = batchDuration
+        self._durationMs: Dict[str, int] = durationMs
+        self._eventTime: Dict[str, str] = eventTime
+        self._stateOperators: List[StateOperatorProgress] = stateOperators
+        self._sources: List[SourceProgress] = sources
+        self._sink: SinkProgress = sink
+        self._numInputRows: int = numInputRows
+        self._inputRowsPerSecond: float = inputRowsPerSecond
+        self._processedRowsPerSecond: float = processedRowsPerSecond
+        self._observedMetrics: Dict[str, Row] = observedMetrics
+
+    @classmethod
+    def fromJObject(cls, jprogress: JavaObject) -> "StreamingQueryProgress":
         from pyspark import SparkContext
 
-        self._jprogress: JavaObject = jprogress
-        self._id: uuid.UUID = uuid.UUID(jprogress.id().toString())
-        self._runId: uuid.UUID = uuid.UUID(jprogress.runId().toString())
-        self._name: Optional[str] = jprogress.name()
-        self._timestamp: str = jprogress.timestamp()
-        self._batchId: int = jprogress.batchId()
-        self._inputRowsPerSecond: float = jprogress.inputRowsPerSecond()
-        self._processedRowsPerSecond: float = jprogress.processedRowsPerSecond()
-        self._batchDuration: int = jprogress.batchDuration()
-        self._durationMs: Dict[str, int] = dict(jprogress.durationMs())
-        self._eventTime: Dict[str, str] = dict(jprogress.eventTime())
-        self._stateOperators: List[StateOperatorProgress] = [
-            StateOperatorProgress(js) for js in jprogress.stateOperators()
-        ]
-        self._sources: List[SourceProgress] = [SourceProgress(js) for js in jprogress.sources()]
-        self._sink: SinkProgress = SinkProgress(jprogress.sink())
-
-        self._observedMetrics: Dict[str, Row] = {
-            k: cloudpickle.loads(
-                SparkContext._jvm.PythonSQLUtils.toPyRow(jr)  # type: ignore[union-attr]
-            )
-            for k, jr in dict(jprogress.observedMetrics()).items()
-        }
+        return cls(
+            jprogress=jprogress,
+            id=uuid.UUID(jprogress.id().toString()),
+            runId=uuid.UUID(jprogress.runId().toString()),
+            name=jprogress.name(),
+            timestamp=jprogress.timestamp(),
+            batchId=jprogress.batchId(),
+            batchDuration=jprogress.batchDuration(),
+            durationMs=dict(jprogress.durationMs()),
+            eventTime=dict(jprogress.eventTime()),
+            stateOperators=[
+                StateOperatorProgress.fromJObject(js) for js in jprogress.stateOperators()
+            ],
+            sources=[SourceProgress.fromJObject(js) for js in jprogress.sources()],
+            sink=SinkProgress.fromJObject(jprogress.sink()),
+            numInputRows=jprogress.numInputRows(),
+            inputRowsPerSecond=jprogress.inputRowsPerSecond(),
+            processedRowsPerSecond=jprogress.processedRowsPerSecond(),
+            observedMetrics={
+                k: cloudpickle.loads(
+                    SparkContext._jvm.PythonSQLUtils.toPyRow(jr)  # type: ignore[union-attr]
+                )
+                for k, jr in dict(jprogress.observedMetrics()).items()
+            },
+        )
+
+    @classmethod
+    def fromJson(cls, j: Dict[str, Any]) -> "StreamingQueryProgress":
+        return cls(
+            jdict=j,
+            id=uuid.UUID(j["id"]),
+            runId=uuid.UUID(j["runId"]),
+            name=j["name"],
+            timestamp=j["timestamp"],
+            batchId=j["batchId"],
+            batchDuration=j["batchDuration"],
+            durationMs=dict(j["durationMs"]),
+            eventTime=dict(j["eventTime"]),
+            stateOperators=[StateOperatorProgress.fromJson(s) for s in j["stateOperators"]],
+            sources=[SourceProgress.fromJson(s) for s in j["sources"]],
+            sink=SinkProgress.fromJson(j["sink"]),
+            numInputRows=j["numInputRows"],
+            inputRowsPerSecond=j["inputRowsPerSecond"],
+            processedRowsPerSecond=j["processedRowsPerSecond"],
+            observedMetrics={
+                k: Row(*row_dict.keys())(*row_dict.values())  # Assume no nested rows
+                for k, row_dict in j["observedMetrics"].items()
+            },
+        )
 
     @property
     def id(self) -> uuid.UUID:
@@ -448,11 +576,11 @@ class StreamingQueryProgress:
         return self._observedMetrics
 
     @property
-    def numInputRows(self) -> Optional[str]:
+    def numInputRows(self) -> int:
         """
         The aggregate (across all sources) number of records processed in a trigger.
         """
-        return self._jprogress.numInputRows()
+        return self._numInputRows
 
     @property
     def inputRowsPerSecond(self) -> float:
@@ -464,7 +592,7 @@ class StreamingQueryProgress:
     @property
     def processedRowsPerSecond(self) -> float:
         """
-        The aggregate (across all sources) rate at which Spark is processing data..
+        The aggregate (across all sources) rate at which Spark is processing data.
         """
         return self._processedRowsPerSecond
 
@@ -473,14 +601,22 @@ class StreamingQueryProgress:
         """
         The compact JSON representation of this progress.
         """
-        return self._jprogress.json()
+        assert self._jdict is not None or self._jprogress is not None
+        if self._jprogress:
+            return self._jprogress.json()
+        else:
+            return json.dumps(self._jdict)
 
     @property
     def prettyJson(self) -> str:
         """
         The pretty (i.e. indented) JSON representation of this progress.
         """
-        return self._jprogress.prettyJson()
+        assert self._jdict is not None or self._jprogress is not None
+        if self._jprogress:
+            return self._jprogress.prettyJson()
+        else:
+            return json.dumps(self._jdict, indent=4)
 
     def __str__(self) -> str:
         return self.prettyJson
@@ -495,20 +631,73 @@ class StateOperatorProgress:
     This API is evolving.
     """
 
-    def __init__(self, jprogress: JavaObject) -> None:
-        self._jprogress: JavaObject = jprogress
-        self._operatorName: str = jprogress.operatorName()
-        self._numRowsTotal: int = jprogress.numRowsTotal()
-        self._numRowsUpdated: int = jprogress.numRowsUpdated()
-        self._allUpdatesTimeMs: int = jprogress.allUpdatesTimeMs()
-        self._numRowsRemoved: int = jprogress.numRowsRemoved()
-        self._allRemovalsTimeMs: int = jprogress.allRemovalsTimeMs()
-        self._commitTimeMs: int = jprogress.commitTimeMs()
-        self._memoryUsedBytes: int = jprogress.memoryUsedBytes()
-        self._numRowsDroppedByWatermark: int = jprogress.numRowsDroppedByWatermark()
-        self._numShufflePartitions: int = jprogress.numShufflePartitions()
-        self._numStateStoreInstances: int = jprogress.numStateStoreInstances()
-        self._customMetrics: Dict[str, int] = dict(jprogress.customMetrics())
+    def __init__(
+        self,
+        operatorName: str,
+        numRowsTotal: int,
+        numRowsUpdated: int,
+        numRowsRemoved: int,
+        allUpdatesTimeMs: int,
+        allRemovalsTimeMs: int,
+        commitTimeMs: int,
+        memoryUsedBytes: int,
+        numRowsDroppedByWatermark: int,
+        numShufflePartitions: int,
+        numStateStoreInstances: int,
+        customMetrics: Dict[str, int],
+        jprogress: Optional[JavaObject] = None,
+        jdict: Optional[Dict[str, Any]] = None,
+    ):
+        self._jprogress: Optional[JavaObject] = jprogress
+        self._jdict: Optional[Dict[str, Any]] = jdict
+        self._operatorName: str = operatorName
+        self._numRowsTotal: int = numRowsTotal
+        self._numRowsUpdated: int = numRowsUpdated
+        self._numRowsRemoved: int = numRowsRemoved
+        self._allUpdatesTimeMs: int = allUpdatesTimeMs
+        self._allRemovalsTimeMs: int = allRemovalsTimeMs
+        self._commitTimeMs: int = commitTimeMs
+        self._memoryUsedBytes: int = memoryUsedBytes
+        self._numRowsDroppedByWatermark: int = numRowsDroppedByWatermark
+        self._numShufflePartitions: int = numShufflePartitions
+        self._numStateStoreInstances: int = numStateStoreInstances
+        self._customMetrics: Dict[str, int] = customMetrics
+
+    @classmethod
+    def fromJObject(cls, jprogress: JavaObject) -> "StateOperatorProgress":
+        return cls(
+            jprogress=jprogress,
+            operatorName=jprogress.operatorName(),
+            numRowsTotal=jprogress.numRowsTotal(),
+            numRowsUpdated=jprogress.numRowsUpdated(),
+            allUpdatesTimeMs=jprogress.allUpdatesTimeMs(),
+            numRowsRemoved=jprogress.numRowsRemoved(),
+            allRemovalsTimeMs=jprogress.allRemovalsTimeMs(),
+            commitTimeMs=jprogress.commitTimeMs(),
+            memoryUsedBytes=jprogress.memoryUsedBytes(),
+            numRowsDroppedByWatermark=jprogress.numRowsDroppedByWatermark(),
+            numShufflePartitions=jprogress.numShufflePartitions(),
+            numStateStoreInstances=jprogress.numStateStoreInstances(),
+            customMetrics=dict(jprogress.customMetrics()),
+        )
+
+    @classmethod
+    def fromJson(cls, j: Dict[str, Any]) -> "StateOperatorProgress":
+        return cls(
+            jdict=j,
+            operatorName=j["operatorName"],
+            numRowsTotal=j["numRowsTotal"],
+            numRowsUpdated=j["numRowsUpdated"],
+            numRowsRemoved=j["numRowsRemoved"],
+            allUpdatesTimeMs=j["allUpdatesTimeMs"],
+            allRemovalsTimeMs=j["allRemovalsTimeMs"],
+            commitTimeMs=j["commitTimeMs"],
+            memoryUsedBytes=j["memoryUsedBytes"],
+            numRowsDroppedByWatermark=j["numRowsDroppedByWatermark"],
+            numShufflePartitions=j["numShufflePartitions"],
+            numStateStoreInstances=j["numStateStoreInstances"],
+            customMetrics=dict(j["customMetrics"]),
+        )
 
     @property
     def operatorName(self) -> str:
@@ -563,14 +752,22 @@ class StateOperatorProgress:
         """
         The compact JSON representation of this progress.
         """
-        return self._jprogress.json()
+        assert self._jdict is not None or self._jprogress is not None
+        if self._jprogress:
+            return self._jprogress.json()
+        else:
+            return json.dumps(self._jdict)
 
     @property
     def prettyJson(self) -> str:
         """
         The pretty (i.e. indented) JSON representation of this progress.
         """
-        return self._jprogress.prettyJson()
+        assert self._jdict is not None or self._jprogress is not None
+        if self._jprogress:
+            return self._jprogress.prettyJson()
+        else:
+            return json.dumps(self._jdict, indent=4)
 
     def __str__(self) -> str:
         return self.prettyJson
@@ -585,16 +782,57 @@ class SourceProgress:
     This API is evolving.
     """
 
-    def __init__(self, jprogress: JavaObject) -> None:
-        self._jprogress: JavaObject = jprogress
-        self._description: str = jprogress.description()
-        self._startOffset: str = jprogress.startOffset()
-        self._endOffset: str = jprogress.endOffset()
-        self._latestOffset: str = jprogress.latestOffset()
-        self._numInputRows: int = jprogress.numInputRows()
-        self._inputRowsPerSecond: float = jprogress.inputRowsPerSecond()
-        self._processedRowsPerSecond: float = jprogress.processedRowsPerSecond()
-        self._metrics: Dict[str, str] = dict(jprogress.metrics())
+    def __init__(
+        self,
+        description: str,
+        startOffset: str,
+        endOffset: str,
+        latestOffset: str,
+        numInputRows: int,
+        inputRowsPerSecond: float,
+        processedRowsPerSecond: float,
+        metrics: Dict[str, str],
+        jprogress: Optional[JavaObject] = None,
+        jdict: Optional[Dict[str, Any]] = None,
+    ) -> None:
+        self._jprogress: Optional[JavaObject] = jprogress
+        self._jdict: Optional[Dict[str, Any]] = jdict
+        self._description: str = description
+        self._startOffset: str = startOffset
+        self._endOffset: str = endOffset
+        self._latestOffset: str = latestOffset
+        self._numInputRows: int = numInputRows
+        self._inputRowsPerSecond: float = inputRowsPerSecond
+        self._processedRowsPerSecond: float = processedRowsPerSecond
+        self._metrics: Dict[str, str] = metrics
+
+    @classmethod
+    def fromJObject(cls, jprogress: JavaObject) -> "SourceProgress":
+        return cls(
+            jprogress=jprogress,
+            description=jprogress.description(),
+            startOffset=str(jprogress.startOffset()),
+            endOffset=str(jprogress.endOffset()),
+            latestOffset=str(jprogress.latestOffset()),
+            numInputRows=jprogress.numInputRows(),
+            inputRowsPerSecond=jprogress.inputRowsPerSecond(),
+            processedRowsPerSecond=jprogress.processedRowsPerSecond(),
+            metrics=dict(jprogress.metrics()),
+        )
+
+    @classmethod
+    def fromJson(cls, j: Dict[str, Any]) -> "SourceProgress":
+        return cls(
+            jdict=j,
+            description=j["description"],
+            startOffset=str(j["startOffset"]),
+            endOffset=str(j["endOffset"]),
+            latestOffset=str(j["latestOffset"]),
+            numInputRows=j["numInputRows"],
+            inputRowsPerSecond=j["inputRowsPerSecond"],
+            processedRowsPerSecond=j["processedRowsPerSecond"],
+            metrics=dict(j["metrics"]),
+        )
 
     @property
     def description(self) -> str:
@@ -654,14 +892,22 @@ class SourceProgress:
         """
         The compact JSON representation of this progress.
         """
-        return self._jprogress.json()
+        assert self._jdict is not None or self._jprogress is not None
+        if self._jprogress:
+            return self._jprogress.json()
+        else:
+            return json.dumps(self._jdict)
 
     @property
     def prettyJson(self) -> str:
         """
         The pretty (i.e. indented) JSON representation of this progress.
         """
-        return self._jprogress.prettyJson()
+        assert self._jdict is not None or self._jprogress is not None
+        if self._jprogress:
+            return self._jprogress.prettyJson()
+        else:
+            return json.dumps(self._jdict, indent=4)
 
     def __str__(self) -> str:
         return self.prettyJson
@@ -676,11 +922,37 @@ class SinkProgress:
     This API is evolving.
     """
 
-    def __init__(self, jprogress: JavaObject) -> None:
-        self._jprogress: JavaObject = jprogress
-        self._description: str = jprogress.description()
-        self._numOutputRows: int = jprogress.numOutputRows()
-        self._metrics: Dict[str, str] = dict(jprogress.metrics())
+    def __init__(
+        self,
+        description: str,
+        numOutputRows: int,
+        metrics: Dict[str, str],
+        jprogress: Optional[JavaObject] = None,
+        jdict: Optional[Dict[str, Any]] = None,
+    ) -> None:
+        self._jprogress: Optional[JavaObject] = jprogress
+        self._jdict: Optional[Dict[str, Any]] = jdict
+        self._description: str = description
+        self._numOutputRows: int = numOutputRows
+        self._metrics: Dict[str, str] = metrics
+
+    @classmethod
+    def fromJObject(cls, jprogress: JavaObject) -> "SinkProgress":
+        return cls(
+            jprogress=jprogress,
+            description=jprogress.description(),
+            numOutputRows=jprogress.numOutputRows(),
+            metrics=dict(jprogress.metrics()),
+        )
+
+    @classmethod
+    def fromJson(cls, j: Dict[str, Any]) -> "SinkProgress":
+        return cls(
+            jdict=j,
+            description=j["description"],
+            numOutputRows=j["numOutputRows"],
+            metrics=j["metrics"],
+        )
 
     @property
     def description(self) -> str:
@@ -706,14 +978,22 @@ class SinkProgress:
         """
         The compact JSON representation of this progress.
         """
-        return self._jprogress.json()
+        assert self._jdict is not None or self._jprogress is not None
+        if self._jprogress:
+            return self._jprogress.json()
+        else:
+            return json.dumps(self._jdict)
 
     @property
     def prettyJson(self) -> str:
         """
         The pretty (i.e. indented) JSON representation of this progress.
         """
-        return self._jprogress.prettyJson()
+        assert self._jdict is not None or self._jprogress is not None
+        if self._jprogress:
+            return self._jprogress.prettyJson()
+        else:
+            return json.dumps(self._jdict, indent=4)
 
     def __str__(self) -> str:
         return self.prettyJson
diff --git a/python/pyspark/sql/tests/streaming/test_streaming_listener.py b/python/pyspark/sql/tests/streaming/test_streaming_listener.py
index 71d76bc4e8d..2bd6d2c6668 100644
--- a/python/pyspark/sql/tests/streaming/test_streaming_listener.py
+++ b/python/pyspark/sql/tests/streaming/test_streaming_listener.py
@@ -19,6 +19,7 @@ import time
 import uuid
 from datetime import datetime
 
+from pyspark import Row
 from pyspark.sql.streaming import StreamingQueryListener
 from pyspark.sql.streaming.listener import (
     QueryStartedEvent,
@@ -51,21 +52,21 @@ class StreamingListenerTests(ReusedSQLTestCase):
             get_number_of_public_methods(
                 "org.apache.spark.sql.streaming.StreamingQueryListener$QueryStartedEvent"
             ),
-            14,
+            15,
             msg,
         )
         self.assertEquals(
             get_number_of_public_methods(
                 "org.apache.spark.sql.streaming.StreamingQueryListener$QueryProgressEvent"
             ),
-            11,
+            12,
             msg,
         )
         self.assertEquals(
             get_number_of_public_methods(
                 "org.apache.spark.sql.streaming.StreamingQueryListener$QueryTerminatedEvent"
             ),
-            14,
+            15,
             msg,
         )
         self.assertEquals(
@@ -112,7 +113,15 @@ class StreamingListenerTests(ReusedSQLTestCase):
             self.spark.streams.addListener(test_listener)
 
             df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load()
-            q = df.writeStream.format("noop").queryName("test").start()
+
+            # check successful stateful query
+            df_stateful = df.groupBy().count()  # make query stateful
+            q = (
+                df_stateful.writeStream.format("noop")
+                .queryName("test")
+                .outputMode("complete")
+                .start()
+            )
             self.assertTrue(q.isActive)
             time.sleep(10)
             q.stop()
@@ -123,6 +132,17 @@ class StreamingListenerTests(ReusedSQLTestCase):
             self.check_start_event(start_event)
             self.check_progress_event(progress_event)
             self.check_terminated_event(terminated_event)
+
+            # Check query terminated with exception
+            from pyspark.sql.functions import col, udf
+
+            bad_udf = udf(lambda x: 1 / 0)
+            q = df.select(bad_udf(col("value"))).writeStream.format("noop").start()
+            time.sleep(5)
+            q.stop()
+            self.spark.sparkContext._jsc.sc().listenerBus().waitUntilEmpty()
+            self.check_terminated_event(terminated_event, "ZeroDivisionError")
+
         finally:
             self.spark.streams.removeListener(test_listener)
 
@@ -131,7 +151,7 @@ class StreamingListenerTests(ReusedSQLTestCase):
         self.assertTrue(isinstance(event, QueryStartedEvent))
         self.assertTrue(isinstance(event.id, uuid.UUID))
         self.assertTrue(isinstance(event.runId, uuid.UUID))
-        self.assertEquals(event.name, "test")
+        self.assertTrue(event.name is None or event.name == "test")
         try:
             datetime.strptime(event.timestamp, "%Y-%m-%dT%H:%M:%S.%fZ")
         except ValueError:
@@ -142,14 +162,20 @@ class StreamingListenerTests(ReusedSQLTestCase):
         self.assertTrue(isinstance(event, QueryProgressEvent))
         self.check_streaming_query_progress(event.progress)
 
-    def check_terminated_event(self, event):
+    def check_terminated_event(self, event, exception=None, error_class=None):
         """Check QueryTerminatedEvent"""
         self.assertTrue(isinstance(event, QueryTerminatedEvent))
         self.assertTrue(isinstance(event.id, uuid.UUID))
         self.assertTrue(isinstance(event.runId, uuid.UUID))
-        # TODO: Needs a test for exception.
-        self.assertEquals(event.exception, None)
-        self.assertEquals(event.errorClassOnException, None)
+        if exception:
+            self.assertTrue(exception in event.exception)
+        else:
+            self.assertEquals(event.exception, None)
+
+        if error_class:
+            self.assertTrue(error_class in event.errorClassOnException)
+        else:
+            self.assertEquals(event.errorClassOnException, None)
 
     def check_streaming_query_progress(self, progress):
         """Check StreamingQueryProgress"""
@@ -191,13 +217,15 @@ class StreamingListenerTests(ReusedSQLTestCase):
         )
         self.assertTrue(all(map(lambda v: isinstance(v, int), progress.durationMs.values())))
 
-        self.assertEquals(progress.eventTime, {})
+        self.assertTrue(all(map(lambda v: isinstance(v, str), progress.eventTime.values())))
 
         self.assertTrue(isinstance(progress.stateOperators, list))
+        self.assertTrue(len(progress.stateOperators) >= 1)
         for so in progress.stateOperators:
             self.check_state_operator_progress(so)
 
         self.assertTrue(isinstance(progress.sources, list))
+        self.assertTrue(len(progress.sources) >= 1)
         for so in progress.sources:
             self.check_source_progress(so)
 
@@ -299,6 +327,179 @@ class StreamingListenerTests(ReusedSQLTestCase):
         self.spark.streams.removeListener(test_listener)
         self.assertEqual(num_listeners, len(self.spark.streams._jsqm.listListeners()))
 
+    def test_query_started_event_fromJson(self):
+        start_event = """
+            {
+                "id" : "78923ec2-8f4d-4266-876e-1f50cf3c283b",
+                "runId" : "55a95d45-e932-4e08-9caa-0a8ecd9391e8",
+                "name" : null,
+                "timestamp" : "2023-06-09T18:13:29.741Z"
+            }
+        """
+        start_event = QueryStartedEvent.fromJson(json.loads(start_event))
+        self.check_start_event(start_event)
+        self.assertEqual(start_event.id, uuid.UUID("78923ec2-8f4d-4266-876e-1f50cf3c283b"))
+        self.assertEqual(start_event.runId, uuid.UUID("55a95d45-e932-4e08-9caa-0a8ecd9391e8"))
+        self.assertIsNone(start_event.name)
+        self.assertEqual(start_event.timestamp, "2023-06-09T18:13:29.741Z")
+
+    def test_query_terminated_event_fromJson(self):
+        terminated_json = """
+            {
+                "id" : "78923ec2-8f4d-4266-876e-1f50cf3c283b",
+                "runId" : "55a95d45-e932-4e08-9caa-0a8ecd9391e8",
+                "exception" : "org.apache.spark.SparkException: Job aborted due to stage failure",
+                "errorClassOnException" : null}
+        """
+        terminated_event = QueryTerminatedEvent.fromJson(json.loads(terminated_json))
+        self.check_terminated_event(terminated_event, "SparkException")
+        self.assertEqual(terminated_event.id, uuid.UUID("78923ec2-8f4d-4266-876e-1f50cf3c283b"))
+        self.assertEqual(terminated_event.runId, uuid.UUID("55a95d45-e932-4e08-9caa-0a8ecd9391e8"))
+        self.assertIn("SparkException", terminated_event.exception)
+        self.assertIsNone(terminated_event.errorClassOnException)
+
+    def test_streaming_query_progress_fromJson(self):
+        progress_json = """
+            {
+              "id" : "00000000-0000-0001-0000-000000000001",
+              "runId" : "00000000-0000-0001-0000-000000000002",
+              "name" : "test",
+              "timestamp" : "2016-12-05T20:54:20.827Z",
+              "batchId" : 2,
+              "numInputRows" : 678,
+              "inputRowsPerSecond" : 10.0,
+              "processedRowsPerSecond" : 5.4,
+              "batchDuration": 5,
+              "durationMs" : {
+                "getBatch" : 0
+              },
+              "eventTime" : {
+                "min" : "2016-12-05T20:54:20.827Z",
+                "avg" : "2016-12-05T20:54:20.827Z",
+                "watermark" : "2016-12-05T20:54:20.827Z",
+                "max" : "2016-12-05T20:54:20.827Z"
+              },
+              "stateOperators" : [ {
+                "operatorName" : "op1",
+                "numRowsTotal" : 0,
+                "numRowsUpdated" : 1,
+                "allUpdatesTimeMs" : 1,
+                "numRowsRemoved" : 2,
+                "allRemovalsTimeMs" : 34,
+                "commitTimeMs" : 23,
+                "memoryUsedBytes" : 3,
+                "numRowsDroppedByWatermark" : 0,
+                "numShufflePartitions" : 2,
+                "numStateStoreInstances" : 2,
+                "customMetrics" : {
+                  "loadedMapCacheHitCount" : 1,
+                  "loadedMapCacheMissCount" : 0,
+                  "stateOnCurrentVersionSizeBytes" : 2
+                }
+              } ],
+              "sources" : [ {
+                "description" : "source",
+                "startOffset" : 123,
+                "endOffset" : 456,
+                "latestOffset" : 789,
+                "numInputRows" : 678,
+                "inputRowsPerSecond" : 10.0,
+                "processedRowsPerSecond" : 5.4,
+                "metrics": {}
+              } ],
+              "sink" : {
+                "description" : "sink",
+                "numOutputRows" : -1,
+                "metrics": {}
+              },
+              "observedMetrics" : {
+                "event1" : {
+                  "c1" : 1,
+                  "c2" : 3.0
+                },
+                "event2" : {
+                  "rc" : 1,
+                  "min_q" : "hello",
+                  "max_q" : "world"
+                }
+              }
+            }
+        """
+        progress = StreamingQueryProgress.fromJson(json.loads(progress_json))
+
+        self.check_streaming_query_progress(progress)
+
+        # checks for progress
+        self.assertEqual(progress.id, uuid.UUID("00000000-0000-0001-0000-000000000001"))
+        self.assertEqual(progress.runId, uuid.UUID("00000000-0000-0001-0000-000000000002"))
+        self.assertEqual(progress.name, "test")
+        self.assertEqual(progress.timestamp, "2016-12-05T20:54:20.827Z")
+        self.assertEqual(progress.batchId, 2)
+        self.assertEqual(progress.numInputRows, 678)
+        self.assertEqual(progress.inputRowsPerSecond, 10.0)
+        self.assertEqual(progress.batchDuration, 5)
+        self.assertEqual(progress.durationMs, {"getBatch": 0})
+        self.assertEqual(
+            progress.eventTime,
+            {
+                "min": "2016-12-05T20:54:20.827Z",
+                "avg": "2016-12-05T20:54:20.827Z",
+                "watermark": "2016-12-05T20:54:20.827Z",
+                "max": "2016-12-05T20:54:20.827Z",
+            },
+        )
+        self.assertEqual(
+            progress.observedMetrics,
+            {
+                "event1": Row("c1", "c2")(1, 3.0),
+                "event2": Row("rc", "min_q", "max_q")(1, "hello", "world"),
+            },
+        )
+
+        # Check stateOperators list
+        self.assertEqual(len(progress.stateOperators), 1)
+        state_operator = progress.stateOperators[0]
+        self.assertTrue(isinstance(state_operator, StateOperatorProgress))
+        self.assertEqual(state_operator.operatorName, "op1")
+        self.assertEqual(state_operator.numRowsTotal, 0)
+        self.assertEqual(state_operator.numRowsUpdated, 1)
+        self.assertEqual(state_operator.allUpdatesTimeMs, 1)
+        self.assertEqual(state_operator.numRowsRemoved, 2)
+        self.assertEqual(state_operator.allRemovalsTimeMs, 34)
+        self.assertEqual(state_operator.commitTimeMs, 23)
+        self.assertEqual(state_operator.memoryUsedBytes, 3)
+        self.assertEqual(state_operator.numRowsDroppedByWatermark, 0)
+        self.assertEqual(state_operator.numShufflePartitions, 2)
+        self.assertEqual(state_operator.numStateStoreInstances, 2)
+        self.assertEqual(
+            state_operator.customMetrics,
+            {
+                "loadedMapCacheHitCount": 1,
+                "loadedMapCacheMissCount": 0,
+                "stateOnCurrentVersionSizeBytes": 2,
+            },
+        )
+
+        # Check sources list
+        self.assertEqual(len(progress.sources), 1)
+        source = progress.sources[0]
+        self.assertTrue(isinstance(source, SourceProgress))
+        self.assertEqual(source.description, "source")
+        self.assertEqual(source.startOffset, "123")
+        self.assertEqual(source.endOffset, "456")
+        self.assertEqual(source.latestOffset, "789")
+        self.assertEqual(source.numInputRows, 678)
+        self.assertEqual(source.inputRowsPerSecond, 10.0)
+        self.assertEqual(source.processedRowsPerSecond, 5.4)
+        self.assertEqual(source.metrics, {})
+
+        # Check sink
+        sink = progress.sink
+        self.assertTrue(isinstance(sink, SinkProgress))
+        self.assertEqual(sink.description, "sink")
+        self.assertEqual(sink.numOutputRows, -1)
+        self.assertEqual(sink.metrics, {})
+
 
 if __name__ == "__main__":
     import unittest
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala
index 61a0ef1b98e..5c0027895cd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala
@@ -19,6 +19,11 @@ package org.apache.spark.sql.streaming
 
 import java.util.UUID
 
+import org.json4s.{JObject, JString}
+import org.json4s.JsonAST.JValue
+import org.json4s.JsonDSL.{jobject2assoc, pair2Assoc}
+import org.json4s.jackson.JsonMethods.{compact, render}
+
 import org.apache.spark.annotation.Evolving
 import org.apache.spark.scheduler.SparkListenerEvent
 
@@ -123,7 +128,17 @@ object StreamingQueryListener {
       val id: UUID,
       val runId: UUID,
       val name: String,
-      val timestamp: String) extends Event
+      val timestamp: String) extends Event {
+
+    def json: String = compact(render(jsonValue))
+
+    private def jsonValue: JValue = {
+      ("id" -> JString(id.toString)) ~
+      ("runId" -> JString(runId.toString)) ~
+      ("name" -> JString(name)) ~
+      ("timestamp" -> JString(timestamp))
+    }
+  }
 
   /**
    * Event representing any progress updates in a query.
@@ -131,7 +146,12 @@ object StreamingQueryListener {
    * @since 2.1.0
    */
   @Evolving
-  class QueryProgressEvent private[sql](val progress: StreamingQueryProgress) extends Event
+  class QueryProgressEvent private[sql](val progress: StreamingQueryProgress) extends Event {
+
+    def json: String = compact(render(jsonValue))
+
+    private def jsonValue: JValue = JObject("progress" -> progress.jsonValue)
+  }
 
   /**
    * Event representing that query is idle and waiting for new data to process.
@@ -145,7 +165,16 @@ object StreamingQueryListener {
   class QueryIdleEvent private[sql](
       val id: UUID,
       val runId: UUID,
-      val timestamp: String) extends Event
+      val timestamp: String) extends Event {
+
+    def json: String = compact(render(jsonValue))
+
+    private def jsonValue: JValue = {
+      ("id" -> JString(id.toString)) ~
+      ("runId" -> JString(runId.toString)) ~
+      ("timestamp" -> JString(timestamp))
+    }
+  }
 
   /**
    * Event representing that termination of a query.
@@ -171,5 +200,14 @@ object StreamingQueryListener {
     def this(id: UUID, runId: UUID, exception: Option[String]) = {
       this(id, runId, exception, None)
     }
+
+    def json: String = compact(render(jsonValue))
+
+    private def jsonValue: JValue = {
+      ("id" -> JString(id.toString)) ~
+      ("runId" -> JString(runId.toString)) ~
+      ("exception" -> JString(exception.orNull)) ~
+      ("errorClassOnException" -> JString(errorClassOnException.orNull))
+    }
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org