You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2022/11/22 00:44:13 UTC

[spark] branch master updated: [SPARK-41005][COLLECT][FOLLOWUP] Remove JSON code path and use `RDD.collect` in Arrow code path

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 908adca2b22 [SPARK-41005][COLLECT][FOLLOWUP] Remove JSON code path and use `RDD.collect` in Arrow code path
908adca2b22 is described below

commit 908adca2b229b05b2ae0dd31cbaaa1fdcde16290
Author: Ruifeng Zheng <ru...@apache.org>
AuthorDate: Tue Nov 22 09:43:44 2022 +0900

    [SPARK-41005][COLLECT][FOLLOWUP] Remove JSON code path and use `RDD.collect` in Arrow code path
    
    ### What changes were proposed in this pull request?
    1, Remove JSON code path;
    2, use RDD.collect in Arrow code path, since existing tests were already broken in Arrow code path;
    3, reenable `test_fill_na`
    
    ### Why are the changes needed?
    existing Arrow code path is still problematic and it fails and fallback to JSON code path, which change the output datatypes of `test_fill_na`
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    reenabled test and added UT
    
    Closes #38706 from zhengruifeng/collect_disable_json.
    
    Authored-by: Ruifeng Zheng <ru...@apache.org>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 .../src/main/protobuf/spark/connect/base.proto     |  14 +-
 .../service/SparkConnectStreamHandler.scala        | 156 ++-------------------
 python/pyspark/sql/connect/client.py               |   5 -
 python/pyspark/sql/connect/proto/base_pb2.py       |  41 ++----
 python/pyspark/sql/connect/proto/base_pb2.pyi      |  51 +------
 .../sql/tests/connect/test_connect_basic.py        |  55 +++++++-
 6 files changed, 82 insertions(+), 240 deletions(-)

diff --git a/connector/connect/src/main/protobuf/spark/connect/base.proto b/connector/connect/src/main/protobuf/spark/connect/base.proto
index 66e27187153..277da6b2431 100644
--- a/connector/connect/src/main/protobuf/spark/connect/base.proto
+++ b/connector/connect/src/main/protobuf/spark/connect/base.proto
@@ -139,11 +139,7 @@ message ExecutePlanRequest {
 message ExecutePlanResponse {
   string client_id = 1;
 
-  // Result type
-  oneof result_type {
-    ArrowBatch arrow_batch = 2;
-    JSONBatch json_batch = 3;
-  }
+  ArrowBatch arrow_batch = 2;
 
   // Metrics for the query execution. Typically, this field is only present in the last
   // batch of results and then represent the overall state of the query execution.
@@ -155,14 +151,6 @@ message ExecutePlanResponse {
     bytes data = 2;
   }
 
-  // Message type when the result is returned as JSON. This is essentially a bulk wrapper
-  // for the JSON result of a Spark DataFrame. All rows are returned in the JSON record format
-  // of `{col -> row}`.
-  message JSONBatch {
-    int64 row_count = 1;
-    bytes data = 2;
-  }
-
   message Metrics {
 
     repeated MetricObject metrics = 1;
diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
index 50ff08f997c..092bdd00dc1 100644
--- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
+++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
@@ -18,12 +18,10 @@
 package org.apache.spark.sql.connect.service
 
 import scala.collection.JavaConverters._
-import scala.util.control.NonFatal
 
 import com.google.protobuf.ByteString
 import io.grpc.stub.StreamObserver
 
-import org.apache.spark.SparkException
 import org.apache.spark.connect.proto
 import org.apache.spark.connect.proto.{ExecutePlanRequest, ExecutePlanResponse}
 import org.apache.spark.internal.Logging
@@ -34,7 +32,6 @@ import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
 import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, QueryStageExec}
 import org.apache.spark.sql.execution.arrow.ArrowConverters
 import org.apache.spark.sql.types.StructType
-import org.apache.spark.util.ThreadUtils
 
 class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResponse])
     extends Logging {
@@ -57,75 +54,7 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp
     // Extract the plan from the request and convert it to a logical plan
     val planner = new SparkConnectPlanner(session)
     val dataframe = Dataset.ofRows(session, planner.transformRelation(request.getPlan.getRoot))
-    try {
-      processAsArrowBatches(request.getClientId, dataframe)
-    } catch {
-      case e: Exception =>
-        logWarning(e.getMessage)
-        processAsJsonBatches(request.getClientId, dataframe)
-    }
-  }
-
-  def processAsJsonBatches(clientId: String, dataframe: DataFrame): Unit = {
-    // Only process up to 10MB of data.
-    val sb = new StringBuilder
-    var rowCount = 0
-    dataframe.toJSON
-      .collect()
-      .foreach(row => {
-
-        // There are a few cases to cover here.
-        // 1. The aggregated buffer size is larger than the MAX_BATCH_SIZE
-        //     -> send the current batch and reset.
-        // 2. The aggregated buffer size is smaller than the MAX_BATCH_SIZE
-        //     -> append the row to the buffer.
-        // 3. The row in question is larger than the MAX_BATCH_SIZE
-        //     -> fail the query.
-
-        // Case 3. - Fail
-        if (row.size > MAX_BATCH_SIZE) {
-          throw SparkException.internalError(
-            s"Serialized row is larger than MAX_BATCH_SIZE: ${row.size} > ${MAX_BATCH_SIZE}")
-        }
-
-        // Case 1 - FLush and send.
-        if (sb.size + row.size > MAX_BATCH_SIZE) {
-          val response = proto.ExecutePlanResponse.newBuilder().setClientId(clientId)
-          val batch = proto.ExecutePlanResponse.JSONBatch
-            .newBuilder()
-            .setData(ByteString.copyFromUtf8(sb.toString()))
-            .setRowCount(rowCount)
-            .build()
-          response.setJsonBatch(batch)
-          responseObserver.onNext(response.build())
-          sb.clear()
-          sb.append(row)
-          rowCount = 1
-        } else {
-          // Case 2 - Append.
-          // Make sure to put the newline delimiters only between items and not at the end.
-          if (rowCount > 0) {
-            sb.append("\n")
-          }
-          sb.append(row)
-          rowCount += 1
-        }
-      })
-
-    // If the last batch is not empty, send out the data to the client.
-    if (sb.size > 0) {
-      val response = proto.ExecutePlanResponse.newBuilder().setClientId(clientId)
-      val batch = proto.ExecutePlanResponse.JSONBatch
-        .newBuilder()
-        .setData(ByteString.copyFromUtf8(sb.toString()))
-        .setRowCount(rowCount)
-        .build()
-      response.setJsonBatch(batch)
-      responseObserver.onNext(response.build())
-    }
-
-    responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
-    responseObserver.onCompleted()
+    processAsArrowBatches(request.getClientId, dataframe)
   }
 
   def processAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
