You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2023/03/01 08:56:53 UTC

[spark] branch branch-3.4 updated: [SPARK-41725][CONNECT] Eager Execution of DF.sql()

This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new 330b03cddb6 [SPARK-41725][CONNECT] Eager Execution of DF.sql()
330b03cddb6 is described below

commit 330b03cddb6e30e0097c754e12d52e8768bfb52a
Author: Martin Grund <ma...@databricks.com>
AuthorDate: Wed Mar 1 16:55:57 2023 +0800

    [SPARK-41725][CONNECT] Eager Execution of DF.sql()
    
    ### What changes were proposed in this pull request?
    
    This patch allows for eager execution of SQL statements using the Spark Connect Data Frame API.  The implementation of the patch is as follows: When `spark.sql` is called, the client sends a command to the server including the SQL statement. The server will evaluate the query and execute the side-effects if necessary. If the query was a command it will return the results as a `Relaiton.LocalRelation` back to the client otherwise it will return a `Relation.SQL` to the client. The clien [...]
    
    ### Why are the changes needed?
    Compatibility
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    UT
    
    Closes #40160 from grundprinzip/eager_sql_v2.
    
    Authored-by: Martin Grund <ma...@databricks.com>
    Signed-off-by: Ruifeng Zheng <ru...@apache.org>
    (cherry picked from commit 51a87ac549120d9fe1fe4503ca8825785d9e886d)
    Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
 .../scala/org/apache/spark/sql/SparkSession.scala  |  14 ++-
 .../org/apache/spark/sql/ClientE2ETestSuite.scala  |  23 ++++-
 .../apache/spark/sql/PlanGenerationTestSuite.scala |   8 --
 .../connect/client/util/RemoteSparkSession.scala   |   4 +
 .../src/main/protobuf/spark/connect/base.proto     |  17 ++-
 .../src/main/protobuf/spark/connect/commands.proto |  15 +++
 .../explain-results/parameterized_sql.explain      |   2 -
 .../query-tests/explain-results/sql.explain        |   2 -
 .../query-tests/queries/parameterized_sql.json     |  12 ---
 .../queries/parameterized_sql.proto.bin            | Bin 41 -> 0 bytes
 .../test/resources/query-tests/queries/sql.json    |   8 --
 .../resources/query-tests/queries/sql.proto.bin    | Bin 16 -> 0 bytes
 .../sql/connect/planner/SparkConnectPlanner.scala  |  85 ++++++++++++++-
 .../service/SparkConnectStreamHandler.scala        |  66 ++++++------
 .../connect/planner/SparkConnectPlannerSuite.scala |  10 +-
 .../plugin/SparkConnectPluginRegistrySuite.scala   |   2 +-
 python/pyspark/sql/connect/client.py               |  32 ++++--
 python/pyspark/sql/connect/plan.py                 |  19 ++++
 python/pyspark/sql/connect/proto/base_pb2.py       | 115 ++++++++++++---------
 python/pyspark/sql/connect/proto/base_pb2.pyi      |  61 ++++++++++-
 python/pyspark/sql/connect/proto/commands_pb2.py   |  77 +++++++++-----
 python/pyspark/sql/connect/proto/commands_pb2.pyi  |  56 ++++++++++
 python/pyspark/sql/connect/session.py              |   9 +-
 python/pyspark/sql/tests/connect/test_client.py    |  17 +++
 .../sql/execution/arrow/ArrowConverters.scala      |   6 +-
 25 files changed, 495 insertions(+), 165 deletions(-)

diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 84731072ebc..e72dc264727 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -123,8 +123,18 @@ class SparkSession private[sql] (
   @Experimental
   def sql(sqlText: String, args: java.util.Map[String, String]): DataFrame = newDataFrame {
     builder =>
-      builder
-        .setSql(proto.SQL.newBuilder().setQuery(sqlText).putAllArgs(args))
+      // Send the SQL once to the server and then check the output.
+      val cmd = newCommand(b =>
+        b.setSqlCommand(proto.SqlCommand.newBuilder().setSql(sqlText).putAllArgs(args)))
+      val plan = proto.Plan.newBuilder().setCommand(cmd)
+      val responseIter = client.execute(plan.build())
+
+      val response = responseIter.asScala
+        .find(_.hasSqlCommandResult)
+        .getOrElse(throw new RuntimeException("SQLCommandResult must be present"))
+
+      // Update the builder with the values from the result.
+      builder.mergeFrom(response.getSqlCommandResult.getRelation)
   }
 
   /**
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
index d47cc3858ab..274424f0f0d 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
@@ -52,6 +52,23 @@ class ClientE2ETestSuite extends RemoteSparkSession {
     assert(result(1).getString(0) == "World")
   }
 
+  test("eager execution of sql") {
+    withTable("test_martin") {
+      // Fails, because table does not exist.
+      assertThrows[StatusRuntimeException] {
+        spark.sql("select * from test_martin").collect()
+      }
+      // Execute eager, DML
+      spark.sql("create table test_martin (id int)")
+      // Execute read again.
+      val rows = spark.sql("select * from test_martin").collect()
+      assert(rows.length == 0)
+      spark.sql("insert into test_martin values (1), (2)")
+      val rows_new = spark.sql("select * from test_martin").collect()
+      assert(rows_new.length == 2)
+    }
+  }
+
   test("simple dataset") {
     val df = spark.range(10).limit(3)
     val result = df.collect()
@@ -189,10 +206,8 @@ class ClientE2ETestSuite extends RemoteSparkSession {
   //  e.g. spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName)
   test("writeTo with create") {
     withTable("myTableV2") {
-      assertThrows[StatusRuntimeException] {
-        // Failed to create as Hive support is required.
-        spark.range(3).writeTo("myTableV2").create()
-      }
+      // Failed to create as Hive support is required.
+      spark.range(3).writeTo("myTableV2").create()
     }
   }
 
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
index bc7111e9bf8..0b198ab8f70 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
@@ -230,14 +230,6 @@ class PlanGenerationTestSuite
   private def temporals = createLocalRelation(temporalsSchemaString)
 
   /* Spark Session API */
-  test("sql") {
-    session.sql("select 1")
-  }
-
-  test("parameterized sql") {
-    session.sql("select 1", Map("minId" -> "7", "maxId" -> "20"))
-  }
-
   test("range") {
     session.range(1, 10, 1, 2)
   }
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala
index 96b3ab4e9ef..0ec31ee9943 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala
@@ -69,6 +69,10 @@ object SparkConnectServerUtils {
         jar,
         "--conf",
         s"spark.connect.grpc.binding.port=$port",
+        "--conf",
+        "spark.sql.catalog.testcat=org.apache.spark.sql.connect.catalog.InMemoryTableCatalog",
+        "--conf",
+        "spark.sql.catalogImplementation=hive",
         "--class",
         "org.apache.spark.sql.connect.SimpleSparkConnectService",
         jar),
diff --git a/connector/connect/common/src/main/protobuf/spark/connect/base.proto b/connector/connect/common/src/main/protobuf/spark/connect/base.proto
index 3eacd0cc482..066a63d58ba 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/base.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/base.proto
@@ -221,12 +221,27 @@ message ExecutePlanRequest {
 message ExecutePlanResponse {
   string client_id = 1;
 
-  ArrowBatch arrow_batch = 2;
+  // Union type for the different response messages.
+  oneof response_type {
+    ArrowBatch arrow_batch = 2;
+
+    // Special case for executing SQL commands.
+    SqlCommandResult sql_command_result = 5;
+
+    // Support arbitrary result objects.
+    google.protobuf.Any extension = 999;
+  }
 
   // Metrics for the query execution. Typically, this field is only present in the last
   // batch of results and then represent the overall state of the query execution.
   Metrics metrics = 4;
 
+  // A SQL command returns an opaque Relation that can be directly used as input for the next
+  // call.
+  message SqlCommandResult {
+    Relation relation = 1;
+  }
+
   // Batch results of metrics.
   message ArrowBatch {
     int64 row_count = 1;
diff --git a/connector/connect/common/src/main/protobuf/spark/connect/commands.proto b/connector/connect/common/src/main/protobuf/spark/connect/commands.proto
index 1f2f473a050..e553dcb1bc4 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/commands.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/commands.proto
@@ -35,6 +35,7 @@ message Command {
     WriteOperation write_operation = 2;
     CreateDataFrameViewCommand create_dataframe_view = 3;
     WriteOperationV2 write_operation_v2 = 4;
+    SqlCommand sql_command = 5;
 
     // This field is used to mark extensions to the protocol. When plugins generate arbitrary
     // Commands they can add them here. During the planning the correct resolution is done.
@@ -43,6 +44,20 @@ message Command {
   }
 }
 
+// A SQL Command is used to trigger the eager evaluation of SQL commands in Spark.
+//
+// When the SQL provide as part of the message is a command it will be immediately evaluated
+// and the result will be collected and returned as part of a LocalRelation. If the result is
+// not a command, the operation will simply return a SQL Relation. This allows the client to be
+// almost oblivious to the server-side behavior.
+message SqlCommand {
+  // (Required) SQL Query.
+  string sql = 1;
+
+  // (Optional) A map of parameter names to literal values.
+  map<string, string> args = 2;
+}
+
 // A command that can create DataFrame global temp view or local temp view.
 message CreateDataFrameViewCommand {
   // (Required) The relation that this view will be built on.
diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/parameterized_sql.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/parameterized_sql.explain
deleted file mode 100644
index 7f5aafb1943..00000000000
--- a/connector/connect/common/src/test/resources/query-tests/explain-results/parameterized_sql.explain
+++ /dev/null
@@ -1,2 +0,0 @@
-Project [1 AS 1#0]
-+- OneRowRelation
diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/sql.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/sql.explain
deleted file mode 100644
index 7f5aafb1943..00000000000
--- a/connector/connect/common/src/test/resources/query-tests/explain-results/sql.explain
+++ /dev/null
@@ -1,2 +0,0 @@
-Project [1 AS 1#0]
-+- OneRowRelation
diff --git a/connector/connect/common/src/test/resources/query-tests/queries/parameterized_sql.json b/connector/connect/common/src/test/resources/query-tests/queries/parameterized_sql.json
deleted file mode 100644
index 5ceb1d5a087..00000000000
--- a/connector/connect/common/src/test/resources/query-tests/queries/parameterized_sql.json
+++ /dev/null
@@ -1,12 +0,0 @@
-{
-  "common": {
-    "planId": "0"
-  },
-  "sql": {
-    "query": "select 1",
-    "args": {
-      "minId": "7",
-      "maxId": "20"
-    }
-  }
-}
\ No newline at end of file
diff --git a/connector/connect/common/src/test/resources/query-tests/queries/parameterized_sql.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/parameterized_sql.proto.bin
deleted file mode 100644
index 50bc8457f31..00000000000
Binary files a/connector/connect/common/src/test/resources/query-tests/queries/parameterized_sql.proto.bin and /dev/null differ
diff --git a/connector/connect/common/src/test/resources/query-tests/queries/sql.json b/connector/connect/common/src/test/resources/query-tests/queries/sql.json
deleted file mode 100644
index c4bc9b2c082..00000000000
--- a/connector/connect/common/src/test/resources/query-tests/queries/sql.json
+++ /dev/null
@@ -1,8 +0,0 @@
-{
-  "common": {
-    "planId": "0"
-  },
-  "sql": {
-    "query": "select 1"
-  }
-}
\ No newline at end of file
diff --git a/connector/connect/common/src/test/resources/query-tests/queries/sql.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/sql.proto.bin
deleted file mode 100644
index 3d4394f23af..00000000000
Binary files a/connector/connect/common/src/test/resources/query-tests/queries/sql.proto.bin and /dev/null differ
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index d52117b469c..c8b1b3125f9 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -21,11 +21,14 @@ import scala.collection.JavaConverters._
 import scala.collection.mutable
 
 import com.google.common.collect.{Lists, Maps}
-import com.google.protobuf.{Any => ProtoAny}
+import com.google.protobuf.{Any => ProtoAny, ByteString}
+import io.grpc.stub.StreamObserver
 
-import org.apache.spark.TaskContext
+import org.apache.spark.{SparkEnv, TaskContext}
 import org.apache.spark.api.python.{PythonEvalType, SimplePythonFunction}
 import org.apache.spark.connect.proto
+import org.apache.spark.connect.proto.{ExecutePlanResponse, SqlCommand}
+import org.apache.spark.connect.proto.ExecutePlanResponse.SqlCommandResult
 import org.apache.spark.sql.{Column, Dataset, Encoders, SparkSession}
 import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, FunctionIdentifier}
 import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, UnresolvedAlias, UnresolvedAttribute, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar}
@@ -34,11 +37,13 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException, ParserUtils}
 import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin}
 import org.apache.spark.sql.catalyst.plans.logical
-import org.apache.spark.sql.catalyst.plans.logical.{Deduplicate, Except, Intersect, LocalRelation, LogicalPlan, Sample, Sort, SubqueryAlias, Union, Unpivot, UnresolvedHint}
+import org.apache.spark.sql.catalyst.plans.logical.{CommandResult, Deduplicate, Except, Intersect, LocalRelation, LogicalPlan, Sample, Sort, SubqueryAlias, Union, Unpivot, UnresolvedHint}
 import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils}
 import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, InvalidPlanInput, UdfPacket}
+import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_SIZE
 import org.apache.spark.sql.connect.planner.LiteralValueProtoConverter.{toCatalystExpression, toCatalystValue}
 import org.apache.spark.sql.connect.plugin.SparkConnectPluginRegistry
+import org.apache.spark.sql.connect.service.SparkConnectStreamHandler
 import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.execution.QueryExecution
 import org.apache.spark.sql.execution.arrow.ArrowConverters
@@ -1438,7 +1443,10 @@ class SparkConnectPlanner(val session: SparkSession) {
     }
   }
 
