You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2023/02/16 00:29:51 UTC
[spark] branch master updated: [SPARK-42426][CONNECT] Fix DataFrameWriter.insertInto to call the corresponding method instead of saveAsTable
This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 3cbc900d2c2 [SPARK-42426][CONNECT] Fix DataFrameWriter.insertInto to call the corresponding method instead of saveAsTable
3cbc900d2c2 is described below
commit 3cbc900d2c2947c85447ef2bd8c1f385ca6e1c49
Author: Takuya UESHIN <ue...@databricks.com>
AuthorDate: Thu Feb 16 09:29:39 2023 +0900
[SPARK-42426][CONNECT] Fix DataFrameWriter.insertInto to call the corresponding method instead of saveAsTable
### What changes were proposed in this pull request?
Fixes `DataFrameWriter.insertInto` to call the corresponding method instead of `saveAsTable`.
### Why are the changes needed?
Currently `SparkConnectPlanner` calls `saveAsTable` instead of `insertInto` even for `DataFrameWriter.insertInto` in Spark Connect, but they have different logic internally, so we should use the corresponding method.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Enabled related tests.
Closes #40024 from ueshin/issues/SPARK-42426/insertInto.
Authored-by: Takuya UESHIN <ue...@databricks.com>
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
.../src/main/protobuf/spark/connect/commands.proto | 15 +++++-
.../org/apache/spark/sql/connect/dsl/package.scala | 11 +++-
.../sql/connect/planner/SparkConnectPlanner.scala | 16 ++++--
.../connect/planner/TableSaveMethodConverter.scala | 40 +++++++++++++++
.../connect/planner/SparkConnectProtoSuite.scala | 24 ++++++++-
python/pyspark/sql/connect/plan.py | 21 +++++++-
python/pyspark/sql/connect/proto/commands_pb2.py | 50 ++++++++++++------
python/pyspark/sql/connect/proto/commands_pb2.pyi | 59 +++++++++++++++++++---
python/pyspark/sql/connect/readwriter.py | 10 ++--
.../pyspark/sql/tests/connect/test_connect_plan.py | 5 +-
10 files changed, 210 insertions(+), 41 deletions(-)
diff --git a/connector/connect/common/src/main/protobuf/spark/connect/commands.proto b/connector/connect/common/src/main/protobuf/spark/connect/commands.proto
index 8872dc626a9..88d7e81beec 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/commands.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/commands.proto
@@ -72,7 +72,7 @@ message WriteOperation {
// The destination of the write operation must be either a path or a table.
oneof save_type {
string path = 3;
- string table_name = 4;
+ SaveTable table = 4;
}
// (Required) the save mode.
@@ -91,6 +91,19 @@ message WriteOperation {
// (Optional) A list of configuration options.
map<string, string> options = 9;
+ message SaveTable {
+ // (Required) The table name.
+ string table_name = 1;
+ // (Required) The method to be called to write to the table.
+ TableSaveMethod save_method = 2;
+
+ enum TableSaveMethod {
+ TABLE_SAVE_METHOD_UNSPECIFIED = 0;
+ TABLE_SAVE_METHOD_SAVE_AS_TABLE = 1;
+ TABLE_SAVE_METHOD_INSERT_INTO = 2;
+ }
+ }
+
message BucketBy {
repeated string bucket_column_names = 1;
int32 num_buckets = 2;
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
index fd6790cead7..4c1fbb877f4 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
@@ -26,8 +26,8 @@ import org.apache.spark.connect.proto.Join.JoinType
import org.apache.spark.connect.proto.SetOperation.SetOpType
import org.apache.spark.sql.SaveMode
import org.apache.spark.sql.connect.common.DataTypeProtoConverter
+import org.apache.spark.sql.connect.planner.{SaveModeConverter, TableSaveMethodConverter}
import org.apache.spark.sql.connect.planner.LiteralValueProtoConverter.toConnectProtoValue
-import org.apache.spark.sql.connect.planner.SaveModeConverter
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils
@@ -200,6 +200,7 @@ package object dsl {
format: Option[String] = None,
path: Option[String] = None,
tableName: Option[String] = None,
+ tableSaveMethod: Option[String] = None,
mode: Option[String] = None,
sortByColumns: Seq[String] = Seq.empty,
partitionByCols: Seq[String] = Seq.empty,
@@ -214,7 +215,13 @@ package object dsl {
.foreach(writeOp.setMode(_))
if (tableName.nonEmpty) {
- tableName.foreach(writeOp.setTableName(_))
+ tableName.foreach { tn =>
+ val saveTable = WriteOperation.SaveTable.newBuilder().setTableName(tn)
+ tableSaveMethod
+ .map(TableSaveMethodConverter.toTableSaveMethodProto(_))
+ .foreach(saveTable.setSaveMethod(_))
+ writeOp.setTable(saveTable.build())
+ }
} else {
path.foreach(writeOp.setPath(_))
}
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 88bfbe8a8e0..4a02ab66ea8 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -1537,8 +1537,18 @@ class SparkConnectPlanner(val session: SparkSession) {
writeOperation.getSaveTypeCase match {
case proto.WriteOperation.SaveTypeCase.PATH => w.save(writeOperation.getPath)
- case proto.WriteOperation.SaveTypeCase.TABLE_NAME =>
- w.saveAsTable(writeOperation.getTableName)
+ case proto.WriteOperation.SaveTypeCase.TABLE =>
+ val tableName = writeOperation.getTable.getTableName
+ writeOperation.getTable.getSaveMethod match {
+ case proto.WriteOperation.SaveTable.TableSaveMethod.TABLE_SAVE_METHOD_SAVE_AS_TABLE =>
+ w.saveAsTable(tableName)
+ case proto.WriteOperation.SaveTable.TableSaveMethod.TABLE_SAVE_METHOD_INSERT_INTO =>
+ w.insertInto(tableName)
+ case _ =>
+ throw new UnsupportedOperationException(
+ "WriteOperation:SaveTable:TableSaveMethod not supported "
+ + s"${writeOperation.getTable.getSaveMethodValue}")
+ }
case _ =>
throw new UnsupportedOperationException(
"WriteOperation:SaveTypeCase not supported "
@@ -1609,7 +1619,7 @@ class SparkConnectPlanner(val session: SparkSession) {
}
case _ =>
throw new UnsupportedOperationException(
- "WriteOperationV2:ModeValue not supported ${writeOperation.getModeValue}")
+ s"WriteOperationV2:ModeValue not supported ${writeOperation.getModeValue}")
}
}
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/TableSaveMethodConverter.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/TableSaveMethodConverter.scala
new file mode 100644
index 00000000000..d3dfee405ea
--- /dev/null
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/TableSaveMethodConverter.scala
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connect.planner
+
+import java.util.Locale
+
+import org.apache.spark.connect.proto
+
+/**
+ * Helper class for conversions between save table method string and
+ * [[proto.WriteOperation.SaveTable.TableSaveMethod]].
+ */
+object TableSaveMethodConverter {
+ def toTableSaveMethodProto(method: String): proto.WriteOperation.SaveTable.TableSaveMethod = {
+ method.toLowerCase(Locale.ROOT) match {
+ case "save_as_table" =>
+ proto.WriteOperation.SaveTable.TableSaveMethod.TABLE_SAVE_METHOD_SAVE_AS_TABLE
+ case "insert_into" =>
+ proto.WriteOperation.SaveTable.TableSaveMethod.TABLE_SAVE_METHOD_INSERT_INTO
+ case _ =>
+ throw new IllegalArgumentException(
+ "Cannot convert from TableSaveMethod to WriteOperation.SaveTable.TableSaveMethod: " +
+ s"${method}")
+ }
+ }
+}
diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
index f94fc215ee3..7fb28f354ce 100644
--- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
+++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
@@ -570,6 +570,7 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
test("Write with partitions") {
val cmd = localRelation.write(
tableName = Some("testtable"),
+ tableSaveMethod = Some("save_as_table"),
format = Some("parquet"),
partitionByCols = Seq("noid"))
assertThrows[AnalysisException] {
@@ -614,6 +615,7 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
transform(
localRelation.write(
tableName = Some("testtable"),
+ tableSaveMethod = Some("save_as_table"),
format = Some("parquet"),
sortByColumns = Seq("id"),
bucketByCols = Seq("id"),
@@ -626,6 +628,7 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
localRelation
.write(
tableName = Some("testtable"),
+ tableSaveMethod = Some("save_as_table"),
format = Some("parquet"),
sortByColumns = Seq("noid"),
bucketByCols = Seq("id"),
@@ -634,7 +637,10 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
test("Write to Table") {
withTable("testtable") {
- val cmd = localRelation.write(format = Some("parquet"), tableName = Some("testtable"))
+ val cmd = localRelation.write(
+ format = Some("parquet"),
+ tableName = Some("testtable"),
+ tableSaveMethod = Some("save_as_table"))
transform(cmd)
// Check that we can find and drop the table.
spark.sql(s"select count(*) from testtable").collect()
@@ -656,6 +662,22 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
}
}
+ test("TableSaveMethod conversion tests") {
+ assertThrows[IllegalArgumentException](
+ TableSaveMethodConverter.toTableSaveMethodProto("unknown"))
+
+ val combinations = Seq(
+ (
+ "save_as_table",
+ proto.WriteOperation.SaveTable.TableSaveMethod.TABLE_SAVE_METHOD_SAVE_AS_TABLE),
+ (
+ "insert_into",
+ proto.WriteOperation.SaveTable.TableSaveMethod.TABLE_SAVE_METHOD_INSERT_INTO))
+ combinations.foreach { a =>
+ assert(TableSaveMethodConverter.toTableSaveMethodProto(a._1) == a._2)
+ }
+ }
+
test("WriteTo with create") {
withTable("testcat.table_name") {
spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName)
diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py
index d37201e4408..3e12ef03515 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -1354,6 +1354,7 @@ class WriteOperation(LogicalPlan):
self.source: Optional[str] = None
self.path: Optional[str] = None
self.table_name: Optional[str] = None
+ self.table_save_method: Optional[str] = None
self.mode: Optional[str] = None
self.sort_cols: List[str] = []
self.partitioning_cols: List[str] = []
@@ -1382,12 +1383,26 @@ class WriteOperation(LogicalPlan):
plan.write_operation.options[k] = cast(str, self.options[k])
if self.table_name is not None:
- plan.write_operation.table_name = self.table_name
+ plan.write_operation.table.table_name = self.table_name
+ if self.table_save_method is not None:
+ tsm = self.table_save_method.lower()
+ if tsm == "save_as_table":
+ plan.write_operation.table.save_method = (
+ proto.WriteOperation.SaveTable.TableSaveMethod.TABLE_SAVE_METHOD_SAVE_AS_TABLE # noqa: E501
+ )
+ elif tsm == "insert_into":
+ plan.write_operation.table.save_method = (
+ proto.WriteOperation.SaveTable.TableSaveMethod.TABLE_SAVE_METHOD_INSERT_INTO
+ )
+ else:
+ raise ValueError(
+ f"Unknown TestSaveMethod value for DataFrame: {self.table_save_method}"
+ )
elif self.path is not None:
plan.write_operation.path = self.path
else:
raise AssertionError(
- "Invalid configuration of WriteCommand, neither path or table_name present."
+ "Invalid configuration of WriteCommand, neither path or table present."
)
if self.mode is not None:
@@ -1411,6 +1426,7 @@ class WriteOperation(LogicalPlan):
f"<WriteOperation source='{self.source}' "
f"path='{self.path} "
f"table_name='{self.table_name}' "
+ f"table_save_method='{self.table_save_method}' "
f"mode='{self.mode}' "
f"sort_cols='{self.sort_cols}' "
f"partitioning_cols='{self.partitioning_cols}' "
@@ -1424,6 +1440,7 @@ class WriteOperation(LogicalPlan):
f"<uL><li>WriteOperation <br />source='{self.source}'<br />"
f"path: '{self.path}<br />"
f"table_name: '{self.table_name}' <br />"
+ f"table_save_method: '{self.table_save_method}' <br />"
f"mode: '{self.mode}' <br />"
f"sort_cols: '{self.sort_cols}' <br />"
f"partitioning_cols: '{self.partitioning_cols}' <br />"
diff --git a/python/pyspark/sql/connect/proto/commands_pb2.py b/python/pyspark/sql/connect/proto/commands_pb2.py
index a4b7fe268ce..faa7dd65e2e 100644
--- a/python/pyspark/sql/connect/proto/commands_pb2.py
+++ b/python/pyspark/sql/connect/proto/commands_pb2.py
@@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
- b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\xab\x03\n\x07\x43ommand\x12]\n\x11register_function\x18\x01 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x10registerFunction\x12H\n\x0fwrite_operation\x18\x02 \x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x1 [...]
+ b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\xab\x03\n\x07\x43ommand\x12]\n\x11register_function\x18\x01 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x10registerFunction\x12H\n\x0fwrite_operation\x18\x02 \x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x1 [...]
)
@@ -44,12 +44,16 @@ _COMMAND = DESCRIPTOR.message_types_by_name["Command"]
_CREATEDATAFRAMEVIEWCOMMAND = DESCRIPTOR.message_types_by_name["CreateDataFrameViewCommand"]
_WRITEOPERATION = DESCRIPTOR.message_types_by_name["WriteOperation"]
_WRITEOPERATION_OPTIONSENTRY = _WRITEOPERATION.nested_types_by_name["OptionsEntry"]
+_WRITEOPERATION_SAVETABLE = _WRITEOPERATION.nested_types_by_name["SaveTable"]
_WRITEOPERATION_BUCKETBY = _WRITEOPERATION.nested_types_by_name["BucketBy"]
_WRITEOPERATIONV2 = DESCRIPTOR.message_types_by_name["WriteOperationV2"]
_WRITEOPERATIONV2_OPTIONSENTRY = _WRITEOPERATIONV2.nested_types_by_name["OptionsEntry"]
_WRITEOPERATIONV2_TABLEPROPERTIESENTRY = _WRITEOPERATIONV2.nested_types_by_name[
"TablePropertiesEntry"
]
+_WRITEOPERATION_SAVETABLE_TABLESAVEMETHOD = _WRITEOPERATION_SAVETABLE.enum_types_by_name[
+ "TableSaveMethod"
+]
_WRITEOPERATION_SAVEMODE = _WRITEOPERATION.enum_types_by_name["SaveMode"]
_WRITEOPERATIONV2_MODE = _WRITEOPERATIONV2.enum_types_by_name["Mode"]
Command = _reflection.GeneratedProtocolMessageType(
@@ -87,6 +91,15 @@ WriteOperation = _reflection.GeneratedProtocolMessageType(
# @@protoc_insertion_point(class_scope:spark.connect.WriteOperation.OptionsEntry)
},
),
+ "SaveTable": _reflection.GeneratedProtocolMessageType(
+ "SaveTable",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _WRITEOPERATION_SAVETABLE,
+ "__module__": "spark.connect.commands_pb2"
+ # @@protoc_insertion_point(class_scope:spark.connect.WriteOperation.SaveTable)
+ },
+ ),
"BucketBy": _reflection.GeneratedProtocolMessageType(
"BucketBy",
(_message.Message,),
@@ -103,6 +116,7 @@ WriteOperation = _reflection.GeneratedProtocolMessageType(
)
_sym_db.RegisterMessage(WriteOperation)
_sym_db.RegisterMessage(WriteOperation.OptionsEntry)
+_sym_db.RegisterMessage(WriteOperation.SaveTable)
_sym_db.RegisterMessage(WriteOperation.BucketBy)
WriteOperationV2 = _reflection.GeneratedProtocolMessageType(
@@ -151,19 +165,23 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_CREATEDATAFRAMEVIEWCOMMAND._serialized_start = 596
_CREATEDATAFRAMEVIEWCOMMAND._serialized_end = 746
_WRITEOPERATION._serialized_start = 749
- _WRITEOPERATION._serialized_end = 1507
- _WRITEOPERATION_OPTIONSENTRY._serialized_start = 1192
- _WRITEOPERATION_OPTIONSENTRY._serialized_end = 1250
- _WRITEOPERATION_BUCKETBY._serialized_start = 1252
- _WRITEOPERATION_BUCKETBY._serialized_end = 1343
- _WRITEOPERATION_SAVEMODE._serialized_start = 1346
- _WRITEOPERATION_SAVEMODE._serialized_end = 1483
- _WRITEOPERATIONV2._serialized_start = 1510
- _WRITEOPERATIONV2._serialized_end = 2305
- _WRITEOPERATIONV2_OPTIONSENTRY._serialized_start = 1192
- _WRITEOPERATIONV2_OPTIONSENTRY._serialized_end = 1250
- _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_start = 2077
- _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_end = 2143
- _WRITEOPERATIONV2_MODE._serialized_start = 2146
- _WRITEOPERATIONV2_MODE._serialized_end = 2305
+ _WRITEOPERATION._serialized_end = 1800
+ _WRITEOPERATION_OPTIONSENTRY._serialized_start = 1224
+ _WRITEOPERATION_OPTIONSENTRY._serialized_end = 1282
+ _WRITEOPERATION_SAVETABLE._serialized_start = 1285
+ _WRITEOPERATION_SAVETABLE._serialized_end = 1543
+ _WRITEOPERATION_SAVETABLE_TABLESAVEMETHOD._serialized_start = 1419
+ _WRITEOPERATION_SAVETABLE_TABLESAVEMETHOD._serialized_end = 1543
+ _WRITEOPERATION_BUCKETBY._serialized_start = 1545
+ _WRITEOPERATION_BUCKETBY._serialized_end = 1636
+ _WRITEOPERATION_SAVEMODE._serialized_start = 1639
+ _WRITEOPERATION_SAVEMODE._serialized_end = 1776
+ _WRITEOPERATIONV2._serialized_start = 1803
+ _WRITEOPERATIONV2._serialized_end = 2598
+ _WRITEOPERATIONV2_OPTIONSENTRY._serialized_start = 1224
+ _WRITEOPERATIONV2_OPTIONSENTRY._serialized_end = 1282
+ _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_start = 2370
+ _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_end = 2436
+ _WRITEOPERATIONV2_MODE._serialized_start = 2439
+ _WRITEOPERATIONV2_MODE._serialized_end = 2598
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/commands_pb2.pyi b/python/pyspark/sql/connect/proto/commands_pb2.pyi
index b8daec597d0..46d1921efc2 100644
--- a/python/pyspark/sql/connect/proto/commands_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/commands_pb2.pyi
@@ -223,6 +223,48 @@ class WriteOperation(google.protobuf.message.Message):
self, field_name: typing_extensions.Literal["key", b"key", "value", b"value"]
) -> None: ...
+ class SaveTable(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ class _TableSaveMethod:
+ ValueType = typing.NewType("ValueType", builtins.int)
+ V: typing_extensions.TypeAlias = ValueType
+
+ class _TableSaveMethodEnumTypeWrapper(
+ google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[
+ WriteOperation.SaveTable._TableSaveMethod.ValueType
+ ],
+ builtins.type,
+ ): # noqa: F821
+ DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
+ TABLE_SAVE_METHOD_UNSPECIFIED: WriteOperation.SaveTable._TableSaveMethod.ValueType # 0
+ TABLE_SAVE_METHOD_SAVE_AS_TABLE: WriteOperation.SaveTable._TableSaveMethod.ValueType # 1
+ TABLE_SAVE_METHOD_INSERT_INTO: WriteOperation.SaveTable._TableSaveMethod.ValueType # 2
+
+ class TableSaveMethod(_TableSaveMethod, metaclass=_TableSaveMethodEnumTypeWrapper): ...
+ TABLE_SAVE_METHOD_UNSPECIFIED: WriteOperation.SaveTable.TableSaveMethod.ValueType # 0
+ TABLE_SAVE_METHOD_SAVE_AS_TABLE: WriteOperation.SaveTable.TableSaveMethod.ValueType # 1
+ TABLE_SAVE_METHOD_INSERT_INTO: WriteOperation.SaveTable.TableSaveMethod.ValueType # 2
+
+ TABLE_NAME_FIELD_NUMBER: builtins.int
+ SAVE_METHOD_FIELD_NUMBER: builtins.int
+ table_name: builtins.str
+ """(Required) The table name."""
+ save_method: global___WriteOperation.SaveTable.TableSaveMethod.ValueType
+ """(Required) The method to be called to write to the table."""
+ def __init__(
+ self,
+ *,
+ table_name: builtins.str = ...,
+ save_method: global___WriteOperation.SaveTable.TableSaveMethod.ValueType = ...,
+ ) -> None: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal[
+ "save_method", b"save_method", "table_name", b"table_name"
+ ],
+ ) -> None: ...
+
class BucketBy(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
@@ -249,7 +291,7 @@ class WriteOperation(google.protobuf.message.Message):
INPUT_FIELD_NUMBER: builtins.int
SOURCE_FIELD_NUMBER: builtins.int
PATH_FIELD_NUMBER: builtins.int
- TABLE_NAME_FIELD_NUMBER: builtins.int
+ TABLE_FIELD_NUMBER: builtins.int
MODE_FIELD_NUMBER: builtins.int
SORT_COLUMN_NAMES_FIELD_NUMBER: builtins.int
PARTITIONING_COLUMNS_FIELD_NUMBER: builtins.int
@@ -261,7 +303,8 @@ class WriteOperation(google.protobuf.message.Message):
source: builtins.str
"""(Optional) Format value according to the Spark documentation. Examples are: text, parquet, delta."""
path: builtins.str
- table_name: builtins.str
+ @property
+ def table(self) -> global___WriteOperation.SaveTable: ...
mode: global___WriteOperation.SaveMode.ValueType
"""(Required) the save mode."""
@property
@@ -288,7 +331,7 @@ class WriteOperation(google.protobuf.message.Message):
input: pyspark.sql.connect.proto.relations_pb2.Relation | None = ...,
source: builtins.str | None = ...,
path: builtins.str = ...,
- table_name: builtins.str = ...,
+ table: global___WriteOperation.SaveTable | None = ...,
mode: global___WriteOperation.SaveMode.ValueType = ...,
sort_column_names: collections.abc.Iterable[builtins.str] | None = ...,
partitioning_columns: collections.abc.Iterable[builtins.str] | None = ...,
@@ -310,8 +353,8 @@ class WriteOperation(google.protobuf.message.Message):
b"save_type",
"source",
b"source",
- "table_name",
- b"table_name",
+ "table",
+ b"table",
],
) -> builtins.bool: ...
def ClearField(
@@ -337,8 +380,8 @@ class WriteOperation(google.protobuf.message.Message):
b"sort_column_names",
"source",
b"source",
- "table_name",
- b"table_name",
+ "table",
+ b"table",
],
) -> None: ...
@typing.overload
@@ -348,7 +391,7 @@ class WriteOperation(google.protobuf.message.Message):
@typing.overload
def WhichOneof(
self, oneof_group: typing_extensions.Literal["save_type", b"save_type"]
- ) -> typing_extensions.Literal["path", "table_name"] | None: ...
+ ) -> typing_extensions.Literal["path", "table"] | None: ...
global___WriteOperation = WriteOperation
diff --git a/python/pyspark/sql/connect/readwriter.py b/python/pyspark/sql/connect/readwriter.py
index 5d886a694cf..7ac034dc221 100644
--- a/python/pyspark/sql/connect/readwriter.py
+++ b/python/pyspark/sql/connect/readwriter.py
@@ -472,9 +472,9 @@ class DataFrameWriter(OptionUtils):
def insertInto(self, tableName: str, overwrite: Optional[bool] = None) -> None:
if overwrite is not None:
self.mode("overwrite" if overwrite else "append")
- elif self._write.mode is None or self._write.mode != "overwrite":
- self.mode("append")
- self.saveAsTable(tableName)
+ self._write.table_name = tableName
+ self._write.table_save_method = "insert_into"
+ self._spark.client.execute_command(self._write.command(self._spark.client))
insertInto.__doc__ = PySparkDataFrameWriter.insertInto.__doc__
@@ -492,6 +492,7 @@ class DataFrameWriter(OptionUtils):
if format is not None:
self.format(format)
self._write.table_name = name
+ self._write.table_save_method = "save_as_table"
self._spark.client.execute_command(self._write.command(self._spark.client))
saveAsTable.__doc__ = PySparkDataFrameWriter.saveAsTable.__doc__
@@ -695,9 +696,6 @@ def _test() -> None:
del pyspark.sql.connect.readwriter.DataFrameReader.option.__doc__
del pyspark.sql.connect.readwriter.DataFrameWriter.option.__doc__
- # TODO(SPARK-42426): insertInto fails when the column names are different from the table columns
- del pyspark.sql.connect.readwriter.DataFrameWriter.insertInto.__doc__
-
globs["spark"] = (
PySparkSession.builder.appName("sql.connect.readwriter tests")
.remote("local[4]")
diff --git a/python/pyspark/sql/tests/connect/test_connect_plan.py b/python/pyspark/sql/tests/connect/test_connect_plan.py
index a5f691d0bef..a152ae0e8c3 100644
--- a/python/pyspark/sql/tests/connect/test_connect_plan.py
+++ b/python/pyspark/sql/tests/connect/test_connect_plan.py
@@ -630,13 +630,14 @@ class SparkConnectPlanTests(PlanOnlyTestFixture):
p = wo.command(None)
self.assertIsNotNone(p)
self.assertTrue(p.write_operation.HasField("path"))
- self.assertFalse(p.write_operation.HasField("table_name"))
+ self.assertFalse(p.write_operation.HasField("table"))
wo.path = None
wo.table_name = "table"
+ wo.table_save_method = "save_as_table"
p = wo.command(None)
self.assertFalse(p.write_operation.HasField("path"))
- self.assertTrue(p.write_operation.HasField("table_name"))
+ self.assertTrue(p.write_operation.HasField("table"))
wo.bucket_cols = ["a", "b", "c"]
p = wo.command(None)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org