@@ -142,83 +71,20 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp
       var numSent = 0
 
       if (numPartitions > 0) {
-        type Batch = (Array[Byte], Long)
-
         val batches = rows.mapPartitionsInternal(
           SparkConnectStreamHandler
             .rowToArrowConverter(schema, maxRecordsPerBatch, maxBatchSize, timeZoneId))
 
-        val signal = new Object
-        val partitions = collection.mutable.Map.empty[Int, Array[Batch]]
-        var error: Throwable = null
-
-        val processPartition = (iter: Iterator[Batch]) => iter.toArray
-
-        // This callback is executed by the DAGScheduler thread.
-        // After fetching a partition, it inserts the partition into the Map, and then
-        // wakes up the main thread.
-        val resultHandler = (partitionId: Int, partition: Array[Batch]) => {
-          signal.synchronized {
-            partitions(partitionId) = partition
-            signal.notify()
-          }
-          ()
-        }
-
-        val future = spark.sparkContext.submitJob(
-          rdd = batches,
-          processPartition = processPartition,
-          partitions = Seq.range(0, numPartitions),
-          resultHandler = resultHandler,
-          resultFunc = () => ())
-
-        // Collect errors and propagate them to the main thread.
-        future.onComplete { result =>
-          result.failed.foreach { throwable =>
-            signal.synchronized {
-              error = throwable
-              signal.notify()
-            }
-          }
-        }(ThreadUtils.sameThread)
-
-        // The main thread will wait until 0-th partition is available,
-        // then send it to client and wait for the next partition.
-        // Different from the implementation of [[Dataset#collectAsArrowToPython]], it sends
-        // the arrow batches in main thread to avoid DAGScheduler thread been blocked for
-        // tasks not related to scheduling. This is particularly important if there are
-        // multiple users or clients running code at the same time.
-        var currentPartitionId = 0
-        while (currentPartitionId < numPartitions) {
-          val partition = signal.synchronized {
-            var result = partitions.remove(currentPartitionId)
-            while (result.isEmpty && error == null) {
-              signal.wait()
-              result = partitions.remove(currentPartitionId)
-            }
-            error match {
-              case NonFatal(e) =>
-                responseObserver.onError(error)
-                logError("Error while processing query.", e)
-                return
-              case fatal: Throwable => throw fatal
-              case null => result.get
-            }
-          }
-
-          partition.foreach { case (bytes, count) =>
-            val response = proto.ExecutePlanResponse.newBuilder().setClientId(clientId)
-            val batch = proto.ExecutePlanResponse.ArrowBatch
-              .newBuilder()
-              .setRowCount(count)
-              .setData(ByteString.copyFrom(bytes))
-              .build()
-            response.setArrowBatch(batch)
-            responseObserver.onNext(response.build())
-            numSent += 1
-          }
-
-          currentPartitionId += 1
+        batches.collect().foreach { case (bytes, count) =>
+          val response = proto.ExecutePlanResponse.newBuilder().setClientId(clientId)
+          val batch = proto.ExecutePlanResponse.ArrowBatch
+            .newBuilder()
+            .setRowCount(count)
+            .setData(ByteString.copyFrom(bytes))
+            .build()
+          response.setArrowBatch(batch)
+          responseObserver.onNext(response.build())
+          numSent += 1
         }
       }
 
diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py
index 5bdf01afc99..fdcf34b7a47 100644
--- a/python/pyspark/sql/connect/client.py
+++ b/python/pyspark/sql/connect/client.py
@@ -16,7 +16,6 @@
 #
 
 
-import io
 import logging
 import os
 import typing
@@ -446,13 +445,9 @@ class RemoteSparkSession(object):
         return AnalyzeResult.fromProto(resp)
 
     def _process_batch(self, b: pb2.ExecutePlanResponse) -> Optional[pandas.DataFrame]:
-        import pandas as pd
-
         if b.arrow_batch is not None and len(b.arrow_batch.data) > 0:
             with pa.ipc.open_stream(b.arrow_batch.data) as rd:
                 return rd.read_pandas()
-        elif b.json_batch is not None and len(b.json_batch.data) > 0:
-            return pd.read_json(io.BytesIO(b.json_batch.data), lines=True)
         return None
 
     def _execute_and_fetch(self, req: pb2.ExecutePlanRequest) -> typing.Optional[pandas.DataFrame]:
diff --git a/python/pyspark/sql/connect/proto/base_pb2.py b/python/pyspark/sql/connect/proto/base_pb2.py
index 7d9f98b243e..daa1c25cc8f 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.py
+++ b/python/pyspark/sql/connect/proto/base_pb2.py
@@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"\xb5\x01\n\x07\x45xplain\x12\x45\n\x0c\x65xplain_mode\x18\x01 \x01(\x0e\x32".spark.connect.Explain. [...]
+    b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"\xb5\x01\n\x07\x45xplain\x12\x45\n\x0c\x65xplain_mode\x18\x01 \x01(\x0e\x32".spark.connect.Explain. [...]
 )
 
 