-  def process(command: proto.Command): Unit = {
+  def process(
+      command: proto.Command,
+      clientId: String,
+      responseObserver: StreamObserver[ExecutePlanResponse]): Unit = {
     command.getCommandTypeCase match {
       case proto.Command.CommandTypeCase.REGISTER_FUNCTION =>
         handleRegisterUserDefinedFunction(command.getRegisterFunction)
@@ -1450,10 +1458,79 @@ class SparkConnectPlanner(val session: SparkSession) {
         handleWriteOperationV2(command.getWriteOperationV2)
       case proto.Command.CommandTypeCase.EXTENSION =>
         handleCommandPlugin(command.getExtension)
+      case proto.Command.CommandTypeCase.SQL_COMMAND =>
+        handleSqlCommand(command.getSqlCommand, clientId, responseObserver)
       case _ => throw new UnsupportedOperationException(s"$command not supported.")
     }
   }
 
+  def handleSqlCommand(
+      getSqlCommand: SqlCommand,
+      clientId: String,
+      responseObserver: StreamObserver[ExecutePlanResponse]): Unit = {
+    // Eagerly execute commands of the provided SQL string.
+    val df = session.sql(getSqlCommand.getSql, getSqlCommand.getArgsMap)
+    // Check if commands have been executed.
+    val isCommand = df.queryExecution.commandExecuted.isInstanceOf[CommandResult]
+    val rows = df.logicalPlan match {
+      case lr: LocalRelation => lr.data
+      case cr: CommandResult => cr.rows
+      case _ => Seq.empty
+    }
+
+    // Convert the results to Arrow.
+    val schema = df.schema
+    val maxRecordsPerBatch = session.sessionState.conf.arrowMaxRecordsPerBatch
+    val maxBatchSize = (SparkEnv.get.conf.get(CONNECT_GRPC_ARROW_MAX_BATCH_SIZE) * 0.7).toLong
+    val timeZoneId = session.sessionState.conf.sessionLocalTimeZone
+
+    // Convert the data.
+    val bytes = if (rows.isEmpty) {
+      ArrowConverters.createEmptyArrowBatch(schema, timeZoneId)
+    } else {
+      val batches = ArrowConverters.toBatchWithSchemaIterator(
+        rows.iterator,
+        schema,
+        maxRecordsPerBatch,
+        maxBatchSize,
+        timeZoneId)
+      assert(batches.size == 1)
+      batches.next()
+    }
+
+    // To avoid explicit handling of the result on the client, we build the expected input
+    // of the relation on the server. The client has to simply forward the result.
+    val result = SqlCommandResult.newBuilder()
+    if (isCommand) {
+      result.setRelation(
+        proto.Relation
+          .newBuilder()
+          .setLocalRelation(
+            proto.LocalRelation
+              .newBuilder()
+              .setData(ByteString.copyFrom(bytes))))
+    } else {
+      result.setRelation(
+        proto.Relation
+          .newBuilder()
+          .setSql(
+            proto.SQL
+              .newBuilder()
+              .setQuery(getSqlCommand.getSql)
+              .putAllArgs(getSqlCommand.getArgsMap)))
+    }
+    // Exactly one SQL Command Result Batch
+    responseObserver.onNext(
+      ExecutePlanResponse
+        .newBuilder()
+        .setClientId(clientId)
+        .setSqlCommandResult(result)
+        .build())
+
+    // Send Metrics
+    SparkConnectStreamHandler.sendMetricsToResponse(clientId, df)
+  }
+
   private def handleRegisterUserDefinedFunction(
       fun: proto.CommonInlineUserDefinedFunction): Unit = {
     fun.getFunctionCase match {
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
index f46aca9c8cf..41ca564e6d3 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
@@ -30,6 +30,7 @@ import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_SIZE
 import org.apache.spark.sql.connect.planner.SparkConnectPlanner
+import org.apache.spark.sql.connect.service.SparkConnectStreamHandler.processAsArrowBatches
 import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
 import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, QueryStageExec}
 import org.apache.spark.sql.execution.arrow.ArrowConverters
@@ -58,10 +59,41 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp
     // Extract the plan from the request and convert it to a logical plan
     val planner = new SparkConnectPlanner(session)
     val dataframe = Dataset.ofRows(session, planner.transformRelation(request.getPlan.getRoot))
-    processAsArrowBatches(request.getClientId, dataframe)
+    processAsArrowBatches(request.getClientId, dataframe, responseObserver)
+    responseObserver.onNext(
+      SparkConnectStreamHandler.sendMetricsToResponse(request.getClientId, dataframe))
+    responseObserver.onCompleted()
+  }
+
+  private def handleCommand(session: SparkSession, request: ExecutePlanRequest): Unit = {
+    val command = request.getPlan.getCommand
+    val planner = new SparkConnectPlanner(session)
+    planner.process(command, request.getClientId, responseObserver)
+    responseObserver.onCompleted()
+  }
+}
+
+object SparkConnectStreamHandler {
+  type Batch = (Array[Byte], Long)
+
+  def rowToArrowConverter(
+      schema: StructType,
+      maxRecordsPerBatch: Int,
+      maxBatchSize: Long,
+      timeZoneId: String): Iterator[InternalRow] => Iterator[Batch] = { rows =>
+    val batches = ArrowConverters.toBatchWithSchemaIterator(
+      rows,
+      schema,
+      maxRecordsPerBatch,
+      maxBatchSize,
+      timeZoneId)
+    batches.map(b => b -> batches.rowCountInLastBatch)
   }
 
-  private def processAsArrowBatches(clientId: String, dataframe: DataFrame): Unit = {
+  def processAsArrowBatches(
+      clientId: String,
+      dataframe: DataFrame,
+      responseObserver: StreamObserver[ExecutePlanResponse]): Unit = {
     val spark = dataframe.sparkSession
     val schema = dataframe.schema
     val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
@@ -163,13 +195,10 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp
         response.setArrowBatch(batch)
         responseObserver.onNext(response.build())
       }
-
-      responseObserver.onNext(sendMetricsToResponse(clientId, dataframe))
-      responseObserver.onCompleted()
     }
   }
 
-  private def sendMetricsToResponse(clientId: String, rows: DataFrame): ExecutePlanResponse = {
+  def sendMetricsToResponse(clientId: String, rows: DataFrame): ExecutePlanResponse = {
     // Send a last batch with the metrics
     ExecutePlanResponse
       .newBuilder()
@@ -177,31 +206,6 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp
       .setMetrics(MetricGenerator.buildMetrics(rows.queryExecution.executedPlan))
       .build()
   }
-
-  private def handleCommand(session: SparkSession, request: ExecutePlanRequest): Unit = {
-    val command = request.getPlan.getCommand
-    val planner = new SparkConnectPlanner(session)
-    planner.process(command)
-    responseObserver.onCompleted()
-  }
-}
-
-object SparkConnectStreamHandler {
-  type Batch = (Array[Byte], Long)
-
-  private def rowToArrowConverter(
-      schema: StructType,
-      maxRecordsPerBatch: Int,
-      maxBatchSize: Long,
-      timeZoneId: String): Iterator[InternalRow] => Iterator[Batch] = { rows =>
-    val batches = ArrowConverters.toBatchWithSchemaIterator(
-      rows,
-      schema,
-      maxRecordsPerBatch,
-      maxBatchSize,
-      timeZoneId)
-    batches.map(b => b -> batches.rowCountInLastBatch)
-  }
 }
 
 object MetricGenerator extends AdaptiveSparkPlanHelper {
diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
index 83056c27729..b79d91d2d10 100644
--- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
+++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
@@ -20,9 +20,11 @@ package org.apache.spark.sql.connect.planner
 import scala.collection.JavaConverters._
 
 import com.google.protobuf.ByteString
+import io.grpc.stub.StreamObserver
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.connect.proto
+import org.apache.spark.connect.proto.ExecutePlanResponse
 import org.apache.spark.connect.proto.Expression.{Alias, ExpressionString, UnresolvedStar}
 import org.apache.spark.sql.{AnalysisException, Dataset, Row}
 import org.apache.spark.sql.catalyst.InternalRow
@@ -41,12 +43,18 @@ import org.apache.spark.unsafe.types.UTF8String
  */
 trait SparkConnectPlanTest extends SharedSparkSession {
 
+  class MockObserver extends StreamObserver[proto.ExecutePlanResponse] {
+    override def onNext(value: ExecutePlanResponse): Unit = {}
+    override def onError(t: Throwable): Unit = {}
+    override def onCompleted(): Unit = {}
+  }
+
   def transform(rel: proto.Relation): logical.LogicalPlan = {
     new SparkConnectPlanner(spark).transformRelation(rel)
   }
 
   def transform(cmd: proto.Command): Unit = {
-    new SparkConnectPlanner(spark).process(cmd)
+    new SparkConnectPlanner(spark).process(cmd, "clientId", new MockObserver())
   }
 
   def readRel: proto.Relation =
diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala
index 7abe4a4d085..39fc90fd002 100644
--- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala
+++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/plugin/SparkConnectPluginRegistrySuite.scala
@@ -195,7 +195,7 @@ class SparkConnectPluginRegistrySuite extends SharedSparkSession with SparkConne
               .build()))
         .build()
 
-      new SparkConnectPlanner(spark).process(plan)
+      new SparkConnectPlanner(spark).process(plan, "clientId", new MockObserver())
       assert(spark.sparkContext.getLocalProperty("testingProperty").equals("Martin"))
     }
   }
diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py
index 8046da409d7..3f70ca6ad15 100644
--- a/python/pyspark/sql/connect/client.py
+++ b/python/pyspark/sql/connect/client.py
@@ -568,7 +568,8 @@ class SparkConnectClient(object):
         logger.info(f"Executing plan {self._proto_to_string(plan)}")
         req = self._execute_plan_request_with_metadata()
         req.plan.CopyFrom(plan)
-        table, _ = self._execute_and_fetch(req)
+        table, _, _2 = self._execute_and_fetch(req)
+        assert table is not None
         return table
 
     def to_pandas(self, plan: pb2.Plan) -> "pd.DataFrame":
@@ -578,7 +579,8 @@ class SparkConnectClient(object):
         logger.info(f"Executing plan {self._proto_to_string(plan)}")
         req = self._execute_plan_request_with_metadata()
         req.plan.CopyFrom(plan)
-        table, metrics = self._execute_and_fetch(req)
+        table, metrics, _ = self._execute_and_fetch(req)
+        assert table is not None
         column_names = table.column_names
         table = table.rename_columns([f"col_{i}" for i in range(len(column_names))])
         pdf = table.to_pandas()
@@ -641,7 +643,9 @@ class SparkConnectClient(object):
         assert result is not None
         return result
 
-    def execute_command(self, command: pb2.Command) -> None:
+    def execute_command(
+        self, command: pb2.Command
+    ) -> Tuple[Optional[pd.DataFrame], Dict[str, Any]]:
         """
         Execute given command.
         """
@@ -650,8 +654,11 @@ class SparkConnectClient(object):
         if self._user_id:
             req.user_context.user_id = self._user_id
         req.plan.command.CopyFrom(command)
