You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2022/11/11 07:48:38 UTC

[spark] branch master updated: [SPARK-41102][CONNECT] Merge SparkConnectPlanner and SparkConnectCommandPlanner

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new fd1e0d028cb [SPARK-41102][CONNECT] Merge SparkConnectPlanner and SparkConnectCommandPlanner
fd1e0d028cb is described below

commit fd1e0d028cb7e26921cd66a421c00d7260092b23
Author: Rui Wang <ru...@databricks.com>
AuthorDate: Fri Nov 11 15:48:21 2022 +0800

    [SPARK-41102][CONNECT] Merge SparkConnectPlanner and SparkConnectCommandPlanner
    
    ### What changes were proposed in this pull request?
    
    In the past, Connect server side separates `Command` and `Relation` into two Planners. However, as we are adding new API, there are certainly cases that a `Command` still has an input which is a Relation. Thus when converting `Command`, it still needs to access the logic of converting `Relation`. View creation is an example of such cases. Usually DDL and DML of SQL will also follow.
    
    This PR refactors to merge the logic of dealing with `Command` and `Relation` into the same planner.
    
    ### Why are the changes needed?
    
    Refactoring.
    
    ### Does this PR introduce _any_ user-facing change?
    
    NO
    ### How was this patch tested?
    
    Existing UT
    
    Closes #38604 from amaliujia/refactor-planners.
    
    Authored-by: Rui Wang <ru...@databricks.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../command/SparkConnectCommandPlanner.scala       | 174 ---------------------
 .../sql/connect/planner/SparkConnectPlanner.scala  | 152 +++++++++++++++++-
 .../sql/connect/service/SparkConnectService.scala  |   2 +-
 .../service/SparkConnectStreamHandler.scala        |   9 +-
 .../planner/SparkConnectCommandPlannerSuite.scala  | 160 -------------------
 .../connect/planner/SparkConnectPlannerSuite.scala |  25 +--
 .../connect/planner/SparkConnectProtoSuite.scala   | 128 ++++++++++++++-
 7 files changed, 291 insertions(+), 359 deletions(-)

diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala
deleted file mode 100644
index 11090976c7f..00000000000
--- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala
+++ /dev/null
@@ -1,174 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.connect.command
-
-import scala.collection.JavaConverters._
-
-import com.google.common.collect.{Lists, Maps}
-
-import org.apache.spark.api.python.{PythonEvalType, SimplePythonFunction}
-import org.apache.spark.connect.proto
-import org.apache.spark.connect.proto.WriteOperation
-import org.apache.spark.sql.{Dataset, SparkSession}
-import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView}
-import org.apache.spark.sql.catalyst.parser.ParseException
-import org.apache.spark.sql.connect.planner.{DataTypeProtoConverter, SparkConnectPlanner}
-import org.apache.spark.sql.errors.QueryCompilationErrors
-import org.apache.spark.sql.execution.command.CreateViewCommand
-import org.apache.spark.sql.execution.python.UserDefinedPythonFunction
-import org.apache.spark.sql.types.StringType
-
-final case class InvalidCommandInput(
-    private val message: String = "",
-    private val cause: Throwable = null)
-    extends Exception(message, cause)
-
-class SparkConnectCommandPlanner(session: SparkSession, command: proto.Command) {
-
-  lazy val pythonExec =
-    sys.env.getOrElse("PYSPARK_PYTHON", sys.env.getOrElse("PYSPARK_DRIVER_PYTHON", "python3"))
-
-  def process(): Unit = {
-    command.getCommandTypeCase match {
-      case proto.Command.CommandTypeCase.CREATE_FUNCTION =>
-        handleCreateScalarFunction(command.getCreateFunction)
-      case proto.Command.CommandTypeCase.WRITE_OPERATION =>
-        handleWriteOperation(command.getWriteOperation)
-      case proto.Command.CommandTypeCase.CREATE_DATAFRAME_VIEW =>
-        handleCreateViewCommand(command.getCreateDataframeView)
-      case _ => throw new UnsupportedOperationException(s"$command not supported.")
-    }
-  }
-
-  /**
-   * This is a helper function that registers a new Python function in the SparkSession.
-   *
-   * Right now this function is very rudimentary and bare-bones just to showcase how it is
-   * possible to remotely serialize a Python function and execute it on the Spark cluster. If the
-   * Python version on the client and server diverge, the execution of the function that is
-   * serialized will most likely fail.
-   *
-   * @param cf
-   */
-  def handleCreateScalarFunction(cf: proto.CreateScalarFunction): Unit = {
-    val function = SimplePythonFunction(
-      cf.getSerializedFunction.toByteArray,
-      Maps.newHashMap(),
-      Lists.newArrayList(),
-      pythonExec,
-      "3.9", // TODO(SPARK-40532) This needs to be an actual Python version.
-      Lists.newArrayList(),
-      null)
-
-    val udf = UserDefinedPythonFunction(
-      cf.getPartsList.asScala.head,
-      function,
-      StringType,
-      PythonEvalType.SQL_BATCHED_UDF,
-      udfDeterministic = false)
-
-    session.udf.registerPython(cf.getPartsList.asScala.head, udf)
-  }
-
-  def handleCreateViewCommand(createView: proto.CreateDataFrameViewCommand): Unit = {
-    val viewType = if (createView.getIsGlobal) GlobalTempView else LocalTempView
-
-    val tableIdentifier =
-      try {
-        session.sessionState.sqlParser.parseTableIdentifier(createView.getName)
-      } catch {
-        case _: ParseException =>
-          throw QueryCompilationErrors.invalidViewNameError(createView.getName)
-      }
-
-    val plan = CreateViewCommand(
-      name = tableIdentifier,
-      userSpecifiedColumns = Nil,
-      comment = None,
-      properties = Map.empty,
-      originalText = None,
-      plan = new SparkConnectPlanner(createView.getInput, session).transform(),
-      allowExisting = false,
-      replace = createView.getReplace,
-      viewType = viewType,
-      isAnalyzed = true)
-
-    Dataset.ofRows(session, plan).queryExecution.commandExecuted
-  }
-
-  /**
-   * Transforms the write operation and executes it.
-   *
-   * The input write operation contains a reference to the input plan and transforms it to the
-   * corresponding logical plan. Afterwards, creates the DataFrameWriter and translates the
-   * parameters of the WriteOperation into the corresponding methods calls.
-   *
-   * @param writeOperation
-   */
-  def handleWriteOperation(writeOperation: WriteOperation): Unit = {
-    // Transform the input plan into the logical plan.
-    val planner = new SparkConnectPlanner(writeOperation.getInput, session)
-    val plan = planner.transform()
-    // And create a Dataset from the plan.
-    val dataset = Dataset.ofRows(session, logicalPlan = plan)
-
-    val w = dataset.write
-    if (writeOperation.getMode != proto.WriteOperation.SaveMode.SAVE_MODE_UNSPECIFIED) {
-      w.mode(DataTypeProtoConverter.toSaveMode(writeOperation.getMode))
-    }
-
-    if (writeOperation.getOptionsCount > 0) {
-      writeOperation.getOptionsMap.asScala.foreach { case (key, value) => w.option(key, value) }
-    }
-
-    if (writeOperation.getSortColumnNamesCount > 0) {
-      val names = writeOperation.getSortColumnNamesList.asScala
-      w.sortBy(names.head, names.tail.toSeq: _*)
-    }
-
-    if (writeOperation.hasBucketBy) {
-      val op = writeOperation.getBucketBy
-      val cols = op.getBucketColumnNamesList.asScala
-      if (op.getNumBuckets <= 0) {
-        throw InvalidCommandInput(
-          s"BucketBy must specify a bucket count > 0, received ${op.getNumBuckets} instead.")
-      }
-      w.bucketBy(op.getNumBuckets, cols.head, cols.tail.toSeq: _*)
-    }
-
-    if (writeOperation.getPartitioningColumnsCount > 0) {
-      val names = writeOperation.getPartitioningColumnsList.asScala
-      w.partitionBy(names.toSeq: _*)
-    }
-
-    if (writeOperation.getSource != null) {
-      w.format(writeOperation.getSource)
-    }
-
-    writeOperation.getSaveTypeCase match {
-      case proto.WriteOperation.SaveTypeCase.PATH => w.save(writeOperation.getPath)
-      case proto.WriteOperation.SaveTypeCase.TABLE_NAME =>
-        w.saveAsTable(writeOperation.getTableName)
-      case _ =>
-        throw new UnsupportedOperationException(
-          "WriteOperation:SaveTypeCase not supported "
-            + s"${writeOperation.getSaveTypeCase.getNumber}")
-    }
-  }
-
-}
diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index b91fef58a11..f8ccc7b62e7 100644
--- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -19,18 +19,25 @@ package org.apache.spark.sql.connect.planner
 
 import scala.collection.JavaConverters._
 