@@ -48,7 +48,6 @@ _ANALYZEPLANRESPONSE = DESCRIPTOR.message_types_by_name["AnalyzePlanResponse"]
 _EXECUTEPLANREQUEST = DESCRIPTOR.message_types_by_name["ExecutePlanRequest"]
 _EXECUTEPLANRESPONSE = DESCRIPTOR.message_types_by_name["ExecutePlanResponse"]
 _EXECUTEPLANRESPONSE_ARROWBATCH = _EXECUTEPLANRESPONSE.nested_types_by_name["ArrowBatch"]
-_EXECUTEPLANRESPONSE_JSONBATCH = _EXECUTEPLANRESPONSE.nested_types_by_name["JSONBatch"]
 _EXECUTEPLANRESPONSE_METRICS = _EXECUTEPLANRESPONSE.nested_types_by_name["Metrics"]
 _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT = _EXECUTEPLANRESPONSE_METRICS.nested_types_by_name[
     "MetricObject"
@@ -139,15 +138,6 @@ ExecutePlanResponse = _reflection.GeneratedProtocolMessageType(
                 # @@protoc_insertion_point(class_scope:spark.connect.ExecutePlanResponse.ArrowBatch)
             },
         ),
-        "JSONBatch": _reflection.GeneratedProtocolMessageType(
-            "JSONBatch",
-            (_message.Message,),
-            {
-                "DESCRIPTOR": _EXECUTEPLANRESPONSE_JSONBATCH,
-                "__module__": "spark.connect.base_pb2"
-                # @@protoc_insertion_point(class_scope:spark.connect.ExecutePlanResponse.JSONBatch)
-            },
-        ),
         "Metrics": _reflection.GeneratedProtocolMessageType(
             "Metrics",
             (_message.Message,),
@@ -191,7 +181,6 @@ ExecutePlanResponse = _reflection.GeneratedProtocolMessageType(
 )
 _sym_db.RegisterMessage(ExecutePlanResponse)
 _sym_db.RegisterMessage(ExecutePlanResponse.ArrowBatch)
-_sym_db.RegisterMessage(ExecutePlanResponse.JSONBatch)
 _sym_db.RegisterMessage(ExecutePlanResponse.Metrics)
 _sym_db.RegisterMessage(ExecutePlanResponse.Metrics.MetricObject)
 _sym_db.RegisterMessage(ExecutePlanResponse.Metrics.MetricObject.ExecutionMetricsEntry)
@@ -219,19 +208,17 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _EXECUTEPLANREQUEST._serialized_start = 986
     _EXECUTEPLANREQUEST._serialized_end = 1193
     _EXECUTEPLANRESPONSE._serialized_start = 1196
-    _EXECUTEPLANRESPONSE._serialized_end = 2137
-    _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_start = 1479
-    _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_end = 1540
-    _EXECUTEPLANRESPONSE_JSONBATCH._serialized_start = 1542
-    _EXECUTEPLANRESPONSE_JSONBATCH._serialized_end = 1602
-    _EXECUTEPLANRESPONSE_METRICS._serialized_start = 1605
-    _EXECUTEPLANRESPONSE_METRICS._serialized_end = 2122
-    _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_start = 1700
-    _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_end = 2032
-    _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 1909
-    _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 2032
-    _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_start = 2034
-    _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_end = 2122
-    _SPARKCONNECTSERVICE._serialized_start = 2140
-    _SPARKCONNECTSERVICE._serialized_end = 2339
+    _EXECUTEPLANRESPONSE._serialized_end = 1979
+    _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_start = 1398
+    _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_end = 1459
+    _EXECUTEPLANRESPONSE_METRICS._serialized_start = 1462
+    _EXECUTEPLANRESPONSE_METRICS._serialized_end = 1979
+    _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_start = 1557
+    _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_end = 1889
+    _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 1766
+    _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 1889
+    _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_start = 1891
+    _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_end = 1979
+    _SPARKCONNECTSERVICE._serialized_start = 1982
+    _SPARKCONNECTSERVICE._serialized_end = 2181
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi b/python/pyspark/sql/connect/proto/base_pb2.pyi
index 18b70de57a3..64bb51d4c0b 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/base_pb2.pyi
@@ -401,28 +401,6 @@ class ExecutePlanResponse(google.protobuf.message.Message):
             self, field_name: typing_extensions.Literal["data", b"data", "row_count", b"row_count"]
         ) -> None: ...
 