-        self._execute(req)
-        return
+        data, _, properties = self._execute_and_fetch(req)
+        if data is not None:
+            return (data.to_pandas(), properties)
+        else:
+            return (None, properties)
 
     def close(self) -> None:
         """
@@ -774,12 +781,12 @@ class SparkConnectClient(object):
 
     def _execute_and_fetch(
         self, req: pb2.ExecutePlanRequest
-    ) -> Tuple["pa.Table", List[PlanMetrics]]:
+    ) -> Tuple[Optional["pa.Table"], List[PlanMetrics], Dict[str, Any]]:
         logger.info("ExecuteAndFetch")
 
         m: Optional[pb2.ExecutePlanResponse.Metrics] = None
         batches: List[pa.RecordBatch] = []
-
+        properties = {}
         try:
             for attempt in Retrying(
                 can_retry=SparkConnectClient.retry_exception, **self._retry_policy
@@ -795,6 +802,8 @@ class SparkConnectClient(object):
                         if b.metrics is not None:
                             logger.debug("Received metric batch.")
                             m = b.metrics
+                        if b.HasField("sql_command_result"):
+                            properties["sql_command_result"] = b.sql_command_result.relation
                         if b.HasField("arrow_batch"):
                             logger.debug(
                                 f"Received arrow batch rows={b.arrow_batch.row_count} "
@@ -807,10 +816,13 @@ class SparkConnectClient(object):
                                     batches.append(batch)
         except grpc.RpcError as rpc_error:
             self._handle_error(rpc_error)
-        assert len(batches) > 0
-        table = pa.Table.from_batches(batches=batches)
         metrics: List[PlanMetrics] = self._build_metrics(m) if m is not None else []
-        return table, metrics
+
+        if len(batches) > 0:
+            table = pa.Table.from_batches(batches=batches)
+            return table, metrics, properties
+        else:
+            return None, metrics, properties
 
     def _config_request_with_metadata(self) -> pb2.ConfigRequest:
         req = pb2.ConfigRequest()
diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py
index f82cf9167cb..7e767885793 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -955,6 +955,14 @@ class SQL(LogicalPlan):
 
         return plan
 
+    def command(self, session: "SparkConnectClient") -> proto.Command:
+        cmd = proto.Command()
+        cmd.sql_command.sql = self._query
+        if self._args is not None and len(self._args) > 0:
+            for k, v in self._args.items():
+                cmd.sql_command.args[k] = v
+        return cmd
+
 
 class Range(LogicalPlan):
     def __init__(
@@ -1880,3 +1888,14 @@ class FrameMap(LogicalPlan):
         plan.frame_map.input.CopyFrom(self._child.plan(session))
         plan.frame_map.func.CopyFrom(self._func.to_plan_udf(session))
         return plan
+
+
+class CachedRelation(LogicalPlan):
+    def __init__(self, plan: proto.Relation) -> None:
+        super(CachedRelation, self).__init__(None)
+        self._plan = plan
+        # Update the plan ID based on the incremented counter.
+        self._plan.common.plan_id = self._plan_id
+
+    def plan(self, session: "SparkConnectClient") -> proto.Relation:
+        return self._plan
diff --git a/python/pyspark/sql/connect/proto/base_pb2.py b/python/pyspark/sql/connect/proto/base_pb2.py
index c43619facb6..628f7ebdd46 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.py
+++ b/python/pyspark/sql/connect/proto/base_pb2.py
@@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x0 [...]
+    b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x0 [...]
 )
 
 
@@ -62,6 +62,9 @@ _ANALYZEPLANRESPONSE_SPARKVERSION = _ANALYZEPLANRESPONSE.nested_types_by_name["S
 _ANALYZEPLANRESPONSE_DDLPARSE = _ANALYZEPLANRESPONSE.nested_types_by_name["DDLParse"]
 _EXECUTEPLANREQUEST = DESCRIPTOR.message_types_by_name["ExecutePlanRequest"]
 _EXECUTEPLANRESPONSE = DESCRIPTOR.message_types_by_name["ExecutePlanResponse"]
+_EXECUTEPLANRESPONSE_SQLCOMMANDRESULT = _EXECUTEPLANRESPONSE.nested_types_by_name[
+    "SqlCommandResult"
+]
 _EXECUTEPLANRESPONSE_ARROWBATCH = _EXECUTEPLANRESPONSE.nested_types_by_name["ArrowBatch"]
 _EXECUTEPLANRESPONSE_METRICS = _EXECUTEPLANRESPONSE.nested_types_by_name["Metrics"]
 _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT = _EXECUTEPLANRESPONSE_METRICS.nested_types_by_name[
@@ -319,6 +322,15 @@ ExecutePlanResponse = _reflection.GeneratedProtocolMessageType(
     "ExecutePlanResponse",
     (_message.Message,),
     {
+        "SqlCommandResult": _reflection.GeneratedProtocolMessageType(
+            "SqlCommandResult",
+            (_message.Message,),
+            {
+                "DESCRIPTOR": _EXECUTEPLANRESPONSE_SQLCOMMANDRESULT,
+                "__module__": "spark.connect.base_pb2"
+                # @@protoc_insertion_point(class_scope:spark.connect.ExecutePlanResponse.SqlCommandResult)
+            },
+        ),
         "ArrowBatch": _reflection.GeneratedProtocolMessageType(
             "ArrowBatch",
             (_message.Message,),
@@ -370,6 +382,7 @@ ExecutePlanResponse = _reflection.GeneratedProtocolMessageType(
     },
 )
 _sym_db.RegisterMessage(ExecutePlanResponse)
+_sym_db.RegisterMessage(ExecutePlanResponse.SqlCommandResult)
 _sym_db.RegisterMessage(ExecutePlanResponse.ArrowBatch)
 _sym_db.RegisterMessage(ExecutePlanResponse.Metrics)
 _sym_db.RegisterMessage(ExecutePlanResponse.Metrics.MetricObject)
@@ -613,53 +626,55 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _EXECUTEPLANREQUEST._serialized_start = 2919
     _EXECUTEPLANREQUEST._serialized_end = 3126
     _EXECUTEPLANRESPONSE._serialized_start = 3129
-    _EXECUTEPLANRESPONSE._serialized_end = 3912
-    _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_start = 3331
-    _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_end = 3392
-    _EXECUTEPLANRESPONSE_METRICS._serialized_start = 3395
-    _EXECUTEPLANRESPONSE_METRICS._serialized_end = 3912
-    _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_start = 3490
-    _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_end = 3822
-    _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 3699
-    _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 3822
-    _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_start = 3824
-    _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_end = 3912
-    _KEYVALUE._serialized_start = 3914
-    _KEYVALUE._serialized_end = 3979
-    _CONFIGREQUEST._serialized_start = 3982
-    _CONFIGREQUEST._serialized_end = 5008
-    _CONFIGREQUEST_OPERATION._serialized_start = 4200
-    _CONFIGREQUEST_OPERATION._serialized_end = 4698
-    _CONFIGREQUEST_SET._serialized_start = 4700
-    _CONFIGREQUEST_SET._serialized_end = 4752
-    _CONFIGREQUEST_GET._serialized_start = 4754
-    _CONFIGREQUEST_GET._serialized_end = 4779
-    _CONFIGREQUEST_GETWITHDEFAULT._serialized_start = 4781
-    _CONFIGREQUEST_GETWITHDEFAULT._serialized_end = 4844
-    _CONFIGREQUEST_GETOPTION._serialized_start = 4846
-    _CONFIGREQUEST_GETOPTION._serialized_end = 4877
-    _CONFIGREQUEST_GETALL._serialized_start = 4879
-    _CONFIGREQUEST_GETALL._serialized_end = 4927
-    _CONFIGREQUEST_UNSET._serialized_start = 4929
-    _CONFIGREQUEST_UNSET._serialized_end = 4956
-    _CONFIGREQUEST_ISMODIFIABLE._serialized_start = 4958
-    _CONFIGREQUEST_ISMODIFIABLE._serialized_end = 4992
-    _CONFIGRESPONSE._serialized_start = 5010
-    _CONFIGRESPONSE._serialized_end = 5130
-    _ADDARTIFACTSREQUEST._serialized_start = 5133
-    _ADDARTIFACTSREQUEST._serialized_end = 5948
-    _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_start = 5480
-    _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_end = 5533
-    _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_start = 5535
-    _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_end = 5646
-    _ADDARTIFACTSREQUEST_BATCH._serialized_start = 5648
-    _ADDARTIFACTSREQUEST_BATCH._serialized_end = 5741
-    _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_start = 5744
-    _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_end = 5937
-    _ADDARTIFACTSRESPONSE._serialized_start = 5951
-    _ADDARTIFACTSRESPONSE._serialized_end = 6139
-    _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_start = 6058
-    _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_end = 6139
-    _SPARKCONNECTSERVICE._serialized_start = 6142
-    _SPARKCONNECTSERVICE._serialized_end = 6507
+    _EXECUTEPLANRESPONSE._serialized_end = 4160
+    _EXECUTEPLANRESPONSE_SQLCOMMANDRESULT._serialized_start = 3489
+    _EXECUTEPLANRESPONSE_SQLCOMMANDRESULT._serialized_end = 3560
+    _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_start = 3562
+    _EXECUTEPLANRESPONSE_ARROWBATCH._serialized_end = 3623
+    _EXECUTEPLANRESPONSE_METRICS._serialized_start = 3626
+    _EXECUTEPLANRESPONSE_METRICS._serialized_end = 4143
+    _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_start = 3721
+    _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT._serialized_end = 4053
+    _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_start = 3930
+    _EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_end = 4053
+    _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_start = 4055
+    _EXECUTEPLANRESPONSE_METRICS_METRICVALUE._serialized_end = 4143
+    _KEYVALUE._serialized_start = 4162
+    _KEYVALUE._serialized_end = 4227
+    _CONFIGREQUEST._serialized_start = 4230
+    _CONFIGREQUEST._serialized_end = 5256
+    _CONFIGREQUEST_OPERATION._serialized_start = 4448
+    _CONFIGREQUEST_OPERATION._serialized_end = 4946
+    _CONFIGREQUEST_SET._serialized_start = 4948
+    _CONFIGREQUEST_SET._serialized_end = 5000
+    _CONFIGREQUEST_GET._serialized_start = 5002
+    _CONFIGREQUEST_GET._serialized_end = 5027
+    _CONFIGREQUEST_GETWITHDEFAULT._serialized_start = 5029
+    _CONFIGREQUEST_GETWITHDEFAULT._serialized_end = 5092
+    _CONFIGREQUEST_GETOPTION._serialized_start = 5094
+    _CONFIGREQUEST_GETOPTION._serialized_end = 5125
+    _CONFIGREQUEST_GETALL._serialized_start = 5127
+    _CONFIGREQUEST_GETALL._serialized_end = 5175
+    _CONFIGREQUEST_UNSET._serialized_start = 5177
+    _CONFIGREQUEST_UNSET._serialized_end = 5204
+    _CONFIGREQUEST_ISMODIFIABLE._serialized_start = 5206
+    _CONFIGREQUEST_ISMODIFIABLE._serialized_end = 5240
+    _CONFIGRESPONSE._serialized_start = 5258
+    _CONFIGRESPONSE._serialized_end = 5378
+    _ADDARTIFACTSREQUEST._serialized_start = 5381
+    _ADDARTIFACTSREQUEST._serialized_end = 6196
+    _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_start = 5728
+    _ADDARTIFACTSREQUEST_ARTIFACTCHUNK._serialized_end = 5781
+    _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_start = 5783
+    _ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT._serialized_end = 5894
+    _ADDARTIFACTSREQUEST_BATCH._serialized_start = 5896
+    _ADDARTIFACTSREQUEST_BATCH._serialized_end = 5989
+    _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_start = 5992
+    _ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT._serialized_end = 6185
+    _ADDARTIFACTSRESPONSE._serialized_start = 6199
+    _ADDARTIFACTSRESPONSE._serialized_end = 6387
+    _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_start = 6306
+    _ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_end = 6387
+    _SPARKCONNECTSERVICE._serialized_start = 6390
+    _SPARKCONNECTSERVICE._serialized_end = 6755
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi b/python/pyspark/sql/connect/proto/base_pb2.pyi
index 677f101aa47..0e800947975 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/base_pb2.pyi
@@ -759,6 +759,28 @@ class ExecutePlanResponse(google.protobuf.message.Message):
 
     DESCRIPTOR: google.protobuf.descriptor.Descriptor
 
+    class SqlCommandResult(google.protobuf.message.Message):
+        """A SQL command returns an opaque Relation that can be directly used as input for the next
+        call.
+        """
+
+        DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+        RELATION_FIELD_NUMBER: builtins.int
+        @property
+        def relation(self) -> pyspark.sql.connect.proto.relations_pb2.Relation: ...
+        def __init__(
+            self,
+            *,
+            relation: pyspark.sql.connect.proto.relations_pb2.Relation | None = ...,
+        ) -> None: ...
+        def HasField(
+            self, field_name: typing_extensions.Literal["relation", b"relation"]
+        ) -> builtins.bool: ...
+        def ClearField(
+            self, field_name: typing_extensions.Literal["relation", b"relation"]
+        ) -> None: ...
+
     class ArrowBatch(google.protobuf.message.Message):
         """Batch results of metrics."""
 
@@ -885,11 +907,19 @@ class ExecutePlanResponse(google.protobuf.message.Message):
 
     CLIENT_ID_FIELD_NUMBER: builtins.int
     ARROW_BATCH_FIELD_NUMBER: builtins.int
+    SQL_COMMAND_RESULT_FIELD_NUMBER: builtins.int
+    EXTENSION_FIELD_NUMBER: builtins.int
     METRICS_FIELD_NUMBER: builtins.int
     client_id: builtins.str
     @property
     def arrow_batch(self) -> global___ExecutePlanResponse.ArrowBatch: ...
     @property
+    def sql_command_result(self) -> global___ExecutePlanResponse.SqlCommandResult:
+        """Special case for executing SQL commands."""
+    @property
+    def extension(self) -> google.protobuf.any_pb2.Any:
+        """Support arbitrary result objects."""
+    @property
     def metrics(self) -> global___ExecutePlanResponse.Metrics:
         """Metrics for the query execution. Typically, this field is only present in the last
         batch of results and then represent the overall state of the query execution.