+import com.google.common.collect.{Lists, Maps}
+
+import org.apache.spark.api.python.{PythonEvalType, SimplePythonFunction}
 import org.apache.spark.connect.proto
+import org.apache.spark.connect.proto.WriteOperation
 import org.apache.spark.sql.{Dataset, SparkSession}
 import org.apache.spark.sql.catalyst.AliasIdentifier
-import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
+import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
 import org.apache.spark.sql.catalyst.expressions
 import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression, NamedExpression}
 import org.apache.spark.sql.catalyst.optimizer.CombineUnions
-import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
+import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException}
 import org.apache.spark.sql.catalyst.plans.{logical, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin}
 import org.apache.spark.sql.catalyst.plans.logical.{Deduplicate, Except, Intersect, LogicalPlan, Sample, SubqueryAlias, Union}
 import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
+import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.execution.QueryExecution
+import org.apache.spark.sql.execution.command.CreateViewCommand
+import org.apache.spark.sql.execution.python.UserDefinedPythonFunction
 import org.apache.spark.sql.types._
 import org.apache.spark.util.Utils
 
@@ -39,14 +46,17 @@ final case class InvalidPlanInput(
     private val cause: Throwable = None.orNull)
     extends Exception(message, cause)
 
-class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
+final case class InvalidCommandInput(
+    private val message: String = "",
+    private val cause: Throwable = null)
+    extends Exception(message, cause)
 