-    class JSONBatch(google.protobuf.message.Message):
-        """Message type when the result is returned as JSON. This is essentially a bulk wrapper
-        for the JSON result of a Spark DataFrame. All rows are returned in the JSON record format
-        of `{col -> row}`.
-        """
-
-        DESCRIPTOR: google.protobuf.descriptor.Descriptor
-
-        ROW_COUNT_FIELD_NUMBER: builtins.int
-        DATA_FIELD_NUMBER: builtins.int
-        row_count: builtins.int
-        data: builtins.bytes
-        def __init__(
-            self,
-            *,
-            row_count: builtins.int = ...,
-            data: builtins.bytes = ...,
-        ) -> None: ...
-        def ClearField(
-            self, field_name: typing_extensions.Literal["data", b"data", "row_count", b"row_count"]
-        ) -> None: ...
-
     class Metrics(google.protobuf.message.Message):
         DESCRIPTOR: google.protobuf.descriptor.Descriptor
 
@@ -530,14 +508,11 @@ class ExecutePlanResponse(google.protobuf.message.Message):
 
     CLIENT_ID_FIELD_NUMBER: builtins.int
     ARROW_BATCH_FIELD_NUMBER: builtins.int
-    JSON_BATCH_FIELD_NUMBER: builtins.int
     METRICS_FIELD_NUMBER: builtins.int
     client_id: builtins.str
     @property
     def arrow_batch(self) -> global___ExecutePlanResponse.ArrowBatch: ...
     @property