@@ -899,18 +929,45 @@ class ExecutePlanResponse(google.protobuf.message.Message):
         *,
         client_id: builtins.str = ...,
         arrow_batch: global___ExecutePlanResponse.ArrowBatch | None = ...,
+        sql_command_result: global___ExecutePlanResponse.SqlCommandResult | None = ...,
+        extension: google.protobuf.any_pb2.Any | None = ...,
         metrics: global___ExecutePlanResponse.Metrics | None = ...,
     ) -> None: ...
     def HasField(
         self,
-        field_name: typing_extensions.Literal["arrow_batch", b"arrow_batch", "metrics", b"metrics"],
+        field_name: typing_extensions.Literal[
+            "arrow_batch",
+            b"arrow_batch",
+            "extension",
+            b"extension",
+            "metrics",
+            b"metrics",
+            "response_type",
+            b"response_type",
+            "sql_command_result",
+            b"sql_command_result",
+        ],
     ) -> builtins.bool: ...
     def ClearField(
         self,
         field_name: typing_extensions.Literal[
-            "arrow_batch", b"arrow_batch", "client_id", b"client_id", "metrics", b"metrics"
+            "arrow_batch",
+            b"arrow_batch",
+            "client_id",
+            b"client_id",
+            "extension",
+            b"extension",
+            "metrics",
+            b"metrics",
+            "response_type",
+            b"response_type",
+            "sql_command_result",
+            b"sql_command_result",
         ],
     ) -> None: ...
+    def WhichOneof(
+        self, oneof_group: typing_extensions.Literal["response_type", b"response_type"]
+    ) -> typing_extensions.Literal["arrow_batch", "sql_command_result", "extension"] | None: ...
 
 global___ExecutePlanResponse = ExecutePlanResponse
 
diff --git a/python/pyspark/sql/connect/proto/commands_pb2.py b/python/pyspark/sql/connect/proto/commands_pb2.py
index c8ade1ea81b..823ed81aa07 100644
--- a/python/pyspark/sql/connect/proto/commands_pb2.py
+++ b/python/pyspark/sql/connect/proto/commands_pb2.py
@@ -36,11 +36,13 @@ from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\xab\x03\n\x07\x43ommand\x12]\n\x11register_function\x18\x01 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x10registerFunction\x12H\n\x0fwrite_operation\x18\x02 \x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x1 [...]
+    b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\xe9\x03\n\x07\x43ommand\x12]\n\x11register_function\x18\x01 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x10registerFunction\x12H\n\x0fwrite_operation\x18\x02 \x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x1 [...]
 )
 
 
 _COMMAND = DESCRIPTOR.message_types_by_name["Command"]
