You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by "yaooqinn (via GitHub)" <gi...@apache.org> on 2023/09/14 08:10:18 UTC

[GitHub] [spark] yaooqinn commented on a diff in pull request #42779: [SPARK-45056][PYTHON][SS][CONNECT] Termination tests for streamingQueryListener and foreachBatch

yaooqinn commented on code in PR #42779:
URL: https://github.com/apache/spark/pull/42779#discussion_r1325551540


##########
connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHodlerSuite.scala:
##########
@@ -79,4 +92,196 @@ class SparkConnectSessionHolderSuite extends SharedSparkSession {
       sessionHolder.getDataFrameOrThrow(key1)
     }
   }
+
+  private def streamingForeachBatchFunction(pysparkPythonPath: String): Array[Byte] = {
+    var binaryFunc: Array[Byte] = null
+    withTempPath { path =>
+      Process(
+        Seq(
+          IntegratedUDFTestUtils.pythonExec,
+          "-c",
+          "from pyspark.serializers import CloudPickleSerializer; " +
+            s"f = open('$path', 'wb');" +
+            "f.write(CloudPickleSerializer().dumps((" +
+            "lambda df, batchId: batchId)))"),
+        None,
+        "PYTHONPATH" -> pysparkPythonPath).!!
+      binaryFunc = Files.readAllBytes(path.toPath)
+    }
+    assert(binaryFunc != null)
+    binaryFunc
+  }
+
+  private def streamingQueryListenerFunction(pysparkPythonPath: String): Array[Byte] = {
+    var binaryFunc: Array[Byte] = null
+    val pythonScript =
+      """
+        |from pyspark.sql.streaming.listener import StreamingQueryListener
+        |
+        |class MyListener(StreamingQueryListener):
+        |    def onQueryStarted(e):
+        |        pass
+        |
+        |    def onQueryIdle(e):
+        |        pass
+        |
+        |    def onQueryProgress(e):
+        |        pass
+        |
+        |    def onQueryTerminated(e):
+        |        pass
+        |
+        |listener = MyListener()
+      """.stripMargin
+    withTempPath { codePath =>
+      Files.write(codePath.toPath, pythonScript.getBytes(StandardCharsets.UTF_8))
+      withTempPath { path =>
+        Process(
+          Seq(
+            IntegratedUDFTestUtils.pythonExec,
+            "-c",
+            "from pyspark.serializers import CloudPickleSerializer; " +
+              s"f = open('$path', 'wb');" +
+              s"exec(open('$codePath', 'r').read());" +
+              "f.write(CloudPickleSerializer().dumps(listener))"),
+          None,
+          "PYTHONPATH" -> pysparkPythonPath).!!
+        binaryFunc = Files.readAllBytes(path.toPath)
+      }
+    }
+    assert(binaryFunc != null)
+    binaryFunc
+  }
+
+  private def dummyPythonFunction(sessionHolder: SessionHolder)(
+      fcn: String => Array[Byte]): SimplePythonFunction = {
+    val sparkPythonPath =
+      s"${IntegratedUDFTestUtils.pysparkPythonPath}:${IntegratedUDFTestUtils.pythonPath}"
+
+    SimplePythonFunction(
+      command = fcn(sparkPythonPath),
+      envVars = mutable.Map("PYTHONPATH" -> sparkPythonPath).asJava,
+      pythonIncludes = sessionHolder.artifactManager.getSparkConnectPythonIncludes.asJava,
+      pythonExec = IntegratedUDFTestUtils.pythonExec,
+      pythonVer = IntegratedUDFTestUtils.pythonVer,
+      broadcastVars = Lists.newArrayList(),
+      accumulator = null)
+  }
+
+  test("python foreachBatch process: process terminates after query is stopped") {
+    // scalastyle:off assume
+    assume(IntegratedUDFTestUtils.shouldTestPythonUDFs)
+    // scalastyle:on assume
+
+    val sessionHolder = SessionHolder.forTesting(spark)
+    try {
+      SparkConnectService.start(spark.sparkContext)
+
+      val pythonFn = dummyPythonFunction(sessionHolder)(streamingForeachBatchFunction)
+      val (fn1, cleaner1) =
+        StreamingForeachBatchHelper.pythonForeachBatchWrapper(pythonFn, sessionHolder)
+      val (fn2, cleaner2) =
+        StreamingForeachBatchHelper.pythonForeachBatchWrapper(pythonFn, sessionHolder)
+
+      val query1 = spark.readStream
+        .format("rate")
+        .load()
+        .writeStream
+        .format("memory")
+        .queryName("foreachBatch_termination_test_q1")
+        .foreachBatch(fn1)
+        .start()
+
+      val query2 = spark.readStream
+        .format("rate")
+        .load()
+        .writeStream
+        .format("memory")
+        .queryName("foreachBatch_termination_test_q2")
+        .foreachBatch(fn2)
+        .start()
+
+      sessionHolder.streamingForeachBatchRunnerCleanerCache
+        .registerCleanerForQuery(query1, cleaner1)
+      sessionHolder.streamingForeachBatchRunnerCleanerCache
+        .registerCleanerForQuery(query2, cleaner2)
+
+      val (runner1, runner2) = (cleaner1.runner, cleaner2.runner)
+
+      // assert both python processes are running
+      assert(!runner1.isWorkerStopped().get)
+      assert(!runner2.isWorkerStopped().get)
+      // stop query1
+      query1.stop()
+      // assert query1's python process is not running
+      eventually(timeout(30.seconds)) {
+        assert(runner1.isWorkerStopped().get)
+        assert(!runner2.isWorkerStopped().get)
+      }
+
+      // stop query2
+      query2.stop()
+      eventually(timeout(30.seconds)) {
+        // assert query2's python process is not running
+        assert(runner2.isWorkerStopped().get)
+      }
+
+      assert(spark.streams.active.isEmpty) // no running query
+      assert(spark.streams.listListeners().length == 1) // only process termination listener
+    } finally {
+      SparkConnectService.stop()
+      // remove process termination listener
+      spark.streams.removeListener(spark.streams.listListeners()(0))

Review Comment:
   Hi @WweiL @HyukjinKwon,  we encountered a test failure with this line 
   
   ```scala
   python foreachBatch process: process terminates after query is stopped *** FAILED *** (1 second, 431 milliseconds)
   [info]   java.lang.ArrayIndexOutOfBoundsException: 0
   [info]   at org.apache.spark.sql.connect.service.SparkConnectSessionHolderSuite.$anonfun$new$7(SparkConnectSessionHodlerSuite.scala:234)
   ```
   
   How about changing it to
   ```scala
   spark.streams.listListeners().foreach(spark.streams.removeListener)
   ```
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org