-    def json_batch(self) -> global___ExecutePlanResponse.JSONBatch: ...
-    @property
     def metrics(self) -> global___ExecutePlanResponse.Metrics:
         """Metrics for the query execution. Typically, this field is only present in the last
         batch of results and then represent the overall state of the query execution.
@@ -547,39 +522,17 @@ class ExecutePlanResponse(google.protobuf.message.Message):
         *,
         client_id: builtins.str = ...,
         arrow_batch: global___ExecutePlanResponse.ArrowBatch | None = ...,
-        json_batch: global___ExecutePlanResponse.JSONBatch | None = ...,
         metrics: global___ExecutePlanResponse.Metrics | None = ...,
     ) -> None: ...
     def HasField(
         self,
-        field_name: typing_extensions.Literal[
-            "arrow_batch",
-            b"arrow_batch",
-            "json_batch",
-            b"json_batch",
-            "metrics",
-            b"metrics",
-            "result_type",
-            b"result_type",
-        ],
+        field_name: typing_extensions.Literal["arrow_batch", b"arrow_batch", "metrics", b"metrics"],
     ) -> builtins.bool: ...
     def ClearField(
         self,
         field_name: typing_extensions.Literal[
-            "arrow_batch",
-            b"arrow_batch",
-            "client_id",
-            b"client_id",
-            "json_batch",
-            b"json_batch",
-            "metrics",
-            b"metrics",
-            "result_type",
-            b"result_type",
+            "arrow_batch", b"arrow_batch", "client_id", b"client_id", "metrics", b"metrics"
         ],
     ) -> None: ...
-    def WhichOneof(
-        self, oneof_group: typing_extensions.Literal["result_type", b"result_type"]
-    ) -> typing_extensions.Literal["arrow_batch", "json_batch"] | None: ...
 
 global___ExecutePlanResponse = ExecutePlanResponse
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py
index d3de94a379f..9e7a5f2f4a5 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -221,7 +221,60 @@ class SparkConnectTests(SparkConnectSQLTestCase):
             with self.assertRaises(_MultiThreadedRendezvous):
                 self.connect.sql("SELECT 1 AS X LIMIT 0").createGlobalTempView("view_1")
 
-    @unittest.skip("test_fill_na is flaky")
+    def test_to_pandas(self):
+        # SPARK-41005: Test to pandas
+        query = """
+            SELECT * FROM VALUES
+            (false, 1, NULL),
+            (false, NULL, float(2.0)),
+            (NULL, 3, float(3.0))
+            AS tab(a, b, c)
+            """
+
+        self.assert_eq(
+            self.connect.sql(query).toPandas(),
+            self.spark.sql(query).toPandas(),
+        )
+
+        query = """
+            SELECT * FROM VALUES
+            (1, 1, NULL),
+            (2, NULL, float(2.0)),
+            (3, 3, float(3.0))
+            AS tab(a, b, c)
+            """
+
+        self.assert_eq(
+            self.connect.sql(query).toPandas(),
+            self.spark.sql(query).toPandas(),
+        )
+
+        query = """
+            SELECT * FROM VALUES
+            (double(1.0), 1, "1"),
+            (NULL, NULL, NULL),
+            (double(2.0), 3, "3")
+            AS tab(a, b, c)
+            """
+
+        self.assert_eq(
+            self.connect.sql(query).toPandas(),
+            self.spark.sql(query).toPandas(),
+        )
+
+        query = """
+            SELECT * FROM VALUES
+            (float(1.0), double(1.0), 1, "1"),
+            (float(2.0), double(2.0), 2, "2"),
+            (float(3.0), double(3.0), 3, "3")
+            AS tab(a, b, c, d)
+            """
+
+        self.assert_eq(
+            self.connect.sql(query).toPandas(),
+            self.spark.sql(query).toPandas(),
+        )
+
     def test_fill_na(self):
         # SPARK-41128: Test fill na
         query = """


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org