+_SQLCOMMAND = DESCRIPTOR.message_types_by_name["SqlCommand"]
+_SQLCOMMAND_ARGSENTRY = _SQLCOMMAND.nested_types_by_name["ArgsEntry"]
 _CREATEDATAFRAMEVIEWCOMMAND = DESCRIPTOR.message_types_by_name["CreateDataFrameViewCommand"]
 _WRITEOPERATION = DESCRIPTOR.message_types_by_name["WriteOperation"]
 _WRITEOPERATION_OPTIONSENTRY = _WRITEOPERATION.nested_types_by_name["OptionsEntry"]
@@ -67,6 +69,27 @@ Command = _reflection.GeneratedProtocolMessageType(
 )
 _sym_db.RegisterMessage(Command)
 
+SqlCommand = _reflection.GeneratedProtocolMessageType(
+    "SqlCommand",
+    (_message.Message,),
+    {
+        "ArgsEntry": _reflection.GeneratedProtocolMessageType(
+            "ArgsEntry",
+            (_message.Message,),
+            {
+                "DESCRIPTOR": _SQLCOMMAND_ARGSENTRY,
+                "__module__": "spark.connect.commands_pb2"
+                # @@protoc_insertion_point(class_scope:spark.connect.SqlCommand.ArgsEntry)
+            },
+        ),
+        "DESCRIPTOR": _SQLCOMMAND,
+        "__module__": "spark.connect.commands_pb2"
+        # @@protoc_insertion_point(class_scope:spark.connect.SqlCommand)
+    },
+)
+_sym_db.RegisterMessage(SqlCommand)
+_sym_db.RegisterMessage(SqlCommand.ArgsEntry)
+
 CreateDataFrameViewCommand = _reflection.GeneratedProtocolMessageType(
     "CreateDataFrameViewCommand",
     (_message.Message,),
@@ -154,6 +177,8 @@ if _descriptor._USE_C_DESCRIPTORS == False:
 
     DESCRIPTOR._options = None
     DESCRIPTOR._serialized_options = b"\n\036org.apache.spark.connect.protoP\001"
+    _SQLCOMMAND_ARGSENTRY._options = None
+    _SQLCOMMAND_ARGSENTRY._serialized_options = b"8\001"
     _WRITEOPERATION_OPTIONSENTRY._options = None
     _WRITEOPERATION_OPTIONSENTRY._serialized_options = b"8\001"
     _WRITEOPERATIONV2_OPTIONSENTRY._options = None
@@ -161,27 +186,31 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._options = None
     _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_options = b"8\001"
     _COMMAND._serialized_start = 166
-    _COMMAND._serialized_end = 593
-    _CREATEDATAFRAMEVIEWCOMMAND._serialized_start = 596
-    _CREATEDATAFRAMEVIEWCOMMAND._serialized_end = 746
-    _WRITEOPERATION._serialized_start = 749
-    _WRITEOPERATION._serialized_end = 1800
-    _WRITEOPERATION_OPTIONSENTRY._serialized_start = 1224
-    _WRITEOPERATION_OPTIONSENTRY._serialized_end = 1282
-    _WRITEOPERATION_SAVETABLE._serialized_start = 1285
-    _WRITEOPERATION_SAVETABLE._serialized_end = 1543
-    _WRITEOPERATION_SAVETABLE_TABLESAVEMETHOD._serialized_start = 1419
-    _WRITEOPERATION_SAVETABLE_TABLESAVEMETHOD._serialized_end = 1543
-    _WRITEOPERATION_BUCKETBY._serialized_start = 1545
-    _WRITEOPERATION_BUCKETBY._serialized_end = 1636
-    _WRITEOPERATION_SAVEMODE._serialized_start = 1639
-    _WRITEOPERATION_SAVEMODE._serialized_end = 1776
-    _WRITEOPERATIONV2._serialized_start = 1803
-    _WRITEOPERATIONV2._serialized_end = 2616
-    _WRITEOPERATIONV2_OPTIONSENTRY._serialized_start = 1224
-    _WRITEOPERATIONV2_OPTIONSENTRY._serialized_end = 1282
-    _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_start = 2375
-    _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_end = 2441
-    _WRITEOPERATIONV2_MODE._serialized_start = 2444
-    _WRITEOPERATIONV2_MODE._serialized_end = 2603
+    _COMMAND._serialized_end = 655
+    _SQLCOMMAND._serialized_start = 658
+    _SQLCOMMAND._serialized_end = 802
+    _SQLCOMMAND_ARGSENTRY._serialized_start = 747
+    _SQLCOMMAND_ARGSENTRY._serialized_end = 802
+    _CREATEDATAFRAMEVIEWCOMMAND._serialized_start = 805
+    _CREATEDATAFRAMEVIEWCOMMAND._serialized_end = 955
+    _WRITEOPERATION._serialized_start = 958
+    _WRITEOPERATION._serialized_end = 2009
+    _WRITEOPERATION_OPTIONSENTRY._serialized_start = 1433
+    _WRITEOPERATION_OPTIONSENTRY._serialized_end = 1491
+    _WRITEOPERATION_SAVETABLE._serialized_start = 1494
+    _WRITEOPERATION_SAVETABLE._serialized_end = 1752
+    _WRITEOPERATION_SAVETABLE_TABLESAVEMETHOD._serialized_start = 1628
+    _WRITEOPERATION_SAVETABLE_TABLESAVEMETHOD._serialized_end = 1752
+    _WRITEOPERATION_BUCKETBY._serialized_start = 1754
+    _WRITEOPERATION_BUCKETBY._serialized_end = 1845
+    _WRITEOPERATION_SAVEMODE._serialized_start = 1848
+    _WRITEOPERATION_SAVEMODE._serialized_end = 1985
+    _WRITEOPERATIONV2._serialized_start = 2012
+    _WRITEOPERATIONV2._serialized_end = 2825
+    _WRITEOPERATIONV2_OPTIONSENTRY._serialized_start = 1433
+    _WRITEOPERATIONV2_OPTIONSENTRY._serialized_end = 1491
+    _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_start = 2584
+    _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_end = 2650
+    _WRITEOPERATIONV2_MODE._serialized_start = 2653
+    _WRITEOPERATIONV2_MODE._serialized_end = 2812
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/commands_pb2.pyi b/python/pyspark/sql/connect/proto/commands_pb2.pyi
index fb767ead329..d2bfaf9ed89 100644
--- a/python/pyspark/sql/connect/proto/commands_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/commands_pb2.pyi
@@ -63,6 +63,7 @@ class Command(google.protobuf.message.Message):
     WRITE_OPERATION_FIELD_NUMBER: builtins.int
     CREATE_DATAFRAME_VIEW_FIELD_NUMBER: builtins.int
     WRITE_OPERATION_V2_FIELD_NUMBER: builtins.int
+    SQL_COMMAND_FIELD_NUMBER: builtins.int
     EXTENSION_FIELD_NUMBER: builtins.int
     @property
     def register_function(
@@ -75,6 +76,8 @@ class Command(google.protobuf.message.Message):
     @property
     def write_operation_v2(self) -> global___WriteOperationV2: ...
     @property
+    def sql_command(self) -> global___SqlCommand: ...
+    @property
     def extension(self) -> google.protobuf.any_pb2.Any:
         """This field is used to mark extensions to the protocol. When plugins generate arbitrary
         Commands they can add them here. During the planning the correct resolution is done.
@@ -87,6 +90,7 @@ class Command(google.protobuf.message.Message):
         write_operation: global___WriteOperation | None = ...,
         create_dataframe_view: global___CreateDataFrameViewCommand | None = ...,
         write_operation_v2: global___WriteOperationV2 | None = ...,
+        sql_command: global___SqlCommand | None = ...,
         extension: google.protobuf.any_pb2.Any | None = ...,
     ) -> None: ...
     def HasField(
@@ -100,6 +104,8 @@ class Command(google.protobuf.message.Message):
             b"extension",
             "register_function",
             b"register_function",
+            "sql_command",
+            b"sql_command",
             "write_operation",
             b"write_operation",
             "write_operation_v2",
@@ -117,6 +123,8 @@ class Command(google.protobuf.message.Message):
             b"extension",
             "register_function",
             b"register_function",
+            "sql_command",
+            b"sql_command",
             "write_operation",
             b"write_operation",
             "write_operation_v2",
@@ -130,11 +138,59 @@ class Command(google.protobuf.message.Message):
         "write_operation",
         "create_dataframe_view",
         "write_operation_v2",
+        "sql_command",
         "extension",
     ] | None: ...
 
 global___Command = Command
 
+class SqlCommand(google.protobuf.message.Message):
+    """A SQL Command is used to trigger the eager evaluation of SQL commands in Spark.
+
+    When the SQL provide as part of the message is a command it will be immediately evaluated
+    and the result will be collected and returned as part of a LocalRelation. If the result is
+    not a command, the operation will simply return a SQL Relation. This allows the client to be
+    almost oblivious to the server-side behavior.
+    """
+
+    DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+    class ArgsEntry(google.protobuf.message.Message):
+        DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+        KEY_FIELD_NUMBER: builtins.int
+        VALUE_FIELD_NUMBER: builtins.int
+        key: builtins.str
+        value: builtins.str
+        def __init__(
+            self,
+            *,
+            key: builtins.str = ...,
+            value: builtins.str = ...,
+        ) -> None: ...
+        def ClearField(
+            self, field_name: typing_extensions.Literal["key", b"key", "value", b"value"]
+        ) -> None: ...
+
+    SQL_FIELD_NUMBER: builtins.int
+    ARGS_FIELD_NUMBER: builtins.int
+    sql: builtins.str
+    """(Required) SQL Query."""
+    @property
+    def args(self) -> google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.str]:
+        """(Optional) A map of parameter names to literal values."""
+    def __init__(
+        self,
+        *,
+        sql: builtins.str = ...,
+        args: collections.abc.Mapping[builtins.str, builtins.str] | None = ...,
+    ) -> None: ...
+    def ClearField(
+        self, field_name: typing_extensions.Literal["args", b"args", "sql", b"sql"]
+    ) -> None: ...
+
+global___SqlCommand = SqlCommand
+
 class CreateDataFrameViewCommand(google.protobuf.message.Message):
     """A command that can create DataFrame global temp view or local temp view."""
 
diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py
index d82dbcb2db0..6b501b7c375 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -50,7 +50,7 @@ from pyspark import SparkContext, SparkConf, __version__
 from pyspark.sql.connect.client import SparkConnectClient
 from pyspark.sql.connect.conf import RuntimeConf
 from pyspark.sql.connect.dataframe import DataFrame
-from pyspark.sql.connect.plan import SQL, Range, LocalRelation
+from pyspark.sql.connect.plan import SQL, Range, LocalRelation, CachedRelation
 from pyspark.sql.connect.readwriter import DataFrameReader
 from pyspark.sql.pandas.serializers import ArrowStreamPandasSerializer
 from pyspark.sql.pandas.types import to_arrow_type, _get_local_timezone
@@ -347,7 +347,12 @@ class SparkSession:
     createDataFrame.__doc__ = PySparkSession.createDataFrame.__doc__
 
     def sql(self, sqlQuery: str, args: Optional[Dict[str, str]] = None) -> "DataFrame":
-        return DataFrame.withPlan(SQL(sqlQuery, args), self)
+        cmd = SQL(sqlQuery, args)
+        data, properties = self.client.execute_command(cmd.command(self._client))
+        if "sql_command_result" in properties:
+            return DataFrame.withPlan(CachedRelation(properties["sql_command_result"]), self)
+        else:
+            return DataFrame.withPlan(SQL(sqlQuery, args), self)
 
     sql.__doc__ = PySparkSession.sql.__doc__
 
diff --git a/python/pyspark/sql/tests/connect/test_client.py b/python/pyspark/sql/tests/connect/test_client.py
index 41b2888eb74..84281a6764f 100644
--- a/python/pyspark/sql/tests/connect/test_client.py
+++ b/python/pyspark/sql/tests/connect/test_client.py
@@ -20,6 +20,11 @@ from typing import Optional
 
 from pyspark.sql.connect.client import SparkConnectClient
 import pyspark.sql.connect.proto as proto
+from pyspark.testing.connectutils import should_test_connect
+
+if should_test_connect:
+    import pandas as pd
+    import pyarrow as pa
 
 
 class SparkConnectClientTestCase(unittest.TestCase):
@@ -60,6 +65,18 @@ class MockService:
         self.req = req
         resp = proto.ExecutePlanResponse()
         resp.client_id = self._session_id
+
+        pdf = pd.DataFrame(data={"col1": [1, 2]})
+        schema = pa.Schema.from_pandas(pdf)
+        table = pa.Table.from_pandas(pdf)
+        sink = pa.BufferOutputStream()
+
+        writer = pa.ipc.new_stream(sink, schema=schema)
+        writer.write(table)
+        writer.close()
+
+        buf = sink.getvalue()
+        resp.arrow_batch.data = buf.to_pybytes()
         return [resp]
 
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
index 40de117f6f6..b22c80d17e8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
@@ -158,7 +158,11 @@ private[sql] object ArrowConverters extends Logging {
             rowCountInLastBatch < maxRecordsPerBatch)) {
           val row = rowIter.next()
           arrowWriter.write(row)
-          estimatedBatchSize += row.asInstanceOf[UnsafeRow].getSizeInBytes
+          estimatedBatchSize += (row match {
+            case ur: UnsafeRow => ur.getSizeInBytes
+            // Trying to estimate the size of the current row, assuming 16 bytes per value.
+            case ir: InternalRow => ir.numFields * 16
+          })
           rowCountInLastBatch += 1
         }
         arrowWriter.finish()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org