-  def transform(): LogicalPlan = {
-    transformRelation(plan)
-  }
+class SparkConnectPlanner(session: SparkSession) {
+  lazy val pythonExec =
+    sys.env.getOrElse("PYSPARK_PYTHON", sys.env.getOrElse("PYSPARK_DRIVER_PYTHON", "python3"))
 
   // The root of the query plan is a relation and we apply the transformations to it.
-  private def transformRelation(rel: proto.Relation): LogicalPlan = {
+  def transformRelation(rel: proto.Relation): LogicalPlan = {
     rel.getRelTypeCase match {
       case proto.Relation.RelTypeCase.READ => transformReadRel(rel.getRead)
       case proto.Relation.RelTypeCase.PROJECT => transformProject(rel.getProject)
@@ -446,4 +456,132 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
     }
   }
 
+  def process(command: proto.Command): Unit = {
+    command.getCommandTypeCase match {
+      case proto.Command.CommandTypeCase.CREATE_FUNCTION =>
+        handleCreateScalarFunction(command.getCreateFunction)
+      case proto.Command.CommandTypeCase.WRITE_OPERATION =>
+        handleWriteOperation(command.getWriteOperation)
+      case proto.Command.CommandTypeCase.CREATE_DATAFRAME_VIEW =>
+        handleCreateViewCommand(command.getCreateDataframeView)
+      case _ => throw new UnsupportedOperationException(s"$command not supported.")
+    }
+  }
+
+  /**
+   * This is a helper function that registers a new Python function in the SparkSession.
+   *
+   * Right now this function is very rudimentary and bare-bones just to showcase how it is
+   * possible to remotely serialize a Python function and execute it on the Spark cluster. If the
+   * Python version on the client and server diverge, the execution of the function that is
+   * serialized will most likely fail.
+   *
+   * @param cf
+   */
+  def handleCreateScalarFunction(cf: proto.CreateScalarFunction): Unit = {
+    val function = SimplePythonFunction(
+      cf.getSerializedFunction.toByteArray,
+      Maps.newHashMap(),
+      Lists.newArrayList(),
+      pythonExec,
+      "3.9", // TODO(SPARK-40532) This needs to be an actual Python version.
+      Lists.newArrayList(),
+      null)
+
+    val udf = UserDefinedPythonFunction(
+      cf.getPartsList.asScala.head,
+      function,
+      StringType,
+      PythonEvalType.SQL_BATCHED_UDF,
+      udfDeterministic = false)
+
+    session.udf.registerPython(cf.getPartsList.asScala.head, udf)
+  }
+
+  def handleCreateViewCommand(createView: proto.CreateDataFrameViewCommand): Unit = {
+    val viewType = if (createView.getIsGlobal) GlobalTempView else LocalTempView
+
+    val tableIdentifier =
+      try {
+        session.sessionState.sqlParser.parseTableIdentifier(createView.getName)
+      } catch {
+        case _: ParseException =>
+          throw QueryCompilationErrors.invalidViewNameError(createView.getName)
+      }
+
+    val plan = CreateViewCommand(
+      name = tableIdentifier,
+      userSpecifiedColumns = Nil,
+      comment = None,
+      properties = Map.empty,
+      originalText = None,
+      plan = transformRelation(createView.getInput),
+      allowExisting = false,
+      replace = createView.getReplace,
+      viewType = viewType,
+      isAnalyzed = true)
+
+    Dataset.ofRows(session, plan).queryExecution.commandExecuted
+  }
+
+  /**
+   * Transforms the write operation and executes it.
+   *
+   * The input write operation contains a reference to the input plan and transforms it to the
+   * corresponding logical plan. Afterwards, creates the DataFrameWriter and translates the
+   * parameters of the WriteOperation into the corresponding methods calls.
+   *
+   * @param writeOperation
+   */
+  def handleWriteOperation(writeOperation: WriteOperation): Unit = {
+    // Transform the input plan into the logical plan.
+    val planner = new SparkConnectPlanner(session)
+    val plan = planner.transformRelation(writeOperation.getInput)
+    // And create a Dataset from the plan.
+    val dataset = Dataset.ofRows(session, logicalPlan = plan)
+
+    val w = dataset.write
+    if (writeOperation.getMode != proto.WriteOperation.SaveMode.SAVE_MODE_UNSPECIFIED) {
+      w.mode(DataTypeProtoConverter.toSaveMode(writeOperation.getMode))
+    }
+
+    if (writeOperation.getOptionsCount > 0) {
+      writeOperation.getOptionsMap.asScala.foreach { case (key, value) => w.option(key, value) }
+    }
+
+    if (writeOperation.getSortColumnNamesCount > 0) {
+      val names = writeOperation.getSortColumnNamesList.asScala
+      w.sortBy(names.head, names.tail.toSeq: _*)
+    }
+
+    if (writeOperation.hasBucketBy) {
+      val op = writeOperation.getBucketBy
+      val cols = op.getBucketColumnNamesList.asScala
+      if (op.getNumBuckets <= 0) {
+        throw InvalidCommandInput(
+          s"BucketBy must specify a bucket count > 0, received ${op.getNumBuckets} instead.")
+      }
+      w.bucketBy(op.getNumBuckets, cols.head, cols.tail.toSeq: _*)
+    }
+
+    if (writeOperation.getPartitioningColumnsCount > 0) {
+      val names = writeOperation.getPartitioningColumnsList.asScala
+      w.partitionBy(names.toSeq: _*)
+    }
+
+    if (writeOperation.getSource != null) {
+      w.format(writeOperation.getSource)
+    }
+
+    writeOperation.getSaveTypeCase match {
+      case proto.WriteOperation.SaveTypeCase.PATH => w.save(writeOperation.getPath)
+      case proto.WriteOperation.SaveTypeCase.TABLE_NAME =>
+        w.saveAsTable(writeOperation.getTableName)
+      case _ =>
+        throw new UnsupportedOperationException(
+          "WriteOperation:SaveTypeCase not supported "
+            + s"${writeOperation.getSaveTypeCase.getNumber}")
+    }
+  }
+
 }
diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
index a1e70975da5..abbad51c601 100644
--- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
+++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
@@ -106,7 +106,7 @@ class SparkConnectService(debug: Boolean)
   def handleAnalyzePlanRequest(
       relation: proto.Relation,
       session: SparkSession): proto.AnalyzeResponse.Builder = {
-    val logicalPlan = new SparkConnectPlanner(relation, session).transform()
+    val logicalPlan = new SparkConnectPlanner(session).transformRelation(relation)
 
     val ds = Dataset.ofRows(session, logicalPlan)
     val explainString = ds.queryExecution.explainString(ExtendedMode)
diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
index 9652fce5425..394d6477d73 100644
--- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
+++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
@@ -27,7 +27,6 @@ import org.apache.spark.connect.proto
 import org.apache.spark.connect.proto.{Request, Response}
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
-import org.apache.spark.sql.connect.command.SparkConnectCommandPlanner
 import org.apache.spark.sql.connect.planner.SparkConnectPlanner
 import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
 import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, QueryStageExec}
@@ -51,8 +50,8 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
 
   def handlePlan(session: SparkSession, request: Request): Unit = {
     // Extract the plan from the request and convert it to a logical plan
-    val planner = new SparkConnectPlanner(request.getPlan.getRoot, session)
-    val dataframe = Dataset.ofRows(session, planner.transform())
+    val planner = new SparkConnectPlanner(session)
+    val dataframe = Dataset.ofRows(session, planner.transformRelation(request.getPlan.getRoot))
     try {
       processAsArrowBatches(request.getClientId, dataframe)
     } catch {
@@ -216,8 +215,8 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
 
   def handleCommand(session: SparkSession, request: Request): Unit = {
     val command = request.getPlan.getCommand
-    val planner = new SparkConnectCommandPlanner(session, command)
-    planner.process()
+    val planner = new SparkConnectPlanner(session)
+    planner.process(command)
     responseObserver.onCompleted()
   }
 }
diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectCommandPlannerSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectCommandPlannerSuite.scala
deleted file mode 100644
index 8ab8e0599fc..00000000000
--- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectCommandPlannerSuite.scala
+++ /dev/null
@@ -1,160 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.connect.planner
-
-import java.nio.file.{Files, Paths}
-
-import org.apache.spark.SparkClassNotFoundException
-import org.apache.spark.connect.proto
-import org.apache.spark.sql.{AnalysisException, SaveMode}
-import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.connect.command.{InvalidCommandInput, SparkConnectCommandPlanner}
-import org.apache.spark.sql.connect.dsl.commands._
-import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils}
-
-class SparkConnectCommandPlannerSuite
-    extends SQLTestUtils
-    with SparkConnectPlanTest
-    with SharedSparkSession {
-
-  lazy val localRelation = createLocalRelationProto(Seq($"id".int))
-
-  def transform(cmd: proto.Command): Unit = {
-    new SparkConnectCommandPlanner(spark, cmd).process()
-  }
-
-  test("Writes fails without path or table") {
-    assertThrows[UnsupportedOperationException] {
-      transform(localRelation.write())
-    }
-  }
-
-  test("Write fails with unknown table - AnalysisException") {
-    val cmd = readRel.write(tableName = Some("dest"))
-    assertThrows[AnalysisException] {
-      transform(cmd)
-    }
-  }
-
-  test("Write with partitions") {
-    val cmd = localRelation.write(
-      tableName = Some("testtable"),
-      format = Some("parquet"),
-      partitionByCols = Seq("noid"))
-    assertThrows[AnalysisException] {
-      transform(cmd)
-    }
-  }
-
-  test("Write with invalid bucketBy configuration") {
-    val cmd = localRelation.write(bucketByCols = Seq("id"), numBuckets = Some(0))
-    assertThrows[InvalidCommandInput] {
-      transform(cmd)
-    }
-  }
-
-  test("Write to Path") {
-    withTempDir { f =>
-      val cmd = localRelation.write(
-        format = Some("parquet"),
-        path = Some(f.getPath),
-        mode = Some("Overwrite"))
-      transform(cmd)
-      assert(Files.exists(Paths.get(f.getPath)), s"Output file must exist: ${f.getPath}")
-    }
-  }
-
-  test("Write to Path with invalid input") {
-    // Wrong data source.
-    assertThrows[SparkClassNotFoundException](
-      transform(
-        localRelation.write(path = Some("/tmp/tmppath"), format = Some("ThisAintNoFormat"))))
-
-    // Default data source not found.
-    assertThrows[SparkClassNotFoundException](
-      transform(localRelation.write(path = Some("/tmp/tmppath"))))
-  }
-
-  test("Write with sortBy") {
-    // Sort by existing column.
-    withTable("testtable") {
-      transform(
-        localRelation.write(
-          tableName = Some("testtable"),
-          format = Some("parquet"),
-          sortByColumns = Seq("id"),
-          bucketByCols = Seq("id"),
-          numBuckets = Some(10)))
-    }
-
-    // Sort by non-existing column
-    assertThrows[AnalysisException](
-      transform(
-        localRelation
-          .write(
-            tableName = Some("testtable"),
-            format = Some("parquet"),
-            sortByColumns = Seq("noid"),
-            bucketByCols = Seq("id"),
-            numBuckets = Some(10))))
-  }
-
-  test("Write to Table") {
-    withTable("testtable") {
-      val cmd = localRelation.write(format = Some("parquet"), tableName = Some("testtable"))
-      transform(cmd)
-      // Check that we can find and drop the table.
-      spark.sql(s"select count(*) from testtable").collect()
-    }
-  }
-
-  test("SaveMode conversion tests") {
-    assertThrows[IllegalArgumentException](
-      DataTypeProtoConverter.toSaveMode(proto.WriteOperation.SaveMode.SAVE_MODE_UNSPECIFIED))
-
-    val combinations = Seq(
-      (SaveMode.Append, proto.WriteOperation.SaveMode.SAVE_MODE_APPEND),
-      (SaveMode.Ignore, proto.WriteOperation.SaveMode.SAVE_MODE_IGNORE),
-      (SaveMode.Overwrite, proto.WriteOperation.SaveMode.SAVE_MODE_OVERWRITE),
-      (SaveMode.ErrorIfExists, proto.WriteOperation.SaveMode.SAVE_MODE_ERROR_IF_EXISTS))
-    combinations.foreach { a =>
-      assert(DataTypeProtoConverter.toSaveModeProto(a._1) == a._2)
-      assert(DataTypeProtoConverter.toSaveMode(a._2) == a._1)
-    }
-  }
-
-  test("Test CreateView") {
-    withView("view1", "view2", "view3", "view4") {
-      transform(localRelation.createView("view1", global = true, replace = true))
-      assert(spark.catalog.tableExists("global_temp.view1"))
-
-      transform(localRelation.createView("view2", global = false, replace = true))
-      assert(spark.catalog.tableExists("view2"))
-
-      transform(localRelation.createView("view3", global = true, replace = false))
-      assertThrows[AnalysisException] {
-        transform(localRelation.createView("view3", global = true, replace = false))
-      }
-
-      transform(localRelation.createView("view4", global = false, replace = false))
-      assertThrows[AnalysisException] {
-        transform(localRelation.createView("view4", global = false, replace = false))
-      }
-    }
-  }
-}
diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
index d2304581c3a..9e5fc41a0c6 100644
--- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
+++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
@@ -33,7 +33,11 @@ import org.apache.spark.sql.test.SharedSparkSession
 trait SparkConnectPlanTest extends SharedSparkSession {
 
   def transform(rel: proto.Relation): logical.LogicalPlan = {
-    new SparkConnectPlanner(rel, spark).transform()
+    new SparkConnectPlanner(spark).transformRelation(rel)
+  }
+
+  def transform(cmd: proto.Command): Unit = {
+    new SparkConnectPlanner(spark).process(cmd)
   }
 
   def readRel: proto.Relation =
@@ -75,24 +79,23 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest {
 
   test("Simple Limit") {
     assertThrows[IndexOutOfBoundsException] {
-      new SparkConnectPlanner(
-        proto.Relation.newBuilder
-          .setLimit(proto.Limit.newBuilder.setLimit(10))
-          .build(),
-        None.orNull)
-        .transform()
+      new SparkConnectPlanner(None.orNull)
+        .transformRelation(
+          proto.Relation.newBuilder
+            .setLimit(proto.Limit.newBuilder.setLimit(10))
+            .build())
     }
   }
 
   test("InvalidInputs") {
     // No Relation Set
     intercept[IndexOutOfBoundsException](
-      new SparkConnectPlanner(proto.Relation.newBuilder().build(), None.orNull).transform())
+      new SparkConnectPlanner(None.orNull).transformRelation(proto.Relation.newBuilder().build()))
 
     intercept[InvalidPlanInput](
-      new SparkConnectPlanner(
-        proto.Relation.newBuilder.setUnknown(proto.Unknown.newBuilder().build()).build(),
-        None.orNull).transform())
+      new SparkConnectPlanner(None.orNull)
+        .transformRelation(
+          proto.Relation.newBuilder.setUnknown(proto.Unknown.newBuilder().build()).build()))
 
   }
 
diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
index 5052b451047..53ea1988809 100644
--- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
+++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
@@ -16,15 +16,19 @@
  */
 package org.apache.spark.sql.connect.planner
 
+import java.nio.file.{Files, Paths}
+
+import org.apache.spark.SparkClassNotFoundException
 import org.apache.spark.connect.proto
 import org.apache.spark.connect.proto.Join.JoinType
-import org.apache.spark.sql.{AnalysisException, Column, DataFrame, Row}
+import org.apache.spark.sql.{AnalysisException, Column, DataFrame, Row, SaveMode}
 import org.apache.spark.sql.catalyst.analysis
 import org.apache.spark.sql.catalyst.expressions.AttributeReference
 import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, PlanTest, RightOuter}
 import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.connect.dsl.MockRemoteSession
