You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2022/11/11 07:48:38 UTC
[spark] branch master updated: [SPARK-41102][CONNECT] Merge SparkConnectPlanner and SparkConnectCommandPlanner
This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new fd1e0d028cb [SPARK-41102][CONNECT] Merge SparkConnectPlanner and SparkConnectCommandPlanner
fd1e0d028cb is described below
commit fd1e0d028cb7e26921cd66a421c00d7260092b23
Author: Rui Wang <ru...@databricks.com>
AuthorDate: Fri Nov 11 15:48:21 2022 +0800
[SPARK-41102][CONNECT] Merge SparkConnectPlanner and SparkConnectCommandPlanner
### What changes were proposed in this pull request?
In the past, Connect server side separates `Command` and `Relation` into two Planners. However, as we are adding new API, there are certainly cases that a `Command` still has an input which is a Relation. Thus when converting `Command`, it still needs to access the logic of converting `Relation`. View creation is an example of such cases. Usually DDL and DML of SQL will also follow.
This PR refactors to merge the logic of dealing with `Command` and `Relation` into the same planner.
### Why are the changes needed?
Refactoring.
### Does this PR introduce _any_ user-facing change?
NO
### How was this patch tested?
Existing UT
Closes #38604 from amaliujia/refactor-planners.
Authored-by: Rui Wang <ru...@databricks.com>
Signed-off-by: Wenchen Fan <we...@databricks.com>
---
.../command/SparkConnectCommandPlanner.scala | 174 ---------------------
.../sql/connect/planner/SparkConnectPlanner.scala | 152 +++++++++++++++++-
.../sql/connect/service/SparkConnectService.scala | 2 +-
.../service/SparkConnectStreamHandler.scala | 9 +-
.../planner/SparkConnectCommandPlannerSuite.scala | 160 -------------------
.../connect/planner/SparkConnectPlannerSuite.scala | 25 +--
.../connect/planner/SparkConnectProtoSuite.scala | 128 ++++++++++++++-
7 files changed, 291 insertions(+), 359 deletions(-)
diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala
deleted file mode 100644
index 11090976c7f..00000000000
--- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/command/SparkConnectCommandPlanner.scala
+++ /dev/null
@@ -1,174 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.connect.command
-
-import scala.collection.JavaConverters._
-
-import com.google.common.collect.{Lists, Maps}
-
-import org.apache.spark.api.python.{PythonEvalType, SimplePythonFunction}
-import org.apache.spark.connect.proto
-import org.apache.spark.connect.proto.WriteOperation
-import org.apache.spark.sql.{Dataset, SparkSession}
-import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView}
-import org.apache.spark.sql.catalyst.parser.ParseException
-import org.apache.spark.sql.connect.planner.{DataTypeProtoConverter, SparkConnectPlanner}
-import org.apache.spark.sql.errors.QueryCompilationErrors
-import org.apache.spark.sql.execution.command.CreateViewCommand
-import org.apache.spark.sql.execution.python.UserDefinedPythonFunction
-import org.apache.spark.sql.types.StringType
-
-final case class InvalidCommandInput(
- private val message: String = "",
- private val cause: Throwable = null)
- extends Exception(message, cause)
-
-class SparkConnectCommandPlanner(session: SparkSession, command: proto.Command) {
-
- lazy val pythonExec =
- sys.env.getOrElse("PYSPARK_PYTHON", sys.env.getOrElse("PYSPARK_DRIVER_PYTHON", "python3"))
-
- def process(): Unit = {
- command.getCommandTypeCase match {
- case proto.Command.CommandTypeCase.CREATE_FUNCTION =>
- handleCreateScalarFunction(command.getCreateFunction)
- case proto.Command.CommandTypeCase.WRITE_OPERATION =>
- handleWriteOperation(command.getWriteOperation)
- case proto.Command.CommandTypeCase.CREATE_DATAFRAME_VIEW =>
- handleCreateViewCommand(command.getCreateDataframeView)
- case _ => throw new UnsupportedOperationException(s"$command not supported.")
- }
- }
-
- /**
- * This is a helper function that registers a new Python function in the SparkSession.
- *
- * Right now this function is very rudimentary and bare-bones just to showcase how it is
- * possible to remotely serialize a Python function and execute it on the Spark cluster. If the
- * Python version on the client and server diverge, the execution of the function that is
- * serialized will most likely fail.
- *
- * @param cf
- */
- def handleCreateScalarFunction(cf: proto.CreateScalarFunction): Unit = {
- val function = SimplePythonFunction(
- cf.getSerializedFunction.toByteArray,
- Maps.newHashMap(),
- Lists.newArrayList(),
- pythonExec,
- "3.9", // TODO(SPARK-40532) This needs to be an actual Python version.
- Lists.newArrayList(),
- null)
-
- val udf = UserDefinedPythonFunction(
- cf.getPartsList.asScala.head,
- function,
- StringType,
- PythonEvalType.SQL_BATCHED_UDF,
- udfDeterministic = false)
-
- session.udf.registerPython(cf.getPartsList.asScala.head, udf)
- }
-
- def handleCreateViewCommand(createView: proto.CreateDataFrameViewCommand): Unit = {
- val viewType = if (createView.getIsGlobal) GlobalTempView else LocalTempView
-
- val tableIdentifier =
- try {
- session.sessionState.sqlParser.parseTableIdentifier(createView.getName)
- } catch {
- case _: ParseException =>
- throw QueryCompilationErrors.invalidViewNameError(createView.getName)
- }
-
- val plan = CreateViewCommand(
- name = tableIdentifier,
- userSpecifiedColumns = Nil,
- comment = None,
- properties = Map.empty,
- originalText = None,
- plan = new SparkConnectPlanner(createView.getInput, session).transform(),
- allowExisting = false,
- replace = createView.getReplace,
- viewType = viewType,
- isAnalyzed = true)
-
- Dataset.ofRows(session, plan).queryExecution.commandExecuted
- }
-
- /**
- * Transforms the write operation and executes it.
- *
- * The input write operation contains a reference to the input plan and transforms it to the
- * corresponding logical plan. Afterwards, creates the DataFrameWriter and translates the
- * parameters of the WriteOperation into the corresponding methods calls.
- *
- * @param writeOperation
- */
- def handleWriteOperation(writeOperation: WriteOperation): Unit = {
- // Transform the input plan into the logical plan.
- val planner = new SparkConnectPlanner(writeOperation.getInput, session)
- val plan = planner.transform()
- // And create a Dataset from the plan.
- val dataset = Dataset.ofRows(session, logicalPlan = plan)
-
- val w = dataset.write
- if (writeOperation.getMode != proto.WriteOperation.SaveMode.SAVE_MODE_UNSPECIFIED) {
- w.mode(DataTypeProtoConverter.toSaveMode(writeOperation.getMode))
- }
-
- if (writeOperation.getOptionsCount > 0) {
- writeOperation.getOptionsMap.asScala.foreach { case (key, value) => w.option(key, value) }
- }
-
- if (writeOperation.getSortColumnNamesCount > 0) {
- val names = writeOperation.getSortColumnNamesList.asScala
- w.sortBy(names.head, names.tail.toSeq: _*)
- }
-
- if (writeOperation.hasBucketBy) {
- val op = writeOperation.getBucketBy
- val cols = op.getBucketColumnNamesList.asScala
- if (op.getNumBuckets <= 0) {
- throw InvalidCommandInput(
- s"BucketBy must specify a bucket count > 0, received ${op.getNumBuckets} instead.")
- }
- w.bucketBy(op.getNumBuckets, cols.head, cols.tail.toSeq: _*)
- }
-
- if (writeOperation.getPartitioningColumnsCount > 0) {
- val names = writeOperation.getPartitioningColumnsList.asScala
- w.partitionBy(names.toSeq: _*)
- }
-
- if (writeOperation.getSource != null) {
- w.format(writeOperation.getSource)
- }
-
- writeOperation.getSaveTypeCase match {
- case proto.WriteOperation.SaveTypeCase.PATH => w.save(writeOperation.getPath)
- case proto.WriteOperation.SaveTypeCase.TABLE_NAME =>
- w.saveAsTable(writeOperation.getTableName)
- case _ =>
- throw new UnsupportedOperationException(
- "WriteOperation:SaveTypeCase not supported "
- + s"${writeOperation.getSaveTypeCase.getNumber}")
- }
- }
-
-}
diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index b91fef58a11..f8ccc7b62e7 100644
--- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -19,18 +19,25 @@ package org.apache.spark.sql.connect.planner
import scala.collection.JavaConverters._
+import com.google.common.collect.{Lists, Maps}
+
+import org.apache.spark.api.python.{PythonEvalType, SimplePythonFunction}
import org.apache.spark.connect.proto
+import org.apache.spark.connect.proto.WriteOperation
import org.apache.spark.sql.{Dataset, SparkSession}
import org.apache.spark.sql.catalyst.AliasIdentifier
-import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
+import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression, NamedExpression}
import org.apache.spark.sql.catalyst.optimizer.CombineUnions
-import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
+import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException}
import org.apache.spark.sql.catalyst.plans.{logical, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin}
import org.apache.spark.sql.catalyst.plans.logical.{Deduplicate, Except, Intersect, LogicalPlan, Sample, SubqueryAlias, Union}
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
+import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.QueryExecution
+import org.apache.spark.sql.execution.command.CreateViewCommand
+import org.apache.spark.sql.execution.python.UserDefinedPythonFunction
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
@@ -39,14 +46,17 @@ final case class InvalidPlanInput(
private val cause: Throwable = None.orNull)
extends Exception(message, cause)
-class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
+final case class InvalidCommandInput(
+ private val message: String = "",
+ private val cause: Throwable = null)
+ extends Exception(message, cause)
- def transform(): LogicalPlan = {
- transformRelation(plan)
- }
+class SparkConnectPlanner(session: SparkSession) {
+ lazy val pythonExec =
+ sys.env.getOrElse("PYSPARK_PYTHON", sys.env.getOrElse("PYSPARK_DRIVER_PYTHON", "python3"))
// The root of the query plan is a relation and we apply the transformations to it.
- private def transformRelation(rel: proto.Relation): LogicalPlan = {
+ def transformRelation(rel: proto.Relation): LogicalPlan = {
rel.getRelTypeCase match {
case proto.Relation.RelTypeCase.READ => transformReadRel(rel.getRead)
case proto.Relation.RelTypeCase.PROJECT => transformProject(rel.getProject)
@@ -446,4 +456,132 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) {
}
}
+ def process(command: proto.Command): Unit = {
+ command.getCommandTypeCase match {
+ case proto.Command.CommandTypeCase.CREATE_FUNCTION =>
+ handleCreateScalarFunction(command.getCreateFunction)
+ case proto.Command.CommandTypeCase.WRITE_OPERATION =>
+ handleWriteOperation(command.getWriteOperation)
+ case proto.Command.CommandTypeCase.CREATE_DATAFRAME_VIEW =>
+ handleCreateViewCommand(command.getCreateDataframeView)
+ case _ => throw new UnsupportedOperationException(s"$command not supported.")
+ }
+ }
+
+ /**
+ * This is a helper function that registers a new Python function in the SparkSession.
+ *
+ * Right now this function is very rudimentary and bare-bones just to showcase how it is
+ * possible to remotely serialize a Python function and execute it on the Spark cluster. If the
+ * Python version on the client and server diverge, the execution of the function that is
+ * serialized will most likely fail.
+ *
+ * @param cf
+ */
+ def handleCreateScalarFunction(cf: proto.CreateScalarFunction): Unit = {
+ val function = SimplePythonFunction(
+ cf.getSerializedFunction.toByteArray,
+ Maps.newHashMap(),
+ Lists.newArrayList(),
+ pythonExec,
+ "3.9", // TODO(SPARK-40532) This needs to be an actual Python version.
+ Lists.newArrayList(),
+ null)
+
+ val udf = UserDefinedPythonFunction(
+ cf.getPartsList.asScala.head,
+ function,
+ StringType,
+ PythonEvalType.SQL_BATCHED_UDF,
+ udfDeterministic = false)
+
+ session.udf.registerPython(cf.getPartsList.asScala.head, udf)
+ }
+
+ def handleCreateViewCommand(createView: proto.CreateDataFrameViewCommand): Unit = {
+ val viewType = if (createView.getIsGlobal) GlobalTempView else LocalTempView
+
+ val tableIdentifier =
+ try {
+ session.sessionState.sqlParser.parseTableIdentifier(createView.getName)
+ } catch {
+ case _: ParseException =>
+ throw QueryCompilationErrors.invalidViewNameError(createView.getName)
+ }
+
+ val plan = CreateViewCommand(
+ name = tableIdentifier,
+ userSpecifiedColumns = Nil,
+ comment = None,
+ properties = Map.empty,
+ originalText = None,
+ plan = transformRelation(createView.getInput),
+ allowExisting = false,
+ replace = createView.getReplace,
+ viewType = viewType,
+ isAnalyzed = true)
+
+ Dataset.ofRows(session, plan).queryExecution.commandExecuted
+ }
+
+ /**
+ * Transforms the write operation and executes it.
+ *
+ * The input write operation contains a reference to the input plan and transforms it to the
+ * corresponding logical plan. Afterwards, creates the DataFrameWriter and translates the
+ * parameters of the WriteOperation into the corresponding methods calls.
+ *
+ * @param writeOperation
+ */
+ def handleWriteOperation(writeOperation: WriteOperation): Unit = {
+ // Transform the input plan into the logical plan.
+ val planner = new SparkConnectPlanner(session)
+ val plan = planner.transformRelation(writeOperation.getInput)
+ // And create a Dataset from the plan.
+ val dataset = Dataset.ofRows(session, logicalPlan = plan)
+
+ val w = dataset.write
+ if (writeOperation.getMode != proto.WriteOperation.SaveMode.SAVE_MODE_UNSPECIFIED) {
+ w.mode(DataTypeProtoConverter.toSaveMode(writeOperation.getMode))
+ }
+
+ if (writeOperation.getOptionsCount > 0) {
+ writeOperation.getOptionsMap.asScala.foreach { case (key, value) => w.option(key, value) }
+ }
+
+ if (writeOperation.getSortColumnNamesCount > 0) {
+ val names = writeOperation.getSortColumnNamesList.asScala
+ w.sortBy(names.head, names.tail.toSeq: _*)
+ }
+
+ if (writeOperation.hasBucketBy) {
+ val op = writeOperation.getBucketBy
+ val cols = op.getBucketColumnNamesList.asScala
+ if (op.getNumBuckets <= 0) {
+ throw InvalidCommandInput(
+ s"BucketBy must specify a bucket count > 0, received ${op.getNumBuckets} instead.")
+ }
+ w.bucketBy(op.getNumBuckets, cols.head, cols.tail.toSeq: _*)
+ }
+
+ if (writeOperation.getPartitioningColumnsCount > 0) {
+ val names = writeOperation.getPartitioningColumnsList.asScala
+ w.partitionBy(names.toSeq: _*)
+ }
+
+ if (writeOperation.getSource != null) {
+ w.format(writeOperation.getSource)
+ }
+
+ writeOperation.getSaveTypeCase match {
+ case proto.WriteOperation.SaveTypeCase.PATH => w.save(writeOperation.getPath)
+ case proto.WriteOperation.SaveTypeCase.TABLE_NAME =>
+ w.saveAsTable(writeOperation.getTableName)
+ case _ =>
+ throw new UnsupportedOperationException(
+ "WriteOperation:SaveTypeCase not supported "
+ + s"${writeOperation.getSaveTypeCase.getNumber}")
+ }
+ }
+
}
diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
index a1e70975da5..abbad51c601 100644
--- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
+++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
@@ -106,7 +106,7 @@ class SparkConnectService(debug: Boolean)
def handleAnalyzePlanRequest(
relation: proto.Relation,
session: SparkSession): proto.AnalyzeResponse.Builder = {
- val logicalPlan = new SparkConnectPlanner(relation, session).transform()
+ val logicalPlan = new SparkConnectPlanner(session).transformRelation(relation)
val ds = Dataset.ofRows(session, logicalPlan)
val explainString = ds.queryExecution.explainString(ExtendedMode)
diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
index 9652fce5425..394d6477d73 100644
--- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
+++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
@@ -27,7 +27,6 @@ import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.{Request, Response}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
-import org.apache.spark.sql.connect.command.SparkConnectCommandPlanner
import org.apache.spark.sql.connect.planner.SparkConnectPlanner
import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, QueryStageExec}
@@ -51,8 +50,8 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
def handlePlan(session: SparkSession, request: Request): Unit = {
// Extract the plan from the request and convert it to a logical plan
- val planner = new SparkConnectPlanner(request.getPlan.getRoot, session)
- val dataframe = Dataset.ofRows(session, planner.transform())
+ val planner = new SparkConnectPlanner(session)
+ val dataframe = Dataset.ofRows(session, planner.transformRelation(request.getPlan.getRoot))
try {
processAsArrowBatches(request.getClientId, dataframe)
} catch {
@@ -216,8 +215,8 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte
def handleCommand(session: SparkSession, request: Request): Unit = {
val command = request.getPlan.getCommand
- val planner = new SparkConnectCommandPlanner(session, command)
- planner.process()
+ val planner = new SparkConnectPlanner(session)
+ planner.process(command)
responseObserver.onCompleted()
}
}
diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectCommandPlannerSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectCommandPlannerSuite.scala
deleted file mode 100644
index 8ab8e0599fc..00000000000
--- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectCommandPlannerSuite.scala
+++ /dev/null
@@ -1,160 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.connect.planner
-
-import java.nio.file.{Files, Paths}
-
-import org.apache.spark.SparkClassNotFoundException
-import org.apache.spark.connect.proto
-import org.apache.spark.sql.{AnalysisException, SaveMode}
-import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.connect.command.{InvalidCommandInput, SparkConnectCommandPlanner}
-import org.apache.spark.sql.connect.dsl.commands._
-import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils}
-
-class SparkConnectCommandPlannerSuite
- extends SQLTestUtils
- with SparkConnectPlanTest
- with SharedSparkSession {
-
- lazy val localRelation = createLocalRelationProto(Seq($"id".int))
-
- def transform(cmd: proto.Command): Unit = {
- new SparkConnectCommandPlanner(spark, cmd).process()
- }
-
- test("Writes fails without path or table") {
- assertThrows[UnsupportedOperationException] {
- transform(localRelation.write())
- }
- }
-
- test("Write fails with unknown table - AnalysisException") {
- val cmd = readRel.write(tableName = Some("dest"))
- assertThrows[AnalysisException] {
- transform(cmd)
- }
- }
-
- test("Write with partitions") {
- val cmd = localRelation.write(
- tableName = Some("testtable"),
- format = Some("parquet"),
- partitionByCols = Seq("noid"))
- assertThrows[AnalysisException] {
- transform(cmd)
- }
- }
-
- test("Write with invalid bucketBy configuration") {
- val cmd = localRelation.write(bucketByCols = Seq("id"), numBuckets = Some(0))
- assertThrows[InvalidCommandInput] {
- transform(cmd)
- }
- }
-
- test("Write to Path") {
- withTempDir { f =>
- val cmd = localRelation.write(
- format = Some("parquet"),
- path = Some(f.getPath),
- mode = Some("Overwrite"))
- transform(cmd)
- assert(Files.exists(Paths.get(f.getPath)), s"Output file must exist: ${f.getPath}")
- }
- }
-
- test("Write to Path with invalid input") {
- // Wrong data source.
- assertThrows[SparkClassNotFoundException](
- transform(
- localRelation.write(path = Some("/tmp/tmppath"), format = Some("ThisAintNoFormat"))))
-
- // Default data source not found.
- assertThrows[SparkClassNotFoundException](
- transform(localRelation.write(path = Some("/tmp/tmppath"))))
- }
-
- test("Write with sortBy") {
- // Sort by existing column.
- withTable("testtable") {
- transform(
- localRelation.write(
- tableName = Some("testtable"),
- format = Some("parquet"),
- sortByColumns = Seq("id"),
- bucketByCols = Seq("id"),
- numBuckets = Some(10)))
- }
-
- // Sort by non-existing column
- assertThrows[AnalysisException](
- transform(
- localRelation
- .write(
- tableName = Some("testtable"),
- format = Some("parquet"),
- sortByColumns = Seq("noid"),
- bucketByCols = Seq("id"),
- numBuckets = Some(10))))
- }
-
- test("Write to Table") {
- withTable("testtable") {
- val cmd = localRelation.write(format = Some("parquet"), tableName = Some("testtable"))
- transform(cmd)
- // Check that we can find and drop the table.
- spark.sql(s"select count(*) from testtable").collect()
- }
- }
-
- test("SaveMode conversion tests") {
- assertThrows[IllegalArgumentException](
- DataTypeProtoConverter.toSaveMode(proto.WriteOperation.SaveMode.SAVE_MODE_UNSPECIFIED))
-
- val combinations = Seq(
- (SaveMode.Append, proto.WriteOperation.SaveMode.SAVE_MODE_APPEND),
- (SaveMode.Ignore, proto.WriteOperation.SaveMode.SAVE_MODE_IGNORE),
- (SaveMode.Overwrite, proto.WriteOperation.SaveMode.SAVE_MODE_OVERWRITE),
- (SaveMode.ErrorIfExists, proto.WriteOperation.SaveMode.SAVE_MODE_ERROR_IF_EXISTS))
- combinations.foreach { a =>
- assert(DataTypeProtoConverter.toSaveModeProto(a._1) == a._2)
- assert(DataTypeProtoConverter.toSaveMode(a._2) == a._1)
- }
- }
-
- test("Test CreateView") {
- withView("view1", "view2", "view3", "view4") {
- transform(localRelation.createView("view1", global = true, replace = true))
- assert(spark.catalog.tableExists("global_temp.view1"))
-
- transform(localRelation.createView("view2", global = false, replace = true))
- assert(spark.catalog.tableExists("view2"))
-
- transform(localRelation.createView("view3", global = true, replace = false))
- assertThrows[AnalysisException] {
- transform(localRelation.createView("view3", global = true, replace = false))
- }
-
- transform(localRelation.createView("view4", global = false, replace = false))
- assertThrows[AnalysisException] {
- transform(localRelation.createView("view4", global = false, replace = false))
- }
- }
- }
-}
diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
index d2304581c3a..9e5fc41a0c6 100644
--- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
+++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
@@ -33,7 +33,11 @@ import org.apache.spark.sql.test.SharedSparkSession
trait SparkConnectPlanTest extends SharedSparkSession {
def transform(rel: proto.Relation): logical.LogicalPlan = {
- new SparkConnectPlanner(rel, spark).transform()
+ new SparkConnectPlanner(spark).transformRelation(rel)
+ }
+
+ def transform(cmd: proto.Command): Unit = {
+ new SparkConnectPlanner(spark).process(cmd)
}
def readRel: proto.Relation =
@@ -75,24 +79,23 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest {
test("Simple Limit") {
assertThrows[IndexOutOfBoundsException] {
- new SparkConnectPlanner(
- proto.Relation.newBuilder
- .setLimit(proto.Limit.newBuilder.setLimit(10))
- .build(),
- None.orNull)
- .transform()
+ new SparkConnectPlanner(None.orNull)
+ .transformRelation(
+ proto.Relation.newBuilder
+ .setLimit(proto.Limit.newBuilder.setLimit(10))
+ .build())
}
}
test("InvalidInputs") {
// No Relation Set
intercept[IndexOutOfBoundsException](
- new SparkConnectPlanner(proto.Relation.newBuilder().build(), None.orNull).transform())
+ new SparkConnectPlanner(None.orNull).transformRelation(proto.Relation.newBuilder().build()))
intercept[InvalidPlanInput](
- new SparkConnectPlanner(
- proto.Relation.newBuilder.setUnknown(proto.Unknown.newBuilder().build()).build(),
- None.orNull).transform())
+ new SparkConnectPlanner(None.orNull)
+ .transformRelation(
+ proto.Relation.newBuilder.setUnknown(proto.Unknown.newBuilder().build()).build()))
}
diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
index 5052b451047..53ea1988809 100644
--- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
+++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
@@ -16,15 +16,19 @@
*/
package org.apache.spark.sql.connect.planner
+import java.nio.file.{Files, Paths}
+
+import org.apache.spark.SparkClassNotFoundException
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.Join.JoinType
-import org.apache.spark.sql.{AnalysisException, Column, DataFrame, Row}
+import org.apache.spark.sql.{AnalysisException, Column, DataFrame, Row, SaveMode}
import org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, PlanTest, RightOuter}
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.connect.dsl.MockRemoteSession
+import org.apache.spark.sql.connect.dsl.commands._
import org.apache.spark.sql.connect.dsl.expressions._
import org.apache.spark.sql.connect.dsl.plans._
import org.apache.spark.sql.functions._
@@ -57,6 +61,8 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
new java.util.ArrayList[Row](),
StructType(Seq(StructField("id", IntegerType), StructField("name", StringType))))
+ lazy val localRelation = createLocalRelationProto(Seq(AttributeReference("id", IntegerType)()))
+
test("Basic select") {
val connectPlan = connectTestRelation.select("id".protoAttr)
val sparkPlan = sparkTestRelation.select("id")
@@ -303,6 +309,126 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
assert(e.getMessage.contains("Found duplicate column(s)"))
}
+ test("Writes fails without path or table") {
+ assertThrows[UnsupportedOperationException] {
+ transform(localRelation.write())
+ }
+ }
+
+ test("Write fails with unknown table - AnalysisException") {
+ val cmd = readRel.write(tableName = Some("dest"))
+ assertThrows[AnalysisException] {
+ transform(cmd)
+ }
+ }
+
+ test("Write with partitions") {
+ val cmd = localRelation.write(
+ tableName = Some("testtable"),
+ format = Some("parquet"),
+ partitionByCols = Seq("noid"))
+ assertThrows[AnalysisException] {
+ transform(cmd)
+ }
+ }
+
+ test("Write with invalid bucketBy configuration") {
+ val cmd = localRelation.write(bucketByCols = Seq("id"), numBuckets = Some(0))
+ assertThrows[InvalidCommandInput] {
+ transform(cmd)
+ }
+ }
+
+ test("Write to Path") {
+ withTempDir { f =>
+ val cmd = localRelation.write(
+ format = Some("parquet"),
+ path = Some(f.getPath),
+ mode = Some("Overwrite"))
+ transform(cmd)
+ assert(Files.exists(Paths.get(f.getPath)), s"Output file must exist: ${f.getPath}")
+ }
+ }
+
+ test("Write to Path with invalid input") {
+ // Wrong data source.
+ assertThrows[SparkClassNotFoundException](
+ transform(
+ localRelation.write(path = Some("/tmp/tmppath"), format = Some("ThisAintNoFormat"))))
+
+ // Default data source not found.
+ assertThrows[SparkClassNotFoundException](
+ transform(localRelation.write(path = Some("/tmp/tmppath"))))
+ }
+
+ test("Write with sortBy") {
+ // Sort by existing column.
+ withTable("testtable") {
+ transform(
+ localRelation.write(
+ tableName = Some("testtable"),
+ format = Some("parquet"),
+ sortByColumns = Seq("id"),
+ bucketByCols = Seq("id"),
+ numBuckets = Some(10)))
+ }
+
+ // Sort by non-existing column
+ assertThrows[AnalysisException](
+ transform(
+ localRelation
+ .write(
+ tableName = Some("testtable"),
+ format = Some("parquet"),
+ sortByColumns = Seq("noid"),
+ bucketByCols = Seq("id"),
+ numBuckets = Some(10))))
+ }
+
+ test("Write to Table") {
+ withTable("testtable") {
+ val cmd = localRelation.write(format = Some("parquet"), tableName = Some("testtable"))
+ transform(cmd)
+ // Check that we can find and drop the table.
+ spark.sql(s"select count(*) from testtable").collect()
+ }
+ }
+
+ test("SaveMode conversion tests") {
+ assertThrows[IllegalArgumentException](
+ DataTypeProtoConverter.toSaveMode(proto.WriteOperation.SaveMode.SAVE_MODE_UNSPECIFIED))
+
+ val combinations = Seq(
+ (SaveMode.Append, proto.WriteOperation.SaveMode.SAVE_MODE_APPEND),
+ (SaveMode.Ignore, proto.WriteOperation.SaveMode.SAVE_MODE_IGNORE),
+ (SaveMode.Overwrite, proto.WriteOperation.SaveMode.SAVE_MODE_OVERWRITE),
+ (SaveMode.ErrorIfExists, proto.WriteOperation.SaveMode.SAVE_MODE_ERROR_IF_EXISTS))
+ combinations.foreach { a =>
+ assert(DataTypeProtoConverter.toSaveModeProto(a._1) == a._2)
+ assert(DataTypeProtoConverter.toSaveMode(a._2) == a._1)
+ }
+ }
+
+ test("Test CreateView") {
+ withView("view1", "view2", "view3", "view4") {
+ transform(localRelation.createView("view1", global = true, replace = true))
+ assert(spark.catalog.tableExists("global_temp.view1"))
+
+ transform(localRelation.createView("view2", global = false, replace = true))
+ assert(spark.catalog.tableExists("view2"))
+
+ transform(localRelation.createView("view3", global = true, replace = false))
+ assertThrows[AnalysisException] {
+ transform(localRelation.createView("view3", global = true, replace = false))
+ }
+
+ transform(localRelation.createView("view4", global = false, replace = false))
+ assertThrows[AnalysisException] {
+ transform(localRelation.createView("view4", global = false, replace = false))
+ }
+ }
+ }
+
private def createLocalRelationProtoByQualifiedAttributes(
attrs: Seq[proto.Expression.QualifiedAttribute]): proto.Relation = {
val localRelationBuilder = proto.LocalRelation.newBuilder()
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org