+import org.apache.spark.sql.connect.dsl.commands._
 import org.apache.spark.sql.connect.dsl.expressions._
 import org.apache.spark.sql.connect.dsl.plans._
 import org.apache.spark.sql.functions._
@@ -57,6 +61,8 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
       new java.util.ArrayList[Row](),
       StructType(Seq(StructField("id", IntegerType), StructField("name", StringType))))
 
+  lazy val localRelation = createLocalRelationProto(Seq(AttributeReference("id", IntegerType)()))
+
   test("Basic select") {
     val connectPlan = connectTestRelation.select("id".protoAttr)
     val sparkPlan = sparkTestRelation.select("id")
@@ -303,6 +309,126 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
     assert(e.getMessage.contains("Found duplicate column(s)"))
   }
 
+  test("Writes fails without path or table") {
+    assertThrows[UnsupportedOperationException] {
+      transform(localRelation.write())
+    }
+  }
+
+  test("Write fails with unknown table - AnalysisException") {
+    val cmd = readRel.write(tableName = Some("dest"))
+    assertThrows[AnalysisException] {
+      transform(cmd)
+    }
+  }
+
+  test("Write with partitions") {
+    val cmd = localRelation.write(
+      tableName = Some("testtable"),
+      format = Some("parquet"),
+      partitionByCols = Seq("noid"))
+    assertThrows[AnalysisException] {
+      transform(cmd)
+    }
+  }
+
+  test("Write with invalid bucketBy configuration") {
+    val cmd = localRelation.write(bucketByCols = Seq("id"), numBuckets = Some(0))
+    assertThrows[InvalidCommandInput] {
+      transform(cmd)
+    }
+  }
+
+  test("Write to Path") {
+    withTempDir { f =>
+      val cmd = localRelation.write(
+        format = Some("parquet"),
+        path = Some(f.getPath),
+        mode = Some("Overwrite"))
+      transform(cmd)
+      assert(Files.exists(Paths.get(f.getPath)), s"Output file must exist: ${f.getPath}")
+    }
+  }
+
+  test("Write to Path with invalid input") {
+    // Wrong data source.
+    assertThrows[SparkClassNotFoundException](
+      transform(
+        localRelation.write(path = Some("/tmp/tmppath"), format = Some("ThisAintNoFormat"))))
+
+    // Default data source not found.
+    assertThrows[SparkClassNotFoundException](
+      transform(localRelation.write(path = Some("/tmp/tmppath"))))
+  }
+
+  test("Write with sortBy") {
+    // Sort by existing column.
+    withTable("testtable") {
+      transform(
+        localRelation.write(
+          tableName = Some("testtable"),
+          format = Some("parquet"),
+          sortByColumns = Seq("id"),
+          bucketByCols = Seq("id"),
+          numBuckets = Some(10)))
+    }
+
+    // Sort by non-existing column
+    assertThrows[AnalysisException](
+      transform(
+        localRelation
+          .write(
+            tableName = Some("testtable"),
+            format = Some("parquet"),
+            sortByColumns = Seq("noid"),
+            bucketByCols = Seq("id"),
+            numBuckets = Some(10))))
+  }
+
+  test("Write to Table") {
+    withTable("testtable") {
+      val cmd = localRelation.write(format = Some("parquet"), tableName = Some("testtable"))
+      transform(cmd)
+      // Check that we can find and drop the table.
+      spark.sql(s"select count(*) from testtable").collect()
+    }
+  }
+
+  test("SaveMode conversion tests") {
+    assertThrows[IllegalArgumentException](
+      DataTypeProtoConverter.toSaveMode(proto.WriteOperation.SaveMode.SAVE_MODE_UNSPECIFIED))
+
+    val combinations = Seq(
+      (SaveMode.Append, proto.WriteOperation.SaveMode.SAVE_MODE_APPEND),
+      (SaveMode.Ignore, proto.WriteOperation.SaveMode.SAVE_MODE_IGNORE),
+      (SaveMode.Overwrite, proto.WriteOperation.SaveMode.SAVE_MODE_OVERWRITE),
+      (SaveMode.ErrorIfExists, proto.WriteOperation.SaveMode.SAVE_MODE_ERROR_IF_EXISTS))
+    combinations.foreach { a =>
+      assert(DataTypeProtoConverter.toSaveModeProto(a._1) == a._2)
+      assert(DataTypeProtoConverter.toSaveMode(a._2) == a._1)
+    }
+  }
+
+  test("Test CreateView") {
+    withView("view1", "view2", "view3", "view4") {
+      transform(localRelation.createView("view1", global = true, replace = true))
+      assert(spark.catalog.tableExists("global_temp.view1"))
+
+      transform(localRelation.createView("view2", global = false, replace = true))
+      assert(spark.catalog.tableExists("view2"))
+
+      transform(localRelation.createView("view3", global = true, replace = false))
+      assertThrows[AnalysisException] {
+        transform(localRelation.createView("view3", global = true, replace = false))
+      }
+
+      transform(localRelation.createView("view4", global = false, replace = false))
+      assertThrows[AnalysisException] {
+        transform(localRelation.createView("view4", global = false, replace = false))
+      }
+    }
+  }
+
   private def createLocalRelationProtoByQualifiedAttributes(
       attrs: Seq[proto.Expression.QualifiedAttribute]): proto.Relation = {
     val localRelationBuilder = proto.LocalRelation.newBuilder()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org